Skip to main content

zeph_tools/
registry.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::borrow::Cow;
5use std::fmt::Write;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum InvocationHint {
9    /// Tool invoked via ```{tag}\n...\n``` fenced block in LLM response
10    FencedBlock(&'static str),
11    /// Tool invoked via structured `ToolCall` JSON
12    ToolCall,
13}
14
15#[derive(Debug, Clone)]
16pub struct ToolDef {
17    pub id: Cow<'static, str>,
18    pub description: Cow<'static, str>,
19    pub schema: schemars::Schema,
20    pub invocation: InvocationHint,
21    /// Raw output schema from an MCP server, if present.
22    ///
23    /// DO NOT convert to `schemars::Schema` — lossy; see #2931 critique P0-1.
24    pub output_schema: Option<serde_json::Value>,
25}
26
27#[derive(Debug, Default)]
28pub struct ToolRegistry {
29    tools: Vec<ToolDef>,
30}
31
32impl ToolRegistry {
33    #[must_use]
34    pub fn from_definitions(tools: Vec<ToolDef>) -> Self {
35        Self { tools }
36    }
37
38    #[must_use]
39    pub fn tools(&self) -> &[ToolDef] {
40        &self.tools
41    }
42
43    #[must_use]
44    pub fn find(&self, id: &str) -> Option<&ToolDef> {
45        self.tools.iter().find(|t| t.id.as_ref() == id)
46    }
47
48    /// Format tools for prompt, excluding tools fully denied by policy.
49    #[must_use]
50    pub fn format_for_prompt_filtered(
51        &self,
52        policy: &crate::permissions::PermissionPolicy,
53    ) -> String {
54        let mut out = String::from("<tools>\n");
55        for tool in &self.tools {
56            if policy.is_fully_denied(&tool.id) {
57                continue;
58            }
59            format_tool(&mut out, tool);
60        }
61        out.push_str("</tools>");
62        out
63    }
64}
65
66fn format_tool(out: &mut String, tool: &ToolDef) {
67    let _ = writeln!(out, "## {}", tool.id);
68    let _ = writeln!(out, "{}", tool.description);
69    match tool.invocation {
70        InvocationHint::FencedBlock(tag) => {
71            let _ = writeln!(out, "Invocation: use ```{tag} fenced block");
72        }
73        InvocationHint::ToolCall => {
74            let _ = writeln!(
75                out,
76                "Invocation: use tool_call with {{\"tool_id\": \"{}\", \"params\": {{...}}}}",
77                tool.id
78            );
79        }
80    }
81    format_schema_params(out, &tool.schema);
82    out.push('\n');
83}
84
85/// Extract the primary type when schemars renders `Option<T>` as `"type": ["T", "null"]`
86/// or `"anyOf": [{"type": "T"}, {"type": "null"}]`.
87fn extract_non_null_type(obj: &serde_json::Map<String, serde_json::Value>) -> Option<&str> {
88    if let Some(arr) = obj.get("type").and_then(|v| v.as_array()) {
89        return arr.iter().filter_map(|v| v.as_str()).find(|t| *t != "null");
90    }
91    obj.get("anyOf")?
92        .as_array()?
93        .iter()
94        .filter_map(|v| v.as_object())
95        .filter_map(|o| o.get("type")?.as_str())
96        .find(|t| *t != "null")
97}
98
99fn format_schema_params(out: &mut String, schema: &schemars::Schema) {
100    let Some(obj) = schema.as_object() else {
101        return;
102    };
103    let Some(serde_json::Value::Object(props)) = obj.get("properties") else {
104        return;
105    };
106    if props.is_empty() {
107        return;
108    }
109
110    let required: Vec<&str> = obj
111        .get("required")
112        .and_then(|v| v.as_array())
113        .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
114        .unwrap_or_default();
115
116    let _ = writeln!(out, "Parameters:");
117    for (name, prop) in props {
118        let prop_obj = prop.as_object();
119        let ty = prop_obj
120            .and_then(|o| {
121                o.get("type")
122                    .and_then(|v| v.as_str())
123                    .or_else(|| extract_non_null_type(o))
124            })
125            .unwrap_or("string");
126        let desc = prop_obj
127            .and_then(|o| o.get("description"))
128            .and_then(|v| v.as_str())
129            .unwrap_or("");
130        let req = if required.contains(&name.as_str()) {
131            "required"
132        } else {
133            "optional"
134        };
135        let _ = writeln!(out, "  - {name}: {desc} ({ty}, {req})");
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::file::ReadParams;
143    use crate::shell::BashParams;
144
145    fn sample_tools() -> Vec<ToolDef> {
146        vec![
147            ToolDef {
148                id: "bash".into(),
149                description: "Execute a shell command".into(),
150                schema: schemars::schema_for!(BashParams),
151                invocation: InvocationHint::FencedBlock("bash"),
152                output_schema: None,
153            },
154            ToolDef {
155                id: "read".into(),
156                description: "Read file contents".into(),
157                schema: schemars::schema_for!(ReadParams),
158                invocation: InvocationHint::ToolCall,
159                output_schema: None,
160            },
161        ]
162    }
163
164    #[test]
165    fn from_definitions_stores_tools() {
166        let reg = ToolRegistry::from_definitions(sample_tools());
167        assert_eq!(reg.tools().len(), 2);
168    }
169
170    #[test]
171    fn default_registry_is_empty() {
172        let reg = ToolRegistry::default();
173        assert!(reg.tools().is_empty());
174    }
175
176    #[test]
177    fn find_existing_tool() {
178        let reg = ToolRegistry::from_definitions(sample_tools());
179        assert!(reg.find("bash").is_some());
180        assert!(reg.find("read").is_some());
181    }
182
183    #[test]
184    fn find_nonexistent_returns_none() {
185        let reg = ToolRegistry::from_definitions(sample_tools());
186        assert!(reg.find("nonexistent").is_none());
187    }
188
189    #[test]
190    fn format_for_prompt_contains_tools() {
191        let reg = ToolRegistry::from_definitions(sample_tools());
192        let prompt =
193            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
194        assert!(prompt.contains("<tools>"));
195        assert!(prompt.contains("</tools>"));
196        assert!(prompt.contains("## bash"));
197        assert!(prompt.contains("## read"));
198    }
199
200    #[test]
201    fn format_for_prompt_shows_invocation_fenced() {
202        let reg = ToolRegistry::from_definitions(sample_tools());
203        let prompt =
204            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
205        assert!(prompt.contains("Invocation: use ```bash fenced block"));
206    }
207
208    #[test]
209    fn format_for_prompt_shows_invocation_tool_call() {
210        let reg = ToolRegistry::from_definitions(sample_tools());
211        let prompt =
212            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
213        assert!(prompt.contains("Invocation: use tool_call"));
214        assert!(prompt.contains("\"tool_id\": \"read\""));
215    }
216
217    #[test]
218    fn format_for_prompt_shows_param_info() {
219        let reg = ToolRegistry::from_definitions(sample_tools());
220        let prompt =
221            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
222        assert!(prompt.contains("command:"));
223        assert!(prompt.contains("required"));
224        assert!(prompt.contains("string"));
225    }
226
227    #[test]
228    fn format_for_prompt_shows_optional_params() {
229        let reg = ToolRegistry::from_definitions(sample_tools());
230        let prompt =
231            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
232        assert!(prompt.contains("offset:"));
233        assert!(prompt.contains("optional"));
234        assert!(
235            prompt.contains("(integer, optional)"),
236            "Option<u32> should render as integer, not string: {prompt}"
237        );
238    }
239
240    #[test]
241    fn format_filtered_excludes_fully_denied() {
242        use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
243        use std::collections::HashMap;
244        let mut rules = HashMap::new();
245        rules.insert(
246            "bash".to_owned(),
247            vec![PermissionRule {
248                pattern: "*".to_owned(),
249                action: PermissionAction::Deny,
250            }],
251        );
252        let policy = PermissionPolicy::new(rules);
253        let reg = ToolRegistry::from_definitions(sample_tools());
254        let prompt = reg.format_for_prompt_filtered(&policy);
255        assert!(!prompt.contains("## bash"));
256        assert!(prompt.contains("## read"));
257    }
258
259    #[test]
260    fn format_filtered_includes_mixed_rules() {
261        use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
262        use std::collections::HashMap;
263        let mut rules = HashMap::new();
264        rules.insert(
265            "bash".to_owned(),
266            vec![
267                PermissionRule {
268                    pattern: "echo *".to_owned(),
269                    action: PermissionAction::Allow,
270                },
271                PermissionRule {
272                    pattern: "*".to_owned(),
273                    action: PermissionAction::Deny,
274                },
275            ],
276        );
277        let policy = PermissionPolicy::new(rules);
278        let reg = ToolRegistry::from_definitions(sample_tools());
279        let prompt = reg.format_for_prompt_filtered(&policy);
280        assert!(prompt.contains("## bash"));
281    }
282
283    #[test]
284    fn format_filtered_no_rules_includes_all() {
285        let policy = crate::permissions::PermissionPolicy::default();
286        let reg = ToolRegistry::from_definitions(sample_tools());
287        let prompt = reg.format_for_prompt_filtered(&policy);
288        assert!(prompt.contains("## bash"));
289        assert!(prompt.contains("## read"));
290    }
291}