Skip to main content

synaptic_models/
structured_output.rs

1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde::de::DeserializeOwned;
6use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError};
7
8/// Wraps a ChatModel to produce structured JSON output.
9///
10/// Injects a system prompt instructing the model to respond with valid JSON
11/// matching a given schema description, then parses the response.
12pub struct StructuredOutputChatModel<T> {
13    inner: Arc<dyn ChatModel>,
14    schema_description: String,
15    _marker: PhantomData<T>,
16}
17
18impl<T: DeserializeOwned + Send + Sync + 'static> StructuredOutputChatModel<T> {
19    /// Create a new StructuredOutputChatModel.
20    ///
21    /// `schema_description` should describe the expected JSON shape, e.g.:
22    /// `{"name": "string", "age": "number", "tags": ["string"]}`
23    pub fn new(inner: Arc<dyn ChatModel>, schema_description: impl Into<String>) -> Self {
24        Self {
25            inner,
26            schema_description: schema_description.into(),
27            _marker: PhantomData,
28        }
29    }
30
31    /// Parse the model's text response as JSON into type T.
32    pub fn parse_response(&self, response: &ChatResponse) -> Result<T, SynapticError> {
33        let text = response.message.content();
34        // Try to extract JSON from the response -- handle markdown code blocks
35        let json_str = extract_json(text);
36        serde_json::from_str::<T>(json_str)
37            .map_err(|e| SynapticError::Parsing(format!("failed to parse structured output: {e}")))
38    }
39
40    /// Call the model and parse the response as T.
41    pub async fn generate(&self, request: ChatRequest) -> Result<(T, ChatResponse), SynapticError> {
42        let response = self.chat(request).await?;
43        let parsed = self.parse_response(&response)?;
44        Ok((parsed, response))
45    }
46}
47
48/// Extract JSON from text, handling optional markdown code blocks.
49fn extract_json(text: &str) -> &str {
50    let trimmed = text.trim();
51    // Check for ```json ... ``` blocks
52    if let Some(start) = trimmed.find("```json") {
53        let json_start = start + 7; // skip "```json"
54        if let Some(end) = trimmed[json_start..].find("```") {
55            return trimmed[json_start..json_start + end].trim();
56        }
57    }
58    // Check for ``` ... ``` blocks
59    if let Some(start) = trimmed.find("```") {
60        let json_start = start + 3;
61        if let Some(end) = trimmed[json_start..].find("```") {
62            return trimmed[json_start..json_start + end].trim();
63        }
64    }
65    trimmed
66}
67
68#[async_trait]
69impl<T: DeserializeOwned + Send + Sync + 'static> ChatModel for StructuredOutputChatModel<T> {
70    async fn chat(&self, mut request: ChatRequest) -> Result<ChatResponse, SynapticError> {
71        // Inject system message with schema instructions
72        let instruction = format!(
73            "You MUST respond with valid JSON matching this schema:\n{}\n\nDo not include any text outside the JSON object. Do not use markdown code blocks.",
74            self.schema_description
75        );
76
77        // Prepend system message
78        request.messages.insert(0, Message::system(instruction));
79
80        self.inner.chat(request).await
81    }
82
83    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
84        // Streaming delegates to inner (structured output parsing happens after collection)
85        self.inner.stream_chat(request)
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn extract_json_plain() {
95        assert_eq!(extract_json(r#"{"a": 1}"#), r#"{"a": 1}"#);
96    }
97
98    #[test]
99    fn extract_json_code_block() {
100        let input = "```json\n{\"a\": 1}\n```";
101        assert_eq!(extract_json(input), r#"{"a": 1}"#);
102    }
103
104    #[test]
105    fn extract_json_plain_code_block() {
106        let input = "```\n{\"a\": 1}\n```";
107        assert_eq!(extract_json(input), r#"{"a": 1}"#);
108    }
109
110    #[test]
111    fn extract_json_with_surrounding_whitespace() {
112        assert_eq!(extract_json("  {\"a\": 1}  "), r#"{"a": 1}"#);
113    }
114}