smolagents_rs/tools/
tool_traits.rs

1//! This module contains the traits for tools that can be used in an agent.
2
3use anyhow::Result;
4use schemars::gen::SchemaSettings;
5use schemars::schema::RootSchema;
6use schemars::JsonSchema;
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9use serde_json::json;
10use std::fmt::Debug;
11
12use crate::models::openai::FunctionCall;
13
14/// A trait for parameters that can be used in a tool. This defines the arguments that can be passed to the tool.
15pub trait Parameters: DeserializeOwned + JsonSchema {}
16
17/// A trait for tools that can be used in an agent.
18pub trait Tool: Debug {
19    type Params: Parameters;
20    /// The name of the tool.
21    fn name(&self) -> &'static str;
22    /// The description of the tool.
23    fn description(&self) -> &'static str;
24    /// The function to call when the tool is used.
25    fn forward(&self, arguments: Self::Params) -> Result<String>;
26}
27
28#[derive(serde::Serialize, serde::Deserialize, Debug)]
29pub enum ToolType {
30    #[serde(rename = "function")]
31    Function,
32}
33
34/// A struct that contains information about a tool. This is used to serialize the tool for the API.
35#[derive(Serialize, Debug)]
36pub struct ToolInfo {
37    #[serde(rename = "type")]
38    tool_type: ToolType,
39    pub function: ToolFunctionInfo,
40}
41/// This struct contains information about the function to call when the tool is used.
42#[derive(Serialize, Debug)]
43pub struct ToolFunctionInfo {
44    pub name: &'static str,
45    pub description: &'static str,
46    pub parameters: RootSchema,
47}
48
49impl ToolInfo {
50    pub fn new<P: Parameters, T: AnyTool>(tool: &T) -> Self {
51        let mut settings = SchemaSettings::draft07();
52        settings.inline_subschemas = true;
53        let generator = settings.into_generator();
54
55        let parameters = generator.into_root_schema_for::<P>();
56
57        Self {
58            tool_type: ToolType::Function,
59            function: ToolFunctionInfo {
60                name: tool.name(),
61                description: tool.description(),
62                parameters,
63            },
64        }
65    }
66
67    pub fn get_parameter_names(&self) -> Vec<String> {
68        if let Some(schema) = &self.function.parameters.schema.object {
69            return schema.properties.keys().cloned().collect();
70        }
71        Vec::new()
72    }
73}
74
75pub fn get_json_schema(tool: &ToolInfo) -> serde_json::Value {
76    json!(tool)
77}
78
79pub trait ToolGroup: Debug {
80    fn call(&self, arguments: &FunctionCall) -> Result<String>;
81    fn tool_info(&self) -> Vec<ToolInfo>;
82}
83
84impl ToolGroup for Vec<Box<dyn AnyTool>> {
85    fn call(&self, arguments: &FunctionCall) -> Result<String> {
86        let tool = self.iter().find(|tool| tool.name() == arguments.name);
87        if let Some(tool) = tool {
88            let p = arguments.arguments.clone();
89            return tool.forward_json(p);
90        }
91        Err(anyhow::anyhow!("Tool not found"))
92    }
93    fn tool_info(&self) -> Vec<ToolInfo> {
94        self.iter().map(|tool| tool.tool_info()).collect()
95    }
96}
97
98pub trait AnyTool: Debug {
99    fn name(&self) -> &'static str;
100    fn description(&self) -> &'static str;
101    fn forward_json(&self, json_args: serde_json::Value) -> Result<String>;
102    fn tool_info(&self) -> ToolInfo;
103    fn clone_box(&self) -> Box<dyn AnyTool>;
104}
105
106impl<T: Tool + Clone + 'static> AnyTool for T {
107    fn name(&self) -> &'static str {
108        Tool::name(self)
109    }
110
111    fn description(&self) -> &'static str {
112        Tool::description(self)
113    }
114
115    fn forward_json(&self, json_args: serde_json::Value) -> Result<String> {
116        let params = serde_json::from_value::<T::Params>(json_args)?;
117        Tool::forward(self, params)
118    }
119
120    fn tool_info(&self) -> ToolInfo {
121        ToolInfo::new::<T::Params, T>(self)
122    }
123
124    fn clone_box(&self) -> Box<dyn AnyTool> {
125        Box::new(self.clone())
126    }
127}
128