1use std::sync::Arc;
2
3use async_trait::async_trait;
4use futures::future::BoxFuture;
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use thiserror::Error;
10
11use crate::messages::{ModelMessage, UserContent};
12use crate::model::Model;
13use crate::usage::RunUsage;
14
15#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
16pub enum ToolKind {
17 Function,
18 Output,
19 External,
20 Unapproved,
21}
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24pub struct ToolDefinition {
25 pub name: String,
26 pub description: Option<String>,
27 pub parameters_json_schema: Value,
28 pub kind: ToolKind,
29 pub sequential: bool,
30 pub metadata: Option<Value>,
31 pub timeout: Option<f64>,
32}
33
34impl ToolDefinition {
35 pub fn new(
36 name: impl Into<String>,
37 description: Option<String>,
38 parameters_json_schema: Value,
39 ) -> Self {
40 Self {
41 name: name.into(),
42 description,
43 parameters_json_schema,
44 kind: ToolKind::Function,
45 sequential: false,
46 metadata: None,
47 timeout: None,
48 }
49 }
50
51 pub fn with_kind(mut self, kind: ToolKind) -> Self {
52 self.kind = kind;
53 self
54 }
55
56 pub fn with_metadata(mut self, metadata: Value) -> Self {
57 self.metadata = Some(metadata);
58 self
59 }
60
61 pub fn with_sequential(mut self, sequential: bool) -> Self {
62 self.sequential = sequential;
63 self
64 }
65
66 pub fn with_timeout(mut self, timeout: f64) -> Self {
67 self.timeout = Some(timeout);
68 self
69 }
70}
71
72#[derive(Clone)]
73pub struct RunContext<Deps> {
74 pub run_id: String,
75 pub deps: Arc<Deps>,
76 pub model: Arc<dyn Model>,
77 pub usage: RunUsage,
78 pub prompt: Option<Arc<Vec<UserContent>>>,
79 pub messages: Arc<Vec<ModelMessage>>,
80 pub tool_call_id: Option<String>,
81 pub tool_name: Option<String>,
82}
83
84impl<Deps> RunContext<Deps> {
85 pub fn for_tool_call(&self, tool_call_id: String, tool_name: String) -> Self {
86 Self {
87 run_id: self.run_id.clone(),
88 deps: Arc::clone(&self.deps),
89 model: Arc::clone(&self.model),
90 usage: self.usage.clone(),
91 prompt: self.prompt.clone(),
92 messages: Arc::clone(&self.messages),
93 tool_call_id: Some(tool_call_id),
94 tool_name: Some(tool_name),
95 }
96 }
97}
98
99#[async_trait]
100pub trait Tool<Deps>: Send + Sync {
101 fn definition(&self) -> ToolDefinition;
102
103 async fn call(&self, ctx: RunContext<Deps>, args: Value) -> Result<Value, ToolError>;
104}
105
106#[async_trait]
107pub trait Toolset<Deps>: Send + Sync {
108 async fn list_tools(&self, ctx: &RunContext<Deps>) -> Result<Vec<ToolDefinition>, ToolError>;
109
110 async fn call_tool(
111 &self,
112 ctx: &RunContext<Deps>,
113 name: &str,
114 args: Value,
115 ) -> Result<Value, ToolError>;
116
117 async fn enter(&self) -> Result<(), ToolError> {
118 Ok(())
119 }
120
121 async fn exit(&self) -> Result<(), ToolError> {
122 Ok(())
123 }
124
125 fn name(&self) -> &str {
126 "toolset"
127 }
128}
129
130pub struct FunctionTool<Deps> {
131 definition: ToolDefinition,
132 handler: Arc<ToolHandler<Deps>>,
133}
134
135type ToolHandler<Deps> =
136 dyn Fn(RunContext<Deps>, Value) -> BoxFuture<'static, Result<Value, ToolError>> + Send + Sync;
137
138impl<Deps> FunctionTool<Deps>
139where
140 Deps: Send + Sync + 'static,
141{
142 pub fn new<Args, Output, Func, Fut>(
143 name: impl Into<String>,
144 description: impl Into<String>,
145 func: Func,
146 ) -> Result<Self, ToolError>
147 where
148 Args: DeserializeOwned + JsonSchema + Send + 'static,
149 Output: Serialize + Send + 'static,
150 Func: Fn(RunContext<Deps>, Args) -> Fut + Send + Sync + 'static,
151 Fut: std::future::Future<Output = Result<Output, ToolError>> + Send + 'static,
152 {
153 let name = name.into();
154 let description = Some(description.into());
155 let schema = schemars::schema_for!(Args);
156 let parameters_json_schema = serde_json::to_value(&schema).map_err(ToolError::Serde)?;
157
158 let definition = ToolDefinition::new(name, description, parameters_json_schema);
159 let func = Arc::new(func);
160 let handler = Arc::new(move |ctx: RunContext<Deps>, args: Value| {
161 let parsed = serde_json::from_value(args).map_err(ToolError::InvalidArgs);
162 let func = Arc::clone(&func);
163 let fut = async move {
164 let parsed = parsed?;
165 let output = func(ctx, parsed).await?;
166 let value = serde_json::to_value(output).map_err(ToolError::Serde)?;
167 Ok(value)
168 };
169 Box::pin(fut) as BoxFuture<'static, Result<Value, ToolError>>
170 });
171
172 Ok(Self {
173 definition,
174 handler,
175 })
176 }
177
178 pub fn with_kind(mut self, kind: ToolKind) -> Self {
179 self.definition.kind = kind;
180 self
181 }
182
183 pub fn with_sequential(mut self, sequential: bool) -> Self {
184 self.definition.sequential = sequential;
185 self
186 }
187
188 pub fn with_timeout(mut self, timeout: f64) -> Self {
189 self.definition.timeout = Some(timeout);
190 self
191 }
192}
193
194#[async_trait]
195impl<Deps> Tool<Deps> for FunctionTool<Deps>
196where
197 Deps: Send + Sync + 'static,
198{
199 fn definition(&self) -> ToolDefinition {
200 self.definition.clone()
201 }
202
203 async fn call(&self, ctx: RunContext<Deps>, args: Value) -> Result<Value, ToolError> {
204 (self.handler)(ctx, args).await
205 }
206}
207
208#[derive(Debug, Error)]
209pub enum ToolError {
210 #[error("invalid tool arguments: {0}")]
211 InvalidArgs(serde_json::Error),
212 #[error("tool execution failed: {0}")]
213 Execution(String),
214 #[error("serialization error: {0}")]
215 Serde(serde_json::Error),
216 #[error("toolset error: {0}")]
217 Toolset(String),
218}