1use serde::de::DeserializeOwned;
7use serdes_ai_tools::ObjectJsonSchema;
8use std::marker::PhantomData;
9
10use crate::mode::OutputMode;
11use crate::schema::{BoxedOutputSchema, OutputSchema};
12use crate::structured::StructuredOutputSchema;
13use crate::text::TextOutputSchema;
14
15pub enum OutputSpec<T> {
20 Text(TextOutputSchema),
22 Structured(StructuredOutputSchema<T>),
24 Custom(BoxedOutputSchema<T>),
26}
27
28impl<T> std::fmt::Debug for OutputSpec<T> {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 OutputSpec::Text(s) => f.debug_tuple("Text").field(s).finish(),
32 OutputSpec::Structured(_) => f.debug_tuple("Structured").field(&"...").finish(),
33 OutputSpec::Custom(_) => f
34 .debug_tuple("Custom")
35 .field(&"<dyn OutputSchema>")
36 .finish(),
37 }
38 }
39}
40
41impl OutputSpec<String> {
42 #[must_use]
44 pub fn text() -> Self {
45 OutputSpec::Text(TextOutputSchema::new())
46 }
47
48 #[must_use]
50 pub fn text_with_schema(schema: TextOutputSchema) -> Self {
51 OutputSpec::Text(schema)
52 }
53}
54
55impl<T: DeserializeOwned + Send + Sync + 'static> OutputSpec<T> {
56 #[must_use]
58 pub fn structured(schema: ObjectJsonSchema) -> Self {
59 OutputSpec::Structured(StructuredOutputSchema::new(schema))
60 }
61
62 #[must_use]
64 pub fn structured_with(schema: StructuredOutputSchema<T>) -> Self {
65 OutputSpec::Structured(schema)
66 }
67
68 pub fn custom<S: OutputSchema<T> + 'static>(schema: S) -> Self {
70 OutputSpec::Custom(Box::new(schema))
71 }
72
73 #[must_use]
75 pub fn mode(&self) -> OutputMode {
76 match self {
77 OutputSpec::Text(s) => s.mode(),
78 OutputSpec::Structured(s) => s.mode(),
79 OutputSpec::Custom(s) => s.mode(),
80 }
81 }
82
83 #[must_use]
85 pub fn tool_definitions(&self) -> Vec<serdes_ai_tools::ToolDefinition> {
86 match self {
87 OutputSpec::Text(s) => s.tool_definitions(),
88 OutputSpec::Structured(s) => s.tool_definitions(),
89 OutputSpec::Custom(s) => s.tool_definitions(),
90 }
91 }
92
93 #[must_use]
95 pub fn json_schema(&self) -> Option<ObjectJsonSchema> {
96 match self {
97 OutputSpec::Text(s) => s.json_schema(),
98 OutputSpec::Structured(s) => s.json_schema(),
99 OutputSpec::Custom(s) => s.json_schema(),
100 }
101 }
102}
103
104impl Default for OutputSpec<String> {
105 fn default() -> Self {
106 OutputSpec::text()
107 }
108}
109
110#[derive(Debug)]
112pub struct OutputSpecBuilder<T> {
113 _phantom: PhantomData<T>,
114}
115
116impl<T> OutputSpecBuilder<T> {
117 #[must_use]
119 pub fn new() -> Self {
120 Self {
121 _phantom: PhantomData,
122 }
123 }
124}
125
126impl<T> Default for OutputSpecBuilder<T> {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132impl OutputSpecBuilder<String> {
133 #[must_use]
135 pub fn text(self) -> OutputSpec<String> {
136 OutputSpec::text()
137 }
138
139 #[must_use]
141 pub fn text_constrained(
142 self,
143 min_length: Option<usize>,
144 max_length: Option<usize>,
145 ) -> OutputSpec<String> {
146 let mut schema = TextOutputSchema::new();
147 if let Some(min) = min_length {
148 schema = schema.with_min_length(min);
149 }
150 if let Some(max) = max_length {
151 schema = schema.with_max_length(max);
152 }
153 OutputSpec::Text(schema)
154 }
155}
156
157impl<T: DeserializeOwned + Send + Sync + 'static> OutputSpecBuilder<T> {
158 #[must_use]
160 pub fn structured(self, schema: ObjectJsonSchema) -> OutputSpec<T> {
161 OutputSpec::structured(schema)
162 }
163
164 #[must_use]
166 pub fn structured_with_tool(
167 self,
168 schema: ObjectJsonSchema,
169 tool_name: impl Into<String>,
170 ) -> OutputSpec<T> {
171 OutputSpec::Structured(StructuredOutputSchema::new(schema).with_tool_name(tool_name))
172 }
173}
174
175pub trait IntoOutputSpec<T> {
177 fn into_output_spec(self) -> OutputSpec<T>;
179}
180
181impl<T> IntoOutputSpec<T> for OutputSpec<T> {
182 fn into_output_spec(self) -> OutputSpec<T> {
183 self
184 }
185}
186
187impl IntoOutputSpec<String> for TextOutputSchema {
188 fn into_output_spec(self) -> OutputSpec<String> {
189 OutputSpec::Text(self)
190 }
191}
192
193impl<T: DeserializeOwned + Send + Sync + 'static> IntoOutputSpec<T> for StructuredOutputSchema<T> {
194 fn into_output_spec(self) -> OutputSpec<T> {
195 OutputSpec::Structured(self)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use serde::Deserialize;
203 use serdes_ai_tools::PropertySchema;
204
205 #[derive(Debug, Deserialize)]
206 struct TestStruct {
207 #[allow(dead_code)]
208 name: String,
209 }
210
211 #[test]
212 fn test_output_spec_text() {
213 let spec = OutputSpec::<String>::text();
214 assert_eq!(spec.mode(), OutputMode::Text);
215 assert!(spec.tool_definitions().is_empty());
216 }
217
218 #[test]
219 fn test_output_spec_structured() {
220 let schema = ObjectJsonSchema::new().with_property(
221 "name",
222 PropertySchema::string("Name").build(),
223 true,
224 );
225
226 let spec = OutputSpec::<TestStruct>::structured(schema);
227 assert_eq!(spec.mode(), OutputMode::Tool);
228 assert_eq!(spec.tool_definitions().len(), 1);
229 }
230
231 #[test]
232 fn test_output_spec_default() {
233 let spec = OutputSpec::<String>::default();
234 assert_eq!(spec.mode(), OutputMode::Text);
235 }
236
237 #[test]
238 fn test_builder_text() {
239 let spec = OutputSpecBuilder::<String>::new().text();
240 assert_eq!(spec.mode(), OutputMode::Text);
241 }
242
243 #[test]
244 fn test_builder_text_constrained() {
245 let spec = OutputSpecBuilder::<String>::new().text_constrained(Some(10), Some(100));
246 assert_eq!(spec.mode(), OutputMode::Text);
247 }
248
249 #[test]
250 fn test_builder_structured() {
251 let schema = ObjectJsonSchema::new().with_property(
252 "name",
253 PropertySchema::string("Name").build(),
254 true,
255 );
256
257 let spec = OutputSpecBuilder::<TestStruct>::new().structured(schema);
258 assert_eq!(spec.mode(), OutputMode::Tool);
259 }
260
261 #[test]
262 fn test_into_output_spec() {
263 let text_schema = TextOutputSchema::new();
264 let spec: OutputSpec<String> = text_schema.into_output_spec();
265 assert_eq!(spec.mode(), OutputMode::Text);
266 }
267}