Skip to main content

spice_framework/
toolkit.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4
5/// A parameter definition for a tool.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ParamDef {
8    pub name: String,
9    #[serde(rename = "type")]
10    pub param_type: String,
11    pub description: String,
12    pub required: bool,
13}
14
15/// A tool definition parsed from a markdown file with YAML frontmatter.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ToolDef {
18    pub name: String,
19    pub description: String,
20    pub parameters: Vec<ParamDef>,
21}
22
23impl ToolDef {
24    /// Parse a tool definition from markdown content with YAML-like frontmatter.
25    ///
26    /// Expected format:
27    /// ```text
28    /// ---
29    /// name: toolName
30    /// description: What the tool does
31    /// parameters:
32    ///   - name: paramName
33    ///     type: string
34    ///     description: What the param is
35    ///     required: true
36    /// ---
37    /// Optional body (ignored for now)
38    /// ```
39    pub fn from_markdown(content: &str) -> Result<Self, String> {
40        let content = content.trim();
41        if !content.starts_with("---") {
42            return Err("Tool markdown must start with --- frontmatter delimiter".into());
43        }
44
45        let after_first = &content[3..];
46        let end_idx = after_first
47            .find("---")
48            .ok_or("Missing closing --- frontmatter delimiter")?;
49        let frontmatter = after_first[..end_idx].trim();
50
51        parse_tool_frontmatter(frontmatter)
52    }
53
54    /// Load a tool definition from a markdown file.
55    pub fn from_file(path: &Path) -> Result<Self, String> {
56        let content =
57            std::fs::read_to_string(path).map_err(|e| format!("Failed to read {}: {}", path.display(), e))?;
58        Self::from_markdown(&content)
59    }
60
61    /// Convert to OpenAI-compatible tool JSON.
62    pub fn to_openai_json(&self) -> serde_json::Value {
63        let mut properties = serde_json::Map::new();
64        let mut required = Vec::new();
65
66        for param in &self.parameters {
67            properties.insert(
68                param.name.clone(),
69                serde_json::json!({
70                    "type": param.param_type,
71                    "description": param.description
72                }),
73            );
74            if param.required {
75                required.push(serde_json::Value::String(param.name.clone()));
76            }
77        }
78
79        serde_json::json!({
80            "type": "function",
81            "function": {
82                "name": self.name,
83                "description": self.description,
84                "parameters": {
85                    "type": "object",
86                    "properties": properties,
87                    "required": required
88                }
89            }
90        })
91    }
92}
93
94/// A collection of tool definitions.
95#[derive(Debug, Clone)]
96pub struct Toolkit {
97    pub tools: Vec<ToolDef>,
98    tools_by_name: HashMap<String, usize>,
99}
100
101impl Toolkit {
102    /// Create a toolkit from a list of tool definitions.
103    pub fn new(tools: Vec<ToolDef>) -> Self {
104        let tools_by_name = tools
105            .iter()
106            .enumerate()
107            .map(|(i, t)| (t.name.clone(), i))
108            .collect();
109        Self { tools, tools_by_name }
110    }
111
112    /// Load all `.md` files from a directory as tool definitions.
113    pub fn from_dir(dir: &Path) -> Result<Self, String> {
114        let mut tools = Vec::new();
115        let entries = std::fs::read_dir(dir)
116            .map_err(|e| format!("Failed to read directory {}: {}", dir.display(), e))?;
117
118        let mut paths: Vec<_> = entries
119            .filter_map(|e| e.ok())
120            .map(|e| e.path())
121            .filter(|p| p.extension().map_or(false, |ext| ext == "md"))
122            .collect();
123        paths.sort();
124
125        for path in paths {
126            tools.push(ToolDef::from_file(&path)?);
127        }
128
129        Ok(Self::new(tools))
130    }
131
132    /// Get a tool by name.
133    pub fn get(&self, name: &str) -> Option<&ToolDef> {
134        self.tools_by_name.get(name).map(|&i| &self.tools[i])
135    }
136
137    /// Get all tool names.
138    pub fn tool_names(&self) -> Vec<String> {
139        self.tools.iter().map(|t| t.name.clone()).collect()
140    }
141
142    /// Convert all tools to OpenAI-compatible JSON array.
143    pub fn to_openai_json(&self) -> Vec<serde_json::Value> {
144        self.tools.iter().map(|t| t.to_openai_json()).collect()
145    }
146
147    /// Generate a human-readable tool listing for embedding in prompts.
148    pub fn to_prompt_listing(&self) -> String {
149        let mut out = String::new();
150        for tool in &self.tools {
151            out.push_str(&format!("### {}\n", tool.name));
152            out.push_str(&format!("{}\n", tool.description));
153            if !tool.parameters.is_empty() {
154                out.push_str("Parameters:\n");
155                for p in &tool.parameters {
156                    let req = if p.required { " (required)" } else { "" };
157                    out.push_str(&format!(
158                        "  - `{}` ({}): {}{}\n",
159                        p.name, p.param_type, p.description, req
160                    ));
161                }
162            }
163            out.push('\n');
164        }
165        out
166    }
167}
168
169/// A prompt template that supports `{{tools}}` placeholder substitution.
170#[derive(Debug, Clone)]
171pub struct PromptTemplate {
172    pub template: String,
173}
174
175impl PromptTemplate {
176    /// Create from a template string.
177    pub fn new(template: impl Into<String>) -> Self {
178        Self {
179            template: template.into(),
180        }
181    }
182
183    /// Load from a file.
184    pub fn from_file(path: &Path) -> Result<Self, String> {
185        let content =
186            std::fs::read_to_string(path).map_err(|e| format!("Failed to read {}: {}", path.display(), e))?;
187        Ok(Self::new(content))
188    }
189
190    /// Render the template, replacing `{{tools}}` with the toolkit listing.
191    pub fn render(&self, toolkit: &Toolkit) -> String {
192        self.template
193            .replace("{{tools}}", &toolkit.to_prompt_listing())
194    }
195
196    /// Render with a custom set of variable replacements.
197    pub fn render_with(&self, vars: &HashMap<String, String>) -> String {
198        let mut result = self.template.clone();
199        for (key, value) in vars {
200            result = result.replace(&format!("{{{{{}}}}}", key), value);
201        }
202        result
203    }
204}
205
206/// Simple YAML-like frontmatter parser (no external YAML dependency).
207fn parse_tool_frontmatter(frontmatter: &str) -> Result<ToolDef, String> {
208    let mut name = String::new();
209    let mut description = String::new();
210    let mut parameters = Vec::new();
211
212    let mut in_parameters = false;
213    let mut current_param: Option<ParamBuilder> = None;
214
215    for line in frontmatter.lines() {
216        let trimmed = line.trim();
217        if trimmed.is_empty() {
218            continue;
219        }
220
221        // Top-level keys
222        if !line.starts_with(' ') && !line.starts_with('\t') {
223            // Flush any pending param
224            if let Some(pb) = current_param.take() {
225                parameters.push(pb.build()?);
226            }
227
228            if let Some(val) = trimmed.strip_prefix("name:") {
229                name = val.trim().to_string();
230                in_parameters = false;
231            } else if let Some(val) = trimmed.strip_prefix("description:") {
232                description = val.trim().to_string();
233                in_parameters = false;
234            } else if trimmed == "parameters:" {
235                in_parameters = true;
236            }
237            continue;
238        }
239
240        if !in_parameters {
241            continue;
242        }
243
244        // Parameter list items
245        let stripped = trimmed.trim_start_matches('-').trim();
246        if trimmed.starts_with('-') {
247            // New parameter entry
248            if let Some(pb) = current_param.take() {
249                parameters.push(pb.build()?);
250            }
251            let mut pb = ParamBuilder::default();
252            if let Some(val) = stripped.strip_prefix("name:") {
253                pb.name = Some(val.trim().to_string());
254            }
255            current_param = Some(pb);
256        } else if let Some(ref mut pb) = current_param {
257            // Continuation of current parameter
258            if let Some(val) = stripped.strip_prefix("name:") {
259                pb.name = Some(val.trim().to_string());
260            } else if let Some(val) = stripped.strip_prefix("type:") {
261                pb.param_type = Some(val.trim().to_string());
262            } else if let Some(val) = stripped.strip_prefix("description:") {
263                pb.description = Some(val.trim().to_string());
264            } else if let Some(val) = stripped.strip_prefix("required:") {
265                pb.required = Some(val.trim() == "true");
266            }
267        }
268    }
269
270    // Flush last param
271    if let Some(pb) = current_param.take() {
272        parameters.push(pb.build()?);
273    }
274
275    if name.is_empty() {
276        return Err("Tool frontmatter missing 'name' field".into());
277    }
278    if description.is_empty() {
279        return Err("Tool frontmatter missing 'description' field".into());
280    }
281
282    Ok(ToolDef {
283        name,
284        description,
285        parameters,
286    })
287}
288
289#[derive(Default)]
290struct ParamBuilder {
291    name: Option<String>,
292    param_type: Option<String>,
293    description: Option<String>,
294    required: Option<bool>,
295}
296
297impl ParamBuilder {
298    fn build(self) -> Result<ParamDef, String> {
299        Ok(ParamDef {
300            name: self.name.ok_or("Parameter missing 'name'")?,
301            param_type: self.param_type.unwrap_or_else(|| "string".into()),
302            description: self.description.unwrap_or_default(),
303            required: self.required.unwrap_or(false),
304        })
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn parse_tool_from_markdown() {
314        let md = r#"---
315name: getWeather
316description: Get current weather for a location
317parameters:
318  - name: location
319    type: string
320    description: The city name
321    required: true
322---
323# getWeather
324Extra docs here.
325"#;
326        let tool = ToolDef::from_markdown(md).unwrap();
327        assert_eq!(tool.name, "getWeather");
328        assert_eq!(tool.parameters.len(), 1);
329        assert_eq!(tool.parameters[0].name, "location");
330        assert!(tool.parameters[0].required);
331    }
332
333    #[test]
334    fn prompt_template_renders_tools() {
335        let toolkit = Toolkit::new(vec![ToolDef {
336            name: "myTool".into(),
337            description: "Does stuff".into(),
338            parameters: vec![],
339        }]);
340        let tpl = PromptTemplate::new("You have these tools:\n{{tools}}\nUse them wisely.");
341        let rendered = tpl.render(&toolkit);
342        assert!(rendered.contains("myTool"));
343        assert!(rendered.contains("Does stuff"));
344        assert!(!rendered.contains("{{tools}}"));
345    }
346}