shape_runtime/type_schema/
intersection.rs1use super::SchemaError;
7use super::field_types::{FieldDef, FieldType};
8use super::schema::TypeSchema;
9use std::collections::HashMap;
10
11impl TypeSchema {
12 pub fn from_intersection(
15 name: impl Into<String>,
16 schemas: &[&TypeSchema],
17 ) -> Result<Self, SchemaError> {
18 let name = name.into();
19
20 let mut seen_fields: HashMap<&str, &str> = HashMap::new();
22 for schema in schemas {
23 for field in &schema.fields {
24 if let Some(existing_type) = seen_fields.get(field.name.as_str()) {
25 return Err(SchemaError::FieldCollision {
26 field_name: field.name.clone(),
27 type1: existing_type.to_string(),
28 type2: schema.name.clone(),
29 });
30 }
31 seen_fields.insert(&field.name, &schema.name);
32 }
33 }
34
35 let mut all_fields: Vec<(String, FieldType)> = Vec::new();
37 let mut field_sources: HashMap<String, String> = HashMap::new();
38 let mut component_types: Vec<String> = Vec::new();
39
40 for schema in schemas {
41 component_types.push(schema.name.clone());
42 for field in &schema.fields {
43 all_fields.push((field.name.clone(), field.field_type.clone()));
44 field_sources.insert(field.name.clone(), schema.name.clone());
45 }
46 }
47
48 let id = super::next_schema_id();
50 let mut fields = Vec::with_capacity(all_fields.len());
51 let mut field_map = HashMap::with_capacity(all_fields.len());
52 let mut offset = 0;
53
54 for (index, (field_name, field_type)) in all_fields.into_iter().enumerate() {
55 let alignment = field_type.alignment();
56 offset = (offset + alignment - 1) & !(alignment - 1);
57
58 let field = FieldDef::new(&field_name, field_type.clone(), offset, index as u16);
59 field_map.insert(field_name, index);
60 offset += field_type.size();
61 fields.push(field);
62 }
63
64 let data_size = (offset + 7) & !7;
65
66 Ok(Self {
67 id,
68 name,
69 fields,
70 field_map,
71 data_size,
72 component_types: Some(component_types),
73 field_sources,
74 enum_info: None,
75 content_hash: None,
76 })
77 }
78
79 pub fn is_intersection(&self) -> bool {
81 self.component_types.is_some()
82 }
83
84 pub fn get_component_types(&self) -> Option<&[String]> {
86 self.component_types.as_deref()
87 }
88
89 pub fn field_source(&self, field_name: &str) -> Option<&str> {
91 self.field_sources.get(field_name).map(|s| s.as_str())
92 }
93
94 pub fn fields_for_component(&self, component_name: &str) -> Vec<&FieldDef> {
96 self.fields
97 .iter()
98 .filter(|f| self.field_sources.get(&f.name).map(|s| s.as_str()) == Some(component_name))
99 .collect()
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn test_intersection_merge_success() {
109 let schema_a = TypeSchema::new("A", vec![("x".to_string(), FieldType::F64)]);
111
112 let schema_b = TypeSchema::new("B", vec![("y".to_string(), FieldType::String)]);
114
115 let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b])
117 .expect("Should merge without collision");
118
119 assert_eq!(merged.name, "AB");
120 assert_eq!(merged.field_count(), 2);
121 assert!(merged.has_field("x"));
122 assert!(merged.has_field("y"));
123 assert!(merged.is_intersection());
124
125 let components = merged.get_component_types().unwrap();
127 assert_eq!(components, &["A", "B"]);
128
129 assert_eq!(merged.field_source("x"), Some("A"));
131 assert_eq!(merged.field_source("y"), Some("B"));
132 }
133
134 #[test]
135 fn test_intersection_field_collision() {
136 let schema_a = TypeSchema::new("A", vec![("x".to_string(), FieldType::F64)]);
138
139 let schema_b = TypeSchema::new("B", vec![("x".to_string(), FieldType::String)]);
141
142 let result = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]);
144 assert!(result.is_err());
145
146 match result {
147 Err(SchemaError::FieldCollision {
148 field_name,
149 type1,
150 type2,
151 }) => {
152 assert_eq!(field_name, "x");
153 assert_eq!(type1, "A");
154 assert_eq!(type2, "B");
155 }
156 _ => panic!("Expected FieldCollision error"),
157 }
158 }
159
160 #[test]
161 fn test_intersection_three_types() {
162 let schema_a = TypeSchema::new("A", vec![("a".to_string(), FieldType::F64)]);
164 let schema_b = TypeSchema::new("B", vec![("b".to_string(), FieldType::String)]);
166 let schema_c = TypeSchema::new("C", vec![("c".to_string(), FieldType::Bool)]);
168
169 let merged = TypeSchema::from_intersection("ABC", &[&schema_a, &schema_b, &schema_c])
171 .expect("Should merge three types");
172
173 assert_eq!(merged.field_count(), 3);
174 assert!(merged.has_field("a"));
175 assert!(merged.has_field("b"));
176 assert!(merged.has_field("c"));
177
178 let components = merged.get_component_types().unwrap();
179 assert_eq!(components, &["A", "B", "C"]);
180 }
181
182 #[test]
183 fn test_intersection_fields_for_component() {
184 let schema_a = TypeSchema::new(
186 "A",
187 vec![
188 ("x".to_string(), FieldType::F64),
189 ("y".to_string(), FieldType::F64),
190 ],
191 );
192
193 let schema_b = TypeSchema::new("B", vec![("z".to_string(), FieldType::String)]);
195
196 let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]).unwrap();
197
198 let a_fields = merged.fields_for_component("A");
200 assert_eq!(a_fields.len(), 2);
201 assert!(a_fields.iter().any(|f| f.name == "x"));
202 assert!(a_fields.iter().any(|f| f.name == "y"));
203
204 let b_fields = merged.fields_for_component("B");
205 assert_eq!(b_fields.len(), 1);
206 assert!(b_fields.iter().any(|f| f.name == "z"));
207 }
208
209 #[test]
210 fn test_intersection_data_size() {
211 let schema_a = TypeSchema::new(
212 "A",
213 vec![
214 ("a1".to_string(), FieldType::F64),
215 ("a2".to_string(), FieldType::I64),
216 ],
217 );
218
219 let schema_b = TypeSchema::new("B", vec![("b1".to_string(), FieldType::Bool)]);
220
221 let merged = TypeSchema::from_intersection("AB", &[&schema_a, &schema_b]).unwrap();
222
223 assert_eq!(merged.data_size, 24);
225
226 assert_eq!(merged.field_offset("a1"), Some(0));
228 assert_eq!(merged.field_offset("a2"), Some(8));
229 assert_eq!(merged.field_offset("b1"), Some(16));
230 }
231}