Skip to main content

uira_core/protocol/
tools.rs

1//! Tool-related types for the protocol
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6/// JSON Schema for a tool's input
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct JsonSchema {
9    #[serde(rename = "type")]
10    pub schema_type: String,
11    #[serde(skip_serializing_if = "Option::is_none")]
12    pub description: Option<String>,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub properties: Option<Value>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub required: Option<Vec<String>>,
17    #[serde(
18        skip_serializing_if = "Option::is_none",
19        rename = "additionalProperties"
20    )]
21    pub additional_properties: Option<bool>,
22    #[serde(flatten)]
23    pub extra: std::collections::HashMap<String, Value>,
24}
25
26impl JsonSchema {
27    pub fn object() -> Self {
28        Self {
29            schema_type: "object".to_string(),
30            description: None,
31            properties: Some(serde_json::json!({})),
32            required: None,
33            additional_properties: Some(false),
34            extra: std::collections::HashMap::new(),
35        }
36    }
37
38    pub fn string() -> Self {
39        Self {
40            schema_type: "string".to_string(),
41            description: None,
42            properties: None,
43            required: None,
44            additional_properties: None,
45            extra: std::collections::HashMap::new(),
46        }
47    }
48
49    pub fn number() -> Self {
50        Self {
51            schema_type: "number".to_string(),
52            description: None,
53            properties: None,
54            required: None,
55            additional_properties: None,
56            extra: std::collections::HashMap::new(),
57        }
58    }
59
60    pub fn boolean() -> Self {
61        Self {
62            schema_type: "boolean".to_string(),
63            description: None,
64            properties: None,
65            required: None,
66            additional_properties: None,
67            extra: std::collections::HashMap::new(),
68        }
69    }
70
71    pub fn array(items: JsonSchema) -> Self {
72        let mut extra = std::collections::HashMap::new();
73        extra.insert("items".to_string(), serde_json::to_value(items).unwrap());
74        Self {
75            schema_type: "array".to_string(),
76            description: None,
77            properties: None,
78            required: None,
79            additional_properties: None,
80            extra,
81        }
82    }
83
84    pub fn description(mut self, description: impl Into<String>) -> Self {
85        self.description = Some(description.into());
86        self
87    }
88
89    pub fn property(mut self, name: &str, schema: JsonSchema) -> Self {
90        let props = self.properties.get_or_insert(serde_json::json!({}));
91        if let Some(obj) = props.as_object_mut() {
92            obj.insert(name.to_string(), serde_json::to_value(schema).unwrap());
93        }
94        self
95    }
96
97    pub fn with_properties(mut self, properties: Value) -> Self {
98        self.properties = Some(properties);
99        self
100    }
101
102    pub fn required(mut self, fields: &[&str]) -> Self {
103        self.required = Some(fields.iter().map(|s| s.to_string()).collect());
104        self
105    }
106
107    pub fn with_required(mut self, required: Vec<String>) -> Self {
108        self.required = Some(required);
109        self
110    }
111}
112
113/// Tool specification for model API
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct ToolSpec {
116    pub name: String,
117    pub description: String,
118    pub input_schema: JsonSchema,
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub cache_control: Option<CacheControl>,
121}
122
123impl ToolSpec {
124    pub fn new(
125        name: impl Into<String>,
126        description: impl Into<String>,
127        schema: JsonSchema,
128    ) -> Self {
129        Self {
130            name: name.into(),
131            description: description.into(),
132            input_schema: schema,
133            cache_control: None,
134        }
135    }
136
137    pub fn with_cache(mut self) -> Self {
138        self.cache_control = Some(CacheControl::ephemeral());
139        self
140    }
141}
142
143/// Cache control for prompt caching
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct CacheControl {
146    #[serde(rename = "type")]
147    pub control_type: String,
148}
149
150impl CacheControl {
151    pub fn ephemeral() -> Self {
152        Self {
153            control_type: "ephemeral".to_string(),
154        }
155    }
156}
157
158/// Approval requirement for a tool
159#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
160#[serde(tag = "type", rename_all = "snake_case")]
161pub enum ApprovalRequirement {
162    /// Skip approval, optionally bypass sandbox
163    Skip {
164        #[serde(default)]
165        bypass_sandbox: bool,
166    },
167    /// Requires user approval
168    NeedsApproval { reason: String },
169    /// Always forbidden
170    Forbidden { reason: String },
171}
172
173impl ApprovalRequirement {
174    pub fn skip() -> Self {
175        Self::Skip {
176            bypass_sandbox: false,
177        }
178    }
179
180    pub fn skip_bypass_sandbox() -> Self {
181        Self::Skip {
182            bypass_sandbox: true,
183        }
184    }
185
186    pub fn needs_approval(reason: impl Into<String>) -> Self {
187        Self::NeedsApproval {
188            reason: reason.into(),
189        }
190    }
191
192    pub fn forbidden(reason: impl Into<String>) -> Self {
193        Self::Forbidden {
194            reason: reason.into(),
195        }
196    }
197}
198
199/// Sandbox preference for a tool
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
201#[serde(rename_all = "snake_case")]
202pub enum SandboxPreference {
203    /// Let the system decide based on policy
204    #[default]
205    Auto,
206    /// Require sandboxing
207    Require,
208    /// Forbid sandboxing (e.g., for tools that need full access)
209    Forbid,
210}
211
212/// Result of tool execution
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct ToolResult {
215    pub tool_use_id: String,
216    pub output: ToolOutput,
217    #[serde(default)]
218    pub is_error: bool,
219}
220
221impl ToolResult {
222    pub fn success(tool_use_id: impl Into<String>, output: ToolOutput) -> Self {
223        Self {
224            tool_use_id: tool_use_id.into(),
225            output,
226            is_error: false,
227        }
228    }
229
230    pub fn error(tool_use_id: impl Into<String>, message: impl Into<String>) -> Self {
231        Self {
232            tool_use_id: tool_use_id.into(),
233            output: ToolOutput::text(message),
234            is_error: true,
235        }
236    }
237}
238
239/// Output from a tool
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct ToolOutput {
242    pub content: Vec<ToolOutputContent>,
243}
244
245impl ToolOutput {
246    pub fn text(text: impl Into<String>) -> Self {
247        Self {
248            content: vec![ToolOutputContent::Text { text: text.into() }],
249        }
250    }
251
252    pub fn json(value: Value) -> Self {
253        Self {
254            content: vec![ToolOutputContent::Text {
255                text: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()),
256            }],
257        }
258    }
259
260    pub fn image(media_type: impl Into<String>, data: impl Into<String>) -> Self {
261        Self {
262            content: vec![ToolOutputContent::Image {
263                source: crate::ImageSource::Base64 {
264                    media_type: media_type.into(),
265                    data: data.into(),
266                },
267            }],
268        }
269    }
270
271    pub fn as_text(&self) -> Option<&str> {
272        self.content.first().and_then(|c| {
273            if let ToolOutputContent::Text { text } = c {
274                Some(text.as_str())
275            } else {
276                None
277            }
278        })
279    }
280
281    pub fn as_json(&self) -> Option<Value> {
282        self.as_text()
283            .and_then(|text| serde_json::from_str(text).ok())
284    }
285}
286
287/// Content type for tool output
288#[derive(Debug, Clone, Serialize, Deserialize)]
289#[serde(tag = "type", rename_all = "snake_case")]
290pub enum ToolOutputContent {
291    Text { text: String },
292    Image { source: crate::ImageSource },
293}
294
295/// Approval request sent to the user
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct ApprovalRequest {
298    pub id: String,
299    pub tool_name: String,
300    pub tool_input: Value,
301    pub reason: String,
302    #[serde(skip_serializing_if = "Option::is_none")]
303    pub suggested_action: Option<SuggestedAction>,
304}
305
306/// Suggested action for approval request
307#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
308#[serde(rename_all = "snake_case")]
309pub enum SuggestedAction {
310    Approve,
311    Deny,
312    ApproveOnce,
313    ApproveAll,
314}
315
316/// User's decision on an approval request
317#[derive(Debug, Clone, Serialize, Deserialize)]
318#[serde(tag = "decision", rename_all = "snake_case")]
319pub enum ReviewDecision {
320    Approve,
321    Deny { reason: Option<String> },
322    ApproveOnce,
323    ApproveAll,
324    Edit { new_input: Value },
325}
326
327impl ReviewDecision {
328    pub fn is_approved(&self) -> bool {
329        matches!(
330            self,
331            Self::Approve | Self::ApproveOnce | Self::ApproveAll | Self::Edit { .. }
332        )
333    }
334
335    pub fn is_denied(&self) -> bool {
336        matches!(self, Self::Deny { .. })
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_json_schema_builder() {
346        let schema = JsonSchema::object()
347            .description("A test schema")
348            .with_properties(serde_json::json!({
349                "name": {"type": "string"}
350            }))
351            .with_required(vec!["name".to_string()]);
352
353        assert_eq!(schema.schema_type, "object");
354        assert!(schema.description.is_some());
355    }
356
357    #[test]
358    fn test_tool_spec() {
359        let spec = ToolSpec::new("read_file", "Read a file from disk", JsonSchema::object());
360        assert_eq!(spec.name, "read_file");
361    }
362
363    #[test]
364    fn test_approval_requirement() {
365        let skip = ApprovalRequirement::skip();
366        assert!(matches!(
367            skip,
368            ApprovalRequirement::Skip {
369                bypass_sandbox: false
370            }
371        ));
372
373        let needs = ApprovalRequirement::needs_approval("Writes to disk");
374        assert!(matches!(needs, ApprovalRequirement::NeedsApproval { .. }));
375    }
376
377    #[test]
378    fn test_tool_result() {
379        let result = ToolResult::success("tc_123", ToolOutput::text("Done!"));
380        assert!(!result.is_error);
381        assert_eq!(result.output.as_text(), Some("Done!"));
382
383        let error = ToolResult::error("tc_456", "File not found");
384        assert!(error.is_error);
385    }
386
387    #[test]
388    fn test_review_decision() {
389        assert!(ReviewDecision::Approve.is_approved());
390        assert!(ReviewDecision::ApproveOnce.is_approved());
391        assert!(!ReviewDecision::Deny { reason: None }.is_approved());
392        assert!(ReviewDecision::Deny { reason: None }.is_denied());
393    }
394}