1use schemars::JsonSchema;
7use serde::de::DeserializeOwned;
8use serde_json::Value;
9
10use crate::types::tools::ToolSpec;
11
12pub fn schema_to_tool_spec<T: JsonSchema>(name: &str, description: &str) -> ToolSpec {
14 let schema = schemars::schema_for!(T);
15 let mut json_schema = serde_json::to_value(schema).unwrap_or_default();
16
17 json_schema = flatten_schema(&json_schema);
18
19 ToolSpec::new(name, description).with_input_schema(json_schema)
20}
21
22pub fn structured_output_spec<T: JsonSchema>() -> ToolSpec {
24 let name = std::any::type_name::<T>()
25 .split("::")
26 .last()
27 .unwrap_or("StructuredOutput")
28 .to_string();
29
30 let description = "IMPORTANT: This StructuredOutputTool should only be invoked as the last and final tool \
31 before returning the completed result to the caller.".to_string();
32
33 schema_to_tool_spec::<T>(&name, &description)
34}
35
36#[derive(Debug)]
38pub struct StructuredOutputResult<T> {
39 pub value: T,
41 pub raw_json: Value,
43}
44
45impl<T: DeserializeOwned> StructuredOutputResult<T> {
46 pub fn from_json(json: Value) -> Result<Self, serde_json::Error> {
48 let value: T = serde_json::from_value(json.clone())?;
49 Ok(Self { value, raw_json: json })
50 }
51
52 pub fn from_str(s: &str) -> Result<Self, serde_json::Error> {
54 let json: Value = serde_json::from_str(s)?;
55 Self::from_json(json)
56 }
57}
58
59pub fn flatten_schema(schema: &Value) -> Value {
61 let mut result = schema.clone();
62
63 let defs_opt = result
64 .as_object_mut()
65 .and_then(|obj| obj.remove("$defs").or_else(|| obj.remove("definitions")));
66
67 if let Some(defs) = defs_opt {
68 resolve_refs(&mut result, &defs);
69 }
70
71 result
72}
73
74fn resolve_refs(value: &mut Value, defs: &Value) {
75 match value {
76 Value::Object(obj) => {
77 if let Some(ref_val) = obj.remove("$ref") {
78 if let Some(ref_str) = ref_val.as_str() {
79 let ref_name = ref_str.split('/').last().unwrap_or("");
80 if let Some(def) = defs.get(ref_name) {
81 let mut resolved = def.clone();
82 resolve_refs(&mut resolved, defs);
83 *value = resolved;
84 return;
85 }
86 }
87 }
88
89 for (_, v) in obj.iter_mut() {
90 resolve_refs(v, defs);
91 }
92 }
93 Value::Array(arr) => {
94 for item in arr.iter_mut() {
95 resolve_refs(item, defs);
96 }
97 }
98 _ => {}
99 }
100}
101
102pub fn process_schema_for_optional_fields(schema: &mut Value, required_fields: &[String]) {
104 if let Some(obj) = schema.as_object_mut() {
105 if let Some(Value::Object(properties)) = obj.get_mut("properties") {
106 for (prop_name, prop_value) in properties.iter_mut() {
107 let is_required = required_fields.contains(prop_name);
108 process_property(prop_value, is_required);
109 }
110 }
111 }
112}
113
114fn process_property(prop: &mut Value, is_required: bool) {
115 if let Some(obj) = prop.as_object_mut() {
116 if let Some(any_of) = obj.remove("anyOf") {
117 if let Some(any_of_arr) = any_of.as_array() {
118 let mut null_type = false;
119 let mut non_null_type: Option<Value> = None;
120
121 for option in any_of_arr {
122 if option.get("type") == Some(&Value::String("null".to_string())) {
123 null_type = true;
124 } else {
125 non_null_type = Some(option.clone());
126 }
127 }
128
129 if null_type && non_null_type.is_some() {
130 let non_null = non_null_type.unwrap();
131 if let Some(non_null_obj) = non_null.as_object() {
132 for (k, v) in non_null_obj {
133 obj.insert(k.clone(), v.clone());
134 }
135 }
136
137 if let Some(type_val) = obj.get_mut("type") {
138 if let Some(type_str) = type_val.as_str() {
139 *type_val = Value::Array(vec![
140 Value::String(type_str.to_string()),
141 Value::String("null".to_string()),
142 ]);
143 }
144 } else {
145 obj.insert(
146 "type".to_string(),
147 Value::Array(vec![
148 Value::String("object".to_string()),
149 Value::String("null".to_string()),
150 ]),
151 );
152 }
153 }
154 }
155 } else if !is_required {
156 if let Some(type_val) = obj.get_mut("type") {
157 if let Some(type_str) = type_val.as_str() {
158 if type_str != "null" {
159 *type_val = Value::Array(vec![
160 Value::String(type_str.to_string()),
161 Value::String("null".to_string()),
162 ]);
163 }
164 }
165 }
166 }
167
168 let nested_required: Vec<String> = obj
169 .get("required")
170 .and_then(|r| r.as_array())
171 .map(|arr| {
172 arr.iter()
173 .filter_map(|v| v.as_str().map(|s| s.to_string()))
174 .collect()
175 })
176 .unwrap_or_default();
177
178 if let Some(Value::Object(nested_props)) = obj.get_mut("properties") {
179 for (prop_name, prop_value) in nested_props.iter_mut() {
180 let is_req = nested_required.contains(prop_name);
181 process_property(prop_value, is_req);
182 }
183 }
184 }
185}
186
187pub fn get_required_fields(schema: &Value) -> Vec<String> {
189 schema
190 .get("required")
191 .and_then(|r| r.as_array())
192 .map(|arr| {
193 arr.iter()
194 .filter_map(|v| v.as_str().map(|s| s.to_string()))
195 .collect()
196 })
197 .unwrap_or_default()
198}
199
200pub fn validate_against_schema(value: &Value, schema: &Value) -> Result<(), String> {
202 if let Some(schema_obj) = schema.as_object() {
203 if let Some(type_val) = schema_obj.get("type") {
204 let types: Vec<&str> = match type_val {
205 Value::String(s) => vec![s.as_str()],
206 Value::Array(arr) => arr.iter().filter_map(|v| v.as_str()).collect(),
207 _ => vec![],
208 };
209
210 let value_type = match value {
211 Value::Null => "null",
212 Value::Bool(_) => "boolean",
213 Value::Number(n) if n.is_i64() || n.is_u64() => "integer",
214 Value::Number(_) => "number",
215 Value::String(_) => "string",
216 Value::Array(_) => "array",
217 Value::Object(_) => "object",
218 };
219
220 let type_matches = types.iter().any(|t| {
221 *t == value_type || (*t == "number" && value_type == "integer")
222 });
223
224 if !type_matches && !types.is_empty() {
225 return Err(format!(
226 "Expected type {:?}, got {}",
227 types, value_type
228 ));
229 }
230 }
231
232 if let Some(Value::Object(properties)) = schema_obj.get("properties") {
233 if let Some(value_obj) = value.as_object() {
234 let required = get_required_fields(schema);
235
236 for req_field in &required {
237 if !value_obj.contains_key(req_field) {
238 return Err(format!("Missing required field: {}", req_field));
239 }
240 }
241
242 for (prop_name, prop_schema) in properties {
243 if let Some(prop_value) = value_obj.get(prop_name) {
244 validate_against_schema(prop_value, prop_schema)?;
245 }
246 }
247 }
248 }
249 }
250
251 Ok(())
252}
253
254pub struct StructuredOutputTool<T: JsonSchema + DeserializeOwned> {
256 spec: ToolSpec,
257 _phantom: std::marker::PhantomData<T>,
258}
259
260impl<T: JsonSchema + DeserializeOwned> StructuredOutputTool<T> {
261 pub fn new() -> Self {
263 let spec = structured_output_spec::<T>();
264 Self {
265 spec,
266 _phantom: std::marker::PhantomData,
267 }
268 }
269
270 pub fn with_name_description(name: &str, description: &str) -> Self {
272 let spec = schema_to_tool_spec::<T>(name, description);
273 Self {
274 spec,
275 _phantom: std::marker::PhantomData,
276 }
277 }
278
279 pub fn spec(&self) -> &ToolSpec {
281 &self.spec
282 }
283
284 pub fn parse(&self, input: &Value) -> Result<T, serde_json::Error> {
286 serde_json::from_value(input.clone())
287 }
288}
289
290impl<T: JsonSchema + DeserializeOwned> Default for StructuredOutputTool<T> {
291 fn default() -> Self {
292 Self::new()
293 }
294}
295
296pub struct StructuredOutputAgentTool {
298 spec: ToolSpec,
299}
300
301impl StructuredOutputAgentTool {
302 pub fn from_type<T: JsonSchema + DeserializeOwned>() -> Self {
304 Self {
305 spec: structured_output_spec::<T>(),
306 }
307 }
308
309 pub fn from_spec(spec: ToolSpec) -> Self {
311 Self { spec }
312 }
313}
314
315#[async_trait::async_trait]
316impl super::AgentTool for StructuredOutputAgentTool {
317 fn name(&self) -> &str {
318 &self.spec.name
319 }
320
321 fn description(&self) -> &str {
322 &self.spec.description
323 }
324
325 fn tool_spec(&self) -> ToolSpec {
326 self.spec.clone()
327 }
328
329 fn tool_type(&self) -> &str {
330 "structured_output"
331 }
332
333 async fn invoke(
334 &self,
335 input: Value,
336 _context: &super::ToolContext,
337 ) -> std::result::Result<super::ToolResult2, String> {
338
339 Ok(super::ToolResult2::success_json(input))
340 }
341}
342
343#[derive(Debug, Default, Clone)]
345pub struct StructuredOutputContext {
346 results: std::collections::HashMap<String, Value>,
348 expected_tool_name: Option<String>,
350 tool_spec: Option<ToolSpec>,
352 is_enabled: bool,
354 pub forced_mode: bool,
356 pub force_attempted: bool,
358 pub stop_loop: bool,
360}
361
362impl StructuredOutputContext {
363 pub fn new() -> Self {
365 Self::default()
366 }
367
368 pub fn with_type<T: JsonSchema + DeserializeOwned>() -> Self {
370 let spec = structured_output_spec::<T>();
371 let name = spec.name.clone();
372
373 Self {
374 results: std::collections::HashMap::new(),
375 expected_tool_name: Some(name),
376 tool_spec: Some(spec),
377 is_enabled: true,
378 forced_mode: false,
379 force_attempted: false,
380 stop_loop: false,
381 }
382 }
383
384 pub fn with_tool_name(name: impl Into<String>, spec: Option<ToolSpec>) -> Self {
386 Self {
387 results: std::collections::HashMap::new(),
388 expected_tool_name: Some(name.into()),
389 tool_spec: spec,
390 is_enabled: true,
391 forced_mode: false,
392 force_attempted: false,
393 stop_loop: false,
394 }
395 }
396
397 pub fn get_tool_spec(&self) -> Option<&ToolSpec> {
399 self.tool_spec.as_ref()
400 }
401
402 pub fn register_tool(&self, registry: &mut super::ToolRegistry) -> bool {
406 if let Some(ref spec) = self.tool_spec {
407 let tool = StructuredOutputAgentTool::from_spec(spec.clone());
408 if registry.register_dynamic(tool).is_ok() {
409 tracing::debug!("Registered structured output tool: {}", spec.name);
410 return true;
411 }
412 }
413 false
414 }
415
416 pub fn cleanup(&self, registry: &mut super::ToolRegistry) {
418 if let Some(ref name) = self.expected_tool_name {
419 if registry.remove_dynamic(name) {
420 tracing::debug!("Cleaned up structured output tool: {}", name);
421 }
422 }
423 }
424
425 pub fn is_enabled(&self) -> bool {
427 self.is_enabled
428 }
429
430 pub fn expected_tool_name(&self) -> Option<&str> {
432 self.expected_tool_name.as_deref()
433 }
434
435 pub fn store_result(&mut self, tool_use_id: &str, result: Value) {
437 self.results.insert(tool_use_id.to_string(), result);
438 }
439
440 pub fn get_result(&self, tool_use_id: &str) -> Option<&Value> {
442 self.results.get(tool_use_id)
443 }
444
445 pub fn set_forced_mode(&mut self) {
447 if !self.is_enabled {
448 return;
449 }
450 self.forced_mode = true;
451 self.force_attempted = true;
452 }
453
454 pub fn has_structured_output_tool(&self, tool_names: &[String]) -> bool {
456 if let Some(expected) = &self.expected_tool_name {
457 tool_names.iter().any(|name| name == expected)
458 } else {
459 false
460 }
461 }
462
463 pub fn extract_result(&mut self, tool_use_ids: &[String]) -> Option<Value> {
465 for id in tool_use_ids {
466 if let Some(result) = self.results.remove(id) {
467 return Some(result);
468 }
469 }
470 None
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use schemars::JsonSchema;
478 use serde::{Deserialize, Serialize};
479
480 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
481 struct TestOutput {
482 name: String,
483 count: i32,
484 }
485
486 #[test]
487 fn test_schema_to_tool_spec() {
488 let spec = schema_to_tool_spec::<TestOutput>("test_output", "A test output type");
489 assert_eq!(spec.name, "test_output");
490 assert!(spec.input_schema.json.get("properties").is_some());
491 }
492
493 #[test]
494 fn test_structured_output_result() {
495 let json = serde_json::json!({
496 "name": "test",
497 "count": 42
498 });
499
500 let result: StructuredOutputResult<TestOutput> =
501 StructuredOutputResult::from_json(json).unwrap();
502 assert_eq!(result.value.name, "test");
503 assert_eq!(result.value.count, 42);
504 }
505
506 #[test]
507 fn test_flatten_schema() {
508 let schema = serde_json::json!({
509 "type": "object",
510 "properties": {
511 "inner": { "$ref": "#/$defs/InnerType" }
512 },
513 "$defs": {
514 "InnerType": {
515 "type": "object",
516 "properties": {
517 "value": { "type": "string" }
518 }
519 }
520 }
521 });
522
523 let flattened = flatten_schema(&schema);
524 let inner = flattened.get("properties").unwrap().get("inner").unwrap();
525 assert!(inner.get("properties").is_some());
526 }
527
528 #[test]
529 fn test_validate_against_schema() {
530 let schema = serde_json::json!({
531 "type": "object",
532 "properties": {
533 "name": { "type": "string" },
534 "count": { "type": "integer" }
535 },
536 "required": ["name"]
537 });
538
539 let valid_value = serde_json::json!({
540 "name": "test",
541 "count": 42
542 });
543
544 assert!(validate_against_schema(&valid_value, &schema).is_ok());
545
546 let invalid_value = serde_json::json!({
547 "count": 42
548 });
549
550 assert!(validate_against_schema(&invalid_value, &schema).is_err());
551 }
552
553 #[test]
554 fn test_structured_output_tool() {
555 let tool = StructuredOutputTool::<TestOutput>::new();
556 let spec = tool.spec();
557 assert!(spec.name.contains("TestOutput"));
558
559 let input = serde_json::json!({
560 "name": "test",
561 "count": 42
562 });
563
564 let parsed = tool.parse(&input).unwrap();
565 assert_eq!(parsed.name, "test");
566 assert_eq!(parsed.count, 42);
567 }
568
569 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
570 struct NestedOutput {
571 inner: InnerType,
572 optional_field: Option<String>,
573 }
574
575 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
576 struct InnerType {
577 value: String,
578 }
579
580 #[test]
581 fn test_nested_type_flattening() {
582 let spec = schema_to_tool_spec::<NestedOutput>("nested", "Nested output");
583 let schema = &spec.input_schema.json;
584
585 let properties = schema.get("properties").unwrap();
586 let inner_prop = properties.get("inner").unwrap();
587
588 assert!(inner_prop.get("properties").is_some());
589 }
590}