serdes_ai_toolsets/
abstract_toolset.rs1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
17pub struct ToolsetTool {
18 pub toolset_id: Option<String>,
20 pub tool_def: ToolDefinition,
22 pub max_retries: u32,
24}
25
26impl ToolsetTool {
27 #[must_use]
29 pub fn new(tool_def: ToolDefinition) -> Self {
30 Self {
31 toolset_id: None,
32 tool_def,
33 max_retries: 3,
34 }
35 }
36
37 #[must_use]
39 pub fn with_toolset_id(mut self, id: impl Into<String>) -> Self {
40 self.toolset_id = Some(id.into());
41 self
42 }
43
44 #[must_use]
46 pub fn with_max_retries(mut self, retries: u32) -> Self {
47 self.max_retries = retries;
48 self
49 }
50
51 #[must_use]
53 pub fn name(&self) -> &str {
54 &self.tool_def.name
55 }
56
57 #[must_use]
59 pub fn description(&self) -> &str {
60 &self.tool_def.description
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ToolsetInfo {
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub id: Option<String>,
70 pub type_name: String,
72 pub tool_count: usize,
74 pub tool_names: Vec<String>,
76}
77
78#[async_trait]
89pub trait AbstractToolset<Deps = ()>: Send + Sync {
90 fn id(&self) -> Option<&str>;
92
93 fn label(&self) -> String {
95 let mut label = self.type_name().to_string();
96 if let Some(id) = self.id() {
97 label.push_str(&format!(" '{}'", id));
98 }
99 label
100 }
101
102 fn type_name(&self) -> &'static str {
104 std::any::type_name::<Self>()
105 }
106
107 fn tool_name_conflict_hint(&self) -> String {
109 format!(
110 "Rename the tool or use PrefixedToolset to avoid conflicts in {}.",
111 self.label()
112 )
113 }
114
115 async fn get_tools(
119 &self,
120 ctx: &RunContext<Deps>,
121 ) -> Result<HashMap<String, ToolsetTool>, ToolError>;
122
123 async fn call_tool(
125 &self,
126 name: &str,
127 args: JsonValue,
128 ctx: &RunContext<Deps>,
129 tool: &ToolsetTool,
130 ) -> Result<ToolReturn, ToolError>;
131
132 async fn enter(&self) -> Result<(), ToolError> {
136 Ok(())
137 }
138
139 async fn exit(&self) -> Result<(), ToolError> {
143 Ok(())
144 }
145}
146
147pub type BoxedToolset<Deps> = Box<dyn AbstractToolset<Deps>>;
149
150pub type ToolsetResult<T> = Result<T, ToolError>;
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_toolset_tool() {
159 let def = ToolDefinition::new("test", "Test tool");
160 let tool = ToolsetTool::new(def)
161 .with_toolset_id("my_toolset")
162 .with_max_retries(5);
163
164 assert_eq!(tool.name(), "test");
165 assert_eq!(tool.toolset_id, Some("my_toolset".to_string()));
166 assert_eq!(tool.max_retries, 5);
167 }
168
169 #[test]
170 fn test_toolset_info_serde() {
171 let info = ToolsetInfo {
172 id: Some("test_id".to_string()),
173 type_name: "TestToolset".to_string(),
174 tool_count: 3,
175 tool_names: vec!["a".to_string(), "b".to_string(), "c".to_string()],
176 };
177
178 let json = serde_json::to_string(&info).unwrap();
179 let parsed: ToolsetInfo = serde_json::from_str(&json).unwrap();
180 assert_eq!(info.id, parsed.id);
181 assert_eq!(info.tool_count, parsed.tool_count);
182 }
183}