vortex_scalar/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5use std::hash::{Hash, Hasher};
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
11use vortex_error::{
12    VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
13};
14
15use crate::{InnerScalarValue, Scalar, ScalarValue};
16
17/// A scalar value representing a struct with named fields.
18///
19/// This type provides a view into a struct scalar value, which can contain
20/// named fields with different types, or be null.
21pub struct StructScalar<'a> {
22    dtype: &'a DType,
23    fields: Option<&'a Arc<[ScalarValue]>>,
24}
25
26impl Display for StructScalar<'_> {
27    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28        match &self.fields {
29            None => write!(f, "null"),
30            Some(fields) => {
31                write!(f, "{{")?;
32                let formatted_fields = self
33                    .names()
34                    .iter()
35                    .zip_eq(self.struct_fields().fields())
36                    .zip_eq(fields.iter())
37                    .map(|((name, dtype), value)| {
38                        let val = Scalar::new(dtype, value.clone());
39                        format!("{name}: {val}")
40                    })
41                    .format(", ");
42                write!(f, "{formatted_fields}")?;
43                write!(f, "}}")
44            }
45        }
46    }
47}
48
49impl PartialEq for StructScalar<'_> {
50    fn eq(&self, other: &Self) -> bool {
51        if !self.dtype.eq_ignore_nullability(other.dtype) {
52            return false;
53        }
54        self.fields() == other.fields()
55    }
56}
57
58impl Eq for StructScalar<'_> {}
59
60/// Ord is not implemented since it's undefined for different field DTypes
61impl PartialOrd for StructScalar<'_> {
62    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
63        if !self.dtype.eq_ignore_nullability(other.dtype) {
64            return None;
65        }
66        self.fields().partial_cmp(&other.fields())
67    }
68}
69
70impl Hash for StructScalar<'_> {
71    fn hash<H: Hasher>(&self, state: &mut H) {
72        self.dtype.hash(state);
73        self.fields().hash(state);
74    }
75}
76
77impl<'a> StructScalar<'a> {
78    pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
79        if !matches!(dtype, DType::Struct(..)) {
80            vortex_bail!("Expected struct scalar, found {}", dtype)
81        }
82        Ok(Self {
83            dtype,
84            fields: value.as_list()?,
85        })
86    }
87
88    /// Returns the data type of this struct scalar.
89    #[inline]
90    pub fn dtype(&self) -> &'a DType {
91        self.dtype
92    }
93
94    /// Returns the struct field definitions.
95    #[inline]
96    pub fn struct_fields(&self) -> &StructFields {
97        self.dtype
98            .as_struct()
99            .vortex_expect("StructScalar always has struct dtype")
100    }
101
102    /// Returns the field names of the struct.
103    pub fn names(&self) -> &FieldNames {
104        self.struct_fields().names()
105    }
106
107    /// Returns true if the struct is null.
108    pub fn is_null(&self) -> bool {
109        self.fields.is_none()
110    }
111
112    /// Returns the field with the given name as a scalar.
113    ///
114    /// Returns None if the field doesn't exist.
115    pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
116        let idx = self.struct_fields().find(name)?;
117        self.field_by_idx(idx)
118    }
119
120    /// Returns the field at the given index as a scalar.
121    ///
122    /// Returns None if the index is out of bounds.
123    ///
124    /// # Panics
125    ///
126    /// Panics if the struct is null.
127    pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
128        let fields = self
129            .fields
130            .vortex_expect("Can't take field out of null struct scalar");
131        Some(Scalar {
132            dtype: self.struct_fields().field_by_index(idx)?,
133            value: fields[idx].clone(),
134        })
135    }
136
137    /// Returns the fields of the struct scalar, or None if the scalar is null.
138    pub fn fields(&self) -> Option<Vec<Scalar>> {
139        let fields = self.fields?;
140        Some(
141            (0..fields.len())
142                .map(|index| {
143                    self.field_by_idx(index)
144                        .vortex_expect("never out of bounds")
145                })
146                .collect::<Vec<_>>(),
147        )
148    }
149
150    pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
151        self.fields.map(Arc::deref)
152    }
153
154    /// Casts this struct scalar to another struct type.
155    ///
156    /// # Errors
157    ///
158    /// Returns an error if the target type is not a struct or if the number of fields don't match.
159    pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
160        let DType::Struct(st, _) = dtype else {
161            vortex_bail!("Can only cast struct to another struct")
162        };
163        let own_st = self.struct_fields();
164
165        if st.fields().len() != own_st.fields().len() {
166            vortex_bail!(
167                "Cannot cast between structs with different number of fields: {} and {}",
168                own_st.fields().len(),
169                st.fields().len()
170            );
171        }
172
173        if let Some(fs) = self.field_values() {
174            let fields = fs
175                .iter()
176                .enumerate()
177                .map(|(i, f)| {
178                    Scalar {
179                        dtype: own_st
180                            .field_by_index(i)
181                            .vortex_expect("Iterating over scalar fields"),
182                        value: f.clone(),
183                    }
184                    .cast(
185                        &st.field_by_index(i)
186                            .vortex_expect("Iterating over scalar fields"),
187                    )
188                    .map(|s| s.value)
189                })
190                .collect::<VortexResult<Vec<_>>>()?;
191            Ok(Scalar {
192                dtype: dtype.clone(),
193                value: ScalarValue(InnerScalarValue::List(fields.into())),
194            })
195        } else {
196            Ok(Scalar::null(dtype.clone()))
197        }
198    }
199
200    /// Projects this struct scalar to include only the specified fields.
201    ///
202    /// # Errors
203    ///
204    /// Returns an error if the struct cannot be projected or if a field is not found.
205    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
206        let struct_dtype = self
207            .dtype
208            .as_struct()
209            .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
210        let projected_dtype = struct_dtype.project(projection)?;
211        let new_fields = if let Some(fs) = self.field_values() {
212            ScalarValue(InnerScalarValue::List(
213                projection
214                    .iter()
215                    .map(|name| {
216                        struct_dtype
217                            .find(name)
218                            .vortex_expect("DType has been successfully projected already")
219                    })
220                    .map(|i| fs[i].clone())
221                    .collect(),
222            ))
223        } else {
224            ScalarValue(InnerScalarValue::Null)
225        };
226        Ok(Scalar::new(
227            DType::Struct(projected_dtype, self.dtype().nullability()),
228            new_fields,
229        ))
230    }
231}
232
233impl Scalar {
234    /// Creates a new struct scalar with the given fields.
235    pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
236        let DType::Struct(struct_fields, _) = &dtype else {
237            vortex_panic!("Expected struct dtype, found {}", dtype);
238        };
239
240        let field_dtypes = struct_fields.fields();
241        if children.len() != field_dtypes.len() {
242            vortex_panic!(
243                "Struct has {} fields but {} children were provided",
244                field_dtypes.len(),
245                children.len()
246            );
247        }
248
249        for (idx, (child, expected_dtype)) in children.iter().zip(field_dtypes).enumerate() {
250            if child.dtype() != &expected_dtype {
251                vortex_panic!(
252                    "Field {} expected dtype {} but got {}",
253                    idx,
254                    expected_dtype,
255                    child.dtype()
256                );
257            }
258        }
259
260        Self {
261            dtype,
262            value: ScalarValue(InnerScalarValue::List(
263                children
264                    .into_iter()
265                    .map(|x| x.into_value())
266                    .collect_vec()
267                    .into(),
268            )),
269        }
270    }
271}
272
273impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
274    type Error = VortexError;
275
276    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
277        Self::try_new(value.dtype(), &value.value)
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use vortex_dtype::PType::I32;
284    use vortex_dtype::{DType, Nullability, StructFields};
285
286    use super::*;
287
288    fn setup_types() -> (DType, DType, DType) {
289        let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
290        let f1_dt = DType::Utf8(Nullability::NonNullable);
291
292        let dtype = DType::Struct(
293            StructFields::new(
294                vec!["a".into(), "b".into()].into(),
295                vec![f0_dt.clone(), f1_dt.clone()],
296            ),
297            Nullability::Nullable,
298        );
299
300        (f0_dt, f1_dt, dtype)
301    }
302
303    #[test]
304    #[should_panic]
305    fn test_struct_scalar_null() {
306        let (_, _, dtype) = setup_types();
307
308        let scalar = Scalar::null(dtype);
309
310        scalar.as_struct().field_by_idx(0).unwrap();
311    }
312
313    #[test]
314    fn test_struct_scalar_non_null() {
315        let (f0_dt, f1_dt, dtype) = setup_types();
316
317        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
318        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
319
320        let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
321        let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
322
323        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
324
325        let scalar_f0 = scalar.as_struct().field_by_idx(0);
326        assert!(scalar_f0.is_some());
327        let scalar_f0 = scalar_f0.unwrap();
328        assert_eq!(scalar_f0, f0_val_null);
329        assert_eq!(scalar_f0.dtype(), &f0_dt);
330
331        let scalar_f1 = scalar.as_struct().field_by_idx(1);
332        assert!(scalar_f1.is_some());
333        let scalar_f1 = scalar_f1.unwrap();
334        assert_eq!(scalar_f1, f1_val_null);
335        assert_eq!(scalar_f1.dtype(), &f1_dt);
336    }
337
338    #[test]
339    #[should_panic(expected = "Expected struct dtype")]
340    fn test_struct_scalar_wrong_dtype() {
341        let dtype = DType::Primitive(I32, Nullability::NonNullable);
342        let scalar = Scalar::primitive::<i32>(1, Nullability::NonNullable);
343
344        Scalar::struct_(dtype, vec![scalar]);
345    }
346
347    #[test]
348    #[should_panic(expected = "Struct has 2 fields but 1 children were provided")]
349    fn test_struct_scalar_wrong_child_count() {
350        let (_, _, dtype) = setup_types();
351        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
352
353        Scalar::struct_(dtype, vec![f0_val]);
354    }
355
356    #[test]
357    #[should_panic(expected = "Field 0 expected dtype i32 but got utf8")]
358    fn test_struct_scalar_wrong_child_dtype() {
359        let (_, _, dtype) = setup_types();
360        let f0_val = Scalar::utf8("wrong", Nullability::NonNullable);
361        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
362
363        Scalar::struct_(dtype, vec![f0_val, f1_val]);
364    }
365}