sayr_engine/
tool.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::error::{AgnoError, Result};
9
10#[async_trait]
11pub trait Tool: Send + Sync {
12    fn name(&self) -> &str;
13    fn description(&self) -> &str;
14
15    /// Optionally return a JSON Schema-like object describing expected arguments.
16    fn parameters(&self) -> Option<Value> {
17        None
18    }
19    async fn call(&self, input: Value) -> Result<Value>;
20}
21
22/// Static description of a tool that can be embedded in prompts.
23#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
24pub struct ToolDescription {
25    pub name: String,
26    pub description: String,
27    pub parameters: Option<Value>,
28}
29
30#[derive(Default, Clone)]
31pub struct ToolRegistry {
32    tools: HashMap<String, Arc<dyn Tool>>,
33}
34
35impl ToolRegistry {
36    pub fn new() -> Self {
37        Self {
38            tools: HashMap::new(),
39        }
40    }
41
42    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
43        self.tools.insert(tool.name().to_string(), Arc::new(tool));
44    }
45
46    pub fn names(&self) -> Vec<String> {
47        self.tools.keys().cloned().collect()
48    }
49
50    /// Get a tool by name
51    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
52        self.tools.get(name).cloned()
53    }
54
55    pub fn describe(&self) -> Vec<ToolDescription> {
56        let mut descriptions: Vec<ToolDescription> = self
57            .tools
58            .values()
59            .map(|tool| ToolDescription {
60                name: tool.name().to_string(),
61                description: tool.description().to_string(),
62                parameters: tool.parameters(),
63            })
64            .collect();
65
66        descriptions.sort_by(|a, b| a.name.cmp(&b.name));
67        descriptions
68    }
69
70    pub async fn call(&self, name: &str, input: Value) -> Result<Value> {
71        let tool = self
72            .tools
73            .get(name)
74            .ok_or_else(|| AgnoError::ToolNotFound(name.to_string()))?;
75        tool.call(input)
76            .await
77            .map_err(|source| AgnoError::ToolInvocation {
78                name: name.to_string(),
79                source: Box::new(source),
80            })
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use async_trait::async_trait;
88
89    struct Echo;
90
91    #[async_trait]
92    impl Tool for Echo {
93        fn name(&self) -> &str {
94            "echo"
95        }
96
97        fn description(&self) -> &str {
98            "Echoes whatever input is provided"
99        }
100
101        fn parameters(&self) -> Option<Value> {
102            Some(serde_json::json!({
103                "type": "object",
104                "properties": {"text": {"type": "string"}},
105                "required": ["text"],
106            }))
107        }
108
109        async fn call(&self, input: Value) -> Result<Value> {
110            Ok(input)
111        }
112    }
113
114    #[tokio::test]
115    async fn describes_registered_tools_with_parameters() {
116        let mut registry = ToolRegistry::new();
117        registry.register(Echo);
118
119        let descriptions = registry.describe();
120        assert_eq!(descriptions.len(), 1);
121        let desc = &descriptions[0];
122        assert_eq!(desc.name, "echo");
123        assert!(desc
124            .parameters
125            .as_ref()
126            .unwrap()
127            .get("properties")
128            .is_some());
129
130        let output = registry
131            .call("echo", serde_json::json!({"text":"hi"}))
132            .await
133            .unwrap();
134        assert_eq!(output["text"], "hi");
135    }
136
137    #[tokio::test]
138    async fn describes_tools_in_deterministic_order() {
139        struct Second;
140
141        #[async_trait]
142        impl Tool for Second {
143            fn name(&self) -> &str {
144                "second"
145            }
146
147            fn description(&self) -> &str {
148                "Second tool"
149            }
150
151            async fn call(&self, input: Value) -> Result<Value> {
152                Ok(input)
153            }
154        }
155
156        let mut registry = ToolRegistry::new();
157        registry.register(Echo);
158        registry.register(Second);
159
160        let descriptions = registry.describe();
161        let names: Vec<String> = descriptions.into_iter().map(|d| d.name).collect();
162        assert_eq!(names, vec!["echo", "second"]);
163    }
164}