steer_core/tools/
builtin_tool.rs1use std::error::Error as StdError;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use tokio_util::sync::CancellationToken;
8
9use crate::app::domain::types::{SessionId, ToolCallId};
10use crate::config::model::ModelId;
11use steer_tools::error::ToolExecutionError;
12use steer_tools::result::ToolResult;
13use steer_tools::{ToolSchema, ToolSpec};
14
15use super::capability::Capabilities;
16use super::services::ToolServices;
17
18#[derive(Debug, Clone)]
19pub struct BuiltinToolContext {
20 pub tool_call_id: ToolCallId,
21 pub session_id: SessionId,
22 pub invoking_model: Option<ModelId>,
23 pub cancellation_token: CancellationToken,
24 pub services: Arc<ToolServices>,
25}
26
27impl BuiltinToolContext {
28 pub fn is_cancelled(&self) -> bool {
29 self.cancellation_token.is_cancelled()
30 }
31}
32
33#[derive(Debug, thiserror::Error)]
34pub enum BuiltinToolError<E: StdError + Send + Sync + 'static> {
35 #[error("Invalid parameters: {0}")]
36 InvalidParams(String),
37
38 #[error("{0}")]
39 Execution(E),
40
41 #[error("Missing capability: {0}")]
42 MissingCapability(String),
43
44 #[error("Cancelled")]
45 Cancelled,
46
47 #[error("Timed out")]
48 Timeout,
49}
50
51impl<E: StdError + Send + Sync + 'static> BuiltinToolError<E> {
52 pub fn invalid_params(msg: impl Into<String>) -> Self {
53 Self::InvalidParams(msg.into())
54 }
55
56 pub fn execution(error: E) -> Self {
57 Self::Execution(error)
58 }
59
60 pub fn missing_capability(cap: &str) -> Self {
61 Self::MissingCapability(cap.to_string())
62 }
63
64 pub fn map_execution<F, E2>(self, f: F) -> BuiltinToolError<E2>
65 where
66 F: FnOnce(E) -> E2,
67 E2: StdError + Send + Sync + 'static,
68 {
69 match self {
70 BuiltinToolError::InvalidParams(msg) => BuiltinToolError::InvalidParams(msg),
71 BuiltinToolError::Execution(err) => BuiltinToolError::Execution(f(err)),
72 BuiltinToolError::MissingCapability(cap) => BuiltinToolError::MissingCapability(cap),
73 BuiltinToolError::Cancelled => BuiltinToolError::Cancelled,
74 BuiltinToolError::Timeout => BuiltinToolError::Timeout,
75 }
76 }
77}
78
79fn build_tool_schema<P, S>(description: String) -> ToolSchema
80where
81 P: JsonSchema,
82 S: ToolSpec<Params = P>,
83{
84 let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
85 s.inline_subschemas = true;
86 });
87 let schema_gen = settings.into_generator();
88 let input_schema = schema_gen.into_root_schema_for::<P>();
89
90 ToolSchema {
91 name: S::NAME.to_string(),
92 display_name: S::DISPLAY_NAME.to_string(),
93 description,
94 input_schema: input_schema.into(),
95 }
96}
97
98pub fn schema_with_description<P, S>(description: impl Into<String>) -> ToolSchema
99where
100 P: JsonSchema,
101 S: ToolSpec<Params = P>,
102{
103 build_tool_schema::<P, S>(description.into())
104}
105
106#[async_trait]
107pub trait BuiltinTool: Send + Sync + 'static {
108 type Params: DeserializeOwned + JsonSchema + Send;
109 type Output: Into<ToolResult> + Send;
110 type Spec: ToolSpec<Params = Self::Params, Result = Self::Output>;
111
112 const DESCRIPTION: &'static str;
113 const REQUIRES_APPROVAL: bool;
114 const REQUIRED_CAPABILITIES: Capabilities;
115
116 async fn execute(
117 &self,
118 params: Self::Params,
119 ctx: &BuiltinToolContext,
120 ) -> Result<Self::Output, BuiltinToolError<<Self::Spec as ToolSpec>::Error>>;
121
122 fn schema() -> ToolSchema
123 where
124 Self: Sized,
125 {
126 build_tool_schema::<Self::Params, Self::Spec>(Self::DESCRIPTION.to_string())
127 }
128}
129
130#[async_trait]
131pub trait BuiltinToolErased: Send + Sync {
132 fn name(&self) -> &'static str;
133 fn requires_approval(&self) -> bool;
134 fn required_capabilities(&self) -> Capabilities;
135 fn schema(&self) -> ToolSchema;
136
137 async fn execute_erased(
138 &self,
139 params: serde_json::Value,
140 ctx: &BuiltinToolContext,
141 ) -> Result<ToolResult, BuiltinToolError<ToolExecutionError>>;
142}
143
144#[async_trait]
145impl<T> BuiltinToolErased for T
146where
147 T: BuiltinTool,
148{
149 fn name(&self) -> &'static str {
150 T::Spec::NAME
151 }
152
153 fn requires_approval(&self) -> bool {
154 T::REQUIRES_APPROVAL
155 }
156
157 fn required_capabilities(&self) -> Capabilities {
158 T::REQUIRED_CAPABILITIES
159 }
160
161 fn schema(&self) -> ToolSchema {
162 T::schema()
163 }
164
165 async fn execute_erased(
166 &self,
167 params: serde_json::Value,
168 ctx: &BuiltinToolContext,
169 ) -> Result<ToolResult, BuiltinToolError<ToolExecutionError>> {
170 let typed_params: T::Params = serde_json::from_value(params)
171 .map_err(|e| BuiltinToolError::invalid_params(e.to_string()))?;
172
173 if ctx.is_cancelled() {
174 return Err(BuiltinToolError::Cancelled);
175 }
176
177 let result = self
178 .execute(typed_params, ctx)
179 .await
180 .map_err(|e| e.map_execution(T::Spec::execution_error))?;
181 Ok(result.into())
182 }
183}