steer_core/tools/
static_tool.rs1use std::error::Error as StdError;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use tokio_util::sync::CancellationToken;
8
9use crate::app::domain::types::{SessionId, ToolCallId};
10use steer_tools::error::ToolExecutionError;
11use steer_tools::result::ToolResult;
12use steer_tools::{ToolSchema, ToolSpec};
13
14use super::capability::Capabilities;
15use super::services::ToolServices;
16
17#[derive(Debug, Clone)]
18pub struct StaticToolContext {
19 pub tool_call_id: ToolCallId,
20 pub session_id: SessionId,
21 pub cancellation_token: CancellationToken,
22 pub services: Arc<ToolServices>,
23}
24
25impl StaticToolContext {
26 pub fn is_cancelled(&self) -> bool {
27 self.cancellation_token.is_cancelled()
28 }
29}
30
31#[derive(Debug, thiserror::Error)]
32pub enum StaticToolError<E: StdError + Send + Sync + 'static> {
33 #[error("Invalid parameters: {0}")]
34 InvalidParams(String),
35
36 #[error("{0}")]
37 Execution(E),
38
39 #[error("Missing capability: {0}")]
40 MissingCapability(String),
41
42 #[error("Cancelled")]
43 Cancelled,
44
45 #[error("Timed out")]
46 Timeout,
47}
48
49impl<E: StdError + Send + Sync + 'static> StaticToolError<E> {
50 pub fn invalid_params(msg: impl Into<String>) -> Self {
51 Self::InvalidParams(msg.into())
52 }
53
54 pub fn execution(error: E) -> Self {
55 Self::Execution(error)
56 }
57
58 pub fn missing_capability(cap: &str) -> Self {
59 Self::MissingCapability(cap.to_string())
60 }
61
62 pub fn map_execution<F, E2>(self, f: F) -> StaticToolError<E2>
63 where
64 F: FnOnce(E) -> E2,
65 E2: StdError + Send + Sync + 'static,
66 {
67 match self {
68 StaticToolError::InvalidParams(msg) => StaticToolError::InvalidParams(msg),
69 StaticToolError::Execution(err) => StaticToolError::Execution(f(err)),
70 StaticToolError::MissingCapability(cap) => StaticToolError::MissingCapability(cap),
71 StaticToolError::Cancelled => StaticToolError::Cancelled,
72 StaticToolError::Timeout => StaticToolError::Timeout,
73 }
74 }
75}
76
77#[async_trait]
78pub trait StaticTool: Send + Sync + 'static {
79 type Params: DeserializeOwned + JsonSchema + Send;
80 type Output: Into<ToolResult> + Send;
81 type Spec: ToolSpec<Params = Self::Params, Result = Self::Output>;
82
83 const DESCRIPTION: &'static str;
84 const REQUIRES_APPROVAL: bool;
85 const REQUIRED_CAPABILITIES: Capabilities;
86
87 async fn execute(
88 &self,
89 params: Self::Params,
90 ctx: &StaticToolContext,
91 ) -> Result<Self::Output, StaticToolError<<Self::Spec as ToolSpec>::Error>>;
92
93 fn schema() -> ToolSchema
94 where
95 Self: Sized,
96 {
97 let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
98 s.inline_subschemas = true;
99 });
100 let schema_gen = settings.into_generator();
101 let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
102
103 ToolSchema {
104 name: Self::Spec::NAME.to_string(),
105 display_name: Self::Spec::DISPLAY_NAME.to_string(),
106 description: Self::DESCRIPTION.to_string(),
107 input_schema: input_schema.into(),
108 }
109 }
110}
111
112#[async_trait]
113pub trait StaticToolErased: Send + Sync {
114 fn name(&self) -> &'static str;
115 fn display_name(&self) -> &'static str;
116 fn description(&self) -> &'static str;
117 fn requires_approval(&self) -> bool;
118 fn required_capabilities(&self) -> Capabilities;
119 fn schema(&self) -> ToolSchema;
120
121 async fn execute_erased(
122 &self,
123 params: serde_json::Value,
124 ctx: &StaticToolContext,
125 ) -> Result<ToolResult, StaticToolError<ToolExecutionError>>;
126}
127
128#[async_trait]
129impl<T> StaticToolErased for T
130where
131 T: StaticTool,
132{
133 fn name(&self) -> &'static str {
134 T::Spec::NAME
135 }
136
137 fn display_name(&self) -> &'static str {
138 T::Spec::DISPLAY_NAME
139 }
140
141 fn description(&self) -> &'static str {
142 T::DESCRIPTION
143 }
144
145 fn requires_approval(&self) -> bool {
146 T::REQUIRES_APPROVAL
147 }
148
149 fn required_capabilities(&self) -> Capabilities {
150 T::REQUIRED_CAPABILITIES
151 }
152
153 fn schema(&self) -> ToolSchema {
154 T::schema()
155 }
156
157 async fn execute_erased(
158 &self,
159 params: serde_json::Value,
160 ctx: &StaticToolContext,
161 ) -> Result<ToolResult, StaticToolError<ToolExecutionError>> {
162 let typed_params: T::Params = serde_json::from_value(params)
163 .map_err(|e| StaticToolError::invalid_params(e.to_string()))?;
164
165 if ctx.is_cancelled() {
166 return Err(StaticToolError::Cancelled);
167 }
168
169 let result = self
170 .execute(typed_params, ctx)
171 .await
172 .map_err(|e| e.map_execution(T::Spec::execution_error))?;
173 Ok(result.into())
174 }
175}