Skip to main content

shape_runtime/type_schema/
intersection.rs

1//! Intersection type support for type schemas
2//!
3//! This module provides functionality for merging multiple type schemas
4//! into intersection types (A + B), with field collision detection.
5
6use super::SchemaError;
7use super::field_types::{FieldDef, FieldType};
8use super::schema::TypeSchema;
9use std::collections::HashMap;
10
11impl TypeSchema {
12    /// Create an intersection type schema by merging multiple schemas.
13    /// Returns an error if any field names collide.
14    pub fn from_intersection(
15        name: impl Into<String>,
16        schemas: &[&TypeSchema],
17    ) -> Result<Self, SchemaError> {
18        let name = name.into();
19
20        // Check for field collisions
21        let mut seen_fields: HashMap<&str, &str> = HashMap::new();
22        for schema in schemas {
23            for field in &schema.fields {
24                if let Some(existing_type) = seen_fields.get(field.name.as_str()) {
25                    return Err(SchemaError::FieldCollision {
26                        field_name: field.name.clone(),
27                        type1: existing_type.to_string(),
28                        type2: schema.name.clone(),
29                    });
30                }
31                seen_fields.insert(&field.name, &schema.name);
32            }
33        }
34
35        // Collect all fields with their source types
36        let mut all_fields: Vec<(String, FieldType)> = Vec::new();
37        let mut field_sources: HashMap<String, String> = HashMap::new();
38        let mut component_types: Vec<String> = Vec::new();
39
40        for schema in schemas {
41            component_types.push(schema.name.clone());
42            for field in &schema.fields {
43                all_fields.push((field.name.clone(), field.field_type.clone()));
44                field_sources.insert(field.name.clone(), schema.name.clone());
45            }
46        }
47
48        // Build the merged schema
49        let id = super::next_schema_id();
50        let mut fields = Vec::with_capacity(all_fields.len());
51        let mut field_map = HashMap::with_capacity(all_fields.len());
52        let mut offset = 0;
53
54        for (index, (field_name, field_type)) in all_fields.into_iter().enumerate() {
55            let alignment = field_type.alignment();
56            offset = (offset + alignment - 1) & !(alignment - 1);
57
58            let field = FieldDef::new(&field_name, field_type.clone(), offset, index as u16);
59            field_map.insert(field_name, index);
60            offset += field_type.size();
61            fields.push(field);
62        }
63
64        let data_size = (offset + 7) & !7;
65
66        Ok(Self {
67            id,
68            name,
69            fields,
70            field_map,
71            data_size,
72            component_types: Some(component_types),
73            field_sources,
74            enum_info: None,
75            content_hash: None,
76        })
77    }
78
79    /// Check if this schema is an intersection type
80    pub fn is_intersection(&self) -> bool {
81        self.component_types.is_some()
82    }
83
84    /// Get the component types if this is an intersection
85    pub fn get_component_types(&self) -> Option<&[String]> {
86        self.component_types.as_deref()
87    }
88
89    /// Get the source type for a field (for decomposition)
90    pub fn field_source(&self, field_name: &str) -> Option<&str> {
91        self.field_sources.get(field_name).map(|s| s.as_str())
92    }
93
94    /// Get fields belonging to a specific component type (for decomposition)
95    pub fn fields_for_component(&self, component_name: &str) -> Vec<&FieldDef> {
96        self.fields
97            .iter()
98            .filter(|f| self.field_sources.get(&f.name).map(|s| s.as_str()) == Some(component_name))
99            .collect()
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_intersection_merge_success() {
109        // type A = { x: number }
110        let schema_a = TypeSchema::new("A", vec![("x".to_string(), FieldType::F64)]);
111
112        // type B = { y: string }
113        let schema_b = TypeSchema::new("B", vec![("y".to_string(), FieldType::String)]);
114
115        // type AB = A + B
116        let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b])
117            .expect("Should merge without collision");
118
119        assert_eq!(merged.name, "AB");
120        assert_eq!(merged.field_count(), 2);
121        assert!(merged.has_field("x"));
122        assert!(merged.has_field("y"));
123        assert!(merged.is_intersection());
124
125        // Check component tracking
126        let components = merged.get_component_types().unwrap();
127        assert_eq!(components, &["A", "B"]);
128
129        // Check field sources for decomposition
130        assert_eq!(merged.field_source("x"), Some("A"));
131        assert_eq!(merged.field_source("y"), Some("B"));
132    }
133
134    #[test]
135    fn test_intersection_field_collision() {
136        // type A = { x: number }
137        let schema_a = TypeSchema::new("A", vec![("x".to_string(), FieldType::F64)]);
138
139        // type B = { x: string }  // Same field name, different type
140        let schema_b = TypeSchema::new("B", vec![("x".to_string(), FieldType::String)]);
141
142        // type AB = A + B should fail with collision error
143        let result = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]);
144        assert!(result.is_err());
145
146        match result {
147            Err(SchemaError::FieldCollision {
148                field_name,
149                type1,
150                type2,
151            }) => {
152                assert_eq!(field_name, "x");
153                assert_eq!(type1, "A");
154                assert_eq!(type2, "B");
155            }
156            _ => panic!("Expected FieldCollision error"),
157        }
158    }
159
160    #[test]
161    fn test_intersection_three_types() {
162        // type A = { a: number }
163        let schema_a = TypeSchema::new("A", vec![("a".to_string(), FieldType::F64)]);
164        // type B = { b: string }
165        let schema_b = TypeSchema::new("B", vec![("b".to_string(), FieldType::String)]);
166        // type C = { c: bool }
167        let schema_c = TypeSchema::new("C", vec![("c".to_string(), FieldType::Bool)]);
168
169        // type ABC = A + B + C
170        let merged = TypeSchema::from_intersection("ABC", &[&schema_a, &schema_b, &schema_c])
171            .expect("Should merge three types");
172
173        assert_eq!(merged.field_count(), 3);
174        assert!(merged.has_field("a"));
175        assert!(merged.has_field("b"));
176        assert!(merged.has_field("c"));
177
178        let components = merged.get_component_types().unwrap();
179        assert_eq!(components, &["A", "B", "C"]);
180    }
181
182    #[test]
183    fn test_intersection_fields_for_component() {
184        // type A = { x: number, y: number }
185        let schema_a = TypeSchema::new(
186            "A",
187            vec![
188                ("x".to_string(), FieldType::F64),
189                ("y".to_string(), FieldType::F64),
190            ],
191        );
192
193        // type B = { z: string }
194        let schema_b = TypeSchema::new("B", vec![("z".to_string(), FieldType::String)]);
195
196        let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]).unwrap();
197
198        // Test decomposition - get fields belonging to each component
199        let a_fields = merged.fields_for_component("A");
200        assert_eq!(a_fields.len(), 2);
201        assert!(a_fields.iter().any(|f| f.name == "x"));
202        assert!(a_fields.iter().any(|f| f.name == "y"));
203
204        let b_fields = merged.fields_for_component("B");
205        assert_eq!(b_fields.len(), 1);
206        assert!(b_fields.iter().any(|f| f.name == "z"));
207    }
208
209    #[test]
210    fn test_intersection_data_size() {
211        let schema_a = TypeSchema::new(
212            "A",
213            vec![
214                ("a1".to_string(), FieldType::F64),
215                ("a2".to_string(), FieldType::I64),
216            ],
217        );
218
219        let schema_b = TypeSchema::new("B", vec![("b1".to_string(), FieldType::Bool)]);
220
221        let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]).unwrap();
222
223        // 3 fields * 8 bytes each = 24 bytes
224        assert_eq!(merged.data_size, 24);
225
226        // Check offsets are computed correctly
227        assert_eq!(merged.field_offset("a1"), Some(0));
228        assert_eq!(merged.field_offset("a2"), Some(8));
229        assert_eq!(merged.field_offset("b1"), Some(16));
230    }
231}