Skip to main content

zeph_tools/
registry.rs

1use std::fmt::Write;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum InvocationHint {
5    /// Tool invoked via ```{tag}\n...\n``` fenced block in LLM response
6    FencedBlock(&'static str),
7    /// Tool invoked via structured `ToolCall` JSON
8    ToolCall,
9}
10
11#[derive(Debug, Clone)]
12pub struct ToolDef {
13    pub id: &'static str,
14    pub description: &'static str,
15    pub schema: schemars::Schema,
16    pub invocation: InvocationHint,
17}
18
19#[derive(Debug, Default)]
20pub struct ToolRegistry {
21    tools: Vec<ToolDef>,
22}
23
24impl ToolRegistry {
25    #[must_use]
26    pub fn from_definitions(tools: Vec<ToolDef>) -> Self {
27        Self { tools }
28    }
29
30    #[must_use]
31    pub fn tools(&self) -> &[ToolDef] {
32        &self.tools
33    }
34
35    #[must_use]
36    pub fn find(&self, id: &str) -> Option<&ToolDef> {
37        self.tools.iter().find(|t| t.id == id)
38    }
39
40    /// Format tools for prompt, excluding tools fully denied by policy.
41    #[must_use]
42    pub fn format_for_prompt_filtered(
43        &self,
44        policy: &crate::permissions::PermissionPolicy,
45    ) -> String {
46        let mut out = String::from("<tools>\n");
47        for tool in &self.tools {
48            if policy.is_fully_denied(tool.id) {
49                continue;
50            }
51            format_tool(&mut out, tool);
52        }
53        out.push_str("</tools>");
54        out
55    }
56}
57
58fn format_tool(out: &mut String, tool: &ToolDef) {
59    let _ = writeln!(out, "## {}", tool.id);
60    let _ = writeln!(out, "{}", tool.description);
61    match tool.invocation {
62        InvocationHint::FencedBlock(tag) => {
63            let _ = writeln!(out, "Invocation: use ```{tag} fenced block");
64        }
65        InvocationHint::ToolCall => {
66            let _ = writeln!(
67                out,
68                "Invocation: use tool_call with {{\"tool_id\": \"{}\", \"params\": {{...}}}}",
69                tool.id
70            );
71        }
72    }
73    format_schema_params(out, &tool.schema);
74    out.push('\n');
75}
76
77/// Extract the primary type when schemars renders `Option<T>` as `"type": ["T", "null"]`
78/// or `"anyOf": [{"type": "T"}, {"type": "null"}]`.
79fn extract_non_null_type(obj: &serde_json::Map<String, serde_json::Value>) -> Option<&str> {
80    if let Some(arr) = obj.get("type").and_then(|v| v.as_array()) {
81        return arr.iter().filter_map(|v| v.as_str()).find(|t| *t != "null");
82    }
83    obj.get("anyOf")?
84        .as_array()?
85        .iter()
86        .filter_map(|v| v.as_object())
87        .filter_map(|o| o.get("type")?.as_str())
88        .find(|t| *t != "null")
89}
90
91fn format_schema_params(out: &mut String, schema: &schemars::Schema) {
92    let Some(obj) = schema.as_object() else {
93        return;
94    };
95    let Some(serde_json::Value::Object(props)) = obj.get("properties") else {
96        return;
97    };
98    if props.is_empty() {
99        return;
100    }
101
102    let required: Vec<&str> = obj
103        .get("required")
104        .and_then(|v| v.as_array())
105        .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
106        .unwrap_or_default();
107
108    let _ = writeln!(out, "Parameters:");
109    for (name, prop) in props {
110        let prop_obj = prop.as_object();
111        let ty = prop_obj
112            .and_then(|o| {
113                o.get("type")
114                    .and_then(|v| v.as_str())
115                    .or_else(|| extract_non_null_type(o))
116            })
117            .unwrap_or("string");
118        let desc = prop_obj
119            .and_then(|o| o.get("description"))
120            .and_then(|v| v.as_str())
121            .unwrap_or("");
122        let req = if required.contains(&name.as_str()) {
123            "required"
124        } else {
125            "optional"
126        };
127        let _ = writeln!(out, "  - {name}: {desc} ({ty}, {req})");
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::file::ReadParams;
135    use crate::shell::BashParams;
136
137    fn sample_tools() -> Vec<ToolDef> {
138        vec![
139            ToolDef {
140                id: "bash",
141                description: "Execute a shell command",
142                schema: schemars::schema_for!(BashParams),
143                invocation: InvocationHint::FencedBlock("bash"),
144            },
145            ToolDef {
146                id: "read",
147                description: "Read file contents",
148                schema: schemars::schema_for!(ReadParams),
149                invocation: InvocationHint::ToolCall,
150            },
151        ]
152    }
153
154    #[test]
155    fn from_definitions_stores_tools() {
156        let reg = ToolRegistry::from_definitions(sample_tools());
157        assert_eq!(reg.tools().len(), 2);
158    }
159
160    #[test]
161    fn default_registry_is_empty() {
162        let reg = ToolRegistry::default();
163        assert!(reg.tools().is_empty());
164    }
165
166    #[test]
167    fn find_existing_tool() {
168        let reg = ToolRegistry::from_definitions(sample_tools());
169        assert!(reg.find("bash").is_some());
170        assert!(reg.find("read").is_some());
171    }
172
173    #[test]
174    fn find_nonexistent_returns_none() {
175        let reg = ToolRegistry::from_definitions(sample_tools());
176        assert!(reg.find("nonexistent").is_none());
177    }
178
179    #[test]
180    fn format_for_prompt_contains_tools() {
181        let reg = ToolRegistry::from_definitions(sample_tools());
182        let prompt =
183            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
184        assert!(prompt.contains("<tools>"));
185        assert!(prompt.contains("</tools>"));
186        assert!(prompt.contains("## bash"));
187        assert!(prompt.contains("## read"));
188    }
189
190    #[test]
191    fn format_for_prompt_shows_invocation_fenced() {
192        let reg = ToolRegistry::from_definitions(sample_tools());
193        let prompt =
194            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
195        assert!(prompt.contains("Invocation: use ```bash fenced block"));
196    }
197
198    #[test]
199    fn format_for_prompt_shows_invocation_tool_call() {
200        let reg = ToolRegistry::from_definitions(sample_tools());
201        let prompt =
202            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
203        assert!(prompt.contains("Invocation: use tool_call"));
204        assert!(prompt.contains("\"tool_id\": \"read\""));
205    }
206
207    #[test]
208    fn format_for_prompt_shows_param_info() {
209        let reg = ToolRegistry::from_definitions(sample_tools());
210        let prompt =
211            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
212        assert!(prompt.contains("command:"));
213        assert!(prompt.contains("required"));
214        assert!(prompt.contains("string"));
215    }
216
217    #[test]
218    fn format_for_prompt_shows_optional_params() {
219        let reg = ToolRegistry::from_definitions(sample_tools());
220        let prompt =
221            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
222        assert!(prompt.contains("offset:"));
223        assert!(prompt.contains("optional"));
224        assert!(
225            prompt.contains("(integer, optional)"),
226            "Option<u32> should render as integer, not string: {prompt}"
227        );
228    }
229
230    #[test]
231    fn format_filtered_excludes_fully_denied() {
232        use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
233        use std::collections::HashMap;
234        let mut rules = HashMap::new();
235        rules.insert(
236            "bash".to_owned(),
237            vec![PermissionRule {
238                pattern: "*".to_owned(),
239                action: PermissionAction::Deny,
240            }],
241        );
242        let policy = PermissionPolicy::new(rules);
243        let reg = ToolRegistry::from_definitions(sample_tools());
244        let prompt = reg.format_for_prompt_filtered(&policy);
245        assert!(!prompt.contains("## bash"));
246        assert!(prompt.contains("## read"));
247    }
248
249    #[test]
250    fn format_filtered_includes_mixed_rules() {
251        use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
252        use std::collections::HashMap;
253        let mut rules = HashMap::new();
254        rules.insert(
255            "bash".to_owned(),
256            vec![
257                PermissionRule {
258                    pattern: "echo *".to_owned(),
259                    action: PermissionAction::Allow,
260                },
261                PermissionRule {
262                    pattern: "*".to_owned(),
263                    action: PermissionAction::Deny,
264                },
265            ],
266        );
267        let policy = PermissionPolicy::new(rules);
268        let reg = ToolRegistry::from_definitions(sample_tools());
269        let prompt = reg.format_for_prompt_filtered(&policy);
270        assert!(prompt.contains("## bash"));
271    }
272
273    #[test]
274    fn format_filtered_no_rules_includes_all() {
275        let policy = crate::permissions::PermissionPolicy::default();
276        let reg = ToolRegistry::from_definitions(sample_tools());
277        let prompt = reg.format_for_prompt_filtered(&policy);
278        assert!(prompt.contains("## bash"));
279        assert!(prompt.contains("## read"));
280    }
281}