1use crate::error::ToolExecutionError;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize, de::DeserializeOwned};
4use serde_json::Value;
5use std::error::Error as StdError;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(transparent)]
9pub struct InputSchema(Value);
10
11impl InputSchema {
12 pub fn new(schema: Value) -> Self {
13 Self(schema)
14 }
15
16 pub fn as_value(&self) -> &Value {
17 &self.0
18 }
19
20 pub fn into_value(self) -> Value {
21 self.0
22 }
23
24 pub fn summary(&self) -> InputSchemaSummary {
25 InputSchemaSummary::from_value(&self.0)
26 }
27
28 pub fn object(properties: serde_json::Map<String, Value>, required: Vec<String>) -> Self {
29 let mut schema = serde_json::Map::new();
30 schema.insert("type".to_string(), Value::String("object".to_string()));
31 schema.insert("properties".to_string(), Value::Object(properties));
32 if !required.is_empty() {
33 let required_values = required.into_iter().map(Value::String).collect::<Vec<_>>();
34 schema.insert("required".to_string(), Value::Array(required_values));
35 }
36 Self(Value::Object(schema))
37 }
38
39 pub fn empty_object() -> Self {
40 Self::object(Default::default(), Vec::new())
41 }
42}
43
44impl From<Value> for InputSchema {
45 fn from(schema: Value) -> Self {
46 Self(schema)
47 }
48}
49
50impl From<schemars::Schema> for InputSchema {
51 fn from(schema: schemars::Schema) -> Self {
52 let schema_value = serde_json::to_value(&schema).unwrap_or(serde_json::Value::Null);
53 Self(ensure_object_properties(schema_value))
54 }
55}
56
57fn ensure_object_properties(schema: Value) -> Value {
58 let mut schema = schema;
59 if let Value::Object(obj) = &mut schema {
60 let is_object = obj
61 .get("type")
62 .and_then(|v| v.as_str())
63 .is_some_and(|t| t == "object");
64 if is_object && !obj.contains_key("properties") {
65 obj.insert(
66 "properties".to_string(),
67 Value::Object(serde_json::Map::new()),
68 );
69 }
70 }
71 schema
72}
73
74#[derive(Debug, Clone)]
75pub struct InputSchemaSummary {
76 pub properties: serde_json::Map<String, Value>,
77 pub required: Vec<String>,
78 pub schema_type: String,
79}
80
81impl InputSchemaSummary {
82 fn from_value(schema: &Value) -> Self {
83 let mut properties = serde_json::Map::new();
84 let mut required = std::collections::BTreeSet::new();
85
86 Self::merge_schema(schema, &mut properties, &mut required);
87
88 let mut schema_type = schema
89 .as_object()
90 .and_then(|obj| obj.get("type"))
91 .and_then(|v| v.as_str())
92 .unwrap_or_default()
93 .to_string();
94
95 if schema_type.is_empty() && (!properties.is_empty() || !required.is_empty()) {
96 schema_type = "object".to_string();
97 }
98
99 Self {
100 properties,
101 required: required.into_iter().collect(),
102 schema_type,
103 }
104 }
105
106 fn merge_schema(
107 schema: &Value,
108 properties: &mut serde_json::Map<String, Value>,
109 required: &mut std::collections::BTreeSet<String>,
110 ) {
111 let Some(obj) = schema.as_object() else {
112 return;
113 };
114
115 if let Some(prop_obj) = obj.get("properties").and_then(|v| v.as_object()) {
116 for (key, value) in prop_obj {
117 merge_property(properties, key, value);
118 }
119 }
120
121 if let Some(req) = obj.get("required").and_then(|v| v.as_array()) {
122 for item in req {
123 if let Some(name) = item.as_str() {
124 required.insert(name.to_string());
125 }
126 }
127 }
128
129 if let Some(all_of) = obj.get("allOf").and_then(|v| v.as_array()) {
130 for sub in all_of {
131 let summary = InputSchemaSummary::from_value(sub);
132 for (key, value) in summary.properties {
133 merge_property(properties, &key, &value);
134 }
135 required.extend(summary.required);
136 }
137 }
138
139 if let Some(one_of) = obj.get("oneOf").and_then(|v| v.as_array()) {
140 Self::merge_one_of(one_of, properties, required);
141 }
142
143 if let Some(any_of) = obj.get("anyOf").and_then(|v| v.as_array()) {
144 Self::merge_one_of(any_of, properties, required);
145 }
146 }
147
148 fn merge_one_of(
149 subschemas: &[Value],
150 properties: &mut serde_json::Map<String, Value>,
151 required: &mut std::collections::BTreeSet<String>,
152 ) {
153 let mut intersection: Option<std::collections::BTreeSet<String>> = None;
154
155 for sub in subschemas {
156 let summary = InputSchemaSummary::from_value(sub);
157 for (key, value) in summary.properties {
158 merge_property(properties, &key, &value);
159 }
160
161 let required_set: std::collections::BTreeSet<String> =
162 summary.required.into_iter().collect();
163
164 intersection = match intersection.take() {
165 None => Some(required_set),
166 Some(existing) => Some(
167 existing
168 .intersection(&required_set)
169 .cloned()
170 .collect::<std::collections::BTreeSet<String>>(),
171 ),
172 };
173 }
174
175 if let Some(required_set) = intersection {
176 required.extend(required_set);
177 }
178 }
179}
180
181fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
182 match properties.get_mut(key) {
183 None => {
184 properties.insert(key.to_string(), value.clone());
185 }
186 Some(existing) => {
187 if existing == value {
188 return;
189 }
190 let existing_values = extract_enum_values(existing);
191 let incoming_values = extract_enum_values(value);
192 if incoming_values.is_empty() && existing_values.is_empty() {
193 return;
194 }
195
196 let mut combined = existing_values;
197 for item in incoming_values {
198 if !combined.contains(&item) {
199 combined.push(item);
200 }
201 }
202
203 if combined.is_empty() {
204 return;
205 }
206
207 if let Some(obj) = existing.as_object_mut() {
208 obj.remove("const");
209 obj.insert("enum".to_string(), Value::Array(combined));
210 }
211 }
212 }
213}
214
215fn extract_enum_values(value: &Value) -> Vec<Value> {
216 let Some(obj) = value.as_object() else {
217 return Vec::new();
218 };
219
220 if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
221 return enum_values.clone();
222 }
223
224 if let Some(const_value) = obj.get("const") {
225 return vec![const_value.clone()];
226 }
227
228 Vec::new()
229}
230
231#[cfg(test)]
232mod tests {
233 use super::InputSchema;
234 use schemars::schema_for;
235
236 #[test]
237 fn dispatch_agent_schema_includes_target() {
238 let schema = schema_for!(crate::tools::dispatch_agent::DispatchAgentParams);
239 let input_schema: InputSchema = schema.into();
240 let summary = input_schema.summary();
241
242 assert!(summary.properties.contains_key("prompt"));
243 assert!(summary.properties.contains_key("target"));
244 assert!(summary.required.contains(&"prompt".to_string()));
245 assert!(summary.required.contains(&"target".to_string()));
246 assert_eq!(summary.schema_type, "object");
247
248 let root = input_schema.as_value();
249 let target_schema = root
250 .get("properties")
251 .and_then(|v| v.get("target"))
252 .expect("target schema should exist");
253 let target_schema = resolve_ref(root, target_schema);
254 let variants = target_schema
255 .get("oneOf")
256 .or_else(|| target_schema.get("anyOf"))
257 .or_else(|| target_schema.get("allOf"))
258 .and_then(|v| v.as_array())
259 .expect("target should be a tagged union");
260
261 let mut session_values = Vec::new();
262 for variant in variants {
263 let session_prop = variant
264 .get("properties")
265 .and_then(|v| v.get("session"))
266 .unwrap_or(&serde_json::Value::Null);
267 session_values.extend(super::extract_enum_values(session_prop));
268 }
269
270 assert!(session_values.contains(&serde_json::Value::String("new".to_string())));
271 assert!(session_values.contains(&serde_json::Value::String("resume".to_string())));
272 }
273
274 fn resolve_ref<'a>(
275 root: &'a serde_json::Value,
276 schema: &'a serde_json::Value,
277 ) -> &'a serde_json::Value {
278 let Some(reference) = schema.get("$ref").and_then(|v| v.as_str()) else {
279 return schema;
280 };
281 let path = reference.strip_prefix("#/").unwrap_or(reference);
282 let mut current = root;
283 for segment in path.split('/') {
284 let Some(next) = current.get(segment) else {
285 return schema;
286 };
287 current = next;
288 }
289 current
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct ToolSchema {
295 pub name: String,
296 #[serde(default)]
297 pub display_name: String,
298 pub description: String,
299 pub input_schema: InputSchema,
300}
301
302pub trait ToolSpec {
303 type Params: DeserializeOwned + JsonSchema + Send;
304 type Result: Into<crate::result::ToolResult> + Send;
305 type Error: StdError + Send + Sync + 'static;
306
307 const NAME: &'static str;
308 const DISPLAY_NAME: &'static str;
309
310 fn execution_error(error: Self::Error) -> ToolExecutionError;
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
314pub struct ToolCall {
315 pub name: String,
316 pub parameters: Value,
317 pub id: String,
318}