strands_agents/tools/
structured_output.rs

1//! Structured output tool for validating and returning typed responses.
2//!
3//! Provides utilities for converting Rust types to tool specifications
4//! and forcing the model to output structured JSON data.
5
6use schemars::JsonSchema;
7use serde::de::DeserializeOwned;
8use serde_json::Value;
9
10use crate::types::tools::ToolSpec;
11
12/// Converts a type implementing JsonSchema to a ToolSpec.
13pub fn schema_to_tool_spec<T: JsonSchema>(name: &str, description: &str) -> ToolSpec {
14    let schema = schemars::schema_for!(T);
15    let mut json_schema = serde_json::to_value(schema).unwrap_or_default();
16
17    json_schema = flatten_schema(&json_schema);
18
19    ToolSpec::new(name, description).with_input_schema(json_schema)
20}
21
22/// Creates a structured output tool spec from a type.
23pub fn structured_output_spec<T: JsonSchema>() -> ToolSpec {
24    let name = std::any::type_name::<T>()
25        .split("::")
26        .last()
27        .unwrap_or("StructuredOutput")
28        .to_string();
29
30    let description = "IMPORTANT: This StructuredOutputTool should only be invoked as the last and final tool \
31         before returning the completed result to the caller.".to_string();
32
33    schema_to_tool_spec::<T>(&name, &description)
34}
35
36/// Result of structured output parsing.
37#[derive(Debug)]
38pub struct StructuredOutputResult<T> {
39    /// The parsed value.
40    pub value: T,
41    /// The raw JSON value.
42    pub raw_json: Value,
43}
44
45impl<T: DeserializeOwned> StructuredOutputResult<T> {
46    /// Parses a structured output from JSON.
47    pub fn from_json(json: Value) -> Result<Self, serde_json::Error> {
48        let value: T = serde_json::from_value(json.clone())?;
49        Ok(Self { value, raw_json: json })
50    }
51
52    /// Parses a structured output from a JSON string.
53    pub fn from_str(s: &str) -> Result<Self, serde_json::Error> {
54        let json: Value = serde_json::from_str(s)?;
55        Self::from_json(json)
56    }
57}
58
59/// Flattens a JSON schema by resolving $ref references.
60pub fn flatten_schema(schema: &Value) -> Value {
61    let mut result = schema.clone();
62
63    let defs_opt = result
64        .as_object_mut()
65        .and_then(|obj| obj.remove("$defs").or_else(|| obj.remove("definitions")));
66
67    if let Some(defs) = defs_opt {
68        resolve_refs(&mut result, &defs);
69    }
70
71    result
72}
73
74fn resolve_refs(value: &mut Value, defs: &Value) {
75    match value {
76        Value::Object(obj) => {
77            if let Some(ref_val) = obj.remove("$ref") {
78                if let Some(ref_str) = ref_val.as_str() {
79                    let ref_name = ref_str.split('/').last().unwrap_or("");
80                    if let Some(def) = defs.get(ref_name) {
81                        let mut resolved = def.clone();
82                        resolve_refs(&mut resolved, defs);
83                        *value = resolved;
84                        return;
85                    }
86                }
87            }
88
89            for (_, v) in obj.iter_mut() {
90                resolve_refs(v, defs);
91            }
92        }
93        Value::Array(arr) => {
94            for item in arr.iter_mut() {
95                resolve_refs(item, defs);
96            }
97        }
98        _ => {}
99    }
100}
101
102/// Processes a schema to handle optional fields properly.
103pub fn process_schema_for_optional_fields(schema: &mut Value, required_fields: &[String]) {
104    if let Some(obj) = schema.as_object_mut() {
105        if let Some(Value::Object(properties)) = obj.get_mut("properties") {
106            for (prop_name, prop_value) in properties.iter_mut() {
107                let is_required = required_fields.contains(prop_name);
108                process_property(prop_value, is_required);
109            }
110        }
111    }
112}
113
114fn process_property(prop: &mut Value, is_required: bool) {
115    if let Some(obj) = prop.as_object_mut() {
116        if let Some(any_of) = obj.remove("anyOf") {
117            if let Some(any_of_arr) = any_of.as_array() {
118                let mut null_type = false;
119                let mut non_null_type: Option<Value> = None;
120
121                for option in any_of_arr {
122                    if option.get("type") == Some(&Value::String("null".to_string())) {
123                        null_type = true;
124                    } else {
125                        non_null_type = Some(option.clone());
126                    }
127                }
128
129                if null_type && non_null_type.is_some() {
130                    let non_null = non_null_type.unwrap();
131                    if let Some(non_null_obj) = non_null.as_object() {
132                        for (k, v) in non_null_obj {
133                            obj.insert(k.clone(), v.clone());
134                        }
135                    }
136
137                    if let Some(type_val) = obj.get_mut("type") {
138                        if let Some(type_str) = type_val.as_str() {
139                            *type_val = Value::Array(vec![
140                                Value::String(type_str.to_string()),
141                                Value::String("null".to_string()),
142                            ]);
143                        }
144                    } else {
145                        obj.insert(
146                            "type".to_string(),
147                            Value::Array(vec![
148                                Value::String("object".to_string()),
149                                Value::String("null".to_string()),
150                            ]),
151                        );
152                    }
153                }
154            }
155        } else if !is_required {
156            if let Some(type_val) = obj.get_mut("type") {
157                if let Some(type_str) = type_val.as_str() {
158                    if type_str != "null" {
159                        *type_val = Value::Array(vec![
160                            Value::String(type_str.to_string()),
161                            Value::String("null".to_string()),
162                        ]);
163                    }
164                }
165            }
166        }
167
168        let nested_required: Vec<String> = obj
169            .get("required")
170            .and_then(|r| r.as_array())
171            .map(|arr| {
172                arr.iter()
173                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
174                    .collect()
175            })
176            .unwrap_or_default();
177
178        if let Some(Value::Object(nested_props)) = obj.get_mut("properties") {
179            for (prop_name, prop_value) in nested_props.iter_mut() {
180                let is_req = nested_required.contains(prop_name);
181                process_property(prop_value, is_req);
182            }
183        }
184    }
185}
186
187/// Extracts required fields from a schema.
188pub fn get_required_fields(schema: &Value) -> Vec<String> {
189    schema
190        .get("required")
191        .and_then(|r| r.as_array())
192        .map(|arr| {
193            arr.iter()
194                .filter_map(|v| v.as_str().map(|s| s.to_string()))
195                .collect()
196        })
197        .unwrap_or_default()
198}
199
200/// Validates that a value conforms to a schema (basic validation).
201pub fn validate_against_schema(value: &Value, schema: &Value) -> Result<(), String> {
202    if let Some(schema_obj) = schema.as_object() {
203        if let Some(type_val) = schema_obj.get("type") {
204            let types: Vec<&str> = match type_val {
205                Value::String(s) => vec![s.as_str()],
206                Value::Array(arr) => arr.iter().filter_map(|v| v.as_str()).collect(),
207                _ => vec![],
208            };
209
210            let value_type = match value {
211                Value::Null => "null",
212                Value::Bool(_) => "boolean",
213                Value::Number(n) if n.is_i64() || n.is_u64() => "integer",
214                Value::Number(_) => "number",
215                Value::String(_) => "string",
216                Value::Array(_) => "array",
217                Value::Object(_) => "object",
218            };
219
220            let type_matches = types.iter().any(|t| {
221                *t == value_type || (*t == "number" && value_type == "integer")
222            });
223
224            if !type_matches && !types.is_empty() {
225                return Err(format!(
226                    "Expected type {:?}, got {}",
227                    types, value_type
228                ));
229            }
230        }
231
232        if let Some(Value::Object(properties)) = schema_obj.get("properties") {
233            if let Some(value_obj) = value.as_object() {
234                let required = get_required_fields(schema);
235
236                for req_field in &required {
237                    if !value_obj.contains_key(req_field) {
238                        return Err(format!("Missing required field: {}", req_field));
239                    }
240                }
241
242                for (prop_name, prop_schema) in properties {
243                    if let Some(prop_value) = value_obj.get(prop_name) {
244                        validate_against_schema(prop_value, prop_schema)?;
245                    }
246                }
247            }
248        }
249    }
250
251    Ok(())
252}
253
254/// Creates a structured output tool that wraps a given output type.
255pub struct StructuredOutputTool<T: JsonSchema + DeserializeOwned> {
256    spec: ToolSpec,
257    _phantom: std::marker::PhantomData<T>,
258}
259
260impl<T: JsonSchema + DeserializeOwned> StructuredOutputTool<T> {
261    /// Creates a new structured output tool.
262    pub fn new() -> Self {
263        let spec = structured_output_spec::<T>();
264        Self {
265            spec,
266            _phantom: std::marker::PhantomData,
267        }
268    }
269
270    /// Creates a new structured output tool with a custom name and description.
271    pub fn with_name_description(name: &str, description: &str) -> Self {
272        let spec = schema_to_tool_spec::<T>(name, description);
273        Self {
274            spec,
275            _phantom: std::marker::PhantomData,
276        }
277    }
278
279    /// Returns the tool specification.
280    pub fn spec(&self) -> &ToolSpec {
281        &self.spec
282    }
283
284    /// Parses the tool input into the output type.
285    pub fn parse(&self, input: &Value) -> Result<T, serde_json::Error> {
286        serde_json::from_value(input.clone())
287    }
288}
289
290impl<T: JsonSchema + DeserializeOwned> Default for StructuredOutputTool<T> {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296/// A type-erased structured output tool that can be registered with the registry.
297pub struct StructuredOutputAgentTool {
298    spec: ToolSpec,
299}
300
301impl StructuredOutputAgentTool {
302    /// Creates a new structured output agent tool from a type.
303    pub fn from_type<T: JsonSchema + DeserializeOwned>() -> Self {
304        Self {
305            spec: structured_output_spec::<T>(),
306        }
307    }
308
309    /// Creates a new structured output agent tool from a spec.
310    pub fn from_spec(spec: ToolSpec) -> Self {
311        Self { spec }
312    }
313}
314
315#[async_trait::async_trait]
316impl super::AgentTool for StructuredOutputAgentTool {
317    fn name(&self) -> &str {
318        &self.spec.name
319    }
320
321    fn description(&self) -> &str {
322        &self.spec.description
323    }
324
325    fn tool_spec(&self) -> ToolSpec {
326        self.spec.clone()
327    }
328
329    fn tool_type(&self) -> &str {
330        "structured_output"
331    }
332
333    async fn invoke(
334        &self,
335        input: Value,
336        _context: &super::ToolContext,
337    ) -> std::result::Result<super::ToolResult2, String> {
338
339        Ok(super::ToolResult2::success_json(input))
340    }
341}
342
343/// Per-invocation context for structured output execution.
344#[derive(Debug, Default, Clone)]
345pub struct StructuredOutputContext {
346    /// Stored results by tool use ID.
347    results: std::collections::HashMap<String, Value>,
348    /// Expected structured output tool name.
349    expected_tool_name: Option<String>,
350    /// The tool specification for registration.
351    tool_spec: Option<ToolSpec>,
352    /// Whether structured output is enabled.
353    is_enabled: bool,
354    /// Whether forced mode is active.
355    pub forced_mode: bool,
356    /// Whether force was attempted.
357    pub force_attempted: bool,
358    /// Whether to stop the event loop.
359    pub stop_loop: bool,
360}
361
362impl StructuredOutputContext {
363    /// Creates a new structured output context.
364    pub fn new() -> Self {
365        Self::default()
366    }
367
368    /// Creates a new structured output context with a specific output type.
369    pub fn with_type<T: JsonSchema + DeserializeOwned>() -> Self {
370        let spec = structured_output_spec::<T>();
371        let name = spec.name.clone();
372
373        Self {
374            results: std::collections::HashMap::new(),
375            expected_tool_name: Some(name),
376            tool_spec: Some(spec),
377            is_enabled: true,
378            forced_mode: false,
379            force_attempted: false,
380            stop_loop: false,
381        }
382    }
383
384    /// Creates a new structured output context with a specific tool name and spec.
385    pub fn with_tool_name(name: impl Into<String>, spec: Option<ToolSpec>) -> Self {
386        Self {
387            results: std::collections::HashMap::new(),
388            expected_tool_name: Some(name.into()),
389            tool_spec: spec,
390            is_enabled: true,
391            forced_mode: false,
392            force_attempted: false,
393            stop_loop: false,
394        }
395    }
396
397    /// Returns the tool specification if available.
398    pub fn get_tool_spec(&self) -> Option<&ToolSpec> {
399        self.tool_spec.as_ref()
400    }
401
402    /// Registers the structured output tool with the given registry.
403    ///
404    /// Returns true if a tool was registered, false otherwise.
405    pub fn register_tool(&self, registry: &mut super::ToolRegistry) -> bool {
406        if let Some(ref spec) = self.tool_spec {
407            let tool = StructuredOutputAgentTool::from_spec(spec.clone());
408            if registry.register_dynamic(tool).is_ok() {
409                tracing::debug!("Registered structured output tool: {}", spec.name);
410                return true;
411            }
412        }
413        false
414    }
415
416    /// Removes the structured output tool from the given registry.
417    pub fn cleanup(&self, registry: &mut super::ToolRegistry) {
418        if let Some(ref name) = self.expected_tool_name {
419            if registry.remove_dynamic(name) {
420                tracing::debug!("Cleaned up structured output tool: {}", name);
421            }
422        }
423    }
424
425    /// Check if structured output is enabled for this context.
426    pub fn is_enabled(&self) -> bool {
427        self.is_enabled
428    }
429
430    /// Get the expected tool name.
431    pub fn expected_tool_name(&self) -> Option<&str> {
432        self.expected_tool_name.as_deref()
433    }
434
435    /// Store a validated structured output result.
436    pub fn store_result(&mut self, tool_use_id: &str, result: Value) {
437        self.results.insert(tool_use_id.to_string(), result);
438    }
439
440    /// Retrieve a stored structured output result.
441    pub fn get_result(&self, tool_use_id: &str) -> Option<&Value> {
442        self.results.get(tool_use_id)
443    }
444
445    /// Mark this context as being in forced structured output mode.
446    pub fn set_forced_mode(&mut self) {
447        if !self.is_enabled {
448            return;
449        }
450        self.forced_mode = true;
451        self.force_attempted = true;
452    }
453
454    /// Check if any tool uses are for the structured output tool.
455    pub fn has_structured_output_tool(&self, tool_names: &[String]) -> bool {
456        if let Some(expected) = &self.expected_tool_name {
457            tool_names.iter().any(|name| name == expected)
458        } else {
459            false
460        }
461    }
462
463    /// Extract and remove structured output result from stored results.
464    pub fn extract_result(&mut self, tool_use_ids: &[String]) -> Option<Value> {
465        for id in tool_use_ids {
466            if let Some(result) = self.results.remove(id) {
467                return Some(result);
468            }
469        }
470        None
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use schemars::JsonSchema;
478    use serde::{Deserialize, Serialize};
479
480    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
481    struct TestOutput {
482        name: String,
483        count: i32,
484    }
485
486    #[test]
487    fn test_schema_to_tool_spec() {
488        let spec = schema_to_tool_spec::<TestOutput>("test_output", "A test output type");
489        assert_eq!(spec.name, "test_output");
490        assert!(spec.input_schema.json.get("properties").is_some());
491    }
492
493    #[test]
494    fn test_structured_output_result() {
495        let json = serde_json::json!({
496            "name": "test",
497            "count": 42
498        });
499
500        let result: StructuredOutputResult<TestOutput> =
501            StructuredOutputResult::from_json(json).unwrap();
502        assert_eq!(result.value.name, "test");
503        assert_eq!(result.value.count, 42);
504    }
505
506    #[test]
507    fn test_flatten_schema() {
508        let schema = serde_json::json!({
509            "type": "object",
510            "properties": {
511                "inner": { "$ref": "#/$defs/InnerType" }
512            },
513            "$defs": {
514                "InnerType": {
515                    "type": "object",
516                    "properties": {
517                        "value": { "type": "string" }
518                    }
519                }
520            }
521        });
522
523        let flattened = flatten_schema(&schema);
524        let inner = flattened.get("properties").unwrap().get("inner").unwrap();
525        assert!(inner.get("properties").is_some());
526    }
527
528    #[test]
529    fn test_validate_against_schema() {
530        let schema = serde_json::json!({
531            "type": "object",
532            "properties": {
533                "name": { "type": "string" },
534                "count": { "type": "integer" }
535            },
536            "required": ["name"]
537        });
538
539        let valid_value = serde_json::json!({
540            "name": "test",
541            "count": 42
542        });
543
544        assert!(validate_against_schema(&valid_value, &schema).is_ok());
545
546        let invalid_value = serde_json::json!({
547            "count": 42
548        });
549
550        assert!(validate_against_schema(&invalid_value, &schema).is_err());
551    }
552
553    #[test]
554    fn test_structured_output_tool() {
555        let tool = StructuredOutputTool::<TestOutput>::new();
556        let spec = tool.spec();
557        assert!(spec.name.contains("TestOutput"));
558
559        let input = serde_json::json!({
560            "name": "test",
561            "count": 42
562        });
563
564        let parsed = tool.parse(&input).unwrap();
565        assert_eq!(parsed.name, "test");
566        assert_eq!(parsed.count, 42);
567    }
568
569    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
570    struct NestedOutput {
571        inner: InnerType,
572        optional_field: Option<String>,
573    }
574
575    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
576    struct InnerType {
577        value: String,
578    }
579
580    #[test]
581    fn test_nested_type_flattening() {
582        let spec = schema_to_tool_spec::<NestedOutput>("nested", "Nested output");
583        let schema = &spec.input_schema.json;
584
585        let properties = schema.get("properties").unwrap();
586        let inner_prop = properties.get("inner").unwrap();
587
588        assert!(inner_prop.get("properties").is_some());
589    }
590}