vortex_scalar/
struct_.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{Hash, Hasher};
3use std::ops::Deref;
4use std::sync::Arc;
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
8use vortex_error::{
9    VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
10};
11
12use crate::{InnerScalarValue, Scalar, ScalarValue};
13
14pub struct StructScalar<'a> {
15    dtype: &'a DType,
16    fields: Option<&'a Arc<[ScalarValue]>>,
17}
18
19impl Display for StructScalar<'_> {
20    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21        match &self.fields {
22            None => write!(f, "null"),
23            Some(fields) => {
24                write!(f, "{{")?;
25                let formatted_fields = self
26                    .names()
27                    .iter()
28                    .zip_eq(self.struct_fields().fields())
29                    .zip_eq(fields.iter())
30                    .map(|((name, dtype), value)| {
31                        let val = Scalar::new(dtype, value.clone());
32                        format!("{name}: {val}")
33                    })
34                    .format(", ");
35                write!(f, "{formatted_fields}")?;
36                write!(f, "}}")
37            }
38        }
39    }
40}
41
42impl PartialEq for StructScalar<'_> {
43    fn eq(&self, other: &Self) -> bool {
44        if !self.dtype.eq_ignore_nullability(other.dtype) {
45            return false;
46        }
47        self.fields() == other.fields()
48    }
49}
50
51impl Eq for StructScalar<'_> {}
52
53/// Ord is not implemented since it's undefined for different field DTypes
54impl PartialOrd for StructScalar<'_> {
55    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
56        if !self.dtype.eq_ignore_nullability(other.dtype) {
57            return None;
58        }
59        self.fields().partial_cmp(&other.fields())
60    }
61}
62
63impl Hash for StructScalar<'_> {
64    fn hash<H: Hasher>(&self, state: &mut H) {
65        self.dtype.hash(state);
66        self.fields().hash(state);
67    }
68}
69
70impl<'a> StructScalar<'a> {
71    pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
72        if !matches!(dtype, DType::Struct(..)) {
73            vortex_bail!("Expected struct scalar, found {}", dtype)
74        }
75        Ok(Self {
76            dtype,
77            fields: value.as_list()?,
78        })
79    }
80
81    #[inline]
82    pub fn dtype(&self) -> &'a DType {
83        self.dtype
84    }
85
86    #[inline]
87    pub fn struct_fields(&self) -> &Arc<StructFields> {
88        let DType::Struct(sdtype, ..) = self.dtype else {
89            vortex_panic!("StructScalar always has struct dtype");
90        };
91        sdtype
92    }
93
94    pub fn names(&self) -> &FieldNames {
95        self.struct_fields().names()
96    }
97
98    pub fn is_null(&self) -> bool {
99        self.fields.is_none()
100    }
101
102    pub fn field(&self, name: impl AsRef<str>) -> VortexResult<Scalar> {
103        let DType::Struct(st, _) = self.dtype() else {
104            unreachable!()
105        };
106        let idx = st.find(name)?;
107        self.field_by_idx(idx)
108    }
109
110    pub fn field_by_idx(&self, idx: usize) -> VortexResult<Scalar> {
111        let fields = self
112            .fields
113            .vortex_expect("Can't take field out of null struct scalar");
114        let DType::Struct(st, _) = self.dtype() else {
115            unreachable!()
116        };
117
118        Ok(Scalar {
119            dtype: st.field_by_index(idx)?,
120            value: fields[idx].clone(),
121        })
122    }
123
124    /// Returns the fields of the struct scalar, or None if the scalar is null.
125    pub fn fields(&self) -> Option<Vec<Scalar>> {
126        let fields = self.fields?;
127        Some(
128            (0..fields.len())
129                .map(|index| {
130                    self.field_by_idx(index)
131                        .vortex_expect("never out of bounds")
132                })
133                .collect::<Vec<_>>(),
134        )
135    }
136
137    pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
138        self.fields.map(Arc::deref)
139    }
140
141    pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
142        let DType::Struct(st, _) = dtype else {
143            vortex_bail!("Can only cast struct to another struct")
144        };
145        let DType::Struct(own_st, _) = self.dtype() else {
146            unreachable!()
147        };
148
149        if st.fields().len() != own_st.fields().len() {
150            vortex_bail!(
151                "Cannot cast between structs with different number of fields: {} and {}",
152                own_st.fields().len(),
153                st.fields().len()
154            );
155        }
156
157        if let Some(fs) = self.field_values() {
158            let fields = fs
159                .iter()
160                .enumerate()
161                .map(|(i, f)| {
162                    Scalar {
163                        dtype: own_st.field_by_index(i)?,
164                        value: f.clone(),
165                    }
166                    .cast(&st.field_by_index(i)?)
167                    .map(|s| s.value)
168                })
169                .collect::<VortexResult<Vec<_>>>()?;
170            Ok(Scalar {
171                dtype: dtype.clone(),
172                value: ScalarValue(InnerScalarValue::List(fields.into())),
173            })
174        } else {
175            Ok(Scalar::null(dtype.clone()))
176        }
177    }
178
179    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
180        let struct_dtype = self
181            .dtype
182            .as_struct()
183            .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
184        let projected_dtype = struct_dtype.project(projection)?;
185        let new_fields = if let Some(fs) = self.field_values() {
186            ScalarValue(InnerScalarValue::List(
187                projection
188                    .iter()
189                    .map(|name| {
190                        struct_dtype
191                            .find(name)
192                            .vortex_expect("DType has been successfully projected already")
193                    })
194                    .map(|i| fs[i].clone())
195                    .collect(),
196            ))
197        } else {
198            ScalarValue(InnerScalarValue::Null)
199        };
200        Ok(Scalar::new(
201            DType::Struct(Arc::new(projected_dtype), self.dtype().nullability()),
202            new_fields,
203        ))
204    }
205}
206
207impl Scalar {
208    pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
209        Self {
210            dtype,
211            value: ScalarValue(InnerScalarValue::List(
212                children
213                    .into_iter()
214                    .map(|x| x.into_value())
215                    .collect_vec()
216                    .into(),
217            )),
218        }
219    }
220}
221
222impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
223    type Error = VortexError;
224
225    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
226        Self::try_new(value.dtype(), &value.value)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use vortex_dtype::PType::I32;
233    use vortex_dtype::{DType, Nullability, StructFields};
234
235    use super::*;
236
237    fn setup_types() -> (DType, DType, DType) {
238        let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
239        let f1_dt = DType::Utf8(Nullability::NonNullable);
240
241        let dtype = DType::Struct(
242            Arc::new(StructFields::new(
243                vec!["a".into(), "b".into()].into(),
244                vec![f0_dt.clone(), f1_dt.clone()],
245            )),
246            Nullability::Nullable,
247        );
248
249        (f0_dt, f1_dt, dtype)
250    }
251
252    #[test]
253    #[should_panic]
254    fn test_struct_scalar_null() {
255        let (_, _, dtype) = setup_types();
256
257        let scalar = Scalar::null(dtype);
258
259        scalar.as_struct().field_by_idx(0).unwrap();
260    }
261
262    #[test]
263    fn test_struct_scalar_non_null() {
264        let (f0_dt, f1_dt, dtype) = setup_types();
265
266        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
267        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
268
269        let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
270        let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
271
272        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
273
274        let scalar_f0 = scalar.as_struct().field_by_idx(0);
275        assert!(scalar_f0.is_ok());
276        let scalar_f0 = scalar_f0.unwrap();
277        assert_eq!(scalar_f0, f0_val_null);
278        assert_eq!(scalar_f0.dtype(), &f0_dt);
279
280        let scalar_f1 = scalar.as_struct().field_by_idx(1);
281        assert!(scalar_f1.is_ok());
282        let scalar_f1 = scalar_f1.unwrap();
283        assert_eq!(scalar_f1, f1_val_null);
284        assert_eq!(scalar_f1.dtype(), &f1_dt);
285    }
286}