Skip to main content

steer_core/tools/
builtin_tool.rs

1use std::error::Error as StdError;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use tokio_util::sync::CancellationToken;
8
9use crate::app::domain::types::{SessionId, ToolCallId};
10use crate::config::model::ModelId;
11use steer_tools::error::ToolExecutionError;
12use steer_tools::result::ToolResult;
13use steer_tools::{ToolSchema, ToolSpec};
14
15use super::capability::Capabilities;
16use super::services::ToolServices;
17
18#[derive(Debug, Clone)]
19pub struct BuiltinToolContext {
20    pub tool_call_id: ToolCallId,
21    pub session_id: SessionId,
22    pub invoking_model: Option<ModelId>,
23    pub cancellation_token: CancellationToken,
24    pub services: Arc<ToolServices>,
25}
26
27impl BuiltinToolContext {
28    pub fn is_cancelled(&self) -> bool {
29        self.cancellation_token.is_cancelled()
30    }
31}
32
33#[derive(Debug, thiserror::Error)]
34pub enum BuiltinToolError<E: StdError + Send + Sync + 'static> {
35    #[error("Invalid parameters: {0}")]
36    InvalidParams(String),
37
38    #[error("{0}")]
39    Execution(E),
40
41    #[error("Missing capability: {0}")]
42    MissingCapability(String),
43
44    #[error("Cancelled")]
45    Cancelled,
46
47    #[error("Timed out")]
48    Timeout,
49}
50
51impl<E: StdError + Send + Sync + 'static> BuiltinToolError<E> {
52    pub fn invalid_params(msg: impl Into<String>) -> Self {
53        Self::InvalidParams(msg.into())
54    }
55
56    pub fn execution(error: E) -> Self {
57        Self::Execution(error)
58    }
59
60    pub fn missing_capability(cap: &str) -> Self {
61        Self::MissingCapability(cap.to_string())
62    }
63
64    pub fn map_execution<F, E2>(self, f: F) -> BuiltinToolError<E2>
65    where
66        F: FnOnce(E) -> E2,
67        E2: StdError + Send + Sync + 'static,
68    {
69        match self {
70            BuiltinToolError::InvalidParams(msg) => BuiltinToolError::InvalidParams(msg),
71            BuiltinToolError::Execution(err) => BuiltinToolError::Execution(f(err)),
72            BuiltinToolError::MissingCapability(cap) => BuiltinToolError::MissingCapability(cap),
73            BuiltinToolError::Cancelled => BuiltinToolError::Cancelled,
74            BuiltinToolError::Timeout => BuiltinToolError::Timeout,
75        }
76    }
77}
78
79fn build_tool_schema<P, S>(description: String) -> ToolSchema
80where
81    P: JsonSchema,
82    S: ToolSpec<Params = P>,
83{
84    let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
85        s.inline_subschemas = true;
86    });
87    let schema_gen = settings.into_generator();
88    let input_schema = schema_gen.into_root_schema_for::<P>();
89
90    ToolSchema {
91        name: S::NAME.to_string(),
92        display_name: S::DISPLAY_NAME.to_string(),
93        description,
94        input_schema: input_schema.into(),
95    }
96}
97
98pub fn schema_with_description<P, S>(description: impl Into<String>) -> ToolSchema
99where
100    P: JsonSchema,
101    S: ToolSpec<Params = P>,
102{
103    build_tool_schema::<P, S>(description.into())
104}
105
106#[async_trait]
107pub trait BuiltinTool: Send + Sync + 'static {
108    type Params: DeserializeOwned + JsonSchema + Send;
109    type Output: Into<ToolResult> + Send;
110    type Spec: ToolSpec<Params = Self::Params, Result = Self::Output>;
111
112    const DESCRIPTION: &'static str;
113    const REQUIRES_APPROVAL: bool;
114    const REQUIRED_CAPABILITIES: Capabilities;
115
116    async fn execute(
117        &self,
118        params: Self::Params,
119        ctx: &BuiltinToolContext,
120    ) -> Result<Self::Output, BuiltinToolError<<Self::Spec as ToolSpec>::Error>>;
121
122    fn schema() -> ToolSchema
123    where
124        Self: Sized,
125    {
126        build_tool_schema::<Self::Params, Self::Spec>(Self::DESCRIPTION.to_string())
127    }
128}
129
130#[async_trait]
131pub trait BuiltinToolErased: Send + Sync {
132    fn name(&self) -> &'static str;
133    fn requires_approval(&self) -> bool;
134    fn required_capabilities(&self) -> Capabilities;
135    fn schema(&self) -> ToolSchema;
136
137    async fn execute_erased(
138        &self,
139        params: serde_json::Value,
140        ctx: &BuiltinToolContext,
141    ) -> Result<ToolResult, BuiltinToolError<ToolExecutionError>>;
142}
143
144#[async_trait]
145impl<T> BuiltinToolErased for T
146where
147    T: BuiltinTool,
148{
149    fn name(&self) -> &'static str {
150        T::Spec::NAME
151    }
152
153    fn requires_approval(&self) -> bool {
154        T::REQUIRES_APPROVAL
155    }
156
157    fn required_capabilities(&self) -> Capabilities {
158        T::REQUIRED_CAPABILITIES
159    }
160
161    fn schema(&self) -> ToolSchema {
162        T::schema()
163    }
164
165    async fn execute_erased(
166        &self,
167        params: serde_json::Value,
168        ctx: &BuiltinToolContext,
169    ) -> Result<ToolResult, BuiltinToolError<ToolExecutionError>> {
170        let typed_params: T::Params = serde_json::from_value(params)
171            .map_err(|e| BuiltinToolError::invalid_params(e.to_string()))?;
172
173        if ctx.is_cancelled() {
174            return Err(BuiltinToolError::Cancelled);
175        }
176
177        let result = self
178            .execute(typed_params, ctx)
179            .await
180            .map_err(|e| e.map_execution(T::Spec::execution_error))?;
181        Ok(result.into())
182    }
183}