vortex_scalar/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_dtype::DType;
14use vortex_dtype::FieldName;
15use vortex_dtype::FieldNames;
16use vortex_dtype::StructFields;
17use vortex_error::VortexError;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_error::vortex_err;
22use vortex_error::vortex_panic;
23
24use crate::InnerScalarValue;
25use crate::Scalar;
26use crate::ScalarValue;
27
28/// A scalar value representing a struct with named fields.
29///
30/// This type provides a view into a struct scalar value, which can contain
31/// named fields with different types, or be null.
32#[derive(Debug, Clone)]
33pub struct StructScalar<'a> {
34    dtype: &'a DType,
35    fields: Option<&'a Arc<[ScalarValue]>>,
36}
37
38impl Display for StructScalar<'_> {
39    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40        match &self.fields {
41            None => write!(f, "null"),
42            Some(fields) => {
43                write!(f, "{{")?;
44                let formatted_fields = self
45                    .names()
46                    .iter()
47                    .zip_eq(self.struct_fields().fields())
48                    .zip_eq(fields.iter())
49                    .map(|((name, dtype), value)| {
50                        let val = Scalar::new(dtype, value.clone());
51                        format!("{name}: {val}")
52                    })
53                    .format(", ");
54                write!(f, "{formatted_fields}")?;
55                write!(f, "}}")
56            }
57        }
58    }
59}
60
61impl PartialEq for StructScalar<'_> {
62    fn eq(&self, other: &Self) -> bool {
63        if !self.dtype.eq_ignore_nullability(other.dtype) {
64            return false;
65        }
66
67        match (self.fields(), other.fields()) {
68            (Some(lhs), Some(rhs)) => lhs.zip(rhs).all(|(l_s, r_s)| l_s == r_s),
69            (None, None) => true,
70            (Some(_), None) | (None, Some(_)) => false,
71        }
72    }
73}
74
75impl Eq for StructScalar<'_> {}
76
77/// Ord is not implemented since it's undefined for different field DTypes
78impl PartialOrd for StructScalar<'_> {
79    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
80        if !self.dtype.eq_ignore_nullability(other.dtype) {
81            return None;
82        }
83
84        match (self.fields(), other.fields()) {
85            (Some(lhs), Some(rhs)) => {
86                for (l_s, r_s) in lhs.zip(rhs) {
87                    match l_s.partial_cmp(&r_s)? {
88                        Ordering::Equal => continue,
89                        Ordering::Less => return Some(Ordering::Less),
90                        Ordering::Greater => return Some(Ordering::Greater),
91                    }
92                }
93            }
94            (None, None) => return Some(Ordering::Equal),
95            (Some(_), None) => return Some(Ordering::Greater),
96            (None, Some(_)) => return Some(Ordering::Less),
97        }
98
99        Some(Ordering::Equal)
100    }
101}
102
103impl Hash for StructScalar<'_> {
104    fn hash<H: Hasher>(&self, state: &mut H) {
105        self.dtype.hash(state);
106        if let Some(fields) = self.fields() {
107            for f in fields {
108                f.hash(state);
109            }
110        }
111    }
112}
113
114impl<'a> StructScalar<'a> {
115    #[inline]
116    pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
117        if !matches!(dtype, DType::Struct(..)) {
118            vortex_bail!("Expected struct scalar, found {}", dtype)
119        }
120
121        Ok(Self {
122            dtype,
123            fields: value.as_list()?,
124        })
125    }
126
127    /// Returns the data type of this struct scalar.
128    #[inline]
129    pub fn dtype(&self) -> &'a DType {
130        self.dtype
131    }
132
133    /// Returns the struct field definitions.
134    #[inline]
135    pub fn struct_fields(&self) -> &StructFields {
136        self.dtype
137            .as_struct_fields_opt()
138            .vortex_expect("StructScalar always has struct dtype")
139    }
140
141    /// Returns the field names of the struct.
142    pub fn names(&self) -> &FieldNames {
143        self.struct_fields().names()
144    }
145
146    /// Returns true if the struct is null.
147    pub fn is_null(&self) -> bool {
148        self.fields.is_none()
149    }
150
151    /// Returns the field with the given name as a scalar.
152    ///
153    /// Returns None if the field doesn't exist.
154    pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
155        let idx = self.struct_fields().find(name)?;
156        self.field_by_idx(idx)
157    }
158
159    /// Returns the field at the given index as a scalar.
160    ///
161    /// Returns None if the index is out of bounds.
162    ///
163    /// # Panics
164    ///
165    /// Panics if the struct is null.
166    pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
167        let fields = self
168            .fields
169            .vortex_expect("Can't take field out of null struct scalar");
170        Some(Scalar::new(
171            self.struct_fields().field_by_index(idx)?,
172            fields[idx].clone(),
173        ))
174    }
175
176    /// Returns the fields of the struct scalar, or None if the scalar is null.
177    pub fn fields(&self) -> Option<impl ExactSizeIterator<Item = Scalar>> {
178        let fields = self.fields?;
179        Some(
180            fields
181                .iter()
182                .zip(self.struct_fields().fields())
183                .map(|(v, dtype)| Scalar::new(dtype, v.clone())),
184        )
185    }
186
187    pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
188        self.fields.map(Arc::deref)
189    }
190
191    /// Casts this struct scalar to another struct type.
192    ///
193    /// # Errors
194    ///
195    /// Returns an error if the target type is not a struct or if the number of fields don't match.
196    pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
197        let DType::Struct(st, _) = dtype else {
198            vortex_bail!(
199                "Cannot cast struct to {}: struct can only be cast to struct",
200                dtype
201            )
202        };
203        let own_st = self.struct_fields();
204
205        if st.fields().len() != own_st.fields().len() {
206            vortex_bail!(
207                "Cannot cast between structs with different number of fields: {} and {}",
208                own_st.fields().len(),
209                st.fields().len()
210            );
211        }
212
213        if let Some(fs) = self.field_values() {
214            let fields = fs
215                .iter()
216                .enumerate()
217                .map(|(i, f)| {
218                    Scalar::new(
219                        own_st
220                            .field_by_index(i)
221                            .vortex_expect("Iterating over scalar fields"),
222                        f.clone(),
223                    )
224                    .cast(
225                        &st.field_by_index(i)
226                            .vortex_expect("Iterating over scalar fields"),
227                    )
228                    .map(|s| s.into_value())
229                })
230                .collect::<VortexResult<Vec<_>>>()?;
231            Ok(Scalar::new(
232                dtype.clone(),
233                ScalarValue(InnerScalarValue::List(fields.into())),
234            ))
235        } else {
236            Ok(Scalar::null(dtype.clone()))
237        }
238    }
239
240    /// Projects this struct scalar to include only the specified fields.
241    ///
242    /// # Errors
243    ///
244    /// Returns an error if the struct cannot be projected or if a field is not found.
245    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
246        let struct_dtype = self
247            .dtype
248            .as_struct_fields_opt()
249            .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
250        let projected_dtype = struct_dtype.project(projection)?;
251        let new_fields = if let Some(fs) = self.field_values() {
252            ScalarValue(InnerScalarValue::List(
253                projection
254                    .iter()
255                    .map(|name| {
256                        struct_dtype
257                            .find(name)
258                            .vortex_expect("DType has been successfully projected already")
259                    })
260                    .map(|i| fs[i].clone())
261                    .collect(),
262            ))
263        } else {
264            ScalarValue(InnerScalarValue::Null)
265        };
266        Ok(Scalar::new(
267            DType::Struct(projected_dtype, self.dtype().nullability()),
268            new_fields,
269        ))
270    }
271}
272
273impl Scalar {
274    /// Creates a new struct scalar with the given fields.
275    pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
276        let DType::Struct(struct_fields, _) = &dtype else {
277            vortex_panic!("Expected struct dtype, found {}", dtype);
278        };
279
280        let field_dtypes = struct_fields.fields();
281        if children.len() != field_dtypes.len() {
282            vortex_panic!(
283                "Struct has {} fields but {} children were provided",
284                field_dtypes.len(),
285                children.len()
286            );
287        }
288
289        for (idx, (child, expected_dtype)) in children.iter().zip(field_dtypes).enumerate() {
290            if child.dtype() != &expected_dtype {
291                vortex_panic!(
292                    "Field {} expected dtype {} but got {}",
293                    idx,
294                    expected_dtype,
295                    child.dtype()
296                );
297            }
298        }
299
300        let mut value_children = Vec::with_capacity(children.len());
301        value_children.extend(children.into_iter().map(|x| x.into_value()));
302
303        Self::new(
304            dtype,
305            ScalarValue(InnerScalarValue::List(value_children.into())),
306        )
307    }
308}
309
310impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
311    type Error = VortexError;
312
313    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
314        Self::try_new(value.dtype(), value.value())
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use vortex_dtype::DType;
321    use vortex_dtype::Nullability;
322    use vortex_dtype::PType::I32;
323    use vortex_dtype::StructFields;
324
325    use super::*;
326
327    fn setup_types() -> (DType, DType, DType) {
328        let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
329        let f1_dt = DType::Utf8(Nullability::NonNullable);
330
331        let dtype = DType::Struct(
332            StructFields::new(["a", "b"].into(), vec![f0_dt.clone(), f1_dt.clone()]),
333            Nullability::Nullable,
334        );
335
336        (f0_dt, f1_dt, dtype)
337    }
338
339    #[test]
340    #[should_panic]
341    fn test_struct_scalar_null() {
342        let (_, _, dtype) = setup_types();
343
344        let scalar = Scalar::null(dtype);
345
346        scalar.as_struct().field_by_idx(0).unwrap();
347    }
348
349    #[test]
350    fn test_struct_scalar_non_null() {
351        let (f0_dt, f1_dt, dtype) = setup_types();
352
353        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
354        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
355
356        let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
357        let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
358
359        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
360
361        let scalar_f0 = scalar.as_struct().field_by_idx(0);
362        assert!(scalar_f0.is_some());
363        let scalar_f0 = scalar_f0.unwrap();
364        assert_eq!(scalar_f0, f0_val_null);
365        assert_eq!(scalar_f0.dtype(), &f0_dt);
366
367        let scalar_f1 = scalar.as_struct().field_by_idx(1);
368        assert!(scalar_f1.is_some());
369        let scalar_f1 = scalar_f1.unwrap();
370        assert_eq!(scalar_f1, f1_val_null);
371        assert_eq!(scalar_f1.dtype(), &f1_dt);
372    }
373
374    #[test]
375    #[should_panic(expected = "Expected struct dtype")]
376    fn test_struct_scalar_wrong_dtype() {
377        let dtype = DType::Primitive(I32, Nullability::NonNullable);
378        let scalar = Scalar::primitive::<i32>(1, Nullability::NonNullable);
379
380        Scalar::struct_(dtype, vec![scalar]);
381    }
382
383    #[test]
384    #[should_panic(expected = "Struct has 2 fields but 1 children were provided")]
385    fn test_struct_scalar_wrong_child_count() {
386        let (_, _, dtype) = setup_types();
387        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
388
389        Scalar::struct_(dtype, vec![f0_val]);
390    }
391
392    #[test]
393    #[should_panic(expected = "Field 0 expected dtype i32 but got utf8")]
394    fn test_struct_scalar_wrong_child_dtype() {
395        let (_, _, dtype) = setup_types();
396        let f0_val = Scalar::utf8("wrong", Nullability::NonNullable);
397        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
398
399        Scalar::struct_(dtype, vec![f0_val, f1_val]);
400    }
401
402    #[test]
403    fn test_struct_field_by_name() {
404        let (_, _, dtype) = setup_types();
405        let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
406        let f1_val = Scalar::utf8("world", Nullability::NonNullable);
407
408        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
409
410        // Get field by name
411        let field_a = scalar.as_struct().field("a");
412        assert!(field_a.is_some());
413        assert_eq!(
414            field_a
415                .unwrap()
416                .as_primitive()
417                .typed_value::<i32>()
418                .unwrap(),
419            42
420        );
421
422        let field_b = scalar.as_struct().field("b");
423        assert!(field_b.is_some());
424        assert_eq!(field_b.unwrap().as_utf8().value().unwrap(), "world".into());
425
426        // Non-existent field
427        let field_c = scalar.as_struct().field("c");
428        assert!(field_c.is_none());
429    }
430
431    #[test]
432    fn test_struct_fields() {
433        let (_, _, dtype) = setup_types();
434        let f0_val = Scalar::primitive::<i32>(100, Nullability::NonNullable);
435        let f1_val = Scalar::utf8("test", Nullability::NonNullable);
436
437        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
438
439        let fields = scalar.as_struct().fields().unwrap().collect::<Vec<_>>();
440        assert_eq!(fields.len(), 2);
441        assert_eq!(fields[0].as_primitive().typed_value::<i32>().unwrap(), 100);
442        assert_eq!(fields[1].as_utf8().value().unwrap(), "test".into());
443    }
444
445    #[test]
446    fn test_struct_null_fields() {
447        let (_, _, dtype) = setup_types();
448        let null_scalar = Scalar::null(dtype);
449
450        assert!(null_scalar.as_struct().is_null());
451        assert!(null_scalar.as_struct().fields().is_none());
452        assert!(null_scalar.as_struct().field_values().is_none());
453    }
454
455    #[test]
456    fn test_struct_cast_to_struct() {
457        // Create source struct
458        let source_fields = StructFields::new(
459            ["x", "y"].into(),
460            vec![
461                DType::Primitive(I32, Nullability::NonNullable),
462                DType::Primitive(I32, Nullability::NonNullable),
463            ],
464        );
465        let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
466
467        // Create target struct with different field types
468        let target_fields = StructFields::new(
469            ["x", "y"].into(),
470            vec![
471                DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
472                DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
473            ],
474        );
475        let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
476
477        let f0 = Scalar::primitive::<i32>(42, Nullability::NonNullable);
478        let f1 = Scalar::primitive::<i32>(123, Nullability::NonNullable);
479        let source_scalar = Scalar::struct_(source_dtype, vec![f0, f1]);
480
481        // Cast to target type
482        let result = source_scalar.as_struct().cast(&target_dtype).unwrap();
483        assert_eq!(result.dtype(), &target_dtype);
484
485        let fields = result.as_struct().fields().unwrap().collect::<Vec<_>>();
486        assert_eq!(fields[0].as_primitive().typed_value::<i64>().unwrap(), 42);
487        assert_eq!(fields[1].as_primitive().typed_value::<i64>().unwrap(), 123);
488    }
489
490    #[test]
491    fn test_struct_cast_mismatched_fields() {
492        let source_fields = StructFields::new(
493            ["a"].into(),
494            vec![DType::Primitive(I32, Nullability::NonNullable)],
495        );
496        let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
497
498        let target_fields = StructFields::new(
499            ["a", "b"].into(),
500            vec![
501                DType::Primitive(I32, Nullability::NonNullable),
502                DType::Primitive(I32, Nullability::NonNullable),
503            ],
504        );
505        let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
506
507        let scalar = Scalar::struct_(
508            source_dtype,
509            vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
510        );
511
512        let result = scalar.as_struct().cast(&target_dtype);
513        assert!(result.is_err());
514    }
515
516    #[test]
517    fn test_struct_cast_to_non_struct() {
518        let (_, _, dtype) = setup_types();
519        let scalar = Scalar::struct_(
520            dtype,
521            vec![
522                Scalar::primitive::<i32>(1, Nullability::NonNullable),
523                Scalar::utf8("test", Nullability::NonNullable),
524            ],
525        );
526
527        let result = scalar
528            .as_struct()
529            .cast(&DType::Primitive(I32, Nullability::NonNullable));
530        assert!(result.is_err());
531    }
532
533    #[test]
534    fn test_struct_project() {
535        let (_, _, dtype) = setup_types();
536        let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
537        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
538
539        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
540
541        // Project to only field "b"
542        let projected = scalar.as_struct().project(&["b".into()]).unwrap();
543        let projected_struct = projected.as_struct();
544
545        assert_eq!(projected_struct.names().len(), 1);
546        assert_eq!(projected_struct.names()[0].as_ref(), "b");
547
548        let fields = projected_struct.fields().unwrap().collect::<Vec<_>>();
549        assert_eq!(fields.len(), 1);
550        assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello");
551    }
552
553    #[test]
554    fn test_struct_project_null() {
555        let (_, _, dtype) = setup_types();
556        let null_scalar = Scalar::null(dtype);
557
558        let projected = null_scalar.as_struct().project(&["a".into()]).unwrap();
559        assert!(projected.as_struct().is_null());
560    }
561
562    #[test]
563    fn test_struct_equality() {
564        let (_, _, dtype) = setup_types();
565
566        let scalar1 = Scalar::struct_(
567            dtype.clone(),
568            vec![
569                Scalar::primitive::<i32>(1, Nullability::NonNullable),
570                Scalar::utf8("test", Nullability::NonNullable),
571            ],
572        );
573
574        let scalar2 = Scalar::struct_(
575            dtype.clone(),
576            vec![
577                Scalar::primitive::<i32>(1, Nullability::NonNullable),
578                Scalar::utf8("test", Nullability::NonNullable),
579            ],
580        );
581
582        let scalar3 = Scalar::struct_(
583            dtype,
584            vec![
585                Scalar::primitive::<i32>(2, Nullability::NonNullable),
586                Scalar::utf8("test", Nullability::NonNullable),
587            ],
588        );
589
590        assert_eq!(scalar1.as_struct(), scalar2.as_struct());
591        assert_ne!(scalar1.as_struct(), scalar3.as_struct());
592    }
593
594    #[test]
595    fn test_struct_partial_ord() {
596        let (_, _, dtype) = setup_types();
597
598        let scalar1 = Scalar::struct_(
599            dtype.clone(),
600            vec![
601                Scalar::primitive::<i32>(1, Nullability::NonNullable),
602                Scalar::utf8("a", Nullability::NonNullable),
603            ],
604        );
605
606        let scalar2 = Scalar::struct_(
607            dtype,
608            vec![
609                Scalar::primitive::<i32>(2, Nullability::NonNullable),
610                Scalar::utf8("b", Nullability::NonNullable),
611            ],
612        );
613
614        // Structs with same dtype can be compared
615        assert!(scalar1.as_struct() < scalar2.as_struct());
616
617        // Different struct types cannot be compared
618        let other_dtype = DType::Struct(
619            StructFields::new(
620                ["c"].into(),
621                vec![DType::Primitive(I32, Nullability::NonNullable)],
622            ),
623            Nullability::NonNullable,
624        );
625        let scalar3 = Scalar::struct_(
626            other_dtype,
627            vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
628        );
629
630        assert_eq!(scalar1.as_struct().partial_cmp(&scalar3.as_struct()), None);
631    }
632
633    #[test]
634    fn test_struct_hash() {
635        use std::collections::hash_map::DefaultHasher;
636        use std::hash::Hash;
637        use std::hash::Hasher;
638
639        let (_, _, dtype) = setup_types();
640
641        let scalar1 = Scalar::struct_(
642            dtype.clone(),
643            vec![
644                Scalar::primitive::<i32>(1, Nullability::NonNullable),
645                Scalar::utf8("test", Nullability::NonNullable),
646            ],
647        );
648
649        let scalar2 = Scalar::struct_(
650            dtype,
651            vec![
652                Scalar::primitive::<i32>(1, Nullability::NonNullable),
653                Scalar::utf8("test", Nullability::NonNullable),
654            ],
655        );
656
657        let mut hasher1 = DefaultHasher::new();
658        scalar1.as_struct().hash(&mut hasher1);
659        let hash1 = hasher1.finish();
660
661        let mut hasher2 = DefaultHasher::new();
662        scalar2.as_struct().hash(&mut hasher2);
663        let hash2 = hasher2.finish();
664
665        assert_eq!(hash1, hash2);
666    }
667
668    #[test]
669    fn test_struct_try_new_non_struct_dtype() {
670        let dtype = DType::Primitive(I32, Nullability::NonNullable);
671        let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
672
673        let result = StructScalar::try_new(&dtype, &value);
674        assert!(result.is_err());
675    }
676
677    #[test]
678    fn test_struct_field_out_of_bounds() {
679        let (_, _, dtype) = setup_types();
680        let scalar = Scalar::struct_(
681            dtype,
682            vec![
683                Scalar::primitive::<i32>(1, Nullability::NonNullable),
684                Scalar::utf8("test", Nullability::NonNullable),
685            ],
686        );
687
688        // Try to access field beyond bounds
689        let field = scalar.as_struct().field_by_idx(10);
690        assert!(field.is_none());
691    }
692}