synaptic_models/
structured_output.rs1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde::de::DeserializeOwned;
6use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapseError};
7
8pub struct StructuredOutputChatModel<T> {
13 inner: Arc<dyn ChatModel>,
14 schema_description: String,
15 _marker: PhantomData<T>,
16}
17
18impl<T: DeserializeOwned + Send + Sync + 'static> StructuredOutputChatModel<T> {
19 pub fn new(inner: Arc<dyn ChatModel>, schema_description: impl Into<String>) -> Self {
24 Self {
25 inner,
26 schema_description: schema_description.into(),
27 _marker: PhantomData,
28 }
29 }
30
31 pub fn parse_response(&self, response: &ChatResponse) -> Result<T, SynapseError> {
33 let text = response.message.content();
34 let json_str = extract_json(text);
36 serde_json::from_str::<T>(json_str)
37 .map_err(|e| SynapseError::Parsing(format!("failed to parse structured output: {e}")))
38 }
39
40 pub async fn generate(&self, request: ChatRequest) -> Result<(T, ChatResponse), SynapseError> {
42 let response = self.chat(request).await?;
43 let parsed = self.parse_response(&response)?;
44 Ok((parsed, response))
45 }
46}
47
48fn extract_json(text: &str) -> &str {
50 let trimmed = text.trim();
51 if let Some(start) = trimmed.find("```json") {
53 let json_start = start + 7; if let Some(end) = trimmed[json_start..].find("```") {
55 return trimmed[json_start..json_start + end].trim();
56 }
57 }
58 if let Some(start) = trimmed.find("```") {
60 let json_start = start + 3;
61 if let Some(end) = trimmed[json_start..].find("```") {
62 return trimmed[json_start..json_start + end].trim();
63 }
64 }
65 trimmed
66}
67
68#[async_trait]
69impl<T: DeserializeOwned + Send + Sync + 'static> ChatModel for StructuredOutputChatModel<T> {
70 async fn chat(&self, mut request: ChatRequest) -> Result<ChatResponse, SynapseError> {
71 let instruction = format!(
73 "You MUST respond with valid JSON matching this schema:\n{}\n\nDo not include any text outside the JSON object. Do not use markdown code blocks.",
74 self.schema_description
75 );
76
77 request.messages.insert(0, Message::system(instruction));
79
80 self.inner.chat(request).await
81 }
82
83 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
84 self.inner.stream_chat(request)
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn extract_json_plain() {
95 assert_eq!(extract_json(r#"{"a": 1}"#), r#"{"a": 1}"#);
96 }
97
98 #[test]
99 fn extract_json_code_block() {
100 let input = "```json\n{\"a\": 1}\n```";
101 assert_eq!(extract_json(input), r#"{"a": 1}"#);
102 }
103
104 #[test]
105 fn extract_json_plain_code_block() {
106 let input = "```\n{\"a\": 1}\n```";
107 assert_eq!(extract_json(input), r#"{"a": 1}"#);
108 }
109
110 #[test]
111 fn extract_json_with_surrounding_whitespace() {
112 assert_eq!(extract_json(" {\"a\": 1} "), r#"{"a": 1}"#);
113 }
114}