1use super::SchemaId;
7use super::enum_support::{EnumInfo, EnumVariantInfo};
8use super::field_types::{FieldDef, FieldType, semantic_to_field_type};
9use arrow_schema::{DataType, Schema as ArrowSchema};
10use sha2::{Digest, Sha256};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct TypeSchema {
16 pub id: SchemaId,
18 pub name: String,
20 pub fields: Vec<FieldDef>,
22 pub(crate) field_map: HashMap<String, usize>,
24 pub data_size: usize,
26 pub component_types: Option<Vec<String>>,
29 pub(crate) field_sources: HashMap<String, String>,
31 pub enum_info: Option<EnumInfo>,
33 #[serde(skip)]
36 pub content_hash: Option<[u8; 32]>,
37}
38
39impl TypeSchema {
40 pub fn new(name: impl Into<String>, field_defs: Vec<(String, FieldType)>) -> Self {
42 let id = super::next_schema_id();
43 let name = name.into();
44
45 let mut fields = Vec::with_capacity(field_defs.len());
46 let mut field_map = HashMap::with_capacity(field_defs.len());
47 let mut offset = 0;
48
49 for (index, (field_name, field_type)) in field_defs.into_iter().enumerate() {
50 let alignment = field_type.alignment();
52 offset = (offset + alignment - 1) & !(alignment - 1);
53
54 let field = FieldDef::new(&field_name, field_type.clone(), offset, index as u16);
55 field_map.insert(field_name, index);
56 offset += field_type.size();
57 fields.push(field);
58 }
59
60 let data_size = (offset + 7) & !7;
62
63 Self {
64 id,
65 name,
66 fields,
67 field_map,
68 data_size,
69 component_types: None,
70 field_sources: HashMap::new(),
71 enum_info: None,
72 content_hash: None,
73 }
74 }
75
76 pub fn get_field(&self, name: &str) -> Option<&FieldDef> {
78 self.field_map.get(name).map(|&idx| &self.fields[idx])
79 }
80
81 pub fn field_offset(&self, name: &str) -> Option<usize> {
83 self.get_field(name).map(|f| f.offset)
84 }
85
86 pub fn field_index(&self, name: &str) -> Option<u16> {
88 self.get_field(name).map(|f| f.index)
89 }
90
91 pub fn field_by_index(&self, index: u16) -> Option<&FieldDef> {
93 self.fields.get(index as usize)
94 }
95
96 pub fn field_count(&self) -> usize {
98 self.fields.len()
99 }
100
101 pub fn has_field(&self, name: &str) -> bool {
103 self.field_map.contains_key(name)
104 }
105
106 pub fn field_names(&self) -> impl Iterator<Item = &str> {
108 self.fields.iter().map(|f| f.name.as_str())
109 }
110
111 pub fn is_enum(&self) -> bool {
113 self.enum_info.is_some()
114 }
115
116 pub fn get_enum_info(&self) -> Option<&EnumInfo> {
118 self.enum_info.as_ref()
119 }
120
121 pub fn variant_id(&self, variant_name: &str) -> Option<u16> {
123 self.enum_info.as_ref()?.variant_id(variant_name)
124 }
125
126 pub fn new_enum(name: impl Into<String>, variants: Vec<EnumVariantInfo>) -> Self {
132 let id = super::next_schema_id();
133 let name = name.into();
134 let enum_info = EnumInfo::new(variants);
135 let max_payload = enum_info.max_payload_fields();
136
137 let mut fields = Vec::with_capacity(1 + max_payload as usize);
139 let mut field_map = HashMap::with_capacity(1 + max_payload as usize);
140
141 fields.push(FieldDef::new("__variant", FieldType::I64, 0, 0));
143 field_map.insert("__variant".to_string(), 0);
144
145 for i in 0..max_payload {
147 let field_name = format!("__payload_{}", i);
148 let offset = 8 + (i as usize * 8);
149 fields.push(FieldDef::new(&field_name, FieldType::Any, offset, i + 1));
150 field_map.insert(field_name, i as usize + 1);
151 }
152
153 let data_size = 8 + (max_payload as usize * 8);
154
155 Self {
156 id,
157 name,
158 fields,
159 field_map,
160 data_size,
161 component_types: None,
162 field_sources: HashMap::new(),
163 enum_info: Some(enum_info),
164 content_hash: None,
165 }
166 }
167
168 pub fn compute_content_hash(&self) -> [u8; 32] {
178 let mut hasher = Sha256::new();
179
180 hasher.update(b"name:");
182 hasher.update(self.name.as_bytes());
183
184 let mut sorted_fields: Vec<&FieldDef> = self.fields.iter().collect();
186 sorted_fields.sort_by(|a, b| a.name.cmp(&b.name));
187
188 hasher.update(b"|fields:");
189 for field in &sorted_fields {
190 hasher.update(b"(");
191 hasher.update(field.name.as_bytes());
192 hasher.update(b":");
193 hasher.update(field.field_type.to_string().as_bytes());
194 hasher.update(b")");
195 }
196
197 if let Some(enum_info) = &self.enum_info {
199 let mut sorted_variants: Vec<&super::enum_support::EnumVariantInfo> =
200 enum_info.variants.iter().collect();
201 sorted_variants.sort_by(|a, b| a.name.cmp(&b.name));
202
203 hasher.update(b"|variants:");
204 for variant in &sorted_variants {
205 hasher.update(b"(");
206 hasher.update(variant.name.as_bytes());
207 hasher.update(b":");
208 hasher.update(variant.payload_fields.to_string().as_bytes());
209 hasher.update(b")");
210 }
211 }
212
213 let result = hasher.finalize();
214 let mut hash = [0u8; 32];
215 hash.copy_from_slice(&result);
216 hash
217 }
218
219 pub fn content_hash(&mut self) -> [u8; 32] {
221 if let Some(hash) = self.content_hash {
222 return hash;
223 }
224 let hash = self.compute_content_hash();
225 self.content_hash = Some(hash);
226 hash
227 }
228
229 pub fn bind_to_arrow_schema(
234 &self,
235 arrow_schema: &ArrowSchema,
236 ) -> Result<TypeBinding, TypeBindingError> {
237 let mut field_to_column = Vec::with_capacity(self.fields.len());
238
239 for field in &self.fields {
240 if field.name.starts_with("__") {
242 field_to_column.push(0); continue;
244 }
245
246 let col_name = field.wire_name();
247 let col_idx =
248 arrow_schema
249 .index_of(col_name)
250 .map_err(|_| TypeBindingError::MissingColumn {
251 field_name: col_name.to_string(),
252 type_name: self.name.clone(),
253 })?;
254
255 let arrow_field = &arrow_schema.fields()[col_idx];
256 if !is_compatible(&field.field_type, arrow_field.data_type()) {
257 return Err(TypeBindingError::TypeMismatch {
258 field_name: field.name.clone(),
259 expected: format!("{:?}", field.field_type),
260 actual: format!("{:?}", arrow_field.data_type()),
261 });
262 }
263
264 field_to_column.push(col_idx);
265 }
266
267 Ok(TypeBinding {
268 schema_name: self.name.clone(),
269 field_to_column,
270 })
271 }
272
273 pub fn from_canonical(canonical: &crate::type_system::environment::CanonicalType) -> Self {
278 let id = super::next_schema_id();
279 let name = canonical.name.clone();
280
281 let mut fields = Vec::with_capacity(canonical.fields.len());
282 let mut field_map = HashMap::with_capacity(canonical.fields.len());
283
284 for (index, cf) in canonical.fields.iter().enumerate() {
285 let field_type = semantic_to_field_type(&cf.field_type, cf.optional);
287
288 let field = FieldDef::new(&cf.name, field_type, cf.offset, index as u16);
289 field_map.insert(cf.name.clone(), index);
290 fields.push(field);
291 }
292
293 Self {
294 id,
295 name,
296 fields,
297 field_map,
298 data_size: canonical.data_size,
299 component_types: None,
300 field_sources: HashMap::new(),
301 enum_info: None,
302 content_hash: None,
303 }
304 }
305}
306
307#[derive(Debug, Clone)]
312pub struct TypeBinding {
313 pub schema_name: String,
315 pub field_to_column: Vec<usize>,
317}
318
319impl TypeBinding {
320 pub fn column_index(&self, field_index: usize) -> Option<usize> {
322 self.field_to_column.get(field_index).copied()
323 }
324}
325
326#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
328pub enum TypeBindingError {
329 #[error("Type '{type_name}' requires column '{field_name}' which is not in the DataTable")]
331 MissingColumn {
332 field_name: String,
333 type_name: String,
334 },
335 #[error("Column '{field_name}' has type {actual} but expected {expected}")]
337 TypeMismatch {
338 field_name: String,
339 expected: String,
340 actual: String,
341 },
342}
343
344fn is_compatible(field_type: &FieldType, arrow_type: &DataType) -> bool {
346 match (field_type, arrow_type) {
347 (FieldType::F64, DataType::Float64) => true,
348 (FieldType::F64, DataType::Float32) => true, (FieldType::F64, DataType::Int64) => true, (FieldType::I64, DataType::Int64) => true,
351 (FieldType::I64, DataType::Int32) => true, (FieldType::Bool, DataType::Boolean) => true,
353 (FieldType::String, DataType::Utf8) => true,
354 (FieldType::String, DataType::LargeUtf8) => true,
355 (FieldType::Timestamp, DataType::Timestamp(_, _)) => true,
356 (FieldType::Timestamp, DataType::Int64) => true, (FieldType::Decimal, DataType::Float64) => true, (FieldType::Decimal, DataType::Int64) => true, (FieldType::Any, _) => true, _ => false,
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_type_schema_creation() {
370 let schema = TypeSchema::new(
371 "TestType",
372 vec![
373 ("a".to_string(), FieldType::F64),
374 ("b".to_string(), FieldType::I64),
375 ("c".to_string(), FieldType::String),
376 ],
377 );
378
379 assert_eq!(schema.name, "TestType");
380 assert_eq!(schema.field_count(), 3);
381 assert_eq!(schema.data_size, 24); }
383
384 #[test]
385 fn test_field_offsets() {
386 let schema = TypeSchema::new(
387 "OffsetTest",
388 vec![
389 ("first".to_string(), FieldType::F64),
390 ("second".to_string(), FieldType::I64),
391 ("third".to_string(), FieldType::Bool),
392 ],
393 );
394
395 assert_eq!(schema.field_offset("first"), Some(0));
396 assert_eq!(schema.field_offset("second"), Some(8));
397 assert_eq!(schema.field_offset("third"), Some(16));
398 assert_eq!(schema.field_offset("nonexistent"), None);
399 }
400
401 #[test]
402 fn test_field_index() {
403 let schema = TypeSchema::new(
404 "IndexTest",
405 vec![
406 ("a".to_string(), FieldType::F64),
407 ("b".to_string(), FieldType::F64),
408 ("c".to_string(), FieldType::F64),
409 ],
410 );
411
412 assert_eq!(schema.field_index("a"), Some(0));
413 assert_eq!(schema.field_index("b"), Some(1));
414 assert_eq!(schema.field_index("c"), Some(2));
415 }
416
417 #[test]
418 fn test_unique_schema_ids() {
419 let schema1 = TypeSchema::new("Type1", vec![]);
420 let schema2 = TypeSchema::new("Type2", vec![]);
421 let schema3 = TypeSchema::new("Type3", vec![]);
422
423 assert_ne!(schema1.id, schema2.id);
425 assert_ne!(schema2.id, schema3.id);
426 assert_ne!(schema1.id, schema3.id);
427 }
428
429 #[test]
434 fn test_enum_schema_creation() {
435 let schema = TypeSchema::new_enum(
436 "Option",
437 vec![
438 EnumVariantInfo::new("Some", 0, 1),
439 EnumVariantInfo::new("None", 1, 0),
440 ],
441 );
442
443 assert_eq!(schema.name, "Option");
444 assert!(schema.is_enum());
445
446 let enum_info = schema.get_enum_info().unwrap();
448 assert_eq!(enum_info.variants.len(), 2);
449 assert_eq!(enum_info.variant_id("Some"), Some(0));
450 assert_eq!(enum_info.variant_id("None"), Some(1));
451 assert_eq!(enum_info.max_payload_fields(), 1);
452 }
453
454 #[test]
455 fn test_enum_schema_layout() {
456 let schema = TypeSchema::new_enum(
457 "Result",
458 vec![
459 EnumVariantInfo::new("Ok", 0, 1),
460 EnumVariantInfo::new("Err", 1, 1),
461 ],
462 );
463
464 assert_eq!(schema.data_size, 16);
466 assert_eq!(schema.field_count(), 2);
467
468 assert_eq!(schema.field_offset("__variant"), Some(0));
470 assert_eq!(schema.field_offset("__payload_0"), Some(8));
471 }
472
473 #[test]
474 fn test_enum_schema_multiple_payloads() {
475 let schema = TypeSchema::new_enum(
477 "Shape",
478 vec![
479 EnumVariantInfo::new("Circle", 0, 1), EnumVariantInfo::new("Rectangle", 1, 2), EnumVariantInfo::new("Point", 2, 0), ],
483 );
484
485 assert_eq!(schema.data_size, 24);
488 assert_eq!(schema.field_count(), 3);
489
490 assert_eq!(schema.field_offset("__variant"), Some(0));
491 assert_eq!(schema.field_offset("__payload_0"), Some(8));
492 assert_eq!(schema.field_offset("__payload_1"), Some(16));
493 }
494
495 #[test]
496 fn test_enum_variant_lookup() {
497 let schema = TypeSchema::new_enum(
498 "Status",
499 vec![
500 EnumVariantInfo::new("Pending", 0, 0),
501 EnumVariantInfo::new("Running", 1, 1),
502 EnumVariantInfo::new("Complete", 2, 1),
503 EnumVariantInfo::new("Failed", 3, 1),
504 ],
505 );
506
507 let enum_info = schema.get_enum_info().unwrap();
508
509 let running = enum_info.variant_by_id(1).unwrap();
511 assert_eq!(running.name, "Running");
512 assert_eq!(running.payload_fields, 1);
513
514 let complete = enum_info.variant_by_name("Complete").unwrap();
516 assert_eq!(complete.id, 2);
517
518 assert!(enum_info.variant_by_id(99).is_none());
520 assert!(enum_info.variant_by_name("Unknown").is_none());
521 }
522
523 #[test]
528 fn test_bind_to_arrow_schema_success() {
529 use arrow_schema::{Field, Schema as ArrowSchema};
530
531 let type_schema = TypeSchema::new(
532 "Candle",
533 vec![
534 ("open".to_string(), FieldType::F64),
535 ("close".to_string(), FieldType::F64),
536 ("volume".to_string(), FieldType::I64),
537 ],
538 );
539
540 let arrow_schema = ArrowSchema::new(vec![
541 Field::new("date", DataType::Utf8, false),
542 Field::new("open", DataType::Float64, false),
543 Field::new("close", DataType::Float64, false),
544 Field::new("volume", DataType::Int64, false),
545 ]);
546
547 let binding = type_schema.bind_to_arrow_schema(&arrow_schema).unwrap();
548 assert_eq!(binding.schema_name, "Candle");
549 assert_eq!(binding.column_index(0), Some(1));
551 assert_eq!(binding.column_index(1), Some(2));
553 assert_eq!(binding.column_index(2), Some(3));
555 }
556
557 #[test]
558 fn test_bind_missing_column() {
559 use arrow_schema::{Field, Schema as ArrowSchema};
560
561 let type_schema = TypeSchema::new(
562 "Candle",
563 vec![
564 ("open".to_string(), FieldType::F64),
565 ("missing_field".to_string(), FieldType::F64),
566 ],
567 );
568
569 let arrow_schema = ArrowSchema::new(vec![Field::new("open", DataType::Float64, false)]);
570
571 let err = type_schema.bind_to_arrow_schema(&arrow_schema).unwrap_err();
572 assert!(matches!(err, TypeBindingError::MissingColumn { .. }));
573 }
574
575 #[test]
576 fn test_bind_type_mismatch() {
577 use arrow_schema::{Field, Schema as ArrowSchema};
578
579 let type_schema = TypeSchema::new("Test", vec![("name".to_string(), FieldType::F64)]);
580
581 let arrow_schema = ArrowSchema::new(vec![
582 Field::new("name", DataType::Utf8, false), ]);
584
585 let err = type_schema.bind_to_arrow_schema(&arrow_schema).unwrap_err();
586 assert!(matches!(err, TypeBindingError::TypeMismatch { .. }));
587 }
588
589 #[test]
590 fn test_bind_compatible_types() {
591 use arrow_schema::{Field, Schema as ArrowSchema, TimeUnit};
592
593 let type_schema = TypeSchema::new(
595 "Wide",
596 vec![
597 ("f32_as_f64".to_string(), FieldType::F64),
598 ("i32_as_i64".to_string(), FieldType::I64),
599 ("ts".to_string(), FieldType::Timestamp),
600 ("any_field".to_string(), FieldType::Any),
601 ],
602 );
603
604 let arrow_schema = ArrowSchema::new(vec![
605 Field::new("f32_as_f64", DataType::Float32, false),
606 Field::new("i32_as_i64", DataType::Int32, false),
607 Field::new(
608 "ts",
609 DataType::Timestamp(TimeUnit::Microsecond, None),
610 false,
611 ),
612 Field::new("any_field", DataType::Boolean, false),
613 ]);
614
615 let binding = type_schema.bind_to_arrow_schema(&arrow_schema);
616 assert!(binding.is_ok());
617 }
618}