vortex_scalar/
struct_.rs

1use std::hash::{Hash, Hasher};
2use std::ops::Deref;
3use std::sync::Arc;
4
5use itertools::Itertools;
6use vortex_dtype::{DType, FieldName, StructDType};
7use vortex_error::{
8    VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
9};
10
11use crate::{InnerScalarValue, Scalar, ScalarValue};
12
13pub struct StructScalar<'a> {
14    dtype: &'a DType,
15    fields: Option<&'a Arc<[ScalarValue]>>,
16}
17
18impl PartialEq for StructScalar<'_> {
19    fn eq(&self, other: &Self) -> bool {
20        if !self.dtype.eq_ignore_nullability(other.dtype) {
21            return false;
22        }
23        self.fields() == other.fields()
24    }
25}
26
27impl Eq for StructScalar<'_> {}
28
29/// Ord is not implemented since it's undefined for different field DTypes
30impl PartialOrd for StructScalar<'_> {
31    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
32        if !self.dtype.eq_ignore_nullability(other.dtype) {
33            return None;
34        }
35        self.fields().partial_cmp(&other.fields())
36    }
37}
38
39impl Hash for StructScalar<'_> {
40    fn hash<H: Hasher>(&self, state: &mut H) {
41        self.dtype.hash(state);
42        self.fields().hash(state);
43    }
44}
45
46impl<'a> StructScalar<'a> {
47    pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
48        if !matches!(dtype, DType::Struct(..)) {
49            vortex_bail!("Expected struct scalar, found {}", dtype)
50        }
51        Ok(Self {
52            dtype,
53            fields: value.as_list()?,
54        })
55    }
56
57    #[inline]
58    pub fn dtype(&self) -> &'a DType {
59        self.dtype
60    }
61
62    #[inline]
63    pub fn struct_dtype(&self) -> &Arc<StructDType> {
64        let DType::Struct(sdtype, ..) = self.dtype else {
65            vortex_panic!("StructScalar always has struct dtype");
66        };
67        sdtype
68    }
69
70    pub fn is_null(&self) -> bool {
71        self.fields.is_none()
72    }
73
74    pub fn field(&self, name: impl AsRef<str>) -> VortexResult<Scalar> {
75        let DType::Struct(st, _) = self.dtype() else {
76            unreachable!()
77        };
78        let idx = st.find(name)?;
79        self.field_by_idx(idx)
80    }
81
82    pub fn field_by_idx(&self, idx: usize) -> VortexResult<Scalar> {
83        let fields = self
84            .fields
85            .vortex_expect("Can't take field out of null struct scalar");
86        let DType::Struct(st, _) = self.dtype() else {
87            unreachable!()
88        };
89
90        Ok(Scalar {
91            dtype: st.field_by_index(idx)?,
92            value: fields[idx].clone(),
93        })
94    }
95
96    /// Returns the fields of the struct scalar, or None if the scalar is null.
97    pub fn fields(&self) -> Option<Vec<Scalar>> {
98        let fields = self.fields?;
99        Some(
100            (0..fields.len())
101                .map(|index| {
102                    self.field_by_idx(index)
103                        .vortex_expect("never out of bounds")
104                })
105                .collect::<Vec<_>>(),
106        )
107    }
108
109    pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
110        self.fields.map(Arc::deref)
111    }
112
113    pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
114        let DType::Struct(st, _) = dtype else {
115            vortex_bail!("Can only cast struct to another struct")
116        };
117        let DType::Struct(own_st, _) = self.dtype() else {
118            unreachable!()
119        };
120
121        if st.fields().len() != own_st.fields().len() {
122            vortex_bail!(
123                "Cannot cast between structs with different number of fields: {} and {}",
124                own_st.fields().len(),
125                st.fields().len()
126            );
127        }
128
129        if let Some(fs) = self.field_values() {
130            let fields = fs
131                .iter()
132                .enumerate()
133                .map(|(i, f)| {
134                    Scalar {
135                        dtype: own_st.field_by_index(i)?,
136                        value: f.clone(),
137                    }
138                    .cast(&st.field_by_index(i)?)
139                    .map(|s| s.value)
140                })
141                .collect::<VortexResult<Vec<_>>>()?;
142            Ok(Scalar {
143                dtype: dtype.clone(),
144                value: ScalarValue(InnerScalarValue::List(fields.into())),
145            })
146        } else {
147            Ok(Scalar::null(dtype.clone()))
148        }
149    }
150
151    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
152        let struct_dtype = self
153            .dtype
154            .as_struct()
155            .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
156        let projected_dtype = struct_dtype.project(projection)?;
157        let new_fields = if let Some(fs) = self.field_values() {
158            ScalarValue(InnerScalarValue::List(
159                projection
160                    .iter()
161                    .map(|name| {
162                        struct_dtype
163                            .find(name)
164                            .vortex_expect("DType has been successfully projected already")
165                    })
166                    .map(|i| fs[i].clone())
167                    .collect(),
168            ))
169        } else {
170            ScalarValue(InnerScalarValue::Null)
171        };
172        Ok(Scalar::new(
173            DType::Struct(Arc::new(projected_dtype), self.dtype().nullability()),
174            new_fields,
175        ))
176    }
177}
178
179impl Scalar {
180    pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
181        Self {
182            dtype,
183            value: ScalarValue(InnerScalarValue::List(
184                children
185                    .into_iter()
186                    .map(|x| x.into_value())
187                    .collect_vec()
188                    .into(),
189            )),
190        }
191    }
192}
193
194impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
195    type Error = VortexError;
196
197    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
198        Self::try_new(value.dtype(), &value.value)
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use vortex_dtype::PType::I32;
205    use vortex_dtype::{DType, Nullability, StructDType};
206
207    use super::*;
208
209    fn setup_types() -> (DType, DType, DType) {
210        let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
211        let f1_dt = DType::Utf8(Nullability::NonNullable);
212
213        let dtype = DType::Struct(
214            Arc::new(StructDType::new(
215                vec!["a".into(), "b".into()].into(),
216                vec![f0_dt.clone(), f1_dt.clone()],
217            )),
218            Nullability::Nullable,
219        );
220
221        (f0_dt, f1_dt, dtype)
222    }
223
224    #[test]
225    #[should_panic]
226    fn test_struct_scalar_null() {
227        let (_, _, dtype) = setup_types();
228
229        let scalar = Scalar::null(dtype);
230
231        scalar.as_struct().field_by_idx(0).unwrap();
232    }
233
234    #[test]
235    fn test_struct_scalar_non_null() {
236        let (f0_dt, f1_dt, dtype) = setup_types();
237
238        let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
239        let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
240
241        let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
242        let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
243
244        let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
245
246        let scalar_f0 = scalar.as_struct().field_by_idx(0);
247        assert!(scalar_f0.is_ok());
248        let scalar_f0 = scalar_f0.unwrap();
249        assert_eq!(scalar_f0, f0_val_null);
250        assert_eq!(scalar_f0.dtype(), &f0_dt);
251
252        let scalar_f1 = scalar.as_struct().field_by_idx(1);
253        assert!(scalar_f1.is_ok());
254        let scalar_f1 = scalar_f1.unwrap();
255        assert_eq!(scalar_f1, f1_val_null);
256        assert_eq!(scalar_f1.dtype(), &f1_dt);
257    }
258}