vortex_scalar/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt::{Display, Formatter};
6use std::hash::{Hash, Hasher};
7use std::ops::Deref;
8use std::sync::Arc;
9
10use itertools::Itertools;
11use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
12use vortex_error::{
13    VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
14};
15
16use crate::{InnerScalarValue, Scalar, ScalarValue};
17
18/// A scalar value representing a struct with named fields.
19///
20/// This type provides a view into a struct scalar value, which can contain
21/// named fields with different types, or be null.
22#[derive(Debug, Clone)]
23pub struct StructScalar<'a> {
24    dtype: &'a DType,
25    fields: Option<&'a Arc<[ScalarValue]>>,
26}
27
28impl Display for StructScalar<'_> {
29    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30        match &self.fields {
31            None => write!(f, "null"),
32            Some(fields) => {
33                write!(f, "{{")?;
34                let formatted_fields = self
35                    .names()
36                    .iter()
37                    .zip_eq(self.struct_fields().fields())
38                    .zip_eq(fields.iter())
39                    .map(|((name, dtype), value)| {
40                        let val = Scalar::new(dtype, value.clone());
41                        format!("{name}: {val}")
42                    })
43                    .format(", ");
44                write!(f, "{formatted_fields}")?;
45                write!(f, "}}")
46            }
47        }
48    }
49}
50
51impl PartialEq for StructScalar<'_> {
52    fn eq(&self, other: &Self) -> bool {
53        if !self.dtype.eq_ignore_nullability(other.dtype) {
54            return false;
55        }
56
57        match (self.fields(), other.fields()) {
58            (Some(lhs), Some(rhs)) => lhs.zip(rhs).all(|(l_s, r_s)| l_s == r_s),
59            (None, None) => true,
60            (Some(_), None) | (None, Some(_)) => false,
61        }
62    }
63}
64
65impl Eq for StructScalar<'_> {}
66
67/// Ord is not implemented since it's undefined for different field DTypes
68impl PartialOrd for StructScalar<'_> {
69    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
70        if !self.dtype.eq_ignore_nullability(other.dtype) {
71            return None;
72        }
73
74        match (self.fields(), other.fields()) {
75            (Some(lhs), Some(rhs)) => {
76                for (l_s, r_s) in lhs.zip(rhs) {
77                    match l_s.partial_cmp(&r_s)? {
78                        Ordering::Equal => continue,
79                        Ordering::Less => return Some(Ordering::Less),
80                        Ordering::Greater => return Some(Ordering::Greater),
81                    }
82                }
83            }
84            (None, None) => return Some(Ordering::Equal),
85            (Some(_), None) => return Some(Ordering::Greater),
86            (None, Some(_)) => return Some(Ordering::Less),
87        }
88
89        Some(Ordering::Equal)
90    }
91}
92
93impl Hash for StructScalar<'_> {
94    fn hash<H: Hasher>(&self, state: &mut H) {
95        self.dtype.hash(state);
96        if let Some(fields) = self.fields() {
97            for f in fields {
98                f.hash(state);
99            }
100        }
101    }
102}
103
104impl<'a> StructScalar<'a> {
105    #[inline]
106    pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
107        if !matches!(dtype, DType::Struct(..)) {
108            vortex_bail!("Expected struct scalar, found {}", dtype)
109        }
110
111        Ok(Self {
112            dtype,
113            fields: value.as_list()?,
114        })
115    }
116
117    /// Returns the data type of this struct scalar.
118    #[inline]
119    pub fn dtype(&self) -> &'a DType {
120        self.dtype
121    }
122
123    /// Returns the struct field definitions.
124    #[inline]
125    pub fn struct_fields(&self) -> &StructFields {
126        self.dtype
127            .as_struct_fields_opt()
128            .vortex_expect("StructScalar always has struct dtype")
129    }
130
131    /// Returns the field names of the struct.
132    pub fn names(&self) -> &FieldNames {
133        self.struct_fields().names()
134    }
135
136    /// Returns true if the struct is null.
137    pub fn is_null(&self) -> bool {
138        self.fields.is_none()
139    }
140
141    /// Returns the field with the given name as a scalar.
142    ///
143    /// Returns None if the field doesn't exist.
144    pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
145        let idx = self.struct_fields().find(name)?;
146        self.field_by_idx(idx)
147    }
148
149    /// Returns the field at the given index as a scalar.
150    ///
151    /// Returns None if the index is out of bounds.
152    ///
153    /// # Panics
154    ///
155    /// Panics if the struct is null.
156    pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
157        let fields = self
158            .fields
159            .vortex_expect("Can't take field out of null struct scalar");
160        Some(Scalar::new(
161            self.struct_fields().field_by_index(idx)?,
162            fields[idx].clone(),
163        ))
164    }
165
166    /// Returns the fields of the struct scalar, or None if the scalar is null.
167    pub fn fields(&self) -> Option<impl ExactSizeIterator<Item = Scalar>> {
168        let fields = self.fields?;
169        Some(
170            fields
171                .iter()
172                .zip(self.struct_fields().fields())
173                .map(|(v, dtype)| Scalar::new(dtype, v.clone())),
174        )
175    }
176
177    pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
178        self.fields.map(Arc::deref)
179    }
180
181    /// Casts this struct scalar to another struct type.
182    ///
183    /// # Errors
184    ///
185    /// Returns an error if the target type is not a struct or if the number of fields don't match.
186    pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
187        let DType::Struct(st, _) = dtype else {
188            vortex_bail!(
189                "Cannot cast struct to {}: struct can only be cast to struct",
190                dtype
191            )
192        };
193        let own_st = self.struct_fields();
194
195        if st.fields().len() != own_st.fields().len() {
196            vortex_bail!(
197                "Cannot cast between structs with different number of fields: {} and {}",
198                own_st.fields().len(),
199                st.fields().len()
200            );
201        }
202
203        if let Some(fs) = self.field_values() {
204            let fields = fs
205                .iter()
206                .enumerate()
207                .map(|(i, f)| {
208                    Scalar::new(
209                        own_st
210                            .field_by_index(i)
211                            .vortex_expect("Iterating over scalar fields"),
212                        f.clone(),
213                    )
214                    .cast(
215                        &st.field_by_index(i)
216                            .vortex_expect("Iterating over scalar fields"),
217                    )
218                    .map(|s| s.into_value())
219                })
220                .collect::<VortexResult<Vec<_>>>()?;
221            Ok(Scalar::new(
222                dtype.clone(),
223                ScalarValue(InnerScalarValue::List(fields.into())),
224            ))
225        } else {
226            Ok(Scalar::null(dtype.clone()))
227        }
228    }
229
230    /// Projects this struct scalar to include only the specified fields.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if the struct cannot be projected or if a field is not found.
235    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
236        let struct_dtype = self
237            .dtype
238            .as_struct_fields_opt()
239            .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
240        let projected_dtype = struct_dtype.project(projection)?;
241        let new_fields = if let Some(fs) = self.field_values() {
242            ScalarValue(InnerScalarValue::List(
243                projection
244                    .iter()
245                    .map(|name| {
246                        struct_dtype
247                            .find(name)
248                            .vortex_expect("DType has been successfully projected already")
249                    })
250                    .map(|i| fs[i].clone())
251                    .collect(),
252            ))
253        } else {
254            ScalarValue(InnerScalarValue::Null)
255        };
256        Ok(Scalar::new(
257            DType::Struct(projected_dtype, self.dtype().nullability()),
258            new_fields,
259        ))
260    }
261}
262
263impl Scalar {
264    /// Creates a new struct scalar with the given fields.
265    pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
266        let DType::Struct(struct_fields, _) = &dtype else {
267            vortex_panic!("Expected struct dtype, found {}", dtype);
268        };
269
270        let field_dtypes = struct_fields.fields();
271        if children.len() != field_dtypes.len() {
272            vortex_panic!(
273                "Struct has {} fields but {} children were provided",
274                field_dtypes.len(),
275                children.len()
276            );
277        }
278
279        for (idx, (child, expected_dtype)) in children.iter().zip(field_dtypes).enumerate() {
280            if child.dtype() != &expected_dtype {
281                vortex_panic!(
282                    "Field {} expected dtype {} but got {}",
283                    idx,
284                    expected_dtype,
285                    child.dtype()
286                );
287            }
288        }
289
290        let mut value_children = Vec::with_capacity(children.len());
291        value_children.extend(children.into_iter().map(|x| x.into_value()));
292
293        Self::new(
294            dtype,
295            ScalarValue(InnerScalarValue::List(value_children.into())),
296        )
297    }
298}
299
300impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
301    type Error = VortexError;
302
303    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
304        Self::try_new(value.dtype(), value.value())
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use vortex_dtype::PType::I32;
311    use vortex_dtype::{DType, Nullability, StructFields};
312
313    use super::*;
314
315    fn setup_types() -> (DType, DType, DType) {
316        let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
317        let f1_dt = DType::Utf8(Nullability::NonNullable);
318
319        let dtype = DType::Struct(
320            StructFields::new(["a", "b"].into(), vec![f0_dt.clone(), f1_dt.clone()]),
321            Nullability::Nullable,
322        );
323
324        (f0_dt, f1_dt, dtype)
325    }
326
327    #[test]
328    #[should_panic]
329    fn test_struct_scalar_null() {
330        let (_, _, dtype) = setup_types();
331
332        let scalar = Scalar::null(dtype);
333
334        scalar.as_struct().field_by_idx(0).unwrap();
335    }
336
337    #[test]
338    fn test_struct_scalar_non_null() {
339        let (f0_dt, f1_dt, dtype) = setup_types();
340
341        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
342        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
343
344        let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
345        let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
346
347        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
348
349        let scalar_f0 = scalar.as_struct().field_by_idx(0);
350        assert!(scalar_f0.is_some());
351        let scalar_f0 = scalar_f0.unwrap();
352        assert_eq!(scalar_f0, f0_val_null);
353        assert_eq!(scalar_f0.dtype(), &f0_dt);
354
355        let scalar_f1 = scalar.as_struct().field_by_idx(1);
356        assert!(scalar_f1.is_some());
357        let scalar_f1 = scalar_f1.unwrap();
358        assert_eq!(scalar_f1, f1_val_null);
359        assert_eq!(scalar_f1.dtype(), &f1_dt);
360    }
361
362    #[test]
363    #[should_panic(expected = "Expected struct dtype")]
364    fn test_struct_scalar_wrong_dtype() {
365        let dtype = DType::Primitive(I32, Nullability::NonNullable);
366        let scalar = Scalar::primitive::<i32>(1, Nullability::NonNullable);
367
368        Scalar::struct_(dtype, vec![scalar]);
369    }
370
371    #[test]
372    #[should_panic(expected = "Struct has 2 fields but 1 children were provided")]
373    fn test_struct_scalar_wrong_child_count() {
374        let (_, _, dtype) = setup_types();
375        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
376
377        Scalar::struct_(dtype, vec![f0_val]);
378    }
379
380    #[test]
381    #[should_panic(expected = "Field 0 expected dtype i32 but got utf8")]
382    fn test_struct_scalar_wrong_child_dtype() {
383        let (_, _, dtype) = setup_types();
384        let f0_val = Scalar::utf8("wrong", Nullability::NonNullable);
385        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
386
387        Scalar::struct_(dtype, vec![f0_val, f1_val]);
388    }
389
390    #[test]
391    fn test_struct_field_by_name() {
392        let (_, _, dtype) = setup_types();
393        let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
394        let f1_val = Scalar::utf8("world", Nullability::NonNullable);
395
396        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
397
398        // Get field by name
399        let field_a = scalar.as_struct().field("a");
400        assert!(field_a.is_some());
401        assert_eq!(
402            field_a
403                .unwrap()
404                .as_primitive()
405                .typed_value::<i32>()
406                .unwrap(),
407            42
408        );
409
410        let field_b = scalar.as_struct().field("b");
411        assert!(field_b.is_some());
412        assert_eq!(field_b.unwrap().as_utf8().value().unwrap(), "world".into());
413
414        // Non-existent field
415        let field_c = scalar.as_struct().field("c");
416        assert!(field_c.is_none());
417    }
418
419    #[test]
420    fn test_struct_fields() {
421        let (_, _, dtype) = setup_types();
422        let f0_val = Scalar::primitive::<i32>(100, Nullability::NonNullable);
423        let f1_val = Scalar::utf8("test", Nullability::NonNullable);
424
425        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
426
427        let fields = scalar.as_struct().fields().unwrap().collect::<Vec<_>>();
428        assert_eq!(fields.len(), 2);
429        assert_eq!(fields[0].as_primitive().typed_value::<i32>().unwrap(), 100);
430        assert_eq!(fields[1].as_utf8().value().unwrap(), "test".into());
431    }
432
433    #[test]
434    fn test_struct_null_fields() {
435        let (_, _, dtype) = setup_types();
436        let null_scalar = Scalar::null(dtype);
437
438        assert!(null_scalar.as_struct().is_null());
439        assert!(null_scalar.as_struct().fields().is_none());
440        assert!(null_scalar.as_struct().field_values().is_none());
441    }
442
443    #[test]
444    fn test_struct_cast_to_struct() {
445        // Create source struct
446        let source_fields = StructFields::new(
447            ["x", "y"].into(),
448            vec![
449                DType::Primitive(I32, Nullability::NonNullable),
450                DType::Primitive(I32, Nullability::NonNullable),
451            ],
452        );
453        let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
454
455        // Create target struct with different field types
456        let target_fields = StructFields::new(
457            ["x", "y"].into(),
458            vec![
459                DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
460                DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
461            ],
462        );
463        let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
464
465        let f0 = Scalar::primitive::<i32>(42, Nullability::NonNullable);
466        let f1 = Scalar::primitive::<i32>(123, Nullability::NonNullable);
467        let source_scalar = Scalar::struct_(source_dtype, vec![f0, f1]);
468
469        // Cast to target type
470        let result = source_scalar.as_struct().cast(&target_dtype).unwrap();
471        assert_eq!(result.dtype(), &target_dtype);
472
473        let fields = result.as_struct().fields().unwrap().collect::<Vec<_>>();
474        assert_eq!(fields[0].as_primitive().typed_value::<i64>().unwrap(), 42);
475        assert_eq!(fields[1].as_primitive().typed_value::<i64>().unwrap(), 123);
476    }
477
478    #[test]
479    fn test_struct_cast_mismatched_fields() {
480        let source_fields = StructFields::new(
481            ["a"].into(),
482            vec![DType::Primitive(I32, Nullability::NonNullable)],
483        );
484        let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
485
486        let target_fields = StructFields::new(
487            ["a", "b"].into(),
488            vec![
489                DType::Primitive(I32, Nullability::NonNullable),
490                DType::Primitive(I32, Nullability::NonNullable),
491            ],
492        );
493        let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
494
495        let scalar = Scalar::struct_(
496            source_dtype,
497            vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
498        );
499
500        let result = scalar.as_struct().cast(&target_dtype);
501        assert!(result.is_err());
502    }
503
504    #[test]
505    fn test_struct_cast_to_non_struct() {
506        let (_, _, dtype) = setup_types();
507        let scalar = Scalar::struct_(
508            dtype,
509            vec![
510                Scalar::primitive::<i32>(1, Nullability::NonNullable),
511                Scalar::utf8("test", Nullability::NonNullable),
512            ],
513        );
514
515        let result = scalar
516            .as_struct()
517            .cast(&DType::Primitive(I32, Nullability::NonNullable));
518        assert!(result.is_err());
519    }
520
521    #[test]
522    fn test_struct_project() {
523        let (_, _, dtype) = setup_types();
524        let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
525        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
526
527        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
528
529        // Project to only field "b"
530        let projected = scalar.as_struct().project(&["b".into()]).unwrap();
531        let projected_struct = projected.as_struct();
532
533        assert_eq!(projected_struct.names().len(), 1);
534        assert_eq!(projected_struct.names()[0].as_ref(), "b");
535
536        let fields = projected_struct.fields().unwrap().collect::<Vec<_>>();
537        assert_eq!(fields.len(), 1);
538        assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello");
539    }
540
541    #[test]
542    fn test_struct_project_null() {
543        let (_, _, dtype) = setup_types();
544        let null_scalar = Scalar::null(dtype);
545
546        let projected = null_scalar.as_struct().project(&["a".into()]).unwrap();
547        assert!(projected.as_struct().is_null());
548    }
549
550    #[test]
551    fn test_struct_equality() {
552        let (_, _, dtype) = setup_types();
553
554        let scalar1 = Scalar::struct_(
555            dtype.clone(),
556            vec![
557                Scalar::primitive::<i32>(1, Nullability::NonNullable),
558                Scalar::utf8("test", Nullability::NonNullable),
559            ],
560        );
561
562        let scalar2 = Scalar::struct_(
563            dtype.clone(),
564            vec![
565                Scalar::primitive::<i32>(1, Nullability::NonNullable),
566                Scalar::utf8("test", Nullability::NonNullable),
567            ],
568        );
569
570        let scalar3 = Scalar::struct_(
571            dtype,
572            vec![
573                Scalar::primitive::<i32>(2, Nullability::NonNullable),
574                Scalar::utf8("test", Nullability::NonNullable),
575            ],
576        );
577
578        assert_eq!(scalar1.as_struct(), scalar2.as_struct());
579        assert_ne!(scalar1.as_struct(), scalar3.as_struct());
580    }
581
582    #[test]
583    fn test_struct_partial_ord() {
584        let (_, _, dtype) = setup_types();
585
586        let scalar1 = Scalar::struct_(
587            dtype.clone(),
588            vec![
589                Scalar::primitive::<i32>(1, Nullability::NonNullable),
590                Scalar::utf8("a", Nullability::NonNullable),
591            ],
592        );
593
594        let scalar2 = Scalar::struct_(
595            dtype,
596            vec![
597                Scalar::primitive::<i32>(2, Nullability::NonNullable),
598                Scalar::utf8("b", Nullability::NonNullable),
599            ],
600        );
601
602        // Structs with same dtype can be compared
603        assert!(scalar1.as_struct() < scalar2.as_struct());
604
605        // Different struct types cannot be compared
606        let other_dtype = DType::Struct(
607            StructFields::new(
608                ["c"].into(),
609                vec![DType::Primitive(I32, Nullability::NonNullable)],
610            ),
611            Nullability::NonNullable,
612        );
613        let scalar3 = Scalar::struct_(
614            other_dtype,
615            vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
616        );
617
618        assert_eq!(scalar1.as_struct().partial_cmp(&scalar3.as_struct()), None);
619    }
620
621    #[test]
622    fn test_struct_hash() {
623        use std::collections::hash_map::DefaultHasher;
624        use std::hash::{Hash, Hasher};
625
626        let (_, _, dtype) = setup_types();
627
628        let scalar1 = Scalar::struct_(
629            dtype.clone(),
630            vec![
631                Scalar::primitive::<i32>(1, Nullability::NonNullable),
632                Scalar::utf8("test", Nullability::NonNullable),
633            ],
634        );
635
636        let scalar2 = Scalar::struct_(
637            dtype,
638            vec![
639                Scalar::primitive::<i32>(1, Nullability::NonNullable),
640                Scalar::utf8("test", Nullability::NonNullable),
641            ],
642        );
643
644        let mut hasher1 = DefaultHasher::new();
645        scalar1.as_struct().hash(&mut hasher1);
646        let hash1 = hasher1.finish();
647
648        let mut hasher2 = DefaultHasher::new();
649        scalar2.as_struct().hash(&mut hasher2);
650        let hash2 = hasher2.finish();
651
652        assert_eq!(hash1, hash2);
653    }
654
655    #[test]
656    fn test_struct_try_new_non_struct_dtype() {
657        let dtype = DType::Primitive(I32, Nullability::NonNullable);
658        let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
659
660        let result = StructScalar::try_new(&dtype, &value);
661        assert!(result.is_err());
662    }
663
664    #[test]
665    fn test_struct_field_out_of_bounds() {
666        let (_, _, dtype) = setup_types();
667        let scalar = Scalar::struct_(
668            dtype,
669            vec![
670                Scalar::primitive::<i32>(1, Nullability::NonNullable),
671                Scalar::utf8("test", Nullability::NonNullable),
672            ],
673        );
674
675        // Try to access field beyond bounds
676        let field = scalar.as_struct().field_by_idx(10);
677        assert!(field.is_none());
678    }
679}