steer_core/tools/
static_tool.rs1use std::error::Error as StdError;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use tokio_util::sync::CancellationToken;
8
9use crate::app::domain::types::{SessionId, ToolCallId};
10use crate::config::model::ModelId;
11use steer_tools::error::ToolExecutionError;
12use steer_tools::result::ToolResult;
13use steer_tools::{ToolSchema, ToolSpec};
14
15use super::capability::Capabilities;
16use super::services::ToolServices;
17
18#[derive(Debug, Clone)]
19pub struct StaticToolContext {
20 pub tool_call_id: ToolCallId,
21 pub session_id: SessionId,
22 pub invoking_model: Option<ModelId>,
23 pub cancellation_token: CancellationToken,
24 pub services: Arc<ToolServices>,
25}
26
27impl StaticToolContext {
28 pub fn is_cancelled(&self) -> bool {
29 self.cancellation_token.is_cancelled()
30 }
31}
32
33#[derive(Debug, thiserror::Error)]
34pub enum StaticToolError<E: StdError + Send + Sync + 'static> {
35 #[error("Invalid parameters: {0}")]
36 InvalidParams(String),
37
38 #[error("{0}")]
39 Execution(E),
40
41 #[error("Missing capability: {0}")]
42 MissingCapability(String),
43
44 #[error("Cancelled")]
45 Cancelled,
46
47 #[error("Timed out")]
48 Timeout,
49}
50
51impl<E: StdError + Send + Sync + 'static> StaticToolError<E> {
52 pub fn invalid_params(msg: impl Into<String>) -> Self {
53 Self::InvalidParams(msg.into())
54 }
55
56 pub fn execution(error: E) -> Self {
57 Self::Execution(error)
58 }
59
60 pub fn missing_capability(cap: &str) -> Self {
61 Self::MissingCapability(cap.to_string())
62 }
63
64 pub fn map_execution<F, E2>(self, f: F) -> StaticToolError<E2>
65 where
66 F: FnOnce(E) -> E2,
67 E2: StdError + Send + Sync + 'static,
68 {
69 match self {
70 StaticToolError::InvalidParams(msg) => StaticToolError::InvalidParams(msg),
71 StaticToolError::Execution(err) => StaticToolError::Execution(f(err)),
72 StaticToolError::MissingCapability(cap) => StaticToolError::MissingCapability(cap),
73 StaticToolError::Cancelled => StaticToolError::Cancelled,
74 StaticToolError::Timeout => StaticToolError::Timeout,
75 }
76 }
77}
78
79#[async_trait]
80pub trait StaticTool: Send + Sync + 'static {
81 type Params: DeserializeOwned + JsonSchema + Send;
82 type Output: Into<ToolResult> + Send;
83 type Spec: ToolSpec<Params = Self::Params, Result = Self::Output>;
84
85 const DESCRIPTION: &'static str;
86 const REQUIRES_APPROVAL: bool;
87 const REQUIRED_CAPABILITIES: Capabilities;
88
89 async fn execute(
90 &self,
91 params: Self::Params,
92 ctx: &StaticToolContext,
93 ) -> Result<Self::Output, StaticToolError<<Self::Spec as ToolSpec>::Error>>;
94
95 fn schema() -> ToolSchema
96 where
97 Self: Sized,
98 {
99 let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
100 s.inline_subschemas = true;
101 });
102 let schema_gen = settings.into_generator();
103 let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
104
105 ToolSchema {
106 name: Self::Spec::NAME.to_string(),
107 display_name: Self::Spec::DISPLAY_NAME.to_string(),
108 description: Self::DESCRIPTION.to_string(),
109 input_schema: input_schema.into(),
110 }
111 }
112}
113
114#[async_trait]
115pub trait StaticToolErased: Send + Sync {
116 fn name(&self) -> &'static str;
117 fn display_name(&self) -> &'static str;
118 fn description(&self) -> &'static str;
119 fn requires_approval(&self) -> bool;
120 fn required_capabilities(&self) -> Capabilities;
121 fn schema(&self) -> ToolSchema;
122
123 async fn execute_erased(
124 &self,
125 params: serde_json::Value,
126 ctx: &StaticToolContext,
127 ) -> Result<ToolResult, StaticToolError<ToolExecutionError>>;
128}
129
130#[async_trait]
131impl<T> StaticToolErased for T
132where
133 T: StaticTool,
134{
135 fn name(&self) -> &'static str {
136 T::Spec::NAME
137 }
138
139 fn display_name(&self) -> &'static str {
140 T::Spec::DISPLAY_NAME
141 }
142
143 fn description(&self) -> &'static str {
144 T::DESCRIPTION
145 }
146
147 fn requires_approval(&self) -> bool {
148 T::REQUIRES_APPROVAL
149 }
150
151 fn required_capabilities(&self) -> Capabilities {
152 T::REQUIRED_CAPABILITIES
153 }
154
155 fn schema(&self) -> ToolSchema {
156 T::schema()
157 }
158
159 async fn execute_erased(
160 &self,
161 params: serde_json::Value,
162 ctx: &StaticToolContext,
163 ) -> Result<ToolResult, StaticToolError<ToolExecutionError>> {
164 let typed_params: T::Params = serde_json::from_value(params)
165 .map_err(|e| StaticToolError::invalid_params(e.to_string()))?;
166
167 if ctx.is_cancelled() {
168 return Err(StaticToolError::Cancelled);
169 }
170
171 let result = self
172 .execute(typed_params, ctx)
173 .await
174 .map_err(|e| e.map_execution(T::Spec::execution_error))?;
175 Ok(result.into())
176 }
177}