vortex_scalar/
struct_.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{Hash, Hasher};
3use std::ops::Deref;
4use std::sync::Arc;
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
8use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
9
10use crate::{InnerScalarValue, Scalar, ScalarValue};
11
12pub struct StructScalar<'a> {
13    dtype: &'a DType,
14    fields: Option<&'a Arc<[ScalarValue]>>,
15}
16
17impl Display for StructScalar<'_> {
18    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
19        match &self.fields {
20            None => write!(f, "null"),
21            Some(fields) => {
22                write!(f, "{{")?;
23                let formatted_fields = self
24                    .names()
25                    .iter()
26                    .zip_eq(self.struct_fields().fields())
27                    .zip_eq(fields.iter())
28                    .map(|((name, dtype), value)| {
29                        let val = Scalar::new(dtype, value.clone());
30                        format!("{name}: {val}")
31                    })
32                    .format(", ");
33                write!(f, "{formatted_fields}")?;
34                write!(f, "}}")
35            }
36        }
37    }
38}
39
40impl PartialEq for StructScalar<'_> {
41    fn eq(&self, other: &Self) -> bool {
42        if !self.dtype.eq_ignore_nullability(other.dtype) {
43            return false;
44        }
45        self.fields() == other.fields()
46    }
47}
48
49impl Eq for StructScalar<'_> {}
50
51/// Ord is not implemented since it's undefined for different field DTypes
52impl PartialOrd for StructScalar<'_> {
53    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
54        if !self.dtype.eq_ignore_nullability(other.dtype) {
55            return None;
56        }
57        self.fields().partial_cmp(&other.fields())
58    }
59}
60
61impl Hash for StructScalar<'_> {
62    fn hash<H: Hasher>(&self, state: &mut H) {
63        self.dtype.hash(state);
64        self.fields().hash(state);
65    }
66}
67
68impl<'a> StructScalar<'a> {
69    pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
70        if !matches!(dtype, DType::Struct(..)) {
71            vortex_bail!("Expected struct scalar, found {}", dtype)
72        }
73        Ok(Self {
74            dtype,
75            fields: value.as_list()?,
76        })
77    }
78
79    #[inline]
80    pub fn dtype(&self) -> &'a DType {
81        self.dtype
82    }
83
84    #[inline]
85    pub fn struct_fields(&self) -> &StructFields {
86        self.dtype
87            .as_struct()
88            .vortex_expect("StructScalar always has struct dtype")
89    }
90
91    pub fn names(&self) -> &FieldNames {
92        self.struct_fields().names()
93    }
94
95    pub fn is_null(&self) -> bool {
96        self.fields.is_none()
97    }
98
99    pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
100        let idx = self.struct_fields().find(name)?;
101        self.field_by_idx(idx)
102    }
103
104    pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
105        let fields = self
106            .fields
107            .vortex_expect("Can't take field out of null struct scalar");
108        Some(Scalar {
109            dtype: self.struct_fields().field_by_index(idx)?,
110            value: fields[idx].clone(),
111        })
112    }
113
114    /// Returns the fields of the struct scalar, or None if the scalar is null.
115    pub fn fields(&self) -> Option<Vec<Scalar>> {
116        let fields = self.fields?;
117        Some(
118            (0..fields.len())
119                .map(|index| {
120                    self.field_by_idx(index)
121                        .vortex_expect("never out of bounds")
122                })
123                .collect::<Vec<_>>(),
124        )
125    }
126
127    pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
128        self.fields.map(Arc::deref)
129    }
130
131    pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
132        let DType::Struct(st, _) = dtype else {
133            vortex_bail!("Can only cast struct to another struct")
134        };
135        let own_st = self.struct_fields();
136
137        if st.fields().len() != own_st.fields().len() {
138            vortex_bail!(
139                "Cannot cast between structs with different number of fields: {} and {}",
140                own_st.fields().len(),
141                st.fields().len()
142            );
143        }
144
145        if let Some(fs) = self.field_values() {
146            let fields = fs
147                .iter()
148                .enumerate()
149                .map(|(i, f)| {
150                    Scalar {
151                        dtype: own_st
152                            .field_by_index(i)
153                            .vortex_expect("Iterating over scalar fields"),
154                        value: f.clone(),
155                    }
156                    .cast(
157                        &st.field_by_index(i)
158                            .vortex_expect("Iterating over scalar fields"),
159                    )
160                    .map(|s| s.value)
161                })
162                .collect::<VortexResult<Vec<_>>>()?;
163            Ok(Scalar {
164                dtype: dtype.clone(),
165                value: ScalarValue(InnerScalarValue::List(fields.into())),
166            })
167        } else {
168            Ok(Scalar::null(dtype.clone()))
169        }
170    }
171
172    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
173        let struct_dtype = self
174            .dtype
175            .as_struct()
176            .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
177        let projected_dtype = struct_dtype.project(projection)?;
178        let new_fields = if let Some(fs) = self.field_values() {
179            ScalarValue(InnerScalarValue::List(
180                projection
181                    .iter()
182                    .map(|name| {
183                        struct_dtype
184                            .find(name)
185                            .vortex_expect("DType has been successfully projected already")
186                    })
187                    .map(|i| fs[i].clone())
188                    .collect(),
189            ))
190        } else {
191            ScalarValue(InnerScalarValue::Null)
192        };
193        Ok(Scalar::new(
194            DType::Struct(projected_dtype, self.dtype().nullability()),
195            new_fields,
196        ))
197    }
198}
199
200impl Scalar {
201    pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
202        Self {
203            dtype,
204            value: ScalarValue(InnerScalarValue::List(
205                children
206                    .into_iter()
207                    .map(|x| x.into_value())
208                    .collect_vec()
209                    .into(),
210            )),
211        }
212    }
213}
214
215impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
216    type Error = VortexError;
217
218    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
219        Self::try_new(value.dtype(), &value.value)
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use vortex_dtype::PType::I32;
226    use vortex_dtype::{DType, Nullability, StructFields};
227
228    use super::*;
229
230    fn setup_types() -> (DType, DType, DType) {
231        let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
232        let f1_dt = DType::Utf8(Nullability::NonNullable);
233
234        let dtype = DType::Struct(
235            StructFields::new(
236                vec!["a".into(), "b".into()].into(),
237                vec![f0_dt.clone(), f1_dt.clone()],
238            ),
239            Nullability::Nullable,
240        );
241
242        (f0_dt, f1_dt, dtype)
243    }
244
245    #[test]
246    #[should_panic]
247    fn test_struct_scalar_null() {
248        let (_, _, dtype) = setup_types();
249
250        let scalar = Scalar::null(dtype);
251
252        scalar.as_struct().field_by_idx(0).unwrap();
253    }
254
255    #[test]
256    fn test_struct_scalar_non_null() {
257        let (f0_dt, f1_dt, dtype) = setup_types();
258
259        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
260        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
261
262        let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
263        let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
264
265        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
266
267        let scalar_f0 = scalar.as_struct().field_by_idx(0);
268        assert!(scalar_f0.is_some());
269        let scalar_f0 = scalar_f0.unwrap();
270        assert_eq!(scalar_f0, f0_val_null);
271        assert_eq!(scalar_f0.dtype(), &f0_dt);
272
273        let scalar_f1 = scalar.as_struct().field_by_idx(1);
274        assert!(scalar_f1.is_some());
275        let scalar_f1 = scalar_f1.unwrap();
276        assert_eq!(scalar_f1, f1_val_null);
277        assert_eq!(scalar_f1.dtype(), &f1_dt);
278    }
279}