1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::sync::Arc;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum ToolEffect {
8 Read,
9 Write,
10 Delete,
11 Search,
12 Execute,
13 Fetch,
14 Patch,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub enum ToolDomain {
20 Workspace,
21 Web,
22 Shell,
23 Browser,
24 Planning,
25 Memory,
26 Collaboration,
27 Integration,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
31pub struct ToolCapabilities {
32 #[serde(default, skip_serializing_if = "Vec::is_empty")]
33 pub effects: Vec<ToolEffect>,
34 #[serde(default, skip_serializing_if = "Vec::is_empty")]
35 pub domains: Vec<ToolDomain>,
36 #[serde(default, skip_serializing_if = "is_false")]
37 pub reads_workspace: bool,
38 #[serde(default, skip_serializing_if = "is_false")]
39 pub writes_workspace: bool,
40 #[serde(default, skip_serializing_if = "is_false")]
41 pub network_access: bool,
42 #[serde(default, skip_serializing_if = "is_false")]
43 pub destructive: bool,
44 #[serde(default, skip_serializing_if = "is_false")]
45 pub requires_verification: bool,
46 #[serde(default, skip_serializing_if = "is_false")]
47 pub preferred_for_discovery: bool,
48 #[serde(default, skip_serializing_if = "is_false")]
49 pub preferred_for_validation: bool,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
53pub struct ToolSchema {
54 pub name: String,
55 pub description: String,
56 pub input_schema: Value,
57 #[serde(default, skip_serializing_if = "ToolCapabilities::is_empty")]
58 pub capabilities: ToolCapabilities,
59}
60
61fn is_false(value: &bool) -> bool {
62 !*value
63}
64
65impl ToolCapabilities {
66 pub fn new() -> Self {
67 Self::default()
68 }
69
70 pub fn effect(mut self, effect: ToolEffect) -> Self {
71 if !self.effects.contains(&effect) {
72 self.effects.push(effect);
73 }
74 self
75 }
76
77 pub fn domain(mut self, domain: ToolDomain) -> Self {
78 if !self.domains.contains(&domain) {
79 self.domains.push(domain);
80 }
81 self
82 }
83
84 pub fn reads_workspace(mut self) -> Self {
85 self.reads_workspace = true;
86 self
87 }
88
89 pub fn writes_workspace(mut self) -> Self {
90 self.writes_workspace = true;
91 self
92 }
93
94 pub fn network_access(mut self) -> Self {
95 self.network_access = true;
96 self
97 }
98
99 pub fn destructive(mut self) -> Self {
100 self.destructive = true;
101 self
102 }
103
104 pub fn requires_verification(mut self) -> Self {
105 self.requires_verification = true;
106 self
107 }
108
109 pub fn preferred_for_discovery(mut self) -> Self {
110 self.preferred_for_discovery = true;
111 self
112 }
113
114 pub fn preferred_for_validation(mut self) -> Self {
115 self.preferred_for_validation = true;
116 self
117 }
118
119 pub fn is_empty(&self) -> bool {
120 self.effects.is_empty()
121 && self.domains.is_empty()
122 && !self.reads_workspace
123 && !self.writes_workspace
124 && !self.network_access
125 && !self.destructive
126 && !self.requires_verification
127 && !self.preferred_for_discovery
128 && !self.preferred_for_validation
129 }
130}
131
132impl ToolSchema {
133 pub fn new(
134 name: impl Into<String>,
135 description: impl Into<String>,
136 input_schema: Value,
137 ) -> Self {
138 Self {
139 name: name.into(),
140 description: description.into(),
141 input_schema,
142 capabilities: ToolCapabilities::default(),
143 }
144 }
145
146 pub fn with_capabilities(mut self, capabilities: ToolCapabilities) -> Self {
147 self.capabilities = capabilities;
148 self
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ToolResult {
154 pub output: String,
155 #[serde(default)]
156 pub metadata: Value,
157}
158
159#[derive(Debug, Clone)]
160pub struct ToolProgressEvent {
161 pub event_type: String,
162 pub properties: Value,
163}
164
165impl ToolProgressEvent {
166 pub fn new(event_type: impl Into<String>, properties: Value) -> Self {
167 Self {
168 event_type: event_type.into(),
169 properties,
170 }
171 }
172}
173
174pub trait ToolProgressSink: Send + Sync {
175 fn publish(&self, event: ToolProgressEvent);
176}
177
178pub type SharedToolProgressSink = Arc<dyn ToolProgressSink>;
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn tool_schema_deserializes_legacy_payload_without_capabilities() {
186 let actual: ToolSchema = serde_json::from_value(serde_json::json!({
187 "name": "read",
188 "description": "Read file contents",
189 "input_schema": {
190 "type": "object"
191 }
192 }))
193 .unwrap();
194
195 let expected = ToolSchema::new(
196 "read",
197 "Read file contents",
198 serde_json::json!({
199 "type": "object"
200 }),
201 );
202
203 assert_eq!(actual, expected);
204 }
205
206 #[test]
207 fn tool_schema_serialization_omits_empty_capabilities() {
208 let actual = serde_json::to_value(ToolSchema::new(
209 "read",
210 "Read file contents",
211 serde_json::json!({
212 "type": "object"
213 }),
214 ))
215 .unwrap();
216
217 let expected = serde_json::json!({
218 "name": "read",
219 "description": "Read file contents",
220 "input_schema": {
221 "type": "object"
222 }
223 });
224
225 assert_eq!(actual, expected);
226 }
227
228 #[test]
229 fn tool_schema_round_trips_capabilities() {
230 let actual: ToolSchema = serde_json::from_value(serde_json::json!({
231 "name": "write",
232 "description": "Write file contents",
233 "input_schema": {
234 "type": "object"
235 },
236 "capabilities": {
237 "effects": ["write"],
238 "domains": ["workspace"],
239 "writes_workspace": true,
240 "requires_verification": true
241 }
242 }))
243 .unwrap();
244
245 let expected = ToolSchema::new(
246 "write",
247 "Write file contents",
248 serde_json::json!({
249 "type": "object"
250 }),
251 )
252 .with_capabilities(
253 ToolCapabilities::new()
254 .effect(ToolEffect::Write)
255 .domain(ToolDomain::Workspace)
256 .writes_workspace()
257 .requires_verification(),
258 );
259
260 assert_eq!(actual, expected);
261 }
262}