Skip to main content

zeph_tools/
registry.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::borrow::Cow;
5use std::fmt::Write;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum InvocationHint {
9    /// Tool invoked via ```{tag}\n...\n``` fenced block in LLM response
10    FencedBlock(&'static str),
11    /// Tool invoked via structured `ToolCall` JSON
12    ToolCall,
13}
14
15#[derive(Debug, Clone)]
16pub struct ToolDef {
17    pub id: Cow<'static, str>,
18    pub description: Cow<'static, str>,
19    pub schema: schemars::Schema,
20    pub invocation: InvocationHint,
21}
22
23#[derive(Debug, Default)]
24pub struct ToolRegistry {
25    tools: Vec<ToolDef>,
26}
27
28impl ToolRegistry {
29    #[must_use]
30    pub fn from_definitions(tools: Vec<ToolDef>) -> Self {
31        Self { tools }
32    }
33
34    #[must_use]
35    pub fn tools(&self) -> &[ToolDef] {
36        &self.tools
37    }
38
39    #[must_use]
40    pub fn find(&self, id: &str) -> Option<&ToolDef> {
41        self.tools.iter().find(|t| t.id.as_ref() == id)
42    }
43
44    /// Format tools for prompt, excluding tools fully denied by policy.
45    #[must_use]
46    pub fn format_for_prompt_filtered(
47        &self,
48        policy: &crate::permissions::PermissionPolicy,
49    ) -> String {
50        let mut out = String::from("<tools>\n");
51        for tool in &self.tools {
52            if policy.is_fully_denied(&tool.id) {
53                continue;
54            }
55            format_tool(&mut out, tool);
56        }
57        out.push_str("</tools>");
58        out
59    }
60}
61
62fn format_tool(out: &mut String, tool: &ToolDef) {
63    let _ = writeln!(out, "## {}", tool.id);
64    let _ = writeln!(out, "{}", tool.description);
65    match tool.invocation {
66        InvocationHint::FencedBlock(tag) => {
67            let _ = writeln!(out, "Invocation: use ```{tag} fenced block");
68        }
69        InvocationHint::ToolCall => {
70            let _ = writeln!(
71                out,
72                "Invocation: use tool_call with {{\"tool_id\": \"{}\", \"params\": {{...}}}}",
73                tool.id
74            );
75        }
76    }
77    format_schema_params(out, &tool.schema);
78    out.push('\n');
79}
80
81/// Extract the primary type when schemars renders `Option<T>` as `"type": ["T", "null"]`
82/// or `"anyOf": [{"type": "T"}, {"type": "null"}]`.
83fn extract_non_null_type(obj: &serde_json::Map<String, serde_json::Value>) -> Option<&str> {
84    if let Some(arr) = obj.get("type").and_then(|v| v.as_array()) {
85        return arr.iter().filter_map(|v| v.as_str()).find(|t| *t != "null");
86    }
87    obj.get("anyOf")?
88        .as_array()?
89        .iter()
90        .filter_map(|v| v.as_object())
91        .filter_map(|o| o.get("type")?.as_str())
92        .find(|t| *t != "null")
93}
94
95fn format_schema_params(out: &mut String, schema: &schemars::Schema) {
96    let Some(obj) = schema.as_object() else {
97        return;
98    };
99    let Some(serde_json::Value::Object(props)) = obj.get("properties") else {
100        return;
101    };
102    if props.is_empty() {
103        return;
104    }
105
106    let required: Vec<&str> = obj
107        .get("required")
108        .and_then(|v| v.as_array())
109        .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
110        .unwrap_or_default();
111
112    let _ = writeln!(out, "Parameters:");
113    for (name, prop) in props {
114        let prop_obj = prop.as_object();
115        let ty = prop_obj
116            .and_then(|o| {
117                o.get("type")
118                    .and_then(|v| v.as_str())
119                    .or_else(|| extract_non_null_type(o))
120            })
121            .unwrap_or("string");
122        let desc = prop_obj
123            .and_then(|o| o.get("description"))
124            .and_then(|v| v.as_str())
125            .unwrap_or("");
126        let req = if required.contains(&name.as_str()) {
127            "required"
128        } else {
129            "optional"
130        };
131        let _ = writeln!(out, "  - {name}: {desc} ({ty}, {req})");
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::file::ReadParams;
139    use crate::shell::BashParams;
140
141    fn sample_tools() -> Vec<ToolDef> {
142        vec![
143            ToolDef {
144                id: "bash".into(),
145                description: "Execute a shell command".into(),
146                schema: schemars::schema_for!(BashParams),
147                invocation: InvocationHint::FencedBlock("bash"),
148            },
149            ToolDef {
150                id: "read".into(),
151                description: "Read file contents".into(),
152                schema: schemars::schema_for!(ReadParams),
153                invocation: InvocationHint::ToolCall,
154            },
155        ]
156    }
157
158    #[test]
159    fn from_definitions_stores_tools() {
160        let reg = ToolRegistry::from_definitions(sample_tools());
161        assert_eq!(reg.tools().len(), 2);
162    }
163
164    #[test]
165    fn default_registry_is_empty() {
166        let reg = ToolRegistry::default();
167        assert!(reg.tools().is_empty());
168    }
169
170    #[test]
171    fn find_existing_tool() {
172        let reg = ToolRegistry::from_definitions(sample_tools());
173        assert!(reg.find("bash").is_some());
174        assert!(reg.find("read").is_some());
175    }
176
177    #[test]
178    fn find_nonexistent_returns_none() {
179        let reg = ToolRegistry::from_definitions(sample_tools());
180        assert!(reg.find("nonexistent").is_none());
181    }
182
183    #[test]
184    fn format_for_prompt_contains_tools() {
185        let reg = ToolRegistry::from_definitions(sample_tools());
186        let prompt =
187            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
188        assert!(prompt.contains("<tools>"));
189        assert!(prompt.contains("</tools>"));
190        assert!(prompt.contains("## bash"));
191        assert!(prompt.contains("## read"));
192    }
193
194    #[test]
195    fn format_for_prompt_shows_invocation_fenced() {
196        let reg = ToolRegistry::from_definitions(sample_tools());
197        let prompt =
198            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
199        assert!(prompt.contains("Invocation: use ```bash fenced block"));
200    }
201
202    #[test]
203    fn format_for_prompt_shows_invocation_tool_call() {
204        let reg = ToolRegistry::from_definitions(sample_tools());
205        let prompt =
206            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
207        assert!(prompt.contains("Invocation: use tool_call"));
208        assert!(prompt.contains("\"tool_id\": \"read\""));
209    }
210
211    #[test]
212    fn format_for_prompt_shows_param_info() {
213        let reg = ToolRegistry::from_definitions(sample_tools());
214        let prompt =
215            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
216        assert!(prompt.contains("command:"));
217        assert!(prompt.contains("required"));
218        assert!(prompt.contains("string"));
219    }
220
221    #[test]
222    fn format_for_prompt_shows_optional_params() {
223        let reg = ToolRegistry::from_definitions(sample_tools());
224        let prompt =
225            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
226        assert!(prompt.contains("offset:"));
227        assert!(prompt.contains("optional"));
228        assert!(
229            prompt.contains("(integer, optional)"),
230            "Option<u32> should render as integer, not string: {prompt}"
231        );
232    }
233
234    #[test]
235    fn format_filtered_excludes_fully_denied() {
236        use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
237        use std::collections::HashMap;
238        let mut rules = HashMap::new();
239        rules.insert(
240            "bash".to_owned(),
241            vec![PermissionRule {
242                pattern: "*".to_owned(),
243                action: PermissionAction::Deny,
244            }],
245        );
246        let policy = PermissionPolicy::new(rules);
247        let reg = ToolRegistry::from_definitions(sample_tools());
248        let prompt = reg.format_for_prompt_filtered(&policy);
249        assert!(!prompt.contains("## bash"));
250        assert!(prompt.contains("## read"));
251    }
252
253    #[test]
254    fn format_filtered_includes_mixed_rules() {
255        use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
256        use std::collections::HashMap;
257        let mut rules = HashMap::new();
258        rules.insert(
259            "bash".to_owned(),
260            vec![
261                PermissionRule {
262                    pattern: "echo *".to_owned(),
263                    action: PermissionAction::Allow,
264                },
265                PermissionRule {
266                    pattern: "*".to_owned(),
267                    action: PermissionAction::Deny,
268                },
269            ],
270        );
271        let policy = PermissionPolicy::new(rules);
272        let reg = ToolRegistry::from_definitions(sample_tools());
273        let prompt = reg.format_for_prompt_filtered(&policy);
274        assert!(prompt.contains("## bash"));
275    }
276
277    #[test]
278    fn format_filtered_no_rules_includes_all() {
279        let policy = crate::permissions::PermissionPolicy::default();
280        let reg = ToolRegistry::from_definitions(sample_tools());
281        let prompt = reg.format_for_prompt_filtered(&policy);
282        assert!(prompt.contains("## bash"));
283        assert!(prompt.contains("## read"));
284    }
285}