Skip to main content

symbi_runtime/reasoning/
output_schema.rs

1//! Output schema registry and management
2//!
3//! Provides `OutputSchema` for declaring expected response formats and
4//! `SchemaRegistry` for storing versioned, pre-compiled schema validators
5//! that can be provided to LLM API calls.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12/// Describes the expected output format for an inference call.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(tag = "type")]
15pub enum OutputSchema {
16    /// Free-form text, no validation.
17    #[serde(rename = "text")]
18    Text,
19
20    /// JSON object, validated for well-formedness only.
21    #[serde(rename = "json_object")]
22    JsonObject,
23
24    /// JSON conforming to an explicit JSON Schema.
25    #[serde(rename = "json_schema")]
26    JsonSchema {
27        /// The raw JSON Schema value.
28        schema: serde_json::Value,
29        /// Human-readable name for logging and API calls.
30        name: String,
31        /// Optional description.
32        #[serde(default, skip_serializing_if = "Option::is_none")]
33        description: Option<String>,
34    },
35}
36
37impl OutputSchema {
38    /// Create a JSON Schema output from a raw schema value.
39    pub fn json_schema(name: impl Into<String>, schema: serde_json::Value) -> Self {
40        OutputSchema::JsonSchema {
41            schema,
42            name: name.into(),
43            description: None,
44        }
45    }
46
47    /// Create a JSON Schema output with description.
48    pub fn json_schema_with_description(
49        name: impl Into<String>,
50        schema: serde_json::Value,
51        description: impl Into<String>,
52    ) -> Self {
53        OutputSchema::JsonSchema {
54            schema,
55            name: name.into(),
56            description: Some(description.into()),
57        }
58    }
59
60    /// Get the JSON Schema value if this is a schema variant.
61    pub fn schema_value(&self) -> Option<&serde_json::Value> {
62        match self {
63            OutputSchema::JsonSchema { schema, .. } => Some(schema),
64            _ => None,
65        }
66    }
67
68    /// Convert to the InferenceOptions ResponseFormat.
69    pub fn to_response_format(&self) -> crate::reasoning::inference::ResponseFormat {
70        match self {
71            OutputSchema::Text => crate::reasoning::inference::ResponseFormat::Text,
72            OutputSchema::JsonObject => crate::reasoning::inference::ResponseFormat::JsonObject,
73            OutputSchema::JsonSchema { schema, name, .. } => {
74                crate::reasoning::inference::ResponseFormat::JsonSchema {
75                    schema: schema.clone(),
76                    name: Some(name.clone()),
77                }
78            }
79        }
80    }
81}
82
83/// A versioned entry in the schema registry.
84#[derive(Debug, Clone)]
85struct SchemaEntry {
86    /// The raw schema.
87    schema: serde_json::Value,
88    /// Pre-compiled validator for fast validation.
89    validator: Arc<jsonschema::Validator>,
90    /// Human-readable name.
91    name: String,
92    /// Optional description.
93    description: Option<String>,
94}
95
96/// Registry key combining name and version.
97#[derive(Debug, Clone, Hash, PartialEq, Eq)]
98struct SchemaKey {
99    name: String,
100    version: String,
101}
102
103/// Thread-safe registry of versioned, pre-compiled JSON Schema validators.
104///
105/// Schemas are registered once, compiled into validators, and reused across
106/// many validation calls. This amortizes the cost of schema compilation
107/// (typically 10-100μs) over the lifetime of the application.
108#[derive(Clone)]
109pub struct SchemaRegistry {
110    schemas: Arc<RwLock<HashMap<SchemaKey, SchemaEntry>>>,
111    /// Tracks the latest version for each schema name.
112    latest_versions: Arc<RwLock<HashMap<String, String>>>,
113}
114
115impl Default for SchemaRegistry {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl SchemaRegistry {
122    /// Create a new empty registry.
123    pub fn new() -> Self {
124        Self {
125            schemas: Arc::new(RwLock::new(HashMap::new())),
126            latest_versions: Arc::new(RwLock::new(HashMap::new())),
127        }
128    }
129
130    /// Register a schema with a name and version.
131    ///
132    /// The schema is compiled into a validator at registration time.
133    /// Returns an error if the schema is invalid.
134    pub async fn register(
135        &self,
136        name: impl Into<String>,
137        version: impl Into<String>,
138        schema: serde_json::Value,
139        description: Option<String>,
140    ) -> Result<(), SchemaRegistryError> {
141        let name = name.into();
142        let version = version.into();
143
144        let validator =
145            jsonschema::validator_for(&schema).map_err(|e| SchemaRegistryError::InvalidSchema {
146                name: name.clone(),
147                reason: e.to_string(),
148            })?;
149
150        let key = SchemaKey {
151            name: name.clone(),
152            version: version.clone(),
153        };
154        let entry = SchemaEntry {
155            schema,
156            validator: Arc::new(validator),
157            name: name.clone(),
158            description,
159        };
160
161        self.schemas.write().await.insert(key, entry);
162        self.latest_versions.write().await.insert(name, version);
163
164        Ok(())
165    }
166
167    /// Get the pre-compiled validator for a specific schema version.
168    pub async fn get_validator(
169        &self,
170        name: &str,
171        version: &str,
172    ) -> Option<Arc<jsonschema::Validator>> {
173        let key = SchemaKey {
174            name: name.into(),
175            version: version.into(),
176        };
177        self.schemas
178            .read()
179            .await
180            .get(&key)
181            .map(|e| Arc::clone(&e.validator))
182    }
183
184    /// Get the pre-compiled validator for the latest version of a schema.
185    pub async fn get_latest_validator(&self, name: &str) -> Option<Arc<jsonschema::Validator>> {
186        let version = self.latest_versions.read().await.get(name).cloned()?;
187        self.get_validator(name, &version).await
188    }
189
190    /// Get the raw schema value for a specific version.
191    pub async fn get_schema(&self, name: &str, version: &str) -> Option<serde_json::Value> {
192        let key = SchemaKey {
193            name: name.into(),
194            version: version.into(),
195        };
196        self.schemas
197            .read()
198            .await
199            .get(&key)
200            .map(|e| e.schema.clone())
201    }
202
203    /// Get the schema as an OutputSchema for the latest version.
204    pub async fn get_output_schema(&self, name: &str) -> Option<OutputSchema> {
205        let version = self.latest_versions.read().await.get(name).cloned()?;
206        let key = SchemaKey {
207            name: name.into(),
208            version,
209        };
210        let schemas = self.schemas.read().await;
211        let entry = schemas.get(&key)?;
212        Some(OutputSchema::JsonSchema {
213            schema: entry.schema.clone(),
214            name: entry.name.clone(),
215            description: entry.description.clone(),
216        })
217    }
218
219    /// List all registered schema names with their latest versions.
220    pub async fn list_schemas(&self) -> Vec<(String, String)> {
221        self.latest_versions
222            .read()
223            .await
224            .iter()
225            .map(|(name, version)| (name.clone(), version.clone()))
226            .collect()
227    }
228
229    /// Remove a schema version from the registry.
230    pub async fn remove(&self, name: &str, version: &str) -> bool {
231        let key = SchemaKey {
232            name: name.into(),
233            version: version.into(),
234        };
235        let removed = self.schemas.write().await.remove(&key).is_some();
236        if removed {
237            // If this was the latest version, find the next latest
238            let mut latest = self.latest_versions.write().await;
239            if latest.get(name).is_some_and(|v| v == version) {
240                // Find another version for this name
241                let schemas = self.schemas.read().await;
242                let next_version = schemas
243                    .keys()
244                    .filter(|k| k.name == name)
245                    .map(|k| k.version.clone())
246                    .max();
247                match next_version {
248                    Some(v) => {
249                        latest.insert(name.into(), v);
250                    }
251                    None => {
252                        latest.remove(name);
253                    }
254                }
255            }
256        }
257        removed
258    }
259}
260
261/// Errors from the schema registry.
262#[derive(Debug, thiserror::Error)]
263pub enum SchemaRegistryError {
264    #[error("Invalid schema '{name}': {reason}")]
265    InvalidSchema { name: String, reason: String },
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_output_schema_text() {
274        let schema = OutputSchema::Text;
275        assert!(schema.schema_value().is_none());
276    }
277
278    #[test]
279    fn test_output_schema_json_schema() {
280        let schema = OutputSchema::json_schema("test", serde_json::json!({"type": "object"}));
281        assert!(schema.schema_value().is_some());
282    }
283
284    #[test]
285    fn test_output_schema_serde_roundtrip() {
286        let schema = OutputSchema::json_schema_with_description(
287            "Result",
288            serde_json::json!({
289                "type": "object",
290                "properties": {"value": {"type": "string"}}
291            }),
292            "A simple result",
293        );
294        let json = serde_json::to_string(&schema).unwrap();
295        let restored: OutputSchema = serde_json::from_str(&json).unwrap();
296        assert!(restored.schema_value().is_some());
297    }
298
299    #[test]
300    fn test_output_schema_to_response_format() {
301        let text = OutputSchema::Text;
302        assert!(matches!(
303            text.to_response_format(),
304            crate::reasoning::inference::ResponseFormat::Text
305        ));
306
307        let json_obj = OutputSchema::JsonObject;
308        assert!(matches!(
309            json_obj.to_response_format(),
310            crate::reasoning::inference::ResponseFormat::JsonObject
311        ));
312
313        let schema = OutputSchema::json_schema("test", serde_json::json!({"type": "object"}));
314        assert!(matches!(
315            schema.to_response_format(),
316            crate::reasoning::inference::ResponseFormat::JsonSchema { .. }
317        ));
318    }
319
320    #[tokio::test]
321    async fn test_schema_registry_register_and_get() {
322        let registry = SchemaRegistry::new();
323
324        let schema = serde_json::json!({
325            "type": "object",
326            "properties": {
327                "name": {"type": "string"}
328            },
329            "required": ["name"]
330        });
331
332        registry
333            .register("test_schema", "1.0.0", schema.clone(), None)
334            .await
335            .unwrap();
336
337        // Get specific version
338        let validator = registry.get_validator("test_schema", "1.0.0").await;
339        assert!(validator.is_some());
340
341        // Get latest version
342        let latest = registry.get_latest_validator("test_schema").await;
343        assert!(latest.is_some());
344
345        // Get raw schema
346        let raw = registry.get_schema("test_schema", "1.0.0").await;
347        assert!(raw.is_some());
348        assert_eq!(raw.unwrap(), schema);
349    }
350
351    #[tokio::test]
352    async fn test_schema_registry_versioning() {
353        let registry = SchemaRegistry::new();
354
355        let v1 = serde_json::json!({"type": "object", "properties": {"a": {"type": "string"}}});
356        let v2 = serde_json::json!({"type": "object", "properties": {"a": {"type": "string"}, "b": {"type": "number"}}});
357
358        registry
359            .register("schema", "1.0.0", v1.clone(), None)
360            .await
361            .unwrap();
362        registry
363            .register("schema", "2.0.0", v2.clone(), None)
364            .await
365            .unwrap();
366
367        // Latest should be v2
368        let latest_schema = registry.get_schema("schema", "2.0.0").await;
369        assert_eq!(latest_schema.unwrap(), v2);
370
371        // Both versions accessible
372        assert!(registry.get_validator("schema", "1.0.0").await.is_some());
373        assert!(registry.get_validator("schema", "2.0.0").await.is_some());
374    }
375
376    #[tokio::test]
377    async fn test_schema_registry_invalid_schema() {
378        let registry = SchemaRegistry::new();
379
380        // A schema with an invalid type should fail
381        let invalid = serde_json::json!({"type": "not_a_real_type"});
382        let result = registry.register("bad", "1.0.0", invalid, None).await;
383        assert!(result.is_err());
384    }
385
386    #[tokio::test]
387    async fn test_schema_registry_list() {
388        let registry = SchemaRegistry::new();
389
390        registry
391            .register("a", "1.0.0", serde_json::json!({"type": "object"}), None)
392            .await
393            .unwrap();
394        registry
395            .register("b", "1.0.0", serde_json::json!({"type": "string"}), None)
396            .await
397            .unwrap();
398
399        let schemas = registry.list_schemas().await;
400        assert_eq!(schemas.len(), 2);
401    }
402
403    #[tokio::test]
404    async fn test_schema_registry_remove() {
405        let registry = SchemaRegistry::new();
406
407        registry
408            .register(
409                "rm_test",
410                "1.0.0",
411                serde_json::json!({"type": "object"}),
412                None,
413            )
414            .await
415            .unwrap();
416        registry
417            .register(
418                "rm_test",
419                "2.0.0",
420                serde_json::json!({"type": "object"}),
421                None,
422            )
423            .await
424            .unwrap();
425
426        assert!(registry.remove("rm_test", "2.0.0").await);
427        // Latest should now fall back
428        assert!(registry.get_validator("rm_test", "1.0.0").await.is_some());
429        assert!(registry.get_validator("rm_test", "2.0.0").await.is_none());
430    }
431
432    #[tokio::test]
433    async fn test_schema_registry_get_output_schema() {
434        let registry = SchemaRegistry::new();
435
436        registry
437            .register(
438                "output",
439                "1.0.0",
440                serde_json::json!({"type": "object"}),
441                Some("Test output".into()),
442            )
443            .await
444            .unwrap();
445
446        let output = registry.get_output_schema("output").await;
447        assert!(output.is_some());
448        match output.unwrap() {
449            OutputSchema::JsonSchema {
450                name, description, ..
451            } => {
452                assert_eq!(name, "output");
453                assert_eq!(description.as_deref(), Some("Test output"));
454            }
455            _ => panic!("Expected JsonSchema variant"),
456        }
457    }
458
459    #[tokio::test]
460    async fn test_schema_registry_nonexistent() {
461        let registry = SchemaRegistry::new();
462        assert!(registry.get_validator("nope", "1.0.0").await.is_none());
463        assert!(registry.get_latest_validator("nope").await.is_none());
464        assert!(registry.get_output_schema("nope").await.is_none());
465    }
466}