Skip to main content

steer_core/tools/
static_tool.rs

1use std::error::Error as StdError;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use tokio_util::sync::CancellationToken;
8
9use crate::app::domain::types::{SessionId, ToolCallId};
10use crate::config::model::ModelId;
11use steer_tools::error::ToolExecutionError;
12use steer_tools::result::ToolResult;
13use steer_tools::{ToolSchema, ToolSpec};
14
15use super::capability::Capabilities;
16use super::services::ToolServices;
17
18#[derive(Debug, Clone)]
19pub struct StaticToolContext {
20    pub tool_call_id: ToolCallId,
21    pub session_id: SessionId,
22    pub invoking_model: Option<ModelId>,
23    pub cancellation_token: CancellationToken,
24    pub services: Arc<ToolServices>,
25}
26
27impl StaticToolContext {
28    pub fn is_cancelled(&self) -> bool {
29        self.cancellation_token.is_cancelled()
30    }
31}
32
33#[derive(Debug, thiserror::Error)]
34pub enum StaticToolError<E: StdError + Send + Sync + 'static> {
35    #[error("Invalid parameters: {0}")]
36    InvalidParams(String),
37
38    #[error("{0}")]
39    Execution(E),
40
41    #[error("Missing capability: {0}")]
42    MissingCapability(String),
43
44    #[error("Cancelled")]
45    Cancelled,
46
47    #[error("Timed out")]
48    Timeout,
49}
50
51impl<E: StdError + Send + Sync + 'static> StaticToolError<E> {
52    pub fn invalid_params(msg: impl Into<String>) -> Self {
53        Self::InvalidParams(msg.into())
54    }
55
56    pub fn execution(error: E) -> Self {
57        Self::Execution(error)
58    }
59
60    pub fn missing_capability(cap: &str) -> Self {
61        Self::MissingCapability(cap.to_string())
62    }
63
64    pub fn map_execution<F, E2>(self, f: F) -> StaticToolError<E2>
65    where
66        F: FnOnce(E) -> E2,
67        E2: StdError + Send + Sync + 'static,
68    {
69        match self {
70            StaticToolError::InvalidParams(msg) => StaticToolError::InvalidParams(msg),
71            StaticToolError::Execution(err) => StaticToolError::Execution(f(err)),
72            StaticToolError::MissingCapability(cap) => StaticToolError::MissingCapability(cap),
73            StaticToolError::Cancelled => StaticToolError::Cancelled,
74            StaticToolError::Timeout => StaticToolError::Timeout,
75        }
76    }
77}
78
79#[async_trait]
80pub trait StaticTool: Send + Sync + 'static {
81    type Params: DeserializeOwned + JsonSchema + Send;
82    type Output: Into<ToolResult> + Send;
83    type Spec: ToolSpec<Params = Self::Params, Result = Self::Output>;
84
85    const DESCRIPTION: &'static str;
86    const REQUIRES_APPROVAL: bool;
87    const REQUIRED_CAPABILITIES: Capabilities;
88
89    async fn execute(
90        &self,
91        params: Self::Params,
92        ctx: &StaticToolContext,
93    ) -> Result<Self::Output, StaticToolError<<Self::Spec as ToolSpec>::Error>>;
94
95    fn schema() -> ToolSchema
96    where
97        Self: Sized,
98    {
99        let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
100            s.inline_subschemas = true;
101        });
102        let schema_gen = settings.into_generator();
103        let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
104
105        ToolSchema {
106            name: Self::Spec::NAME.to_string(),
107            display_name: Self::Spec::DISPLAY_NAME.to_string(),
108            description: Self::DESCRIPTION.to_string(),
109            input_schema: input_schema.into(),
110        }
111    }
112}
113
114#[async_trait]
115pub trait StaticToolErased: Send + Sync {
116    fn name(&self) -> &'static str;
117    fn display_name(&self) -> &'static str;
118    fn description(&self) -> &'static str;
119    fn requires_approval(&self) -> bool;
120    fn required_capabilities(&self) -> Capabilities;
121    fn schema(&self) -> ToolSchema;
122
123    async fn execute_erased(
124        &self,
125        params: serde_json::Value,
126        ctx: &StaticToolContext,
127    ) -> Result<ToolResult, StaticToolError<ToolExecutionError>>;
128}
129
130#[async_trait]
131impl<T> StaticToolErased for T
132where
133    T: StaticTool,
134{
135    fn name(&self) -> &'static str {
136        T::Spec::NAME
137    }
138
139    fn display_name(&self) -> &'static str {
140        T::Spec::DISPLAY_NAME
141    }
142
143    fn description(&self) -> &'static str {
144        T::DESCRIPTION
145    }
146
147    fn requires_approval(&self) -> bool {
148        T::REQUIRES_APPROVAL
149    }
150
151    fn required_capabilities(&self) -> Capabilities {
152        T::REQUIRED_CAPABILITIES
153    }
154
155    fn schema(&self) -> ToolSchema {
156        T::schema()
157    }
158
159    async fn execute_erased(
160        &self,
161        params: serde_json::Value,
162        ctx: &StaticToolContext,
163    ) -> Result<ToolResult, StaticToolError<ToolExecutionError>> {
164        let typed_params: T::Params = serde_json::from_value(params)
165            .map_err(|e| StaticToolError::invalid_params(e.to_string()))?;
166
167        if ctx.is_cancelled() {
168            return Err(StaticToolError::Cancelled);
169        }
170
171        let result = self
172            .execute(typed_params, ctx)
173            .await
174            .map_err(|e| e.map_execution(T::Spec::execution_error))?;
175        Ok(result.into())
176    }
177}