tx2_link/
schema.rs

1use crate::error::{LinkError, Result};
2use crate::protocol::{ComponentId, FieldId, FieldType};
3use ahash::AHashMap;
4use serde::{Deserialize, Serialize};
5use std::sync::{Arc, RwLock};
6
7pub type SchemaVersion = u32;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ComponentSchema {
11    pub component_id: ComponentId,
12    pub version: SchemaVersion,
13    pub fields: Vec<FieldSchema>,
14    pub description: Option<String>,
15}
16
17impl ComponentSchema {
18    pub fn new(component_id: ComponentId, version: SchemaVersion) -> Self {
19        Self {
20            component_id,
21            version,
22            fields: Vec::new(),
23            description: None,
24        }
25    }
26
27    pub fn with_field(mut self, field: FieldSchema) -> Self {
28        self.fields.push(field);
29        self
30    }
31
32    pub fn with_description(mut self, description: String) -> Self {
33        self.description = Some(description);
34        self
35    }
36
37    pub fn get_field(&self, field_id: &str) -> Option<&FieldSchema> {
38        self.fields.iter().find(|f| f.field_id == field_id)
39    }
40
41    pub fn validate_field(&self, field_id: &str, field_type: &FieldType) -> bool {
42        if let Some(schema) = self.get_field(field_id) {
43            &schema.field_type == field_type
44        } else {
45            false
46        }
47    }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct FieldSchema {
52    pub field_id: FieldId,
53    pub field_type: FieldType,
54    pub optional: bool,
55    pub default_value: Option<String>,
56    pub description: Option<String>,
57}
58
59impl FieldSchema {
60    pub fn new(field_id: FieldId, field_type: FieldType) -> Self {
61        Self {
62            field_id,
63            field_type,
64            optional: false,
65            default_value: None,
66            description: None,
67        }
68    }
69
70    pub fn optional(mut self) -> Self {
71        self.optional = true;
72        self
73    }
74
75    pub fn with_default(mut self, default: String) -> Self {
76        self.default_value = Some(default);
77        self
78    }
79
80    pub fn with_description(mut self, description: String) -> Self {
81        self.description = Some(description);
82        self
83    }
84}
85
86pub struct SchemaRegistry {
87    schemas: Arc<RwLock<AHashMap<ComponentId, ComponentSchema>>>,
88    version_history: Arc<RwLock<AHashMap<ComponentId, Vec<SchemaVersion>>>>,
89    current_version: SchemaVersion,
90}
91
92impl SchemaRegistry {
93    pub fn new() -> Self {
94        Self {
95            schemas: Arc::new(RwLock::new(AHashMap::new())),
96            version_history: Arc::new(RwLock::new(AHashMap::new())),
97            current_version: 1,
98        }
99    }
100
101    pub fn register(&self, schema: ComponentSchema) -> Result<()> {
102        let mut schemas = self.schemas.write()
103            .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
104
105        let mut version_history = self.version_history.write()
106            .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
107
108        let component_id = schema.component_id.clone();
109        let version = schema.version;
110
111        if let Some(existing) = schemas.get(&component_id) {
112            if existing.version >= version {
113                return Err(LinkError::Unknown(
114                    format!("Schema version {} already exists or is newer for component {}", version, component_id)
115                ));
116            }
117        }
118
119        version_history.entry(component_id.clone())
120            .or_insert_with(Vec::new)
121            .push(version);
122
123        schemas.insert(component_id, schema);
124
125        Ok(())
126    }
127
128    pub fn get(&self, component_id: &str) -> Result<ComponentSchema> {
129        let schemas = self.schemas.read()
130            .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
131
132        schemas.get(component_id)
133            .cloned()
134            .ok_or_else(|| LinkError::SchemaNotFound(component_id.to_string()))
135    }
136
137    pub fn get_version(&self, component_id: &str, version: SchemaVersion) -> Result<ComponentSchema> {
138        let schema = self.get(component_id)?;
139
140        if schema.version == version {
141            Ok(schema)
142        } else {
143            Err(LinkError::SchemaMismatch {
144                expected: version.to_string(),
145                actual: schema.version.to_string(),
146            })
147        }
148    }
149
150    pub fn has(&self, component_id: &str) -> bool {
151        self.schemas.read()
152            .map(|schemas| schemas.contains_key(component_id))
153            .unwrap_or(false)
154    }
155
156    pub fn get_all(&self) -> Result<Vec<ComponentSchema>> {
157        let schemas = self.schemas.read()
158            .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
159
160        Ok(schemas.values().cloned().collect())
161    }
162
163    pub fn get_version_history(&self, component_id: &str) -> Result<Vec<SchemaVersion>> {
164        let history = self.version_history.read()
165            .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
166
167        Ok(history.get(component_id)
168            .cloned()
169            .unwrap_or_default())
170    }
171
172    pub fn validate_compatibility(&self, old_version: SchemaVersion, new_version: SchemaVersion) -> bool {
173        new_version >= old_version
174    }
175
176    pub fn clear(&self) -> Result<()> {
177        let mut schemas = self.schemas.write()
178            .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
179
180        let mut version_history = self.version_history.write()
181            .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
182
183        schemas.clear();
184        version_history.clear();
185
186        Ok(())
187    }
188
189    pub fn get_current_version(&self) -> SchemaVersion {
190        self.current_version
191    }
192
193    pub fn set_current_version(&mut self, version: SchemaVersion) {
194        self.current_version = version;
195    }
196}
197
198impl Default for SchemaRegistry {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl Clone for SchemaRegistry {
205    fn clone(&self) -> Self {
206        Self {
207            schemas: Arc::clone(&self.schemas),
208            version_history: Arc::clone(&self.version_history),
209            current_version: self.current_version,
210        }
211    }
212}
213
214pub struct SchemaValidator {
215    registry: SchemaRegistry,
216}
217
218impl SchemaValidator {
219    pub fn new(registry: SchemaRegistry) -> Self {
220        Self { registry }
221    }
222
223    pub fn validate_component(&self, component_id: &str, fields: &AHashMap<FieldId, FieldType>) -> Result<()> {
224        let schema = self.registry.get(component_id)?;
225
226        for field_schema in &schema.fields {
227            if !field_schema.optional {
228                if !fields.contains_key(&field_schema.field_id) {
229                    return Err(LinkError::InvalidMessage(
230                        format!("Required field '{}' missing in component '{}'", field_schema.field_id, component_id)
231                    ));
232                }
233            }
234
235            if let Some(field_type) = fields.get(&field_schema.field_id) {
236                if field_type != &field_schema.field_type {
237                    return Err(LinkError::InvalidMessage(
238                        format!("Field '{}' has wrong type in component '{}'", field_schema.field_id, component_id)
239                    ));
240                }
241            }
242        }
243
244        Ok(())
245    }
246
247    pub fn get_registry(&self) -> &SchemaRegistry {
248        &self.registry
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_schema_registry() {
258        let registry = SchemaRegistry::new();
259
260        let schema = ComponentSchema::new("Position".to_string(), 1)
261            .with_field(FieldSchema::new("x".to_string(), FieldType::F64))
262            .with_field(FieldSchema::new("y".to_string(), FieldType::F64))
263            .with_description("2D position component".to_string());
264
265        registry.register(schema.clone()).unwrap();
266
267        let retrieved = registry.get("Position").unwrap();
268        assert_eq!(retrieved.component_id, "Position");
269        assert_eq!(retrieved.fields.len(), 2);
270    }
271
272    #[test]
273    fn test_schema_versioning() {
274        let registry = SchemaRegistry::new();
275
276        let schema_v1 = ComponentSchema::new("Position".to_string(), 1)
277            .with_field(FieldSchema::new("x".to_string(), FieldType::F64))
278            .with_field(FieldSchema::new("y".to_string(), FieldType::F64));
279
280        registry.register(schema_v1).unwrap();
281
282        let schema_v2 = ComponentSchema::new("Position".to_string(), 2)
283            .with_field(FieldSchema::new("x".to_string(), FieldType::F64))
284            .with_field(FieldSchema::new("y".to_string(), FieldType::F64))
285            .with_field(FieldSchema::new("z".to_string(), FieldType::F64).optional());
286
287        registry.register(schema_v2).unwrap();
288
289        let history = registry.get_version_history("Position").unwrap();
290        assert_eq!(history.len(), 2);
291        assert!(history.contains(&1));
292        assert!(history.contains(&2));
293    }
294
295    #[test]
296    fn test_schema_validation() {
297        let registry = SchemaRegistry::new();
298
299        let schema = ComponentSchema::new("Position".to_string(), 1)
300            .with_field(FieldSchema::new("x".to_string(), FieldType::F64))
301            .with_field(FieldSchema::new("y".to_string(), FieldType::F64));
302
303        registry.register(schema).unwrap();
304
305        let validator = SchemaValidator::new(registry);
306
307        let mut fields = AHashMap::new();
308        fields.insert("x".to_string(), FieldType::F64);
309        fields.insert("y".to_string(), FieldType::F64);
310
311        assert!(validator.validate_component("Position", &fields).is_ok());
312
313        let mut invalid_fields = AHashMap::new();
314        invalid_fields.insert("x".to_string(), FieldType::F64);
315
316        assert!(validator.validate_component("Position", &invalid_fields).is_err());
317    }
318}