1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(tag = "type")]
15pub enum OutputSchema {
16 #[serde(rename = "text")]
18 Text,
19
20 #[serde(rename = "json_object")]
22 JsonObject,
23
24 #[serde(rename = "json_schema")]
26 JsonSchema {
27 schema: serde_json::Value,
29 name: String,
31 #[serde(default, skip_serializing_if = "Option::is_none")]
33 description: Option<String>,
34 },
35}
36
37impl OutputSchema {
38 pub fn json_schema(name: impl Into<String>, schema: serde_json::Value) -> Self {
40 OutputSchema::JsonSchema {
41 schema,
42 name: name.into(),
43 description: None,
44 }
45 }
46
47 pub fn json_schema_with_description(
49 name: impl Into<String>,
50 schema: serde_json::Value,
51 description: impl Into<String>,
52 ) -> Self {
53 OutputSchema::JsonSchema {
54 schema,
55 name: name.into(),
56 description: Some(description.into()),
57 }
58 }
59
60 pub fn schema_value(&self) -> Option<&serde_json::Value> {
62 match self {
63 OutputSchema::JsonSchema { schema, .. } => Some(schema),
64 _ => None,
65 }
66 }
67
68 pub fn to_response_format(&self) -> crate::reasoning::inference::ResponseFormat {
70 match self {
71 OutputSchema::Text => crate::reasoning::inference::ResponseFormat::Text,
72 OutputSchema::JsonObject => crate::reasoning::inference::ResponseFormat::JsonObject,
73 OutputSchema::JsonSchema { schema, name, .. } => {
74 crate::reasoning::inference::ResponseFormat::JsonSchema {
75 schema: schema.clone(),
76 name: Some(name.clone()),
77 }
78 }
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85struct SchemaEntry {
86 schema: serde_json::Value,
88 validator: Arc<jsonschema::Validator>,
90 name: String,
92 description: Option<String>,
94}
95
96#[derive(Debug, Clone, Hash, PartialEq, Eq)]
98struct SchemaKey {
99 name: String,
100 version: String,
101}
102
103#[derive(Clone)]
109pub struct SchemaRegistry {
110 schemas: Arc<RwLock<HashMap<SchemaKey, SchemaEntry>>>,
111 latest_versions: Arc<RwLock<HashMap<String, String>>>,
113}
114
115impl Default for SchemaRegistry {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121impl SchemaRegistry {
122 pub fn new() -> Self {
124 Self {
125 schemas: Arc::new(RwLock::new(HashMap::new())),
126 latest_versions: Arc::new(RwLock::new(HashMap::new())),
127 }
128 }
129
130 pub async fn register(
135 &self,
136 name: impl Into<String>,
137 version: impl Into<String>,
138 schema: serde_json::Value,
139 description: Option<String>,
140 ) -> Result<(), SchemaRegistryError> {
141 let name = name.into();
142 let version = version.into();
143
144 let validator =
145 jsonschema::validator_for(&schema).map_err(|e| SchemaRegistryError::InvalidSchema {
146 name: name.clone(),
147 reason: e.to_string(),
148 })?;
149
150 let key = SchemaKey {
151 name: name.clone(),
152 version: version.clone(),
153 };
154 let entry = SchemaEntry {
155 schema,
156 validator: Arc::new(validator),
157 name: name.clone(),
158 description,
159 };
160
161 self.schemas.write().await.insert(key, entry);
162 self.latest_versions.write().await.insert(name, version);
163
164 Ok(())
165 }
166
167 pub async fn get_validator(
169 &self,
170 name: &str,
171 version: &str,
172 ) -> Option<Arc<jsonschema::Validator>> {
173 let key = SchemaKey {
174 name: name.into(),
175 version: version.into(),
176 };
177 self.schemas
178 .read()
179 .await
180 .get(&key)
181 .map(|e| Arc::clone(&e.validator))
182 }
183
184 pub async fn get_latest_validator(&self, name: &str) -> Option<Arc<jsonschema::Validator>> {
186 let version = self.latest_versions.read().await.get(name).cloned()?;
187 self.get_validator(name, &version).await
188 }
189
190 pub async fn get_schema(&self, name: &str, version: &str) -> Option<serde_json::Value> {
192 let key = SchemaKey {
193 name: name.into(),
194 version: version.into(),
195 };
196 self.schemas
197 .read()
198 .await
199 .get(&key)
200 .map(|e| e.schema.clone())
201 }
202
203 pub async fn get_output_schema(&self, name: &str) -> Option<OutputSchema> {
205 let version = self.latest_versions.read().await.get(name).cloned()?;
206 let key = SchemaKey {
207 name: name.into(),
208 version,
209 };
210 let schemas = self.schemas.read().await;
211 let entry = schemas.get(&key)?;
212 Some(OutputSchema::JsonSchema {
213 schema: entry.schema.clone(),
214 name: entry.name.clone(),
215 description: entry.description.clone(),
216 })
217 }
218
219 pub async fn list_schemas(&self) -> Vec<(String, String)> {
221 self.latest_versions
222 .read()
223 .await
224 .iter()
225 .map(|(name, version)| (name.clone(), version.clone()))
226 .collect()
227 }
228
229 pub async fn remove(&self, name: &str, version: &str) -> bool {
231 let key = SchemaKey {
232 name: name.into(),
233 version: version.into(),
234 };
235 let removed = self.schemas.write().await.remove(&key).is_some();
236 if removed {
237 let mut latest = self.latest_versions.write().await;
239 if latest.get(name).is_some_and(|v| v == version) {
240 let schemas = self.schemas.read().await;
242 let next_version = schemas
243 .keys()
244 .filter(|k| k.name == name)
245 .map(|k| k.version.clone())
246 .max();
247 match next_version {
248 Some(v) => {
249 latest.insert(name.into(), v);
250 }
251 None => {
252 latest.remove(name);
253 }
254 }
255 }
256 }
257 removed
258 }
259}
260
261#[derive(Debug, thiserror::Error)]
263pub enum SchemaRegistryError {
264 #[error("Invalid schema '{name}': {reason}")]
265 InvalidSchema { name: String, reason: String },
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_output_schema_text() {
274 let schema = OutputSchema::Text;
275 assert!(schema.schema_value().is_none());
276 }
277
278 #[test]
279 fn test_output_schema_json_schema() {
280 let schema = OutputSchema::json_schema("test", serde_json::json!({"type": "object"}));
281 assert!(schema.schema_value().is_some());
282 }
283
284 #[test]
285 fn test_output_schema_serde_roundtrip() {
286 let schema = OutputSchema::json_schema_with_description(
287 "Result",
288 serde_json::json!({
289 "type": "object",
290 "properties": {"value": {"type": "string"}}
291 }),
292 "A simple result",
293 );
294 let json = serde_json::to_string(&schema).unwrap();
295 let restored: OutputSchema = serde_json::from_str(&json).unwrap();
296 assert!(restored.schema_value().is_some());
297 }
298
299 #[test]
300 fn test_output_schema_to_response_format() {
301 let text = OutputSchema::Text;
302 assert!(matches!(
303 text.to_response_format(),
304 crate::reasoning::inference::ResponseFormat::Text
305 ));
306
307 let json_obj = OutputSchema::JsonObject;
308 assert!(matches!(
309 json_obj.to_response_format(),
310 crate::reasoning::inference::ResponseFormat::JsonObject
311 ));
312
313 let schema = OutputSchema::json_schema("test", serde_json::json!({"type": "object"}));
314 assert!(matches!(
315 schema.to_response_format(),
316 crate::reasoning::inference::ResponseFormat::JsonSchema { .. }
317 ));
318 }
319
320 #[tokio::test]
321 async fn test_schema_registry_register_and_get() {
322 let registry = SchemaRegistry::new();
323
324 let schema = serde_json::json!({
325 "type": "object",
326 "properties": {
327 "name": {"type": "string"}
328 },
329 "required": ["name"]
330 });
331
332 registry
333 .register("test_schema", "1.0.0", schema.clone(), None)
334 .await
335 .unwrap();
336
337 let validator = registry.get_validator("test_schema", "1.0.0").await;
339 assert!(validator.is_some());
340
341 let latest = registry.get_latest_validator("test_schema").await;
343 assert!(latest.is_some());
344
345 let raw = registry.get_schema("test_schema", "1.0.0").await;
347 assert!(raw.is_some());
348 assert_eq!(raw.unwrap(), schema);
349 }
350
351 #[tokio::test]
352 async fn test_schema_registry_versioning() {
353 let registry = SchemaRegistry::new();
354
355 let v1 = serde_json::json!({"type": "object", "properties": {"a": {"type": "string"}}});
356 let v2 = serde_json::json!({"type": "object", "properties": {"a": {"type": "string"}, "b": {"type": "number"}}});
357
358 registry
359 .register("schema", "1.0.0", v1.clone(), None)
360 .await
361 .unwrap();
362 registry
363 .register("schema", "2.0.0", v2.clone(), None)
364 .await
365 .unwrap();
366
367 let latest_schema = registry.get_schema("schema", "2.0.0").await;
369 assert_eq!(latest_schema.unwrap(), v2);
370
371 assert!(registry.get_validator("schema", "1.0.0").await.is_some());
373 assert!(registry.get_validator("schema", "2.0.0").await.is_some());
374 }
375
376 #[tokio::test]
377 async fn test_schema_registry_invalid_schema() {
378 let registry = SchemaRegistry::new();
379
380 let invalid = serde_json::json!({"type": "not_a_real_type"});
382 let result = registry.register("bad", "1.0.0", invalid, None).await;
383 assert!(result.is_err());
384 }
385
386 #[tokio::test]
387 async fn test_schema_registry_list() {
388 let registry = SchemaRegistry::new();
389
390 registry
391 .register("a", "1.0.0", serde_json::json!({"type": "object"}), None)
392 .await
393 .unwrap();
394 registry
395 .register("b", "1.0.0", serde_json::json!({"type": "string"}), None)
396 .await
397 .unwrap();
398
399 let schemas = registry.list_schemas().await;
400 assert_eq!(schemas.len(), 2);
401 }
402
403 #[tokio::test]
404 async fn test_schema_registry_remove() {
405 let registry = SchemaRegistry::new();
406
407 registry
408 .register(
409 "rm_test",
410 "1.0.0",
411 serde_json::json!({"type": "object"}),
412 None,
413 )
414 .await
415 .unwrap();
416 registry
417 .register(
418 "rm_test",
419 "2.0.0",
420 serde_json::json!({"type": "object"}),
421 None,
422 )
423 .await
424 .unwrap();
425
426 assert!(registry.remove("rm_test", "2.0.0").await);
427 assert!(registry.get_validator("rm_test", "1.0.0").await.is_some());
429 assert!(registry.get_validator("rm_test", "2.0.0").await.is_none());
430 }
431
432 #[tokio::test]
433 async fn test_schema_registry_get_output_schema() {
434 let registry = SchemaRegistry::new();
435
436 registry
437 .register(
438 "output",
439 "1.0.0",
440 serde_json::json!({"type": "object"}),
441 Some("Test output".into()),
442 )
443 .await
444 .unwrap();
445
446 let output = registry.get_output_schema("output").await;
447 assert!(output.is_some());
448 match output.unwrap() {
449 OutputSchema::JsonSchema {
450 name, description, ..
451 } => {
452 assert_eq!(name, "output");
453 assert_eq!(description.as_deref(), Some("Test output"));
454 }
455 _ => panic!("Expected JsonSchema variant"),
456 }
457 }
458
459 #[tokio::test]
460 async fn test_schema_registry_nonexistent() {
461 let registry = SchemaRegistry::new();
462 assert!(registry.get_validator("nope", "1.0.0").await.is_none());
463 assert!(registry.get_latest_validator("nope").await.is_none());
464 assert!(registry.get_output_schema("nope").await.is_none());
465 }
466}