Skip to main content

rustic_ai/
tools.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use futures::future::BoxFuture;
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use thiserror::Error;
10
11use crate::messages::{ModelMessage, UserContent};
12use crate::model::Model;
13use crate::usage::RunUsage;
14
15#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
16pub enum ToolKind {
17    Function,
18    Output,
19    External,
20    Unapproved,
21}
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24pub struct ToolDefinition {
25    pub name: String,
26    pub description: Option<String>,
27    pub parameters_json_schema: Value,
28    pub kind: ToolKind,
29    pub sequential: bool,
30    pub metadata: Option<Value>,
31    pub timeout: Option<f64>,
32}
33
34impl ToolDefinition {
35    pub fn new(
36        name: impl Into<String>,
37        description: Option<String>,
38        parameters_json_schema: Value,
39    ) -> Self {
40        Self {
41            name: name.into(),
42            description,
43            parameters_json_schema,
44            kind: ToolKind::Function,
45            sequential: false,
46            metadata: None,
47            timeout: None,
48        }
49    }
50
51    pub fn with_kind(mut self, kind: ToolKind) -> Self {
52        self.kind = kind;
53        self
54    }
55
56    pub fn with_metadata(mut self, metadata: Value) -> Self {
57        self.metadata = Some(metadata);
58        self
59    }
60
61    pub fn with_sequential(mut self, sequential: bool) -> Self {
62        self.sequential = sequential;
63        self
64    }
65
66    pub fn with_timeout(mut self, timeout: f64) -> Self {
67        self.timeout = Some(timeout);
68        self
69    }
70}
71
72#[derive(Clone)]
73pub struct RunContext<Deps> {
74    pub run_id: String,
75    pub deps: Arc<Deps>,
76    pub model: Arc<dyn Model>,
77    pub usage: RunUsage,
78    pub prompt: Option<Arc<Vec<UserContent>>>,
79    pub messages: Arc<Vec<ModelMessage>>,
80    pub tool_call_id: Option<String>,
81    pub tool_name: Option<String>,
82}
83
84impl<Deps> RunContext<Deps> {
85    pub fn for_tool_call(&self, tool_call_id: String, tool_name: String) -> Self {
86        Self {
87            run_id: self.run_id.clone(),
88            deps: Arc::clone(&self.deps),
89            model: Arc::clone(&self.model),
90            usage: self.usage.clone(),
91            prompt: self.prompt.clone(),
92            messages: Arc::clone(&self.messages),
93            tool_call_id: Some(tool_call_id),
94            tool_name: Some(tool_name),
95        }
96    }
97}
98
99#[async_trait]
100pub trait Tool<Deps>: Send + Sync {
101    fn definition(&self) -> ToolDefinition;
102
103    async fn call(&self, ctx: RunContext<Deps>, args: Value) -> Result<Value, ToolError>;
104}
105
106#[async_trait]
107pub trait Toolset<Deps>: Send + Sync {
108    async fn list_tools(&self, ctx: &RunContext<Deps>) -> Result<Vec<ToolDefinition>, ToolError>;
109
110    async fn call_tool(
111        &self,
112        ctx: &RunContext<Deps>,
113        name: &str,
114        args: Value,
115    ) -> Result<Value, ToolError>;
116
117    async fn enter(&self) -> Result<(), ToolError> {
118        Ok(())
119    }
120
121    async fn exit(&self) -> Result<(), ToolError> {
122        Ok(())
123    }
124
125    fn name(&self) -> &str {
126        "toolset"
127    }
128}
129
130pub struct FunctionTool<Deps> {
131    definition: ToolDefinition,
132    handler: Arc<ToolHandler<Deps>>,
133}
134
135type ToolHandler<Deps> =
136    dyn Fn(RunContext<Deps>, Value) -> BoxFuture<'static, Result<Value, ToolError>> + Send + Sync;
137
138impl<Deps> FunctionTool<Deps>
139where
140    Deps: Send + Sync + 'static,
141{
142    pub fn new<Args, Output, Func, Fut>(
143        name: impl Into<String>,
144        description: impl Into<String>,
145        func: Func,
146    ) -> Result<Self, ToolError>
147    where
148        Args: DeserializeOwned + JsonSchema + Send + 'static,
149        Output: Serialize + Send + 'static,
150        Func: Fn(RunContext<Deps>, Args) -> Fut + Send + Sync + 'static,
151        Fut: std::future::Future<Output = Result<Output, ToolError>> + Send + 'static,
152    {
153        let name = name.into();
154        let description = Some(description.into());
155        let schema = schemars::schema_for!(Args);
156        let parameters_json_schema = serde_json::to_value(&schema).map_err(ToolError::Serde)?;
157
158        let definition = ToolDefinition::new(name, description, parameters_json_schema);
159        let func = Arc::new(func);
160        let handler = Arc::new(move |ctx: RunContext<Deps>, args: Value| {
161            let parsed = serde_json::from_value(args).map_err(ToolError::InvalidArgs);
162            let func = Arc::clone(&func);
163            let fut = async move {
164                let parsed = parsed?;
165                let output = func(ctx, parsed).await?;
166                let value = serde_json::to_value(output).map_err(ToolError::Serde)?;
167                Ok(value)
168            };
169            Box::pin(fut) as BoxFuture<'static, Result<Value, ToolError>>
170        });
171
172        Ok(Self {
173            definition,
174            handler,
175        })
176    }
177
178    pub fn with_kind(mut self, kind: ToolKind) -> Self {
179        self.definition.kind = kind;
180        self
181    }
182
183    pub fn with_sequential(mut self, sequential: bool) -> Self {
184        self.definition.sequential = sequential;
185        self
186    }
187
188    pub fn with_timeout(mut self, timeout: f64) -> Self {
189        self.definition.timeout = Some(timeout);
190        self
191    }
192}
193
194#[async_trait]
195impl<Deps> Tool<Deps> for FunctionTool<Deps>
196where
197    Deps: Send + Sync + 'static,
198{
199    fn definition(&self) -> ToolDefinition {
200        self.definition.clone()
201    }
202
203    async fn call(&self, ctx: RunContext<Deps>, args: Value) -> Result<Value, ToolError> {
204        (self.handler)(ctx, args).await
205    }
206}
207
208#[derive(Debug, Error)]
209pub enum ToolError {
210    #[error("invalid tool arguments: {0}")]
211    InvalidArgs(serde_json::Error),
212    #[error("tool execution failed: {0}")]
213    Execution(String),
214    #[error("serialization error: {0}")]
215    Serde(serde_json::Error),
216    #[error("toolset error: {0}")]
217    Toolset(String),
218}