Skip to main content

shape_runtime/type_schema/
schema.rs

1//! Core TypeSchema struct and methods
2//!
3//! This module defines the TypeSchema structure that describes the memory layout
4//! of a declared type, with computed field offsets for JIT optimization.
5
6use super::SchemaId;
7use super::enum_support::{EnumInfo, EnumVariantInfo};
8use super::field_types::{FieldDef, FieldType, semantic_to_field_type};
9use arrow_schema::{DataType, Schema as ArrowSchema};
10use sha2::{Digest, Sha256};
11use std::collections::HashMap;
12
13/// Schema describing the memory layout of a declared type
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct TypeSchema {
16    /// Unique schema identifier
17    pub id: SchemaId,
18    /// Type name (e.g., "Candle", "Trade")
19    pub name: String,
20    /// Field definitions with computed offsets
21    pub fields: Vec<FieldDef>,
22    /// Field lookup by name
23    pub(crate) field_map: HashMap<String, usize>,
24    /// Total size of the object data in bytes (excluding header)
25    pub data_size: usize,
26    /// Component types (for intersection types, tracks which types were merged)
27    /// Maps field name to the source type name for decomposition
28    pub component_types: Option<Vec<String>>,
29    /// Maps each field to its source component type (for decomposition)
30    pub(crate) field_sources: HashMap<String, String>,
31    /// Enum-specific information (if this is an enum type)
32    pub enum_info: Option<EnumInfo>,
33    /// Content hash (SHA-256) derived from structural definition.
34    /// Computed lazily and cached. Skipped during serialization since it is derived.
35    #[serde(skip)]
36    pub content_hash: Option<[u8; 32]>,
37}
38
39impl TypeSchema {
40    /// Create a new type schema with the given fields
41    pub fn new(name: impl Into<String>, field_defs: Vec<(String, FieldType)>) -> Self {
42        let id = super::next_schema_id();
43        let name = name.into();
44
45        let mut fields = Vec::with_capacity(field_defs.len());
46        let mut field_map = HashMap::with_capacity(field_defs.len());
47        let mut offset = 0;
48
49        for (index, (field_name, field_type)) in field_defs.into_iter().enumerate() {
50            // Align offset to field's alignment requirement
51            let alignment = field_type.alignment();
52            offset = (offset + alignment - 1) & !(alignment - 1);
53
54            let field = FieldDef::new(&field_name, field_type.clone(), offset, index as u16);
55            field_map.insert(field_name, index);
56            offset += field_type.size();
57            fields.push(field);
58        }
59
60        // Round up total size to 8-byte alignment
61        let data_size = (offset + 7) & !7;
62
63        Self {
64            id,
65            name,
66            fields,
67            field_map,
68            data_size,
69            component_types: None,
70            field_sources: HashMap::new(),
71            enum_info: None,
72            content_hash: None,
73        }
74    }
75
76    /// Get field definition by name
77    pub fn get_field(&self, name: &str) -> Option<&FieldDef> {
78        self.field_map.get(name).map(|&idx| &self.fields[idx])
79    }
80
81    /// Get field offset by name (returns None if field doesn't exist)
82    pub fn field_offset(&self, name: &str) -> Option<usize> {
83        self.get_field(name).map(|f| f.offset)
84    }
85
86    /// Get field index by name
87    pub fn field_index(&self, name: &str) -> Option<u16> {
88        self.get_field(name).map(|f| f.index)
89    }
90
91    /// Get field by index
92    pub fn field_by_index(&self, index: u16) -> Option<&FieldDef> {
93        self.fields.get(index as usize)
94    }
95
96    /// Number of fields in this schema
97    pub fn field_count(&self) -> usize {
98        self.fields.len()
99    }
100
101    /// Check if schema has a field with the given name
102    pub fn has_field(&self, name: &str) -> bool {
103        self.field_map.contains_key(name)
104    }
105
106    /// Iterator over field names
107    pub fn field_names(&self) -> impl Iterator<Item = &str> {
108        self.fields.iter().map(|f| f.name.as_str())
109    }
110
111    /// Check if this schema is for an enum type
112    pub fn is_enum(&self) -> bool {
113        self.enum_info.is_some()
114    }
115
116    /// Get enum info if this is an enum type
117    pub fn get_enum_info(&self) -> Option<&EnumInfo> {
118        self.enum_info.as_ref()
119    }
120
121    /// Get variant ID by name (for enum types)
122    pub fn variant_id(&self, variant_name: &str) -> Option<u16> {
123        self.enum_info.as_ref()?.variant_id(variant_name)
124    }
125
126    /// Create an enum schema with variant information
127    ///
128    /// Layout:
129    /// - Field 0: __variant (I64) - variant discriminator at offset 0
130    /// - Field 1+: __payload_N (Any) - payload fields at offset 8, 16, etc.
131    pub fn new_enum(name: impl Into<String>, variants: Vec<EnumVariantInfo>) -> Self {
132        let id = super::next_schema_id();
133        let name = name.into();
134        let enum_info = EnumInfo::new(variants);
135        let max_payload = enum_info.max_payload_fields();
136
137        // Build fields: __variant + __payload_0..N
138        let mut fields = Vec::with_capacity(1 + max_payload as usize);
139        let mut field_map = HashMap::with_capacity(1 + max_payload as usize);
140
141        // Variant discriminator at offset 0
142        fields.push(FieldDef::new("__variant", FieldType::I64, 0, 0));
143        field_map.insert("__variant".to_string(), 0);
144
145        // Payload fields at offsets 8, 16, etc.
146        for i in 0..max_payload {
147            let field_name = format!("__payload_{}", i);
148            let offset = 8 + (i as usize * 8);
149            fields.push(FieldDef::new(&field_name, FieldType::Any, offset, i + 1));
150            field_map.insert(field_name, i as usize + 1);
151        }
152
153        let data_size = 8 + (max_payload as usize * 8);
154
155        Self {
156            id,
157            name,
158            fields,
159            field_map,
160            data_size,
161            component_types: None,
162            field_sources: HashMap::new(),
163            enum_info: Some(enum_info),
164            content_hash: None,
165        }
166    }
167
168    /// Compute the content hash (SHA-256) from the structural definition.
169    ///
170    /// The hash is derived deterministically from:
171    /// - The type name
172    /// - Fields sorted by name, each contributing field name + field type string
173    /// - Enum variant info (if present), sorted by variant name
174    ///
175    /// For recursive type references (`Object("Foo")`), only the type name is
176    /// hashed to avoid infinite recursion.
177    pub fn compute_content_hash(&self) -> [u8; 32] {
178        let mut hasher = Sha256::new();
179
180        // Hash the type name
181        hasher.update(b"name:");
182        hasher.update(self.name.as_bytes());
183
184        // Hash fields in deterministic order (sorted by name)
185        let mut sorted_fields: Vec<&FieldDef> = self.fields.iter().collect();
186        sorted_fields.sort_by(|a, b| a.name.cmp(&b.name));
187
188        hasher.update(b"|fields:");
189        for field in &sorted_fields {
190            hasher.update(b"(");
191            hasher.update(field.name.as_bytes());
192            hasher.update(b":");
193            hasher.update(field.field_type.to_string().as_bytes());
194            hasher.update(b")");
195        }
196
197        // Hash enum variant info if present
198        if let Some(enum_info) = &self.enum_info {
199            let mut sorted_variants: Vec<&super::enum_support::EnumVariantInfo> =
200                enum_info.variants.iter().collect();
201            sorted_variants.sort_by(|a, b| a.name.cmp(&b.name));
202
203            hasher.update(b"|variants:");
204            for variant in &sorted_variants {
205                hasher.update(b"(");
206                hasher.update(variant.name.as_bytes());
207                hasher.update(b":");
208                hasher.update(variant.payload_fields.to_string().as_bytes());
209                hasher.update(b")");
210            }
211        }
212
213        let result = hasher.finalize();
214        let mut hash = [0u8; 32];
215        hash.copy_from_slice(&result);
216        hash
217    }
218
219    /// Return the cached content hash, computing and caching it if needed.
220    pub fn content_hash(&mut self) -> [u8; 32] {
221        if let Some(hash) = self.content_hash {
222            return hash;
223        }
224        let hash = self.compute_content_hash();
225        self.content_hash = Some(hash);
226        hash
227    }
228
229    /// Bind this TypeSchema to an Arrow schema, producing a TypeBinding.
230    ///
231    /// Validates that every field in the TypeSchema has a compatible column in the
232    /// Arrow schema. Returns a mapping from TypeSchema field index → Arrow column index.
233    pub fn bind_to_arrow_schema(
234        &self,
235        arrow_schema: &ArrowSchema,
236    ) -> Result<TypeBinding, TypeBindingError> {
237        let mut field_to_column = Vec::with_capacity(self.fields.len());
238
239        for field in &self.fields {
240            // Skip internal enum fields
241            if field.name.starts_with("__") {
242                field_to_column.push(0); // placeholder
243                continue;
244            }
245
246            let col_name = field.wire_name();
247            let col_idx =
248                arrow_schema
249                    .index_of(col_name)
250                    .map_err(|_| TypeBindingError::MissingColumn {
251                        field_name: col_name.to_string(),
252                        type_name: self.name.clone(),
253                    })?;
254
255            let arrow_field = &arrow_schema.fields()[col_idx];
256            if !is_compatible(&field.field_type, arrow_field.data_type()) {
257                return Err(TypeBindingError::TypeMismatch {
258                    field_name: field.name.clone(),
259                    expected: format!("{:?}", field.field_type),
260                    actual: format!("{:?}", arrow_field.data_type()),
261                });
262            }
263
264            field_to_column.push(col_idx);
265        }
266
267        Ok(TypeBinding {
268            schema_name: self.name.clone(),
269            field_to_column,
270        })
271    }
272
273    /// Create a type schema from a canonical type (for evolved types)
274    ///
275    /// This converts the semantic CanonicalType representation into a JIT-ready
276    /// TypeSchema with proper field offsets and types.
277    pub fn from_canonical(canonical: &crate::type_system::environment::CanonicalType) -> Self {
278        let id = super::next_schema_id();
279        let name = canonical.name.clone();
280
281        let mut fields = Vec::with_capacity(canonical.fields.len());
282        let mut field_map = HashMap::with_capacity(canonical.fields.len());
283
284        for (index, cf) in canonical.fields.iter().enumerate() {
285            // Convert SemanticType to FieldType
286            let field_type = semantic_to_field_type(&cf.field_type, cf.optional);
287
288            let field = FieldDef::new(&cf.name, field_type, cf.offset, index as u16);
289            field_map.insert(cf.name.clone(), index);
290            fields.push(field);
291        }
292
293        Self {
294            id,
295            name,
296            fields,
297            field_map,
298            data_size: canonical.data_size,
299            component_types: None,
300            field_sources: HashMap::new(),
301            enum_info: None,
302            content_hash: None,
303        }
304    }
305}
306
307/// Mapping from TypeSchema field indices to Arrow column indices.
308///
309/// Used for O(1) field→column resolution when accessing DataTable columns
310/// through a typed view.
311#[derive(Debug, Clone)]
312pub struct TypeBinding {
313    /// The type name this binding is for.
314    pub schema_name: String,
315    /// Maps TypeSchema field index → Arrow column index.
316    pub field_to_column: Vec<usize>,
317}
318
319impl TypeBinding {
320    /// Get the Arrow column index for a given TypeSchema field index.
321    pub fn column_index(&self, field_index: usize) -> Option<usize> {
322        self.field_to_column.get(field_index).copied()
323    }
324}
325
326/// Error during type binding to Arrow schema.
327#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
328pub enum TypeBindingError {
329    /// Arrow schema is missing a column required by the TypeSchema.
330    #[error("Type '{type_name}' requires column '{field_name}' which is not in the DataTable")]
331    MissingColumn {
332        field_name: String,
333        type_name: String,
334    },
335    /// Arrow column type is incompatible with the TypeSchema field type.
336    #[error("Column '{field_name}' has type {actual} but expected {expected}")]
337    TypeMismatch {
338        field_name: String,
339        expected: String,
340        actual: String,
341    },
342}
343
344/// Check if a Shape FieldType is compatible with an Arrow DataType.
345fn is_compatible(field_type: &FieldType, arrow_type: &DataType) -> bool {
346    match (field_type, arrow_type) {
347        (FieldType::F64, DataType::Float64) => true,
348        (FieldType::F64, DataType::Float32) => true, // widening is ok
349        (FieldType::F64, DataType::Int64) => true,   // numeric promotion
350        (FieldType::I64, DataType::Int64) => true,
351        (FieldType::I64, DataType::Int32) => true, // widening is ok
352        (FieldType::Bool, DataType::Boolean) => true,
353        (FieldType::String, DataType::Utf8) => true,
354        (FieldType::String, DataType::LargeUtf8) => true,
355        (FieldType::Timestamp, DataType::Timestamp(_, _)) => true,
356        (FieldType::Timestamp, DataType::Int64) => true, // timestamps are i64 internally
357        (FieldType::Decimal, DataType::Float64) => true, // Decimal stored as f64
358        (FieldType::Decimal, DataType::Int64) => true,   // numeric promotion
359        (FieldType::Any, _) => true,                     // Any matches everything
360        _ => false,
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_type_schema_creation() {
370        let schema = TypeSchema::new(
371            "TestType",
372            vec![
373                ("a".to_string(), FieldType::F64),
374                ("b".to_string(), FieldType::I64),
375                ("c".to_string(), FieldType::String),
376            ],
377        );
378
379        assert_eq!(schema.name, "TestType");
380        assert_eq!(schema.field_count(), 3);
381        assert_eq!(schema.data_size, 24); // 3 * 8 bytes
382    }
383
384    #[test]
385    fn test_field_offsets() {
386        let schema = TypeSchema::new(
387            "OffsetTest",
388            vec![
389                ("first".to_string(), FieldType::F64),
390                ("second".to_string(), FieldType::I64),
391                ("third".to_string(), FieldType::Bool),
392            ],
393        );
394
395        assert_eq!(schema.field_offset("first"), Some(0));
396        assert_eq!(schema.field_offset("second"), Some(8));
397        assert_eq!(schema.field_offset("third"), Some(16));
398        assert_eq!(schema.field_offset("nonexistent"), None);
399    }
400
401    #[test]
402    fn test_field_index() {
403        let schema = TypeSchema::new(
404            "IndexTest",
405            vec![
406                ("a".to_string(), FieldType::F64),
407                ("b".to_string(), FieldType::F64),
408                ("c".to_string(), FieldType::F64),
409            ],
410        );
411
412        assert_eq!(schema.field_index("a"), Some(0));
413        assert_eq!(schema.field_index("b"), Some(1));
414        assert_eq!(schema.field_index("c"), Some(2));
415    }
416
417    #[test]
418    fn test_unique_schema_ids() {
419        let schema1 = TypeSchema::new("Type1", vec![]);
420        let schema2 = TypeSchema::new("Type2", vec![]);
421        let schema3 = TypeSchema::new("Type3", vec![]);
422
423        // IDs should be unique
424        assert_ne!(schema1.id, schema2.id);
425        assert_ne!(schema2.id, schema3.id);
426        assert_ne!(schema1.id, schema3.id);
427    }
428
429    // ==========================================================================
430    // Enum Schema Tests
431    // ==========================================================================
432
433    #[test]
434    fn test_enum_schema_creation() {
435        let schema = TypeSchema::new_enum(
436            "Option",
437            vec![
438                EnumVariantInfo::new("Some", 0, 1),
439                EnumVariantInfo::new("None", 1, 0),
440            ],
441        );
442
443        assert_eq!(schema.name, "Option");
444        assert!(schema.is_enum());
445
446        // Check variant info
447        let enum_info = schema.get_enum_info().unwrap();
448        assert_eq!(enum_info.variants.len(), 2);
449        assert_eq!(enum_info.variant_id("Some"), Some(0));
450        assert_eq!(enum_info.variant_id("None"), Some(1));
451        assert_eq!(enum_info.max_payload_fields(), 1);
452    }
453
454    #[test]
455    fn test_enum_schema_layout() {
456        let schema = TypeSchema::new_enum(
457            "Result",
458            vec![
459                EnumVariantInfo::new("Ok", 0, 1),
460                EnumVariantInfo::new("Err", 1, 1),
461            ],
462        );
463
464        // Layout: __variant (8 bytes) + __payload_0 (8 bytes) = 16 bytes
465        assert_eq!(schema.data_size, 16);
466        assert_eq!(schema.field_count(), 2);
467
468        // Check field offsets
469        assert_eq!(schema.field_offset("__variant"), Some(0));
470        assert_eq!(schema.field_offset("__payload_0"), Some(8));
471    }
472
473    #[test]
474    fn test_enum_schema_multiple_payloads() {
475        // Enum with variants having different payload counts
476        let schema = TypeSchema::new_enum(
477            "Shape",
478            vec![
479                EnumVariantInfo::new("Circle", 0, 1),    // radius only
480                EnumVariantInfo::new("Rectangle", 1, 2), // width, height
481                EnumVariantInfo::new("Point", 2, 0),     // no payload
482            ],
483        );
484
485        // Layout should accommodate max payload (2 fields)
486        // __variant (8) + __payload_0 (8) + __payload_1 (8) = 24 bytes
487        assert_eq!(schema.data_size, 24);
488        assert_eq!(schema.field_count(), 3);
489
490        assert_eq!(schema.field_offset("__variant"), Some(0));
491        assert_eq!(schema.field_offset("__payload_0"), Some(8));
492        assert_eq!(schema.field_offset("__payload_1"), Some(16));
493    }
494
495    #[test]
496    fn test_enum_variant_lookup() {
497        let schema = TypeSchema::new_enum(
498            "Status",
499            vec![
500                EnumVariantInfo::new("Pending", 0, 0),
501                EnumVariantInfo::new("Running", 1, 1),
502                EnumVariantInfo::new("Complete", 2, 1),
503                EnumVariantInfo::new("Failed", 3, 1),
504            ],
505        );
506
507        let enum_info = schema.get_enum_info().unwrap();
508
509        // Lookup by ID
510        let running = enum_info.variant_by_id(1).unwrap();
511        assert_eq!(running.name, "Running");
512        assert_eq!(running.payload_fields, 1);
513
514        // Lookup by name
515        let complete = enum_info.variant_by_name("Complete").unwrap();
516        assert_eq!(complete.id, 2);
517
518        // Non-existent variants
519        assert!(enum_info.variant_by_id(99).is_none());
520        assert!(enum_info.variant_by_name("Unknown").is_none());
521    }
522
523    // ==========================================================================
524    // TypeBinding Tests
525    // ==========================================================================
526
527    #[test]
528    fn test_bind_to_arrow_schema_success() {
529        use arrow_schema::{Field, Schema as ArrowSchema};
530
531        let type_schema = TypeSchema::new(
532            "Candle",
533            vec![
534                ("open".to_string(), FieldType::F64),
535                ("close".to_string(), FieldType::F64),
536                ("volume".to_string(), FieldType::I64),
537            ],
538        );
539
540        let arrow_schema = ArrowSchema::new(vec![
541            Field::new("date", DataType::Utf8, false),
542            Field::new("open", DataType::Float64, false),
543            Field::new("close", DataType::Float64, false),
544            Field::new("volume", DataType::Int64, false),
545        ]);
546
547        let binding = type_schema.bind_to_arrow_schema(&arrow_schema).unwrap();
548        assert_eq!(binding.schema_name, "Candle");
549        // "open" is field 0 in TypeSchema, column 1 in Arrow
550        assert_eq!(binding.column_index(0), Some(1));
551        // "close" is field 1, column 2
552        assert_eq!(binding.column_index(1), Some(2));
553        // "volume" is field 2, column 3
554        assert_eq!(binding.column_index(2), Some(3));
555    }
556
557    #[test]
558    fn test_bind_missing_column() {
559        use arrow_schema::{Field, Schema as ArrowSchema};
560
561        let type_schema = TypeSchema::new(
562            "Candle",
563            vec![
564                ("open".to_string(), FieldType::F64),
565                ("missing_field".to_string(), FieldType::F64),
566            ],
567        );
568
569        let arrow_schema = ArrowSchema::new(vec![Field::new("open", DataType::Float64, false)]);
570
571        let err = type_schema.bind_to_arrow_schema(&arrow_schema).unwrap_err();
572        assert!(matches!(err, TypeBindingError::MissingColumn { .. }));
573    }
574
575    #[test]
576    fn test_bind_type_mismatch() {
577        use arrow_schema::{Field, Schema as ArrowSchema};
578
579        let type_schema = TypeSchema::new("Test", vec![("name".to_string(), FieldType::F64)]);
580
581        let arrow_schema = ArrowSchema::new(vec![
582            Field::new("name", DataType::Utf8, false), // String, not Float64
583        ]);
584
585        let err = type_schema.bind_to_arrow_schema(&arrow_schema).unwrap_err();
586        assert!(matches!(err, TypeBindingError::TypeMismatch { .. }));
587    }
588
589    #[test]
590    fn test_bind_compatible_types() {
591        use arrow_schema::{Field, Schema as ArrowSchema, TimeUnit};
592
593        // Test widening and promotion rules
594        let type_schema = TypeSchema::new(
595            "Wide",
596            vec![
597                ("f32_as_f64".to_string(), FieldType::F64),
598                ("i32_as_i64".to_string(), FieldType::I64),
599                ("ts".to_string(), FieldType::Timestamp),
600                ("any_field".to_string(), FieldType::Any),
601            ],
602        );
603
604        let arrow_schema = ArrowSchema::new(vec![
605            Field::new("f32_as_f64", DataType::Float32, false),
606            Field::new("i32_as_i64", DataType::Int32, false),
607            Field::new(
608                "ts",
609                DataType::Timestamp(TimeUnit::Microsecond, None),
610                false,
611            ),
612            Field::new("any_field", DataType::Boolean, false),
613        ]);
614
615        let binding = type_schema.bind_to_arrow_schema(&arrow_schema);
616        assert!(binding.is_ok());
617    }
618}