swink_agent/agent/
structured_output.rs1use serde_json::Value;
2use std::sync::Arc;
3use tokio_util::sync::CancellationToken;
4
5use crate::error::AgentError;
6use crate::tool::{AgentToolResult, ToolFuture};
7use crate::types::{AgentMessage, ContentBlock, LlmMessage};
8use crate::util::now_timestamp;
9
10use super::Agent;
11
12impl Agent {
13 pub async fn structured_output(
15 &mut self,
16 prompt: String,
17 schema: Value,
18 ) -> Result<Value, AgentError> {
19 let tool = Arc::new(StructuredOutputTool {
20 schema: schema.clone(),
21 });
22
23 self.state.tools.push(tool);
24 let result = self
25 .run_structured_output_attempts(prompt, schema.clone())
26 .await;
27 self.remove_structured_output_tool();
28 result
29 }
30
31 pub fn structured_output_sync(
37 &mut self,
38 prompt: String,
39 schema: Value,
40 ) -> Result<Value, AgentError> {
41 let rt = super::invoke::new_blocking_runtime()?;
42 rt.block_on(self.structured_output(prompt, schema))
43 }
44
45 pub async fn structured_output_typed<T: serde::de::DeserializeOwned>(
47 &mut self,
48 prompt: String,
49 schema: Value,
50 ) -> Result<T, AgentError> {
51 let value = self.structured_output(prompt, schema).await?;
52 serde_json::from_value(value).map_err(|e| AgentError::StructuredOutputFailed {
53 attempts: 1,
54 last_error: format!("deserialization failed: {e}"),
55 })
56 }
57
58 pub fn structured_output_typed_sync<T: serde::de::DeserializeOwned>(
64 &mut self,
65 prompt: String,
66 schema: Value,
67 ) -> Result<T, AgentError> {
68 let rt = super::invoke::new_blocking_runtime()?;
69 rt.block_on(self.structured_output_typed(prompt, schema))
70 }
71
72 async fn run_structured_output_attempts(
73 &mut self,
74 prompt: String,
75 schema: Value,
76 ) -> Result<Value, AgentError> {
77 let mut last_error = String::new();
78 let max_retries = self.structured_output_max_retries;
79
80 for attempt in 0..=max_retries {
81 let result = if attempt == 0 {
82 let user_msg = AgentMessage::Llm(LlmMessage::User(crate::types::UserMessage {
83 content: vec![ContentBlock::Text {
84 text: prompt.clone(),
85 }],
86 timestamp: now_timestamp(),
87 cache_hint: None,
88 }));
89 self.prompt_async(vec![user_msg]).await?
90 } else {
91 self.continue_async().await?
92 };
93
94 match extract_structured_output(&result, &schema) {
95 Ok(value) => return Ok(value),
96 Err(e) => {
97 last_error.clone_from(&e);
98 if attempt < max_retries {
99 let feedback = AgentMessage::Llm(LlmMessage::ToolResult(
100 crate::types::ToolResultMessage {
101 tool_call_id: find_structured_output_call_id(&result)
102 .unwrap_or_default(),
103 content: vec![ContentBlock::Text {
104 text: format!(
105 "Validation failed: {e}. Please try again with valid \
106 output."
107 ),
108 }],
109 is_error: true,
110 timestamp: now_timestamp(),
111 details: serde_json::Value::Null,
112 cache_hint: None,
113 },
114 ));
115 self.state.messages.push(feedback);
116 }
117 }
118 }
119 }
120
121 Err(AgentError::StructuredOutputFailed {
122 attempts: max_retries + 1,
123 last_error,
124 })
125 }
126
127 fn remove_structured_output_tool(&mut self) {
128 self.state
129 .tools
130 .retain(|t| t.name() != "__structured_output");
131 }
132}
133
134struct StructuredOutputTool {
136 schema: Value,
137}
138
139#[allow(clippy::unnecessary_literal_bound)]
140impl crate::tool::AgentTool for StructuredOutputTool {
141 fn name(&self) -> &str {
142 "__structured_output"
143 }
144
145 fn label(&self) -> &str {
146 "Structured Output"
147 }
148
149 fn description(&self) -> &str {
150 "Return structured data matching the required JSON schema. Call this tool with the \
151 requested data as the arguments."
152 }
153
154 fn parameters_schema(&self) -> &Value {
155 &self.schema
156 }
157
158 fn execute(
159 &self,
160 _tool_call_id: &str,
161 params: Value,
162 _cancellation_token: CancellationToken,
163 _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
164 _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
165 _credential: Option<crate::credential::ResolvedCredential>,
166 ) -> ToolFuture<'_> {
167 Box::pin(async move {
168 AgentToolResult::text(serde_json::to_string(¶ms).unwrap_or_default())
169 })
170 }
171}
172
173fn extract_structured_output(
174 result: &crate::types::AgentResult,
175 schema: &Value,
176) -> Result<Value, String> {
177 for msg in &result.messages {
178 if let AgentMessage::Llm(LlmMessage::Assistant(assistant)) = msg {
179 for block in &assistant.content {
180 if let ContentBlock::ToolCall {
181 name, arguments, ..
182 } = block
183 && name == "__structured_output"
184 {
185 let validation = crate::tool::validate_tool_arguments(schema, arguments);
186 match validation {
187 Ok(()) => return Ok(arguments.clone()),
188 Err(errors) => return Err(errors.join("; ")),
189 }
190 }
191 }
192 }
193 }
194 Err("no __structured_output tool call found in response".to_string())
195}
196
197fn find_structured_output_call_id(result: &crate::types::AgentResult) -> Option<String> {
198 for msg in &result.messages {
199 if let AgentMessage::Llm(LlmMessage::Assistant(assistant)) = msg {
200 for block in &assistant.content {
201 if let ContentBlock::ToolCall { name, id, .. } = block
202 && name == "__structured_output"
203 {
204 return Some(id.clone());
205 }
206 }
207 }
208 }
209 None
210}