Skip to main content

steer_core/tools/
static_tool.rs

1use std::error::Error as StdError;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use tokio_util::sync::CancellationToken;
8
9use crate::app::domain::types::{SessionId, ToolCallId};
10use steer_tools::error::ToolExecutionError;
11use steer_tools::result::ToolResult;
12use steer_tools::{ToolSchema, ToolSpec};
13
14use super::capability::Capabilities;
15use super::services::ToolServices;
16
17#[derive(Debug, Clone)]
18pub struct StaticToolContext {
19    pub tool_call_id: ToolCallId,
20    pub session_id: SessionId,
21    pub cancellation_token: CancellationToken,
22    pub services: Arc<ToolServices>,
23}
24
25impl StaticToolContext {
26    pub fn is_cancelled(&self) -> bool {
27        self.cancellation_token.is_cancelled()
28    }
29}
30
31#[derive(Debug, thiserror::Error)]
32pub enum StaticToolError<E: StdError + Send + Sync + 'static> {
33    #[error("Invalid parameters: {0}")]
34    InvalidParams(String),
35
36    #[error("{0}")]
37    Execution(E),
38
39    #[error("Missing capability: {0}")]
40    MissingCapability(String),
41
42    #[error("Cancelled")]
43    Cancelled,
44
45    #[error("Timed out")]
46    Timeout,
47}
48
49impl<E: StdError + Send + Sync + 'static> StaticToolError<E> {
50    pub fn invalid_params(msg: impl Into<String>) -> Self {
51        Self::InvalidParams(msg.into())
52    }
53
54    pub fn execution(error: E) -> Self {
55        Self::Execution(error)
56    }
57
58    pub fn missing_capability(cap: &str) -> Self {
59        Self::MissingCapability(cap.to_string())
60    }
61
62    pub fn map_execution<F, E2>(self, f: F) -> StaticToolError<E2>
63    where
64        F: FnOnce(E) -> E2,
65        E2: StdError + Send + Sync + 'static,
66    {
67        match self {
68            StaticToolError::InvalidParams(msg) => StaticToolError::InvalidParams(msg),
69            StaticToolError::Execution(err) => StaticToolError::Execution(f(err)),
70            StaticToolError::MissingCapability(cap) => StaticToolError::MissingCapability(cap),
71            StaticToolError::Cancelled => StaticToolError::Cancelled,
72            StaticToolError::Timeout => StaticToolError::Timeout,
73        }
74    }
75}
76
77#[async_trait]
78pub trait StaticTool: Send + Sync + 'static {
79    type Params: DeserializeOwned + JsonSchema + Send;
80    type Output: Into<ToolResult> + Send;
81    type Spec: ToolSpec<Params = Self::Params, Result = Self::Output>;
82
83    const DESCRIPTION: &'static str;
84    const REQUIRES_APPROVAL: bool;
85    const REQUIRED_CAPABILITIES: Capabilities;
86
87    async fn execute(
88        &self,
89        params: Self::Params,
90        ctx: &StaticToolContext,
91    ) -> Result<Self::Output, StaticToolError<<Self::Spec as ToolSpec>::Error>>;
92
93    fn schema() -> ToolSchema
94    where
95        Self: Sized,
96    {
97        let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
98            s.inline_subschemas = true;
99        });
100        let schema_gen = settings.into_generator();
101        let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
102
103        ToolSchema {
104            name: Self::Spec::NAME.to_string(),
105            display_name: Self::Spec::DISPLAY_NAME.to_string(),
106            description: Self::DESCRIPTION.to_string(),
107            input_schema: input_schema.into(),
108        }
109    }
110}
111
112#[async_trait]
113pub trait StaticToolErased: Send + Sync {
114    fn name(&self) -> &'static str;
115    fn display_name(&self) -> &'static str;
116    fn description(&self) -> &'static str;
117    fn requires_approval(&self) -> bool;
118    fn required_capabilities(&self) -> Capabilities;
119    fn schema(&self) -> ToolSchema;
120
121    async fn execute_erased(
122        &self,
123        params: serde_json::Value,
124        ctx: &StaticToolContext,
125    ) -> Result<ToolResult, StaticToolError<ToolExecutionError>>;
126}
127
128#[async_trait]
129impl<T> StaticToolErased for T
130where
131    T: StaticTool,
132{
133    fn name(&self) -> &'static str {
134        T::Spec::NAME
135    }
136
137    fn display_name(&self) -> &'static str {
138        T::Spec::DISPLAY_NAME
139    }
140
141    fn description(&self) -> &'static str {
142        T::DESCRIPTION
143    }
144
145    fn requires_approval(&self) -> bool {
146        T::REQUIRES_APPROVAL
147    }
148
149    fn required_capabilities(&self) -> Capabilities {
150        T::REQUIRED_CAPABILITIES
151    }
152
153    fn schema(&self) -> ToolSchema {
154        T::schema()
155    }
156
157    async fn execute_erased(
158        &self,
159        params: serde_json::Value,
160        ctx: &StaticToolContext,
161    ) -> Result<ToolResult, StaticToolError<ToolExecutionError>> {
162        let typed_params: T::Params = serde_json::from_value(params)
163            .map_err(|e| StaticToolError::invalid_params(e.to_string()))?;
164
165        if ctx.is_cancelled() {
166            return Err(StaticToolError::Cancelled);
167        }
168
169        let result = self
170            .execute(typed_params, ctx)
171            .await
172            .map_err(|e| e.map_execution(T::Spec::execution_error))?;
173        Ok(result.into())
174    }
175}