1use std::hash::{Hash, Hasher};
2use std::ops::Deref;
3use std::sync::Arc;
4
5use itertools::Itertools;
6use vortex_dtype::{DType, FieldName, StructDType};
7use vortex_error::{
8 VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
9};
10
11use crate::{InnerScalarValue, Scalar, ScalarValue};
12
13pub struct StructScalar<'a> {
14 dtype: &'a DType,
15 fields: Option<&'a Arc<[ScalarValue]>>,
16}
17
18impl PartialEq for StructScalar<'_> {
19 fn eq(&self, other: &Self) -> bool {
20 if !self.dtype.eq_ignore_nullability(other.dtype) {
21 return false;
22 }
23 self.fields() == other.fields()
24 }
25}
26
27impl Eq for StructScalar<'_> {}
28
29impl PartialOrd for StructScalar<'_> {
31 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
32 if !self.dtype.eq_ignore_nullability(other.dtype) {
33 return None;
34 }
35 self.fields().partial_cmp(&other.fields())
36 }
37}
38
39impl Hash for StructScalar<'_> {
40 fn hash<H: Hasher>(&self, state: &mut H) {
41 self.dtype.hash(state);
42 self.fields().hash(state);
43 }
44}
45
46impl<'a> StructScalar<'a> {
47 pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
48 if !matches!(dtype, DType::Struct(..)) {
49 vortex_bail!("Expected struct scalar, found {}", dtype)
50 }
51 Ok(Self {
52 dtype,
53 fields: value.as_list()?,
54 })
55 }
56
57 #[inline]
58 pub fn dtype(&self) -> &'a DType {
59 self.dtype
60 }
61
62 #[inline]
63 pub fn struct_dtype(&self) -> &Arc<StructDType> {
64 let DType::Struct(sdtype, ..) = self.dtype else {
65 vortex_panic!("StructScalar always has struct dtype");
66 };
67 sdtype
68 }
69
70 pub fn is_null(&self) -> bool {
71 self.fields.is_none()
72 }
73
74 pub fn field(&self, name: impl AsRef<str>) -> VortexResult<Scalar> {
75 let DType::Struct(st, _) = self.dtype() else {
76 unreachable!()
77 };
78 let idx = st.find(name)?;
79 self.field_by_idx(idx)
80 }
81
82 pub fn field_by_idx(&self, idx: usize) -> VortexResult<Scalar> {
83 let fields = self
84 .fields
85 .vortex_expect("Can't take field out of null struct scalar");
86 let DType::Struct(st, _) = self.dtype() else {
87 unreachable!()
88 };
89
90 Ok(Scalar {
91 dtype: st.field_by_index(idx)?,
92 value: fields[idx].clone(),
93 })
94 }
95
96 pub fn fields(&self) -> Option<Vec<Scalar>> {
98 let fields = self.fields?;
99 Some(
100 (0..fields.len())
101 .map(|index| {
102 self.field_by_idx(index)
103 .vortex_expect("never out of bounds")
104 })
105 .collect::<Vec<_>>(),
106 )
107 }
108
109 pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
110 self.fields.map(Arc::deref)
111 }
112
113 pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
114 let DType::Struct(st, _) = dtype else {
115 vortex_bail!("Can only cast struct to another struct")
116 };
117 let DType::Struct(own_st, _) = self.dtype() else {
118 unreachable!()
119 };
120
121 if st.fields().len() != own_st.fields().len() {
122 vortex_bail!(
123 "Cannot cast between structs with different number of fields: {} and {}",
124 own_st.fields().len(),
125 st.fields().len()
126 );
127 }
128
129 if let Some(fs) = self.field_values() {
130 let fields = fs
131 .iter()
132 .enumerate()
133 .map(|(i, f)| {
134 Scalar {
135 dtype: own_st.field_by_index(i)?,
136 value: f.clone(),
137 }
138 .cast(&st.field_by_index(i)?)
139 .map(|s| s.value)
140 })
141 .collect::<VortexResult<Vec<_>>>()?;
142 Ok(Scalar {
143 dtype: dtype.clone(),
144 value: ScalarValue(InnerScalarValue::List(fields.into())),
145 })
146 } else {
147 Ok(Scalar::null(dtype.clone()))
148 }
149 }
150
151 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
152 let struct_dtype = self
153 .dtype
154 .as_struct()
155 .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
156 let projected_dtype = struct_dtype.project(projection)?;
157 let new_fields = if let Some(fs) = self.field_values() {
158 ScalarValue(InnerScalarValue::List(
159 projection
160 .iter()
161 .map(|name| {
162 struct_dtype
163 .find(name)
164 .vortex_expect("DType has been successfully projected already")
165 })
166 .map(|i| fs[i].clone())
167 .collect(),
168 ))
169 } else {
170 ScalarValue(InnerScalarValue::Null)
171 };
172 Ok(Scalar::new(
173 DType::Struct(Arc::new(projected_dtype), self.dtype().nullability()),
174 new_fields,
175 ))
176 }
177}
178
179impl Scalar {
180 pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
181 Self {
182 dtype,
183 value: ScalarValue(InnerScalarValue::List(
184 children
185 .into_iter()
186 .map(|x| x.into_value())
187 .collect_vec()
188 .into(),
189 )),
190 }
191 }
192}
193
194impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
195 type Error = VortexError;
196
197 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
198 Self::try_new(value.dtype(), &value.value)
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use vortex_dtype::PType::I32;
205 use vortex_dtype::{DType, Nullability, StructDType};
206
207 use super::*;
208
209 fn setup_types() -> (DType, DType, DType) {
210 let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
211 let f1_dt = DType::Utf8(Nullability::NonNullable);
212
213 let dtype = DType::Struct(
214 Arc::new(StructDType::new(
215 vec!["a".into(), "b".into()].into(),
216 vec![f0_dt.clone(), f1_dt.clone()],
217 )),
218 Nullability::Nullable,
219 );
220
221 (f0_dt, f1_dt, dtype)
222 }
223
224 #[test]
225 #[should_panic]
226 fn test_struct_scalar_null() {
227 let (_, _, dtype) = setup_types();
228
229 let scalar = Scalar::null(dtype);
230
231 scalar.as_struct().field_by_idx(0).unwrap();
232 }
233
234 #[test]
235 fn test_struct_scalar_non_null() {
236 let (f0_dt, f1_dt, dtype) = setup_types();
237
238 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
239 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
240
241 let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
242 let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
243
244 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
245
246 let scalar_f0 = scalar.as_struct().field_by_idx(0);
247 assert!(scalar_f0.is_ok());
248 let scalar_f0 = scalar_f0.unwrap();
249 assert_eq!(scalar_f0, f0_val_null);
250 assert_eq!(scalar_f0.dtype(), &f0_dt);
251
252 let scalar_f1 = scalar.as_struct().field_by_idx(1);
253 assert!(scalar_f1.is_ok());
254 let scalar_f1 = scalar_f1.unwrap();
255 assert_eq!(scalar_f1, f1_val_null);
256 assert_eq!(scalar_f1.dtype(), &f1_dt);
257 }
258}