1use async_trait::async_trait;
7use serde::de::DeserializeOwned;
8use serde_json::Value as JsonValue;
9use serdes_ai_tools::{ObjectJsonSchema, ToolDefinition};
10use std::marker::PhantomData;
11
12use crate::error::OutputParseError;
13use crate::mode::OutputMode;
14use crate::schema::OutputSchema;
15
16pub const DEFAULT_OUTPUT_TOOL_NAME: &str = "final_result";
18
19pub const DEFAULT_OUTPUT_TOOL_DESCRIPTION: &str = "The final response which ends this conversation";
21
22#[derive(Debug, Clone)]
47pub struct StructuredOutputSchema<T> {
48 pub tool_name: String,
50 pub tool_description: String,
52 pub schema: ObjectJsonSchema,
54 pub strict: Option<bool>,
56 mode: OutputMode,
58 _phantom: PhantomData<T>,
59}
60
61impl<T: DeserializeOwned + Send + Sync> StructuredOutputSchema<T> {
62 #[must_use]
64 pub fn new(schema: ObjectJsonSchema) -> Self {
65 Self {
66 tool_name: DEFAULT_OUTPUT_TOOL_NAME.to_string(),
67 tool_description: DEFAULT_OUTPUT_TOOL_DESCRIPTION.to_string(),
68 schema,
69 strict: None,
70 mode: OutputMode::Tool,
71 _phantom: PhantomData,
72 }
73 }
74
75 #[must_use]
77 pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
78 self.tool_name = name.into();
79 self
80 }
81
82 #[must_use]
84 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
85 self.tool_description = desc.into();
86 self
87 }
88
89 #[must_use]
91 pub fn with_strict(mut self, strict: bool) -> Self {
92 self.strict = Some(strict);
93 self
94 }
95
96 #[must_use]
98 pub fn with_mode(mut self, mode: OutputMode) -> Self {
99 self.mode = mode;
100 self
101 }
102}
103
104#[async_trait]
105impl<T: DeserializeOwned + Send + Sync> OutputSchema<T> for StructuredOutputSchema<T> {
106 fn mode(&self) -> OutputMode {
107 self.mode
108 }
109
110 fn tool_definitions(&self) -> Vec<ToolDefinition> {
111 vec![ToolDefinition::new(&self.tool_name, &self.tool_description)
112 .with_parameters(self.schema.clone())
113 .with_strict(self.strict.unwrap_or(false))]
114 }
115
116 fn json_schema(&self) -> Option<ObjectJsonSchema> {
117 Some(self.schema.clone())
118 }
119
120 fn parse_text(&self, text: &str) -> Result<T, OutputParseError> {
121 let json_str = extract_json(text)?;
123 serde_json::from_str(&json_str).map_err(OutputParseError::JsonParse)
124 }
125
126 fn parse_tool_call(&self, name: &str, args: &JsonValue) -> Result<T, OutputParseError> {
127 if name != self.tool_name {
128 return Err(OutputParseError::unexpected_tool(&self.tool_name, name));
129 }
130 serde_json::from_value(args.clone()).map_err(OutputParseError::JsonParse)
131 }
132
133 fn parse_native(&self, value: &JsonValue) -> Result<T, OutputParseError> {
134 serde_json::from_value(value.clone()).map_err(OutputParseError::JsonParse)
135 }
136}
137
138pub fn extract_json(text: &str) -> Result<String, OutputParseError> {
146 let text = text.trim();
147
148 if let Some(rest) = text.strip_prefix("```json") {
150 if let Some(end) = rest.find("```") {
151 return Ok(rest[..end].trim().to_string());
152 }
153 }
154
155 if let Some(rest) = text.strip_prefix("```") {
157 let rest = if let Some(newline) = rest.find('\n') {
159 &rest[newline + 1..]
160 } else {
161 rest
162 };
163 if let Some(end) = rest.find("```") {
164 return Ok(rest[..end].trim().to_string());
165 }
166 }
167
168 if let Some(start) = text.find('{') {
170 if let Some(end) = text.rfind('}') {
171 if end > start {
172 let candidate = &text[start..=end];
173 if serde_json::from_str::<JsonValue>(candidate).is_ok() {
175 return Ok(candidate.to_string());
176 }
177 }
178 }
179 }
180
181 if let Some(start) = text.find('[') {
183 if let Some(end) = text.rfind(']') {
184 if end > start {
185 let candidate = &text[start..=end];
186 if serde_json::from_str::<JsonValue>(candidate).is_ok() {
188 return Ok(candidate.to_string());
189 }
190 }
191 }
192 }
193
194 if serde_json::from_str::<JsonValue>(text).is_ok() {
196 return Ok(text.to_string());
197 }
198
199 Err(OutputParseError::NoJsonFound)
200}
201
202#[derive(Debug, Clone, Default)]
204pub struct AnyJsonSchema {
205 tool_name: String,
206 tool_description: String,
207}
208
209impl AnyJsonSchema {
210 #[must_use]
212 pub fn new() -> Self {
213 Self {
214 tool_name: DEFAULT_OUTPUT_TOOL_NAME.to_string(),
215 tool_description: DEFAULT_OUTPUT_TOOL_DESCRIPTION.to_string(),
216 }
217 }
218
219 #[must_use]
221 pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
222 self.tool_name = name.into();
223 self
224 }
225}
226
227#[async_trait]
228impl OutputSchema<JsonValue> for AnyJsonSchema {
229 fn mode(&self) -> OutputMode {
230 OutputMode::Tool
231 }
232
233 fn tool_definitions(&self) -> Vec<ToolDefinition> {
234 vec![ToolDefinition::new(&self.tool_name, &self.tool_description)]
235 }
236
237 fn parse_text(&self, text: &str) -> Result<JsonValue, OutputParseError> {
238 let json_str = extract_json(text)?;
239 serde_json::from_str(&json_str).map_err(OutputParseError::JsonParse)
240 }
241
242 fn parse_tool_call(
243 &self,
244 _name: &str,
245 args: &JsonValue,
246 ) -> Result<JsonValue, OutputParseError> {
247 Ok(args.clone())
248 }
249
250 fn parse_native(&self, value: &JsonValue) -> Result<JsonValue, OutputParseError> {
251 Ok(value.clone())
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use serde::Deserialize;
259 use serdes_ai_tools::PropertySchema;
260
261 #[derive(Debug, Deserialize, PartialEq)]
262 struct Person {
263 name: String,
264 age: u32,
265 }
266
267 fn person_schema() -> ObjectJsonSchema {
268 ObjectJsonSchema::new()
269 .with_property("name", PropertySchema::string("Name").build(), true)
270 .with_property("age", PropertySchema::integer("Age").build(), true)
271 }
272
273 #[test]
274 fn test_structured_schema_new() {
275 let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
276 assert_eq!(schema.tool_name, DEFAULT_OUTPUT_TOOL_NAME);
277 assert_eq!(schema.mode(), OutputMode::Tool);
278 }
279
280 #[test]
281 fn test_structured_schema_with_tool_name() {
282 let schema: StructuredOutputSchema<Person> =
283 StructuredOutputSchema::new(person_schema()).with_tool_name("submit_person");
284 assert_eq!(schema.tool_name, "submit_person");
285 }
286
287 #[test]
288 fn test_structured_schema_tool_definitions() {
289 let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema())
290 .with_tool_name("result")
291 .with_description("Submit the person");
292
293 let defs = schema.tool_definitions();
294 assert_eq!(defs.len(), 1);
295 assert_eq!(defs[0].name, "result");
296 assert_eq!(defs[0].description, "Submit the person");
297 }
298
299 #[test]
300 fn test_structured_schema_parse_tool_call() {
301 let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
302
303 let args = serde_json::json!({"name": "Alice", "age": 30});
304 let result = schema.parse_tool_call("final_result", &args).unwrap();
305 assert_eq!(result.name, "Alice");
306 assert_eq!(result.age, 30);
307 }
308
309 #[test]
310 fn test_structured_schema_parse_tool_call_wrong_name() {
311 let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
312
313 let args = serde_json::json!({"name": "Alice", "age": 30});
314 let result = schema.parse_tool_call("wrong_tool", &args);
315 assert!(result.is_err());
316 }
317
318 #[test]
319 fn test_structured_schema_parse_native() {
320 let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
321
322 let value = serde_json::json!({"name": "Bob", "age": 25});
323 let result = schema.parse_native(&value).unwrap();
324 assert_eq!(result.name, "Bob");
325 assert_eq!(result.age, 25);
326 }
327
328 #[test]
329 fn test_structured_schema_parse_text_raw_json() {
330 let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
331
332 let text = r#"{"name": "Charlie", "age": 35}"#;
333 let result = schema.parse_text(text).unwrap();
334 assert_eq!(result.name, "Charlie");
335 assert_eq!(result.age, 35);
336 }
337
338 #[test]
339 fn test_structured_schema_parse_text_markdown() {
340 let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
341
342 let text = r#"Here is the result:
343```json
344{"name": "Diana", "age": 28}
345```
346Done!"#;
347 let result = schema.parse_text(text).unwrap();
348 assert_eq!(result.name, "Diana");
349 assert_eq!(result.age, 28);
350 }
351
352 #[test]
353 fn test_extract_json_code_block() {
354 let text = r#"```json
355{"key": "value"}
356```"#;
357 let result = extract_json(text).unwrap();
358 assert_eq!(result, r#"{"key": "value"}"#);
359 }
360
361 #[test]
362 fn test_extract_json_plain_code_block() {
363 let text = r#"```
364{"key": "value"}
365```"#;
366 let result = extract_json(text).unwrap();
367 assert_eq!(result, r#"{"key": "value"}"#);
368 }
369
370 #[test]
371 fn test_extract_json_embedded() {
372 let text = r#"The result is: {"x": 1, "y": 2} and that's it."#;
373 let result = extract_json(text).unwrap();
374 assert_eq!(result, r#"{"x": 1, "y": 2}"#);
375 }
376
377 #[test]
378 fn test_extract_json_array() {
379 let text = r#"Here are the items: [1, 2, 3]"#;
380 let result = extract_json(text).unwrap();
381 assert_eq!(result, "[1, 2, 3]");
382 }
383
384 #[test]
385 fn test_extract_json_not_found() {
386 let text = "This is just plain text with no JSON.";
387 let result = extract_json(text);
388 assert!(result.is_err());
389 }
390
391 #[test]
392 fn test_any_json_schema() {
393 let schema = AnyJsonSchema::new();
394
395 let value = serde_json::json!({"anything": [1, 2, 3]});
396 let result = schema.parse_native(&value).unwrap();
397 assert_eq!(result, value);
398 }
399}