spice_framework/
toolkit.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ParamDef {
8 pub name: String,
9 #[serde(rename = "type")]
10 pub param_type: String,
11 pub description: String,
12 pub required: bool,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ToolDef {
18 pub name: String,
19 pub description: String,
20 pub parameters: Vec<ParamDef>,
21}
22
23impl ToolDef {
24 pub fn from_markdown(content: &str) -> Result<Self, String> {
40 let content = content.trim();
41 if !content.starts_with("---") {
42 return Err("Tool markdown must start with --- frontmatter delimiter".into());
43 }
44
45 let after_first = &content[3..];
46 let end_idx = after_first
47 .find("---")
48 .ok_or("Missing closing --- frontmatter delimiter")?;
49 let frontmatter = after_first[..end_idx].trim();
50
51 parse_tool_frontmatter(frontmatter)
52 }
53
54 pub fn from_file(path: &Path) -> Result<Self, String> {
56 let content =
57 std::fs::read_to_string(path).map_err(|e| format!("Failed to read {}: {}", path.display(), e))?;
58 Self::from_markdown(&content)
59 }
60
61 pub fn to_openai_json(&self) -> serde_json::Value {
63 let mut properties = serde_json::Map::new();
64 let mut required = Vec::new();
65
66 for param in &self.parameters {
67 properties.insert(
68 param.name.clone(),
69 serde_json::json!({
70 "type": param.param_type,
71 "description": param.description
72 }),
73 );
74 if param.required {
75 required.push(serde_json::Value::String(param.name.clone()));
76 }
77 }
78
79 serde_json::json!({
80 "type": "function",
81 "function": {
82 "name": self.name,
83 "description": self.description,
84 "parameters": {
85 "type": "object",
86 "properties": properties,
87 "required": required
88 }
89 }
90 })
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct Toolkit {
97 pub tools: Vec<ToolDef>,
98 tools_by_name: HashMap<String, usize>,
99}
100
101impl Toolkit {
102 pub fn new(tools: Vec<ToolDef>) -> Self {
104 let tools_by_name = tools
105 .iter()
106 .enumerate()
107 .map(|(i, t)| (t.name.clone(), i))
108 .collect();
109 Self { tools, tools_by_name }
110 }
111
112 pub fn from_dir(dir: &Path) -> Result<Self, String> {
114 let mut tools = Vec::new();
115 let entries = std::fs::read_dir(dir)
116 .map_err(|e| format!("Failed to read directory {}: {}", dir.display(), e))?;
117
118 let mut paths: Vec<_> = entries
119 .filter_map(|e| e.ok())
120 .map(|e| e.path())
121 .filter(|p| p.extension().map_or(false, |ext| ext == "md"))
122 .collect();
123 paths.sort();
124
125 for path in paths {
126 tools.push(ToolDef::from_file(&path)?);
127 }
128
129 Ok(Self::new(tools))
130 }
131
132 pub fn get(&self, name: &str) -> Option<&ToolDef> {
134 self.tools_by_name.get(name).map(|&i| &self.tools[i])
135 }
136
137 pub fn tool_names(&self) -> Vec<String> {
139 self.tools.iter().map(|t| t.name.clone()).collect()
140 }
141
142 pub fn to_openai_json(&self) -> Vec<serde_json::Value> {
144 self.tools.iter().map(|t| t.to_openai_json()).collect()
145 }
146
147 pub fn to_prompt_listing(&self) -> String {
149 let mut out = String::new();
150 for tool in &self.tools {
151 out.push_str(&format!("### {}\n", tool.name));
152 out.push_str(&format!("{}\n", tool.description));
153 if !tool.parameters.is_empty() {
154 out.push_str("Parameters:\n");
155 for p in &tool.parameters {
156 let req = if p.required { " (required)" } else { "" };
157 out.push_str(&format!(
158 " - `{}` ({}): {}{}\n",
159 p.name, p.param_type, p.description, req
160 ));
161 }
162 }
163 out.push('\n');
164 }
165 out
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct PromptTemplate {
172 pub template: String,
173}
174
175impl PromptTemplate {
176 pub fn new(template: impl Into<String>) -> Self {
178 Self {
179 template: template.into(),
180 }
181 }
182
183 pub fn from_file(path: &Path) -> Result<Self, String> {
185 let content =
186 std::fs::read_to_string(path).map_err(|e| format!("Failed to read {}: {}", path.display(), e))?;
187 Ok(Self::new(content))
188 }
189
190 pub fn render(&self, toolkit: &Toolkit) -> String {
192 self.template
193 .replace("{{tools}}", &toolkit.to_prompt_listing())
194 }
195
196 pub fn render_with(&self, vars: &HashMap<String, String>) -> String {
198 let mut result = self.template.clone();
199 for (key, value) in vars {
200 result = result.replace(&format!("{{{{{}}}}}", key), value);
201 }
202 result
203 }
204}
205
206fn parse_tool_frontmatter(frontmatter: &str) -> Result<ToolDef, String> {
208 let mut name = String::new();
209 let mut description = String::new();
210 let mut parameters = Vec::new();
211
212 let mut in_parameters = false;
213 let mut current_param: Option<ParamBuilder> = None;
214
215 for line in frontmatter.lines() {
216 let trimmed = line.trim();
217 if trimmed.is_empty() {
218 continue;
219 }
220
221 if !line.starts_with(' ') && !line.starts_with('\t') {
223 if let Some(pb) = current_param.take() {
225 parameters.push(pb.build()?);
226 }
227
228 if let Some(val) = trimmed.strip_prefix("name:") {
229 name = val.trim().to_string();
230 in_parameters = false;
231 } else if let Some(val) = trimmed.strip_prefix("description:") {
232 description = val.trim().to_string();
233 in_parameters = false;
234 } else if trimmed == "parameters:" {
235 in_parameters = true;
236 }
237 continue;
238 }
239
240 if !in_parameters {
241 continue;
242 }
243
244 let stripped = trimmed.trim_start_matches('-').trim();
246 if trimmed.starts_with('-') {
247 if let Some(pb) = current_param.take() {
249 parameters.push(pb.build()?);
250 }
251 let mut pb = ParamBuilder::default();
252 if let Some(val) = stripped.strip_prefix("name:") {
253 pb.name = Some(val.trim().to_string());
254 }
255 current_param = Some(pb);
256 } else if let Some(ref mut pb) = current_param {
257 if let Some(val) = stripped.strip_prefix("name:") {
259 pb.name = Some(val.trim().to_string());
260 } else if let Some(val) = stripped.strip_prefix("type:") {
261 pb.param_type = Some(val.trim().to_string());
262 } else if let Some(val) = stripped.strip_prefix("description:") {
263 pb.description = Some(val.trim().to_string());
264 } else if let Some(val) = stripped.strip_prefix("required:") {
265 pb.required = Some(val.trim() == "true");
266 }
267 }
268 }
269
270 if let Some(pb) = current_param.take() {
272 parameters.push(pb.build()?);
273 }
274
275 if name.is_empty() {
276 return Err("Tool frontmatter missing 'name' field".into());
277 }
278 if description.is_empty() {
279 return Err("Tool frontmatter missing 'description' field".into());
280 }
281
282 Ok(ToolDef {
283 name,
284 description,
285 parameters,
286 })
287}
288
289#[derive(Default)]
290struct ParamBuilder {
291 name: Option<String>,
292 param_type: Option<String>,
293 description: Option<String>,
294 required: Option<bool>,
295}
296
297impl ParamBuilder {
298 fn build(self) -> Result<ParamDef, String> {
299 Ok(ParamDef {
300 name: self.name.ok_or("Parameter missing 'name'")?,
301 param_type: self.param_type.unwrap_or_else(|| "string".into()),
302 description: self.description.unwrap_or_default(),
303 required: self.required.unwrap_or(false),
304 })
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn parse_tool_from_markdown() {
314 let md = r#"---
315name: getWeather
316description: Get current weather for a location
317parameters:
318 - name: location
319 type: string
320 description: The city name
321 required: true
322---
323# getWeather
324Extra docs here.
325"#;
326 let tool = ToolDef::from_markdown(md).unwrap();
327 assert_eq!(tool.name, "getWeather");
328 assert_eq!(tool.parameters.len(), 1);
329 assert_eq!(tool.parameters[0].name, "location");
330 assert!(tool.parameters[0].required);
331 }
332
333 #[test]
334 fn prompt_template_renders_tools() {
335 let toolkit = Toolkit::new(vec![ToolDef {
336 name: "myTool".into(),
337 description: "Does stuff".into(),
338 parameters: vec![],
339 }]);
340 let tpl = PromptTemplate::new("You have these tools:\n{{tools}}\nUse them wisely.");
341 let rendered = tpl.render(&toolkit);
342 assert!(rendered.contains("myTool"));
343 assert!(rendered.contains("Does stuff"));
344 assert!(!rendered.contains("{{tools}}"));
345 }
346}