vortex_scalar/
struct_.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{Hash, Hasher};
3use std::ops::Deref;
4use std::sync::Arc;
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, StructDType};
8use vortex_error::{
9    VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
10};
11
12use crate::{InnerScalarValue, Scalar, ScalarValue};
13
14pub struct StructScalar<'a> {
15    dtype: &'a DType,
16    fields: Option<&'a Arc<[ScalarValue]>>,
17}
18
19impl Display for StructScalar<'_> {
20    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21        match &self.fields {
22            None => write!(f, "null"),
23            Some(fields) => {
24                write!(f, "{{")?;
25                let formatted_fields = self
26                    .struct_dtype()
27                    .names()
28                    .iter()
29                    .zip_eq(self.struct_dtype().fields())
30                    .zip_eq(fields.iter())
31                    .map(|((name, dtype), value)| {
32                        let val = Scalar::new(dtype, value.clone());
33                        format!("{name}: {val}")
34                    })
35                    .format(", ");
36                write!(f, "{}", formatted_fields)?;
37                write!(f, "}}")
38            }
39        }
40    }
41}
42
43impl PartialEq for StructScalar<'_> {
44    fn eq(&self, other: &Self) -> bool {
45        if !self.dtype.eq_ignore_nullability(other.dtype) {
46            return false;
47        }
48        self.fields() == other.fields()
49    }
50}
51
52impl Eq for StructScalar<'_> {}
53
54/// Ord is not implemented since it's undefined for different field DTypes
55impl PartialOrd for StructScalar<'_> {
56    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
57        if !self.dtype.eq_ignore_nullability(other.dtype) {
58            return None;
59        }
60        self.fields().partial_cmp(&other.fields())
61    }
62}
63
64impl Hash for StructScalar<'_> {
65    fn hash<H: Hasher>(&self, state: &mut H) {
66        self.dtype.hash(state);
67        self.fields().hash(state);
68    }
69}
70
71impl<'a> StructScalar<'a> {
72    pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
73        if !matches!(dtype, DType::Struct(..)) {
74            vortex_bail!("Expected struct scalar, found {}", dtype)
75        }
76        Ok(Self {
77            dtype,
78            fields: value.as_list()?,
79        })
80    }
81
82    #[inline]
83    pub fn dtype(&self) -> &'a DType {
84        self.dtype
85    }
86
87    #[inline]
88    pub fn struct_dtype(&self) -> &Arc<StructDType> {
89        let DType::Struct(sdtype, ..) = self.dtype else {
90            vortex_panic!("StructScalar always has struct dtype");
91        };
92        sdtype
93    }
94
95    pub fn is_null(&self) -> bool {
96        self.fields.is_none()
97    }
98
99    pub fn field(&self, name: impl AsRef<str>) -> VortexResult<Scalar> {
100        let DType::Struct(st, _) = self.dtype() else {
101            unreachable!()
102        };
103        let idx = st.find(name)?;
104        self.field_by_idx(idx)
105    }
106
107    pub fn field_by_idx(&self, idx: usize) -> VortexResult<Scalar> {
108        let fields = self
109            .fields
110            .vortex_expect("Can't take field out of null struct scalar");
111        let DType::Struct(st, _) = self.dtype() else {
112            unreachable!()
113        };
114
115        Ok(Scalar {
116            dtype: st.field_by_index(idx)?,
117            value: fields[idx].clone(),
118        })
119    }
120
121    /// Returns the fields of the struct scalar, or None if the scalar is null.
122    pub fn fields(&self) -> Option<Vec<Scalar>> {
123        let fields = self.fields?;
124        Some(
125            (0..fields.len())
126                .map(|index| {
127                    self.field_by_idx(index)
128                        .vortex_expect("never out of bounds")
129                })
130                .collect::<Vec<_>>(),
131        )
132    }
133
134    pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
135        self.fields.map(Arc::deref)
136    }
137
138    pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
139        let DType::Struct(st, _) = dtype else {
140            vortex_bail!("Can only cast struct to another struct")
141        };
142        let DType::Struct(own_st, _) = self.dtype() else {
143            unreachable!()
144        };
145
146        if st.fields().len() != own_st.fields().len() {
147            vortex_bail!(
148                "Cannot cast between structs with different number of fields: {} and {}",
149                own_st.fields().len(),
150                st.fields().len()
151            );
152        }
153
154        if let Some(fs) = self.field_values() {
155            let fields = fs
156                .iter()
157                .enumerate()
158                .map(|(i, f)| {
159                    Scalar {
160                        dtype: own_st.field_by_index(i)?,
161                        value: f.clone(),
162                    }
163                    .cast(&st.field_by_index(i)?)
164                    .map(|s| s.value)
165                })
166                .collect::<VortexResult<Vec<_>>>()?;
167            Ok(Scalar {
168                dtype: dtype.clone(),
169                value: ScalarValue(InnerScalarValue::List(fields.into())),
170            })
171        } else {
172            Ok(Scalar::null(dtype.clone()))
173        }
174    }
175
176    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
177        let struct_dtype = self
178            .dtype
179            .as_struct()
180            .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
181        let projected_dtype = struct_dtype.project(projection)?;
182        let new_fields = if let Some(fs) = self.field_values() {
183            ScalarValue(InnerScalarValue::List(
184                projection
185                    .iter()
186                    .map(|name| {
187                        struct_dtype
188                            .find(name)
189                            .vortex_expect("DType has been successfully projected already")
190                    })
191                    .map(|i| fs[i].clone())
192                    .collect(),
193            ))
194        } else {
195            ScalarValue(InnerScalarValue::Null)
196        };
197        Ok(Scalar::new(
198            DType::Struct(Arc::new(projected_dtype), self.dtype().nullability()),
199            new_fields,
200        ))
201    }
202}
203
204impl Scalar {
205    pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
206        Self {
207            dtype,
208            value: ScalarValue(InnerScalarValue::List(
209                children
210                    .into_iter()
211                    .map(|x| x.into_value())
212                    .collect_vec()
213                    .into(),
214            )),
215        }
216    }
217}
218
219impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
220    type Error = VortexError;
221
222    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
223        Self::try_new(value.dtype(), &value.value)
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use vortex_dtype::PType::I32;
230    use vortex_dtype::{DType, Nullability, StructDType};
231
232    use super::*;
233
234    fn setup_types() -> (DType, DType, DType) {
235        let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
236        let f1_dt = DType::Utf8(Nullability::NonNullable);
237
238        let dtype = DType::Struct(
239            Arc::new(StructDType::new(
240                vec!["a".into(), "b".into()].into(),
241                vec![f0_dt.clone(), f1_dt.clone()],
242            )),
243            Nullability::Nullable,
244        );
245
246        (f0_dt, f1_dt, dtype)
247    }
248
249    #[test]
250    #[should_panic]
251    fn test_struct_scalar_null() {
252        let (_, _, dtype) = setup_types();
253
254        let scalar = Scalar::null(dtype);
255
256        scalar.as_struct().field_by_idx(0).unwrap();
257    }
258
259    #[test]
260    fn test_struct_scalar_non_null() {
261        let (f0_dt, f1_dt, dtype) = setup_types();
262
263        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
264        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
265
266        let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
267        let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
268
269        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
270
271        let scalar_f0 = scalar.as_struct().field_by_idx(0);
272        assert!(scalar_f0.is_ok());
273        let scalar_f0 = scalar_f0.unwrap();
274        assert_eq!(scalar_f0, f0_val_null);
275        assert_eq!(scalar_f0.dtype(), &f0_dt);
276
277        let scalar_f1 = scalar.as_struct().field_by_idx(1);
278        assert!(scalar_f1.is_ok());
279        let scalar_f1 = scalar_f1.unwrap();
280        assert_eq!(scalar_f1, f1_val_null);
281        assert_eq!(scalar_f1.dtype(), &f1_dt);
282    }
283}