shape_runtime/type_schema/
intersection.rs1use super::SchemaError;
7use super::field_types::{FieldDef, FieldType};
8use super::schema::TypeSchema;
9use std::collections::HashMap;
10
11impl TypeSchema {
12 pub fn from_intersection(
15 name: impl Into<String>,
16 schemas: &[&TypeSchema],
17 ) -> Result<Self, SchemaError> {
18 let name = name.into();
19
20 let mut seen_fields: HashMap<&str, &str> = HashMap::new();
22 for schema in schemas {
23 for field in &schema.fields {
24 if let Some(existing_type) = seen_fields.get(field.name.as_str()) {
25 return Err(SchemaError::FieldCollision {
26 field_name: field.name.clone(),
27 type1: existing_type.to_string(),
28 type2: schema.name.clone(),
29 });
30 }
31 seen_fields.insert(&field.name, &schema.name);
32 }
33 }
34
35 let mut all_fields: Vec<(String, FieldType)> = Vec::new();
37 let mut field_sources: HashMap<String, String> = HashMap::new();
38 let mut component_types: Vec<String> = Vec::new();
39
40 for schema in schemas {
41 component_types.push(schema.name.clone());
42 for field in &schema.fields {
43 all_fields.push((field.name.clone(), field.field_type.clone()));
44 field_sources.insert(field.name.clone(), schema.name.clone());
45 }
46 }
47
48 let id = super::current_registry().allocate_id();
51 let mut fields = Vec::with_capacity(all_fields.len());
52 let mut field_map = HashMap::with_capacity(all_fields.len());
53 let mut offset = 0;
54
55 for (index, (field_name, field_type)) in all_fields.into_iter().enumerate() {
56 let alignment = field_type.alignment();
57 offset = (offset + alignment - 1) & !(alignment - 1);
58
59 let field = FieldDef::new(&field_name, field_type.clone(), offset, index as u16);
60 field_map.insert(field_name, index);
61 offset += field_type.size();
62 fields.push(field);
63 }
64
65 let data_size = (offset + 7) & !7;
66
67 Ok(Self {
68 id,
69 name,
70 fields,
71 field_map,
72 data_size,
73 component_types: Some(component_types),
74 field_sources,
75 enum_info: None,
76 content_hash: None,
77 })
78 }
79
80 pub fn is_intersection(&self) -> bool {
82 self.component_types.is_some()
83 }
84
85 pub fn get_component_types(&self) -> Option<&[String]> {
87 self.component_types.as_deref()
88 }
89
90 pub fn field_source(&self, field_name: &str) -> Option<&str> {
92 self.field_sources.get(field_name).map(|s| s.as_str())
93 }
94
95 pub fn fields_for_component(&self, component_name: &str) -> Vec<&FieldDef> {
97 self.fields
98 .iter()
99 .filter(|f| self.field_sources.get(&f.name).map(|s| s.as_str()) == Some(component_name))
100 .collect()
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_intersection_merge_success() {
110 let schema_a = TypeSchema::new("A", vec![("x".to_string(), FieldType::F64)]);
112
113 let schema_b = TypeSchema::new("B", vec![("y".to_string(), FieldType::String)]);
115
116 let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b])
118 .expect("Should merge without collision");
119
120 assert_eq!(merged.name, "AB");
121 assert_eq!(merged.field_count(), 2);
122 assert!(merged.has_field("x"));
123 assert!(merged.has_field("y"));
124 assert!(merged.is_intersection());
125
126 let components = merged.get_component_types().unwrap();
128 assert_eq!(components, &["A", "B"]);
129
130 assert_eq!(merged.field_source("x"), Some("A"));
132 assert_eq!(merged.field_source("y"), Some("B"));
133 }
134
135 #[test]
136 fn test_intersection_field_collision() {
137 let schema_a = TypeSchema::new("A", vec![("x".to_string(), FieldType::F64)]);
139
140 let schema_b = TypeSchema::new("B", vec![("x".to_string(), FieldType::String)]);
142
143 let result = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]);
145 assert!(result.is_err());
146
147 match result {
148 Err(SchemaError::FieldCollision {
149 field_name,
150 type1,
151 type2,
152 }) => {
153 assert_eq!(field_name, "x");
154 assert_eq!(type1, "A");
155 assert_eq!(type2, "B");
156 }
157 _ => panic!("Expected FieldCollision error"),
158 }
159 }
160
161 #[test]
162 fn test_intersection_three_types() {
163 let schema_a = TypeSchema::new("A", vec![("a".to_string(), FieldType::F64)]);
165 let schema_b = TypeSchema::new("B", vec![("b".to_string(), FieldType::String)]);
167 let schema_c = TypeSchema::new("C", vec![("c".to_string(), FieldType::Bool)]);
169
170 let merged = TypeSchema::from_intersection("ABC", &[&schema_a, &schema_b, &schema_c])
172 .expect("Should merge three types");
173
174 assert_eq!(merged.field_count(), 3);
175 assert!(merged.has_field("a"));
176 assert!(merged.has_field("b"));
177 assert!(merged.has_field("c"));
178
179 let components = merged.get_component_types().unwrap();
180 assert_eq!(components, &["A", "B", "C"]);
181 }
182
183 #[test]
184 fn test_intersection_fields_for_component() {
185 let schema_a = TypeSchema::new(
187 "A",
188 vec![
189 ("x".to_string(), FieldType::F64),
190 ("y".to_string(), FieldType::F64),
191 ],
192 );
193
194 let schema_b = TypeSchema::new("B", vec![("z".to_string(), FieldType::String)]);
196
197 let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]).unwrap();
198
199 let a_fields = merged.fields_for_component("A");
201 assert_eq!(a_fields.len(), 2);
202 assert!(a_fields.iter().any(|f| f.name == "x"));
203 assert!(a_fields.iter().any(|f| f.name == "y"));
204
205 let b_fields = merged.fields_for_component("B");
206 assert_eq!(b_fields.len(), 1);
207 assert!(b_fields.iter().any(|f| f.name == "z"));
208 }
209
210 #[test]
211 fn test_intersection_data_size() {
212 let schema_a = TypeSchema::new(
213 "A",
214 vec![
215 ("a1".to_string(), FieldType::F64),
216 ("a2".to_string(), FieldType::I64),
217 ],
218 );
219
220 let schema_b = TypeSchema::new("B", vec![("b1".to_string(), FieldType::Bool)]);
221
222 let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]).unwrap();
223
224 assert_eq!(merged.data_size, 24);
226
227 assert_eq!(merged.field_offset("a1"), Some(0));
229 assert_eq!(merged.field_offset("a2"), Some(8));
230 assert_eq!(merged.field_offset("b1"), Some(16));
231 }
232}