Skip to main content

swink_agent/agent/
structured_output.rs

1use serde_json::Value;
2use std::sync::Arc;
3use tokio_util::sync::CancellationToken;
4
5use crate::error::AgentError;
6use crate::tool::{AgentToolResult, ToolFuture};
7use crate::types::{AgentMessage, ContentBlock, LlmMessage};
8use crate::util::now_timestamp;
9
10use super::Agent;
11
12impl Agent {
13    /// Run a structured output extraction loop.
14    pub async fn structured_output(
15        &mut self,
16        prompt: String,
17        schema: Value,
18    ) -> Result<Value, AgentError> {
19        let tool = Arc::new(StructuredOutputTool {
20            schema: schema.clone(),
21        });
22
23        self.state.tools.push(tool);
24        let result = self
25            .run_structured_output_attempts(prompt, schema.clone())
26            .await;
27        self.remove_structured_output_tool();
28        result
29    }
30
31    /// Run a structured output extraction loop, blocking the current thread.
32    ///
33    /// # Errors
34    ///
35    /// Returns [`AgentError::SyncInAsyncContext`] if called from within a Tokio runtime.
36    pub fn structured_output_sync(
37        &mut self,
38        prompt: String,
39        schema: Value,
40    ) -> Result<Value, AgentError> {
41        let rt = super::invoke::new_blocking_runtime()?;
42        rt.block_on(self.structured_output(prompt, schema))
43    }
44
45    /// Run structured output extraction and deserialize into a typed result.
46    pub async fn structured_output_typed<T: serde::de::DeserializeOwned>(
47        &mut self,
48        prompt: String,
49        schema: Value,
50    ) -> Result<T, AgentError> {
51        let value = self.structured_output(prompt, schema).await?;
52        serde_json::from_value(value).map_err(|e| AgentError::StructuredOutputFailed {
53            attempts: 1,
54            last_error: format!("deserialization failed: {e}"),
55        })
56    }
57
58    /// Run structured output extraction, deserialize into a typed result, blocking.
59    ///
60    /// # Errors
61    ///
62    /// Returns [`AgentError::SyncInAsyncContext`] if called from within a Tokio runtime.
63    pub fn structured_output_typed_sync<T: serde::de::DeserializeOwned>(
64        &mut self,
65        prompt: String,
66        schema: Value,
67    ) -> Result<T, AgentError> {
68        let rt = super::invoke::new_blocking_runtime()?;
69        rt.block_on(self.structured_output_typed(prompt, schema))
70    }
71
72    async fn run_structured_output_attempts(
73        &mut self,
74        prompt: String,
75        schema: Value,
76    ) -> Result<Value, AgentError> {
77        let mut last_error = String::new();
78        let max_retries = self.structured_output_max_retries;
79
80        for attempt in 0..=max_retries {
81            let result = if attempt == 0 {
82                let user_msg = AgentMessage::Llm(LlmMessage::User(crate::types::UserMessage {
83                    content: vec![ContentBlock::Text {
84                        text: prompt.clone(),
85                    }],
86                    timestamp: now_timestamp(),
87                    cache_hint: None,
88                }));
89                self.prompt_async(vec![user_msg]).await?
90            } else {
91                self.continue_async().await?
92            };
93
94            match extract_structured_output(&result, &schema) {
95                Ok(value) => return Ok(value),
96                Err(e) => {
97                    last_error.clone_from(&e);
98                    if attempt < max_retries {
99                        let feedback = AgentMessage::Llm(LlmMessage::ToolResult(
100                            crate::types::ToolResultMessage {
101                                tool_call_id: find_structured_output_call_id(&result)
102                                    .unwrap_or_default(),
103                                content: vec![ContentBlock::Text {
104                                    text: format!(
105                                        "Validation failed: {e}. Please try again with valid \
106                                         output."
107                                    ),
108                                }],
109                                is_error: true,
110                                timestamp: now_timestamp(),
111                                details: serde_json::Value::Null,
112                                cache_hint: None,
113                            },
114                        ));
115                        self.state.messages.push(feedback);
116                    }
117                }
118            }
119        }
120
121        Err(AgentError::StructuredOutputFailed {
122            attempts: max_retries + 1,
123            last_error,
124        })
125    }
126
127    fn remove_structured_output_tool(&mut self) {
128        self.state
129            .tools
130            .retain(|t| t.name() != "__structured_output");
131    }
132}
133
134/// Synthetic tool used for structured output extraction.
135struct StructuredOutputTool {
136    schema: Value,
137}
138
139#[allow(clippy::unnecessary_literal_bound)]
140impl crate::tool::AgentTool for StructuredOutputTool {
141    fn name(&self) -> &str {
142        "__structured_output"
143    }
144
145    fn label(&self) -> &str {
146        "Structured Output"
147    }
148
149    fn description(&self) -> &str {
150        "Return structured data matching the required JSON schema. Call this tool with the \
151         requested data as the arguments."
152    }
153
154    fn parameters_schema(&self) -> &Value {
155        &self.schema
156    }
157
158    fn execute(
159        &self,
160        _tool_call_id: &str,
161        params: Value,
162        _cancellation_token: CancellationToken,
163        _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
164        _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
165        _credential: Option<crate::credential::ResolvedCredential>,
166    ) -> ToolFuture<'_> {
167        Box::pin(async move {
168            AgentToolResult::text(serde_json::to_string(&params).unwrap_or_default())
169        })
170    }
171}
172
173fn extract_structured_output(
174    result: &crate::types::AgentResult,
175    schema: &Value,
176) -> Result<Value, String> {
177    for msg in &result.messages {
178        if let AgentMessage::Llm(LlmMessage::Assistant(assistant)) = msg {
179            for block in &assistant.content {
180                if let ContentBlock::ToolCall {
181                    name, arguments, ..
182                } = block
183                    && name == "__structured_output"
184                {
185                    let validation = crate::tool::validate_tool_arguments(schema, arguments);
186                    match validation {
187                        Ok(()) => return Ok(arguments.clone()),
188                        Err(errors) => return Err(errors.join("; ")),
189                    }
190                }
191            }
192        }
193    }
194    Err("no __structured_output tool call found in response".to_string())
195}
196
197fn find_structured_output_call_id(result: &crate::types::AgentResult) -> Option<String> {
198    for msg in &result.messages {
199        if let AgentMessage::Llm(LlmMessage::Assistant(assistant)) = msg {
200            for block in &assistant.content {
201                if let ContentBlock::ToolCall { name, id, .. } = block
202                    && name == "__structured_output"
203                {
204                    return Some(id.clone());
205                }
206            }
207        }
208    }
209    None
210}