vortex_scalar/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5use std::hash::{Hash, Hasher};
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
11use vortex_error::{
12    VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
13};
14
15use crate::{InnerScalarValue, Scalar, ScalarValue};
16
17/// A scalar value representing a struct with named fields.
18///
19/// This type provides a view into a struct scalar value, which can contain
20/// named fields with different types, or be null.
21#[derive(Debug)]
22pub struct StructScalar<'a> {
23    dtype: &'a DType,
24    fields: Option<&'a Arc<[ScalarValue]>>,
25}
26
27impl Display for StructScalar<'_> {
28    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
29        match &self.fields {
30            None => write!(f, "null"),
31            Some(fields) => {
32                write!(f, "{{")?;
33                let formatted_fields = self
34                    .names()
35                    .iter()
36                    .zip_eq(self.struct_fields().fields())
37                    .zip_eq(fields.iter())
38                    .map(|((name, dtype), value)| {
39                        let val = Scalar::new(dtype, value.clone());
40                        format!("{name}: {val}")
41                    })
42                    .format(", ");
43                write!(f, "{formatted_fields}")?;
44                write!(f, "}}")
45            }
46        }
47    }
48}
49
50impl PartialEq for StructScalar<'_> {
51    fn eq(&self, other: &Self) -> bool {
52        if !self.dtype.eq_ignore_nullability(other.dtype) {
53            return false;
54        }
55        self.fields() == other.fields()
56    }
57}
58
59impl Eq for StructScalar<'_> {}
60
61/// Ord is not implemented since it's undefined for different field DTypes
62impl PartialOrd for StructScalar<'_> {
63    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
64        if !self.dtype.eq_ignore_nullability(other.dtype) {
65            return None;
66        }
67        self.fields().partial_cmp(&other.fields())
68    }
69}
70
71impl Hash for StructScalar<'_> {
72    fn hash<H: Hasher>(&self, state: &mut H) {
73        self.dtype.hash(state);
74        self.fields().hash(state);
75    }
76}
77
78impl<'a> StructScalar<'a> {
79    pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
80        if !matches!(dtype, DType::Struct(..)) {
81            vortex_bail!("Expected struct scalar, found {}", dtype)
82        }
83        Ok(Self {
84            dtype,
85            fields: value.as_list()?,
86        })
87    }
88
89    /// Returns the data type of this struct scalar.
90    #[inline]
91    pub fn dtype(&self) -> &'a DType {
92        self.dtype
93    }
94
95    /// Returns the struct field definitions.
96    #[inline]
97    pub fn struct_fields(&self) -> &StructFields {
98        self.dtype
99            .as_struct()
100            .vortex_expect("StructScalar always has struct dtype")
101    }
102
103    /// Returns the field names of the struct.
104    pub fn names(&self) -> &FieldNames {
105        self.struct_fields().names()
106    }
107
108    /// Returns true if the struct is null.
109    pub fn is_null(&self) -> bool {
110        self.fields.is_none()
111    }
112
113    /// Returns the field with the given name as a scalar.
114    ///
115    /// Returns None if the field doesn't exist.
116    pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
117        let idx = self.struct_fields().find(name)?;
118        self.field_by_idx(idx)
119    }
120
121    /// Returns the field at the given index as a scalar.
122    ///
123    /// Returns None if the index is out of bounds.
124    ///
125    /// # Panics
126    ///
127    /// Panics if the struct is null.
128    pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
129        let fields = self
130            .fields
131            .vortex_expect("Can't take field out of null struct scalar");
132        Some(Scalar::new(
133            self.struct_fields().field_by_index(idx)?,
134            fields[idx].clone(),
135        ))
136    }
137
138    /// Returns the fields of the struct scalar, or None if the scalar is null.
139    pub fn fields(&self) -> Option<Vec<Scalar>> {
140        let fields = self.fields?;
141        Some(
142            (0..fields.len())
143                .map(|index| {
144                    self.field_by_idx(index)
145                        .vortex_expect("never out of bounds")
146                })
147                .collect::<Vec<_>>(),
148        )
149    }
150
151    pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
152        self.fields.map(Arc::deref)
153    }
154
155    /// Casts this struct scalar to another struct type.
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if the target type is not a struct or if the number of fields don't match.
160    pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
161        let DType::Struct(st, _) = dtype else {
162            vortex_bail!(
163                "Cannot cast struct to {}: struct can only be cast to struct",
164                dtype
165            )
166        };
167        let own_st = self.struct_fields();
168
169        if st.fields().len() != own_st.fields().len() {
170            vortex_bail!(
171                "Cannot cast between structs with different number of fields: {} and {}",
172                own_st.fields().len(),
173                st.fields().len()
174            );
175        }
176
177        if let Some(fs) = self.field_values() {
178            let fields = fs
179                .iter()
180                .enumerate()
181                .map(|(i, f)| {
182                    Scalar::new(
183                        own_st
184                            .field_by_index(i)
185                            .vortex_expect("Iterating over scalar fields"),
186                        f.clone(),
187                    )
188                    .cast(
189                        &st.field_by_index(i)
190                            .vortex_expect("Iterating over scalar fields"),
191                    )
192                    .map(|s| s.value)
193                })
194                .collect::<VortexResult<Vec<_>>>()?;
195            Ok(Scalar::new(
196                dtype.clone(),
197                ScalarValue(InnerScalarValue::List(fields.into())),
198            ))
199        } else {
200            Ok(Scalar::null(dtype.clone()))
201        }
202    }
203
204    /// Projects this struct scalar to include only the specified fields.
205    ///
206    /// # Errors
207    ///
208    /// Returns an error if the struct cannot be projected or if a field is not found.
209    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
210        let struct_dtype = self
211            .dtype
212            .as_struct()
213            .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
214        let projected_dtype = struct_dtype.project(projection)?;
215        let new_fields = if let Some(fs) = self.field_values() {
216            ScalarValue(InnerScalarValue::List(
217                projection
218                    .iter()
219                    .map(|name| {
220                        struct_dtype
221                            .find(name)
222                            .vortex_expect("DType has been successfully projected already")
223                    })
224                    .map(|i| fs[i].clone())
225                    .collect(),
226            ))
227        } else {
228            ScalarValue(InnerScalarValue::Null)
229        };
230        Ok(Scalar::new(
231            DType::Struct(projected_dtype, self.dtype().nullability()),
232            new_fields,
233        ))
234    }
235}
236
237impl Scalar {
238    /// Creates a new struct scalar with the given fields.
239    pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
240        let DType::Struct(struct_fields, _) = &dtype else {
241            vortex_panic!("Expected struct dtype, found {}", dtype);
242        };
243
244        let field_dtypes = struct_fields.fields();
245        if children.len() != field_dtypes.len() {
246            vortex_panic!(
247                "Struct has {} fields but {} children were provided",
248                field_dtypes.len(),
249                children.len()
250            );
251        }
252
253        for (idx, (child, expected_dtype)) in children.iter().zip(field_dtypes).enumerate() {
254            if child.dtype() != &expected_dtype {
255                vortex_panic!(
256                    "Field {} expected dtype {} but got {}",
257                    idx,
258                    expected_dtype,
259                    child.dtype()
260                );
261            }
262        }
263
264        Self::new(
265            dtype,
266            ScalarValue(InnerScalarValue::List(
267                children
268                    .into_iter()
269                    .map(|x| x.into_value())
270                    .collect_vec()
271                    .into(),
272            )),
273        )
274    }
275}
276
277impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
278    type Error = VortexError;
279
280    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
281        Self::try_new(value.dtype(), &value.value)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use vortex_dtype::PType::I32;
288    use vortex_dtype::{DType, Nullability, StructFields};
289
290    use super::*;
291
292    fn setup_types() -> (DType, DType, DType) {
293        let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
294        let f1_dt = DType::Utf8(Nullability::NonNullable);
295
296        let dtype = DType::Struct(
297            StructFields::new(
298                vec!["a".into(), "b".into()].into(),
299                vec![f0_dt.clone(), f1_dt.clone()],
300            ),
301            Nullability::Nullable,
302        );
303
304        (f0_dt, f1_dt, dtype)
305    }
306
307    #[test]
308    #[should_panic]
309    fn test_struct_scalar_null() {
310        let (_, _, dtype) = setup_types();
311
312        let scalar = Scalar::null(dtype);
313
314        scalar.as_struct().field_by_idx(0).unwrap();
315    }
316
317    #[test]
318    fn test_struct_scalar_non_null() {
319        let (f0_dt, f1_dt, dtype) = setup_types();
320
321        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
322        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
323
324        let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
325        let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
326
327        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
328
329        let scalar_f0 = scalar.as_struct().field_by_idx(0);
330        assert!(scalar_f0.is_some());
331        let scalar_f0 = scalar_f0.unwrap();
332        assert_eq!(scalar_f0, f0_val_null);
333        assert_eq!(scalar_f0.dtype(), &f0_dt);
334
335        let scalar_f1 = scalar.as_struct().field_by_idx(1);
336        assert!(scalar_f1.is_some());
337        let scalar_f1 = scalar_f1.unwrap();
338        assert_eq!(scalar_f1, f1_val_null);
339        assert_eq!(scalar_f1.dtype(), &f1_dt);
340    }
341
342    #[test]
343    #[should_panic(expected = "Expected struct dtype")]
344    fn test_struct_scalar_wrong_dtype() {
345        let dtype = DType::Primitive(I32, Nullability::NonNullable);
346        let scalar = Scalar::primitive::<i32>(1, Nullability::NonNullable);
347
348        Scalar::struct_(dtype, vec![scalar]);
349    }
350
351    #[test]
352    #[should_panic(expected = "Struct has 2 fields but 1 children were provided")]
353    fn test_struct_scalar_wrong_child_count() {
354        let (_, _, dtype) = setup_types();
355        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
356
357        Scalar::struct_(dtype, vec![f0_val]);
358    }
359
360    #[test]
361    #[should_panic(expected = "Field 0 expected dtype i32 but got utf8")]
362    fn test_struct_scalar_wrong_child_dtype() {
363        let (_, _, dtype) = setup_types();
364        let f0_val = Scalar::utf8("wrong", Nullability::NonNullable);
365        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
366
367        Scalar::struct_(dtype, vec![f0_val, f1_val]);
368    }
369
370    #[test]
371    fn test_struct_field_by_name() {
372        let (_, _, dtype) = setup_types();
373        let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
374        let f1_val = Scalar::utf8("world", Nullability::NonNullable);
375
376        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
377
378        // Get field by name
379        let field_a = scalar.as_struct().field("a");
380        assert!(field_a.is_some());
381        assert_eq!(
382            field_a
383                .unwrap()
384                .as_primitive()
385                .typed_value::<i32>()
386                .unwrap(),
387            42
388        );
389
390        let field_b = scalar.as_struct().field("b");
391        assert!(field_b.is_some());
392        assert_eq!(field_b.unwrap().as_utf8().value().unwrap(), "world".into());
393
394        // Non-existent field
395        let field_c = scalar.as_struct().field("c");
396        assert!(field_c.is_none());
397    }
398
399    #[test]
400    fn test_struct_fields() {
401        let (_, _, dtype) = setup_types();
402        let f0_val = Scalar::primitive::<i32>(100, Nullability::NonNullable);
403        let f1_val = Scalar::utf8("test", Nullability::NonNullable);
404
405        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
406
407        let fields = scalar.as_struct().fields().unwrap();
408        assert_eq!(fields.len(), 2);
409        assert_eq!(fields[0].as_primitive().typed_value::<i32>().unwrap(), 100);
410        assert_eq!(fields[1].as_utf8().value().unwrap(), "test".into());
411    }
412
413    #[test]
414    fn test_struct_null_fields() {
415        let (_, _, dtype) = setup_types();
416        let null_scalar = Scalar::null(dtype);
417
418        assert!(null_scalar.as_struct().is_null());
419        assert!(null_scalar.as_struct().fields().is_none());
420        assert!(null_scalar.as_struct().field_values().is_none());
421    }
422
423    #[test]
424    fn test_struct_cast_to_struct() {
425        // Create source struct
426        let source_fields = StructFields::new(
427            vec!["x".into(), "y".into()].into(),
428            vec![
429                DType::Primitive(I32, Nullability::NonNullable),
430                DType::Primitive(I32, Nullability::NonNullable),
431            ],
432        );
433        let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
434
435        // Create target struct with different field types
436        let target_fields = StructFields::new(
437            vec!["x".into(), "y".into()].into(),
438            vec![
439                DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
440                DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
441            ],
442        );
443        let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
444
445        let f0 = Scalar::primitive::<i32>(42, Nullability::NonNullable);
446        let f1 = Scalar::primitive::<i32>(123, Nullability::NonNullable);
447        let source_scalar = Scalar::struct_(source_dtype, vec![f0, f1]);
448
449        // Cast to target type
450        let result = source_scalar.as_struct().cast(&target_dtype).unwrap();
451        assert_eq!(result.dtype(), &target_dtype);
452
453        let fields = result.as_struct().fields().unwrap();
454        assert_eq!(fields[0].as_primitive().typed_value::<i64>().unwrap(), 42);
455        assert_eq!(fields[1].as_primitive().typed_value::<i64>().unwrap(), 123);
456    }
457
458    #[test]
459    fn test_struct_cast_mismatched_fields() {
460        let source_fields = StructFields::new(
461            vec!["a".into()].into(),
462            vec![DType::Primitive(I32, Nullability::NonNullable)],
463        );
464        let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
465
466        let target_fields = StructFields::new(
467            vec!["a".into(), "b".into()].into(),
468            vec![
469                DType::Primitive(I32, Nullability::NonNullable),
470                DType::Primitive(I32, Nullability::NonNullable),
471            ],
472        );
473        let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
474
475        let scalar = Scalar::struct_(
476            source_dtype,
477            vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
478        );
479
480        let result = scalar.as_struct().cast(&target_dtype);
481        assert!(result.is_err());
482    }
483
484    #[test]
485    fn test_struct_cast_to_non_struct() {
486        let (_, _, dtype) = setup_types();
487        let scalar = Scalar::struct_(
488            dtype,
489            vec![
490                Scalar::primitive::<i32>(1, Nullability::NonNullable),
491                Scalar::utf8("test", Nullability::NonNullable),
492            ],
493        );
494
495        let result = scalar
496            .as_struct()
497            .cast(&DType::Primitive(I32, Nullability::NonNullable));
498        assert!(result.is_err());
499    }
500
501    #[test]
502    fn test_struct_project() {
503        let (_, _, dtype) = setup_types();
504        let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
505        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
506
507        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
508
509        // Project to only field "b"
510        let projected = scalar.as_struct().project(&["b".into()]).unwrap();
511        let projected_struct = projected.as_struct();
512
513        assert_eq!(projected_struct.names().len(), 1);
514        assert_eq!(projected_struct.names()[0].as_ref(), "b");
515
516        let fields = projected_struct.fields().unwrap();
517        assert_eq!(fields.len(), 1);
518        assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello");
519    }
520
521    #[test]
522    fn test_struct_project_null() {
523        let (_, _, dtype) = setup_types();
524        let null_scalar = Scalar::null(dtype);
525
526        let projected = null_scalar.as_struct().project(&["a".into()]).unwrap();
527        assert!(projected.as_struct().is_null());
528    }
529
530    #[test]
531    fn test_struct_equality() {
532        let (_, _, dtype) = setup_types();
533
534        let scalar1 = Scalar::struct_(
535            dtype.clone(),
536            vec![
537                Scalar::primitive::<i32>(1, Nullability::NonNullable),
538                Scalar::utf8("test", Nullability::NonNullable),
539            ],
540        );
541
542        let scalar2 = Scalar::struct_(
543            dtype.clone(),
544            vec![
545                Scalar::primitive::<i32>(1, Nullability::NonNullable),
546                Scalar::utf8("test", Nullability::NonNullable),
547            ],
548        );
549
550        let scalar3 = Scalar::struct_(
551            dtype,
552            vec![
553                Scalar::primitive::<i32>(2, Nullability::NonNullable),
554                Scalar::utf8("test", Nullability::NonNullable),
555            ],
556        );
557
558        assert_eq!(scalar1.as_struct(), scalar2.as_struct());
559        assert_ne!(scalar1.as_struct(), scalar3.as_struct());
560    }
561
562    #[test]
563    fn test_struct_partial_ord() {
564        let (_, _, dtype) = setup_types();
565
566        let scalar1 = Scalar::struct_(
567            dtype.clone(),
568            vec![
569                Scalar::primitive::<i32>(1, Nullability::NonNullable),
570                Scalar::utf8("a", Nullability::NonNullable),
571            ],
572        );
573
574        let scalar2 = Scalar::struct_(
575            dtype,
576            vec![
577                Scalar::primitive::<i32>(2, Nullability::NonNullable),
578                Scalar::utf8("b", Nullability::NonNullable),
579            ],
580        );
581
582        // Structs with same dtype can be compared
583        assert!(scalar1.as_struct() < scalar2.as_struct());
584
585        // Different struct types cannot be compared
586        let other_dtype = DType::Struct(
587            StructFields::new(
588                vec!["c".into()].into(),
589                vec![DType::Primitive(I32, Nullability::NonNullable)],
590            ),
591            Nullability::NonNullable,
592        );
593        let scalar3 = Scalar::struct_(
594            other_dtype,
595            vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
596        );
597
598        assert_eq!(scalar1.as_struct().partial_cmp(&scalar3.as_struct()), None);
599    }
600
601    #[test]
602    fn test_struct_hash() {
603        use std::collections::hash_map::DefaultHasher;
604        use std::hash::{Hash, Hasher};
605
606        let (_, _, dtype) = setup_types();
607
608        let scalar1 = Scalar::struct_(
609            dtype.clone(),
610            vec![
611                Scalar::primitive::<i32>(1, Nullability::NonNullable),
612                Scalar::utf8("test", Nullability::NonNullable),
613            ],
614        );
615
616        let scalar2 = Scalar::struct_(
617            dtype,
618            vec![
619                Scalar::primitive::<i32>(1, Nullability::NonNullable),
620                Scalar::utf8("test", Nullability::NonNullable),
621            ],
622        );
623
624        let mut hasher1 = DefaultHasher::new();
625        scalar1.as_struct().hash(&mut hasher1);
626        let hash1 = hasher1.finish();
627
628        let mut hasher2 = DefaultHasher::new();
629        scalar2.as_struct().hash(&mut hasher2);
630        let hash2 = hasher2.finish();
631
632        assert_eq!(hash1, hash2);
633    }
634
635    #[test]
636    fn test_struct_try_new_non_struct_dtype() {
637        let dtype = DType::Primitive(I32, Nullability::NonNullable);
638        let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
639
640        let result = StructScalar::try_new(&dtype, &value);
641        assert!(result.is_err());
642    }
643
644    #[test]
645    fn test_struct_field_out_of_bounds() {
646        let (_, _, dtype) = setup_types();
647        let scalar = Scalar::struct_(
648            dtype,
649            vec![
650                Scalar::primitive::<i32>(1, Nullability::NonNullable),
651                Scalar::utf8("test", Nullability::NonNullable),
652            ],
653        );
654
655        // Try to access field beyond bounds
656        let field = scalar.as_struct().field_by_idx(10);
657        assert!(field.is_none());
658    }
659}