smolagents_rs/tools/
tool_traits.rs1use anyhow::Result;
4use schemars::gen::SchemaSettings;
5use schemars::schema::RootSchema;
6use schemars::JsonSchema;
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9use serde_json::json;
10use std::fmt::Debug;
11
12use crate::models::openai::FunctionCall;
13
14pub trait Parameters: DeserializeOwned + JsonSchema {}
16
17pub trait Tool: Debug {
19 type Params: Parameters;
20 fn name(&self) -> &'static str;
22 fn description(&self) -> &'static str;
24 fn forward(&self, arguments: Self::Params) -> Result<String>;
26}
27
28#[derive(serde::Serialize, serde::Deserialize, Debug)]
29pub enum ToolType {
30 #[serde(rename = "function")]
31 Function,
32}
33
34#[derive(Serialize, Debug)]
36pub struct ToolInfo {
37 #[serde(rename = "type")]
38 tool_type: ToolType,
39 pub function: ToolFunctionInfo,
40}
41#[derive(Serialize, Debug)]
43pub struct ToolFunctionInfo {
44 pub name: &'static str,
45 pub description: &'static str,
46 pub parameters: RootSchema,
47}
48
49impl ToolInfo {
50 pub fn new<P: Parameters, T: AnyTool>(tool: &T) -> Self {
51 let mut settings = SchemaSettings::draft07();
52 settings.inline_subschemas = true;
53 let generator = settings.into_generator();
54
55 let parameters = generator.into_root_schema_for::<P>();
56
57 Self {
58 tool_type: ToolType::Function,
59 function: ToolFunctionInfo {
60 name: tool.name(),
61 description: tool.description(),
62 parameters,
63 },
64 }
65 }
66
67 pub fn get_parameter_names(&self) -> Vec<String> {
68 if let Some(schema) = &self.function.parameters.schema.object {
69 return schema.properties.keys().cloned().collect();
70 }
71 Vec::new()
72 }
73}
74
75pub fn get_json_schema(tool: &ToolInfo) -> serde_json::Value {
76 json!(tool)
77}
78
79pub trait ToolGroup: Debug {
80 fn call(&self, arguments: &FunctionCall) -> Result<String>;
81 fn tool_info(&self) -> Vec<ToolInfo>;
82}
83
84impl ToolGroup for Vec<Box<dyn AnyTool>> {
85 fn call(&self, arguments: &FunctionCall) -> Result<String> {
86 let tool = self.iter().find(|tool| tool.name() == arguments.name);
87 if let Some(tool) = tool {
88 let p = arguments.arguments.clone();
89 return tool.forward_json(p);
90 }
91 Err(anyhow::anyhow!("Tool not found"))
92 }
93 fn tool_info(&self) -> Vec<ToolInfo> {
94 self.iter().map(|tool| tool.tool_info()).collect()
95 }
96}
97
98pub trait AnyTool: Debug {
99 fn name(&self) -> &'static str;
100 fn description(&self) -> &'static str;
101 fn forward_json(&self, json_args: serde_json::Value) -> Result<String>;
102 fn tool_info(&self) -> ToolInfo;
103 fn clone_box(&self) -> Box<dyn AnyTool>;
104}
105
106impl<T: Tool + Clone + 'static> AnyTool for T {
107 fn name(&self) -> &'static str {
108 Tool::name(self)
109 }
110
111 fn description(&self) -> &'static str {
112 Tool::description(self)
113 }
114
115 fn forward_json(&self, json_args: serde_json::Value) -> Result<String> {
116 let params = serde_json::from_value::<T::Params>(json_args)?;
117 Tool::forward(self, params)
118 }
119
120 fn tool_info(&self) -> ToolInfo {
121 ToolInfo::new::<T::Params, T>(self)
122 }
123
124 fn clone_box(&self) -> Box<dyn AnyTool> {
125 Box::new(self.clone())
126 }
127}
128