Skip to main content

shape_runtime/type_schema/
intersection.rs

1//! Intersection type support for type schemas
2//!
3//! This module provides functionality for merging multiple type schemas
4//! into intersection types (A + B), with field collision detection.
5
6use super::SchemaError;
7use super::field_types::{FieldDef, FieldType};
8use super::schema::TypeSchema;
9use std::collections::HashMap;
10
11impl TypeSchema {
12    /// Create an intersection type schema by merging multiple schemas.
13    /// Returns an error if any field names collide.
14    pub fn from_intersection(
15        name: impl Into<String>,
16        schemas: &[&TypeSchema],
17    ) -> Result<Self, SchemaError> {
18        let name = name.into();
19
20        // Check for field collisions
21        let mut seen_fields: HashMap<&str, &str> = HashMap::new();
22        for schema in schemas {
23            for field in &schema.fields {
24                if let Some(existing_type) = seen_fields.get(field.name.as_str()) {
25                    return Err(SchemaError::FieldCollision {
26                        field_name: field.name.clone(),
27                        type1: existing_type.to_string(),
28                        type2: schema.name.clone(),
29                    });
30                }
31                seen_fields.insert(&field.name, &schema.name);
32            }
33        }
34
35        // Collect all fields with their source types
36        let mut all_fields: Vec<(String, FieldType)> = Vec::new();
37        let mut field_sources: HashMap<String, String> = HashMap::new();
38        let mut component_types: Vec<String> = Vec::new();
39
40        for schema in schemas {
41            component_types.push(schema.name.clone());
42            for field in &schema.fields {
43                all_fields.push((field.name.clone(), field.field_type.clone()));
44                field_sources.insert(field.name.clone(), schema.name.clone());
45            }
46        }
47
48        // Build the merged schema via the ambient registry (B1.7: always
49        // available, no legacy-counter fallback).
50        let id = super::current_registry().allocate_id();
51        let mut fields = Vec::with_capacity(all_fields.len());
52        let mut field_map = HashMap::with_capacity(all_fields.len());
53        let mut offset = 0;
54
55        for (index, (field_name, field_type)) in all_fields.into_iter().enumerate() {
56            let alignment = field_type.alignment();
57            offset = (offset + alignment - 1) & !(alignment - 1);
58
59            let field = FieldDef::new(&field_name, field_type.clone(), offset, index as u16);
60            field_map.insert(field_name, index);
61            offset += field_type.size();
62            fields.push(field);
63        }
64
65        let data_size = (offset + 7) & !7;
66
67        Ok(Self {
68            id,
69            name,
70            fields,
71            field_map,
72            data_size,
73            component_types: Some(component_types),
74            field_sources,
75            enum_info: None,
76            content_hash: None,
77        })
78    }
79
80    /// Check if this schema is an intersection type
81    pub fn is_intersection(&self) -> bool {
82        self.component_types.is_some()
83    }
84
85    /// Get the component types if this is an intersection
86    pub fn get_component_types(&self) -> Option<&[String]> {
87        self.component_types.as_deref()
88    }
89
90    /// Get the source type for a field (for decomposition)
91    pub fn field_source(&self, field_name: &str) -> Option<&str> {
92        self.field_sources.get(field_name).map(|s| s.as_str())
93    }
94
95    /// Get fields belonging to a specific component type (for decomposition)
96    pub fn fields_for_component(&self, component_name: &str) -> Vec<&FieldDef> {
97        self.fields
98            .iter()
99            .filter(|f| self.field_sources.get(&f.name).map(|s| s.as_str()) == Some(component_name))
100            .collect()
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn test_intersection_merge_success() {
110        // type A = { x: number }
111        let schema_a = TypeSchema::new("A", vec![("x".to_string(), FieldType::F64)]);
112
113        // type B = { y: string }
114        let schema_b = TypeSchema::new("B", vec![("y".to_string(), FieldType::String)]);
115
116        // type AB = A + B
117        let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b])
118            .expect("Should merge without collision");
119
120        assert_eq!(merged.name, "AB");
121        assert_eq!(merged.field_count(), 2);
122        assert!(merged.has_field("x"));
123        assert!(merged.has_field("y"));
124        assert!(merged.is_intersection());
125
126        // Check component tracking
127        let components = merged.get_component_types().unwrap();
128        assert_eq!(components, &["A", "B"]);
129
130        // Check field sources for decomposition
131        assert_eq!(merged.field_source("x"), Some("A"));
132        assert_eq!(merged.field_source("y"), Some("B"));
133    }
134
135    #[test]
136    fn test_intersection_field_collision() {
137        // type A = { x: number }
138        let schema_a = TypeSchema::new("A", vec![("x".to_string(), FieldType::F64)]);
139
140        // type B = { x: string }  // Same field name, different type
141        let schema_b = TypeSchema::new("B", vec![("x".to_string(), FieldType::String)]);
142
143        // type AB = A + B should fail with collision error
144        let result = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]);
145        assert!(result.is_err());
146
147        match result {
148            Err(SchemaError::FieldCollision {
149                field_name,
150                type1,
151                type2,
152            }) => {
153                assert_eq!(field_name, "x");
154                assert_eq!(type1, "A");
155                assert_eq!(type2, "B");
156            }
157            _ => panic!("Expected FieldCollision error"),
158        }
159    }
160
161    #[test]
162    fn test_intersection_three_types() {
163        // type A = { a: number }
164        let schema_a = TypeSchema::new("A", vec![("a".to_string(), FieldType::F64)]);
165        // type B = { b: string }
166        let schema_b = TypeSchema::new("B", vec![("b".to_string(), FieldType::String)]);
167        // type C = { c: bool }
168        let schema_c = TypeSchema::new("C", vec![("c".to_string(), FieldType::Bool)]);
169
170        // type ABC = A + B + C
171        let merged = TypeSchema::from_intersection("ABC", &[&schema_a, &schema_b, &schema_c])
172            .expect("Should merge three types");
173
174        assert_eq!(merged.field_count(), 3);
175        assert!(merged.has_field("a"));
176        assert!(merged.has_field("b"));
177        assert!(merged.has_field("c"));
178
179        let components = merged.get_component_types().unwrap();
180        assert_eq!(components, &["A", "B", "C"]);
181    }
182
183    #[test]
184    fn test_intersection_fields_for_component() {
185        // type A = { x: number, y: number }
186        let schema_a = TypeSchema::new(
187            "A",
188            vec![
189                ("x".to_string(), FieldType::F64),
190                ("y".to_string(), FieldType::F64),
191            ],
192        );
193
194        // type B = { z: string }
195        let schema_b = TypeSchema::new("B", vec![("z".to_string(), FieldType::String)]);
196
197        let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]).unwrap();
198
199        // Test decomposition - get fields belonging to each component
200        let a_fields = merged.fields_for_component("A");
201        assert_eq!(a_fields.len(), 2);
202        assert!(a_fields.iter().any(|f| f.name == "x"));
203        assert!(a_fields.iter().any(|f| f.name == "y"));
204
205        let b_fields = merged.fields_for_component("B");
206        assert_eq!(b_fields.len(), 1);
207        assert!(b_fields.iter().any(|f| f.name == "z"));
208    }
209
210    #[test]
211    fn test_intersection_data_size() {
212        let schema_a = TypeSchema::new(
213            "A",
214            vec![
215                ("a1".to_string(), FieldType::F64),
216                ("a2".to_string(), FieldType::I64),
217            ],
218        );
219
220        let schema_b = TypeSchema::new("B", vec![("b1".to_string(), FieldType::Bool)]);
221
222        let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]).unwrap();
223
224        // 3 fields * 8 bytes each = 24 bytes
225        assert_eq!(merged.data_size, 24);
226
227        // Check offsets are computed correctly
228        assert_eq!(merged.field_offset("a1"), Some(0));
229        assert_eq!(merged.field_offset("a2"), Some(8));
230        assert_eq!(merged.field_offset("b1"), Some(16));
231    }
232}