serdes_ai_output/
schema.rs1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{ObjectJsonSchema, ToolDefinition};
9
10use crate::error::OutputParseError;
11use crate::mode::OutputMode;
12
13#[async_trait]
22pub trait OutputSchema<T: Send>: Send + Sync {
23 fn mode(&self) -> OutputMode;
25
26 fn tool_definitions(&self) -> Vec<ToolDefinition> {
30 vec![]
31 }
32
33 fn json_schema(&self) -> Option<ObjectJsonSchema> {
37 None
38 }
39
40 fn supports_mode(&self, mode: OutputMode) -> bool {
42 match mode {
43 OutputMode::Text => true, OutputMode::Tool => !self.tool_definitions().is_empty(),
45 OutputMode::Native | OutputMode::Prompted => self.json_schema().is_some(),
46 }
47 }
48
49 fn parse_text(&self, text: &str) -> Result<T, OutputParseError>;
51
52 fn parse_tool_call(&self, name: &str, args: &JsonValue) -> Result<T, OutputParseError>;
54
55 fn parse_native(&self, value: &JsonValue) -> Result<T, OutputParseError>;
57
58 fn parse(
60 &self,
61 mode: OutputMode,
62 text: Option<&str>,
63 tool_name: Option<&str>,
64 args: Option<&JsonValue>,
65 ) -> Result<T, OutputParseError> {
66 match mode {
67 OutputMode::Text => {
68 let text = text.ok_or_else(|| OutputParseError::custom("No text output"))?;
69 self.parse_text(text)
70 }
71 OutputMode::Tool => {
72 let name = tool_name.ok_or_else(|| OutputParseError::custom("No tool call"))?;
73 let args = args.ok_or_else(|| OutputParseError::custom("No tool arguments"))?;
74 self.parse_tool_call(name, args)
75 }
76 OutputMode::Native | OutputMode::Prompted => {
77 if let (Some(name), Some(args)) = (tool_name, args) {
79 return self.parse_tool_call(name, args);
80 }
81 if let Some(args) = args {
82 return self.parse_native(args);
83 }
84 if let Some(text) = text {
85 return self.parse_text(text);
86 }
87 Err(OutputParseError::custom("No output to parse"))
88 }
89 }
90 }
91}
92
93pub type BoxedOutputSchema<T> = Box<dyn OutputSchema<T>>;
95
96#[derive(Debug)]
98pub struct OutputSchemaWrapper<S, T> {
99 inner: S,
100 _phantom: std::marker::PhantomData<T>,
101}
102
103impl<S, T> OutputSchemaWrapper<S, T> {
104 pub fn new(inner: S) -> Self {
106 Self {
107 inner,
108 _phantom: std::marker::PhantomData,
109 }
110 }
111
112 pub fn inner(&self) -> &S {
114 &self.inner
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 struct MockSchema;
123
124 #[async_trait]
125 impl OutputSchema<String> for MockSchema {
126 fn mode(&self) -> OutputMode {
127 OutputMode::Text
128 }
129
130 fn parse_text(&self, text: &str) -> Result<String, OutputParseError> {
131 Ok(text.to_string())
132 }
133
134 fn parse_tool_call(
135 &self,
136 _name: &str,
137 args: &JsonValue,
138 ) -> Result<String, OutputParseError> {
139 args.as_str()
140 .map(String::from)
141 .ok_or(OutputParseError::NotJson)
142 }
143
144 fn parse_native(&self, value: &JsonValue) -> Result<String, OutputParseError> {
145 value
146 .as_str()
147 .map(String::from)
148 .ok_or(OutputParseError::NotJson)
149 }
150 }
151
152 #[test]
153 fn test_mock_schema_parse_text() {
154 let schema = MockSchema;
155 let result = schema.parse_text("hello").unwrap();
156 assert_eq!(result, "hello");
157 }
158
159 #[test]
160 fn test_mock_schema_supports_mode() {
161 let schema = MockSchema;
162 assert!(schema.supports_mode(OutputMode::Text));
163 assert!(!schema.supports_mode(OutputMode::Tool));
164 assert!(!schema.supports_mode(OutputMode::Native));
165 }
166
167 #[test]
168 fn test_parse_dispatch() {
169 let schema = MockSchema;
170
171 let result = schema
173 .parse(OutputMode::Text, Some("hello"), None, None)
174 .unwrap();
175 assert_eq!(result, "hello");
176
177 let result = schema.parse(OutputMode::Text, None, None, None);
179 assert!(result.is_err());
180 }
181}