1use std::fmt::{Display, Formatter};
5use std::hash::{Hash, Hasher};
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
11use vortex_error::{
12 VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
13};
14
15use crate::{InnerScalarValue, Scalar, ScalarValue};
16
17pub struct StructScalar<'a> {
22 dtype: &'a DType,
23 fields: Option<&'a Arc<[ScalarValue]>>,
24}
25
26impl Display for StructScalar<'_> {
27 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28 match &self.fields {
29 None => write!(f, "null"),
30 Some(fields) => {
31 write!(f, "{{")?;
32 let formatted_fields = self
33 .names()
34 .iter()
35 .zip_eq(self.struct_fields().fields())
36 .zip_eq(fields.iter())
37 .map(|((name, dtype), value)| {
38 let val = Scalar::new(dtype, value.clone());
39 format!("{name}: {val}")
40 })
41 .format(", ");
42 write!(f, "{formatted_fields}")?;
43 write!(f, "}}")
44 }
45 }
46 }
47}
48
49impl PartialEq for StructScalar<'_> {
50 fn eq(&self, other: &Self) -> bool {
51 if !self.dtype.eq_ignore_nullability(other.dtype) {
52 return false;
53 }
54 self.fields() == other.fields()
55 }
56}
57
58impl Eq for StructScalar<'_> {}
59
60impl PartialOrd for StructScalar<'_> {
62 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
63 if !self.dtype.eq_ignore_nullability(other.dtype) {
64 return None;
65 }
66 self.fields().partial_cmp(&other.fields())
67 }
68}
69
70impl Hash for StructScalar<'_> {
71 fn hash<H: Hasher>(&self, state: &mut H) {
72 self.dtype.hash(state);
73 self.fields().hash(state);
74 }
75}
76
77impl<'a> StructScalar<'a> {
78 pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
79 if !matches!(dtype, DType::Struct(..)) {
80 vortex_bail!("Expected struct scalar, found {}", dtype)
81 }
82 Ok(Self {
83 dtype,
84 fields: value.as_list()?,
85 })
86 }
87
88 #[inline]
90 pub fn dtype(&self) -> &'a DType {
91 self.dtype
92 }
93
94 #[inline]
96 pub fn struct_fields(&self) -> &StructFields {
97 self.dtype
98 .as_struct()
99 .vortex_expect("StructScalar always has struct dtype")
100 }
101
102 pub fn names(&self) -> &FieldNames {
104 self.struct_fields().names()
105 }
106
107 pub fn is_null(&self) -> bool {
109 self.fields.is_none()
110 }
111
112 pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
116 let idx = self.struct_fields().find(name)?;
117 self.field_by_idx(idx)
118 }
119
120 pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
128 let fields = self
129 .fields
130 .vortex_expect("Can't take field out of null struct scalar");
131 Some(Scalar {
132 dtype: self.struct_fields().field_by_index(idx)?,
133 value: fields[idx].clone(),
134 })
135 }
136
137 pub fn fields(&self) -> Option<Vec<Scalar>> {
139 let fields = self.fields?;
140 Some(
141 (0..fields.len())
142 .map(|index| {
143 self.field_by_idx(index)
144 .vortex_expect("never out of bounds")
145 })
146 .collect::<Vec<_>>(),
147 )
148 }
149
150 pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
151 self.fields.map(Arc::deref)
152 }
153
154 pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
160 let DType::Struct(st, _) = dtype else {
161 vortex_bail!("Can only cast struct to another struct")
162 };
163 let own_st = self.struct_fields();
164
165 if st.fields().len() != own_st.fields().len() {
166 vortex_bail!(
167 "Cannot cast between structs with different number of fields: {} and {}",
168 own_st.fields().len(),
169 st.fields().len()
170 );
171 }
172
173 if let Some(fs) = self.field_values() {
174 let fields = fs
175 .iter()
176 .enumerate()
177 .map(|(i, f)| {
178 Scalar {
179 dtype: own_st
180 .field_by_index(i)
181 .vortex_expect("Iterating over scalar fields"),
182 value: f.clone(),
183 }
184 .cast(
185 &st.field_by_index(i)
186 .vortex_expect("Iterating over scalar fields"),
187 )
188 .map(|s| s.value)
189 })
190 .collect::<VortexResult<Vec<_>>>()?;
191 Ok(Scalar {
192 dtype: dtype.clone(),
193 value: ScalarValue(InnerScalarValue::List(fields.into())),
194 })
195 } else {
196 Ok(Scalar::null(dtype.clone()))
197 }
198 }
199
200 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
206 let struct_dtype = self
207 .dtype
208 .as_struct()
209 .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
210 let projected_dtype = struct_dtype.project(projection)?;
211 let new_fields = if let Some(fs) = self.field_values() {
212 ScalarValue(InnerScalarValue::List(
213 projection
214 .iter()
215 .map(|name| {
216 struct_dtype
217 .find(name)
218 .vortex_expect("DType has been successfully projected already")
219 })
220 .map(|i| fs[i].clone())
221 .collect(),
222 ))
223 } else {
224 ScalarValue(InnerScalarValue::Null)
225 };
226 Ok(Scalar::new(
227 DType::Struct(projected_dtype, self.dtype().nullability()),
228 new_fields,
229 ))
230 }
231}
232
233impl Scalar {
234 pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
236 let DType::Struct(struct_fields, _) = &dtype else {
237 vortex_panic!("Expected struct dtype, found {}", dtype);
238 };
239
240 let field_dtypes = struct_fields.fields();
241 if children.len() != field_dtypes.len() {
242 vortex_panic!(
243 "Struct has {} fields but {} children were provided",
244 field_dtypes.len(),
245 children.len()
246 );
247 }
248
249 for (idx, (child, expected_dtype)) in children.iter().zip(field_dtypes).enumerate() {
250 if child.dtype() != &expected_dtype {
251 vortex_panic!(
252 "Field {} expected dtype {} but got {}",
253 idx,
254 expected_dtype,
255 child.dtype()
256 );
257 }
258 }
259
260 Self {
261 dtype,
262 value: ScalarValue(InnerScalarValue::List(
263 children
264 .into_iter()
265 .map(|x| x.into_value())
266 .collect_vec()
267 .into(),
268 )),
269 }
270 }
271}
272
273impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
274 type Error = VortexError;
275
276 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
277 Self::try_new(value.dtype(), &value.value)
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use vortex_dtype::PType::I32;
284 use vortex_dtype::{DType, Nullability, StructFields};
285
286 use super::*;
287
288 fn setup_types() -> (DType, DType, DType) {
289 let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
290 let f1_dt = DType::Utf8(Nullability::NonNullable);
291
292 let dtype = DType::Struct(
293 StructFields::new(
294 vec!["a".into(), "b".into()].into(),
295 vec![f0_dt.clone(), f1_dt.clone()],
296 ),
297 Nullability::Nullable,
298 );
299
300 (f0_dt, f1_dt, dtype)
301 }
302
303 #[test]
304 #[should_panic]
305 fn test_struct_scalar_null() {
306 let (_, _, dtype) = setup_types();
307
308 let scalar = Scalar::null(dtype);
309
310 scalar.as_struct().field_by_idx(0).unwrap();
311 }
312
313 #[test]
314 fn test_struct_scalar_non_null() {
315 let (f0_dt, f1_dt, dtype) = setup_types();
316
317 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
318 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
319
320 let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
321 let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
322
323 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
324
325 let scalar_f0 = scalar.as_struct().field_by_idx(0);
326 assert!(scalar_f0.is_some());
327 let scalar_f0 = scalar_f0.unwrap();
328 assert_eq!(scalar_f0, f0_val_null);
329 assert_eq!(scalar_f0.dtype(), &f0_dt);
330
331 let scalar_f1 = scalar.as_struct().field_by_idx(1);
332 assert!(scalar_f1.is_some());
333 let scalar_f1 = scalar_f1.unwrap();
334 assert_eq!(scalar_f1, f1_val_null);
335 assert_eq!(scalar_f1.dtype(), &f1_dt);
336 }
337
338 #[test]
339 #[should_panic(expected = "Expected struct dtype")]
340 fn test_struct_scalar_wrong_dtype() {
341 let dtype = DType::Primitive(I32, Nullability::NonNullable);
342 let scalar = Scalar::primitive::<i32>(1, Nullability::NonNullable);
343
344 Scalar::struct_(dtype, vec![scalar]);
345 }
346
347 #[test]
348 #[should_panic(expected = "Struct has 2 fields but 1 children were provided")]
349 fn test_struct_scalar_wrong_child_count() {
350 let (_, _, dtype) = setup_types();
351 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
352
353 Scalar::struct_(dtype, vec![f0_val]);
354 }
355
356 #[test]
357 #[should_panic(expected = "Field 0 expected dtype i32 but got utf8")]
358 fn test_struct_scalar_wrong_child_dtype() {
359 let (_, _, dtype) = setup_types();
360 let f0_val = Scalar::utf8("wrong", Nullability::NonNullable);
361 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
362
363 Scalar::struct_(dtype, vec![f0_val, f1_val]);
364 }
365}