1use std::cmp::Ordering;
5use std::fmt::{Display, Formatter};
6use std::hash::{Hash, Hasher};
7use std::ops::Deref;
8use std::sync::Arc;
9
10use itertools::Itertools;
11use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
12use vortex_error::{
13 VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
14};
15
16use crate::{InnerScalarValue, Scalar, ScalarValue};
17
18#[derive(Debug, Clone)]
23pub struct StructScalar<'a> {
24 dtype: &'a DType,
25 fields: Option<&'a Arc<[ScalarValue]>>,
26}
27
28impl Display for StructScalar<'_> {
29 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30 match &self.fields {
31 None => write!(f, "null"),
32 Some(fields) => {
33 write!(f, "{{")?;
34 let formatted_fields = self
35 .names()
36 .iter()
37 .zip_eq(self.struct_fields().fields())
38 .zip_eq(fields.iter())
39 .map(|((name, dtype), value)| {
40 let val = Scalar::new(dtype, value.clone());
41 format!("{name}: {val}")
42 })
43 .format(", ");
44 write!(f, "{formatted_fields}")?;
45 write!(f, "}}")
46 }
47 }
48 }
49}
50
51impl PartialEq for StructScalar<'_> {
52 fn eq(&self, other: &Self) -> bool {
53 if !self.dtype.eq_ignore_nullability(other.dtype) {
54 return false;
55 }
56
57 match (self.fields(), other.fields()) {
58 (Some(lhs), Some(rhs)) => lhs.zip(rhs).all(|(l_s, r_s)| l_s == r_s),
59 (None, None) => true,
60 (Some(_), None) | (None, Some(_)) => false,
61 }
62 }
63}
64
65impl Eq for StructScalar<'_> {}
66
67impl PartialOrd for StructScalar<'_> {
69 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
70 if !self.dtype.eq_ignore_nullability(other.dtype) {
71 return None;
72 }
73
74 match (self.fields(), other.fields()) {
75 (Some(lhs), Some(rhs)) => {
76 for (l_s, r_s) in lhs.zip(rhs) {
77 match l_s.partial_cmp(&r_s)? {
78 Ordering::Equal => continue,
79 Ordering::Less => return Some(Ordering::Less),
80 Ordering::Greater => return Some(Ordering::Greater),
81 }
82 }
83 }
84 (None, None) => return Some(Ordering::Equal),
85 (Some(_), None) => return Some(Ordering::Greater),
86 (None, Some(_)) => return Some(Ordering::Less),
87 }
88
89 Some(Ordering::Equal)
90 }
91}
92
93impl Hash for StructScalar<'_> {
94 fn hash<H: Hasher>(&self, state: &mut H) {
95 self.dtype.hash(state);
96 if let Some(fields) = self.fields() {
97 for f in fields {
98 f.hash(state);
99 }
100 }
101 }
102}
103
104impl<'a> StructScalar<'a> {
105 #[inline]
106 pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
107 if !matches!(dtype, DType::Struct(..)) {
108 vortex_bail!("Expected struct scalar, found {}", dtype)
109 }
110
111 Ok(Self {
112 dtype,
113 fields: value.as_list()?,
114 })
115 }
116
117 #[inline]
119 pub fn dtype(&self) -> &'a DType {
120 self.dtype
121 }
122
123 #[inline]
125 pub fn struct_fields(&self) -> &StructFields {
126 self.dtype
127 .as_struct_fields_opt()
128 .vortex_expect("StructScalar always has struct dtype")
129 }
130
131 pub fn names(&self) -> &FieldNames {
133 self.struct_fields().names()
134 }
135
136 pub fn is_null(&self) -> bool {
138 self.fields.is_none()
139 }
140
141 pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
145 let idx = self.struct_fields().find(name)?;
146 self.field_by_idx(idx)
147 }
148
149 pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
157 let fields = self
158 .fields
159 .vortex_expect("Can't take field out of null struct scalar");
160 Some(Scalar::new(
161 self.struct_fields().field_by_index(idx)?,
162 fields[idx].clone(),
163 ))
164 }
165
166 pub fn fields(&self) -> Option<impl ExactSizeIterator<Item = Scalar>> {
168 let fields = self.fields?;
169 Some(
170 fields
171 .iter()
172 .zip(self.struct_fields().fields())
173 .map(|(v, dtype)| Scalar::new(dtype, v.clone())),
174 )
175 }
176
177 pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
178 self.fields.map(Arc::deref)
179 }
180
181 pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
187 let DType::Struct(st, _) = dtype else {
188 vortex_bail!(
189 "Cannot cast struct to {}: struct can only be cast to struct",
190 dtype
191 )
192 };
193 let own_st = self.struct_fields();
194
195 if st.fields().len() != own_st.fields().len() {
196 vortex_bail!(
197 "Cannot cast between structs with different number of fields: {} and {}",
198 own_st.fields().len(),
199 st.fields().len()
200 );
201 }
202
203 if let Some(fs) = self.field_values() {
204 let fields = fs
205 .iter()
206 .enumerate()
207 .map(|(i, f)| {
208 Scalar::new(
209 own_st
210 .field_by_index(i)
211 .vortex_expect("Iterating over scalar fields"),
212 f.clone(),
213 )
214 .cast(
215 &st.field_by_index(i)
216 .vortex_expect("Iterating over scalar fields"),
217 )
218 .map(|s| s.into_value())
219 })
220 .collect::<VortexResult<Vec<_>>>()?;
221 Ok(Scalar::new(
222 dtype.clone(),
223 ScalarValue(InnerScalarValue::List(fields.into())),
224 ))
225 } else {
226 Ok(Scalar::null(dtype.clone()))
227 }
228 }
229
230 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
236 let struct_dtype = self
237 .dtype
238 .as_struct_fields_opt()
239 .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
240 let projected_dtype = struct_dtype.project(projection)?;
241 let new_fields = if let Some(fs) = self.field_values() {
242 ScalarValue(InnerScalarValue::List(
243 projection
244 .iter()
245 .map(|name| {
246 struct_dtype
247 .find(name)
248 .vortex_expect("DType has been successfully projected already")
249 })
250 .map(|i| fs[i].clone())
251 .collect(),
252 ))
253 } else {
254 ScalarValue(InnerScalarValue::Null)
255 };
256 Ok(Scalar::new(
257 DType::Struct(projected_dtype, self.dtype().nullability()),
258 new_fields,
259 ))
260 }
261}
262
263impl Scalar {
264 pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
266 let DType::Struct(struct_fields, _) = &dtype else {
267 vortex_panic!("Expected struct dtype, found {}", dtype);
268 };
269
270 let field_dtypes = struct_fields.fields();
271 if children.len() != field_dtypes.len() {
272 vortex_panic!(
273 "Struct has {} fields but {} children were provided",
274 field_dtypes.len(),
275 children.len()
276 );
277 }
278
279 for (idx, (child, expected_dtype)) in children.iter().zip(field_dtypes).enumerate() {
280 if child.dtype() != &expected_dtype {
281 vortex_panic!(
282 "Field {} expected dtype {} but got {}",
283 idx,
284 expected_dtype,
285 child.dtype()
286 );
287 }
288 }
289
290 let mut value_children = Vec::with_capacity(children.len());
291 value_children.extend(children.into_iter().map(|x| x.into_value()));
292
293 Self::new(
294 dtype,
295 ScalarValue(InnerScalarValue::List(value_children.into())),
296 )
297 }
298}
299
300impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
301 type Error = VortexError;
302
303 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
304 Self::try_new(value.dtype(), value.value())
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use vortex_dtype::PType::I32;
311 use vortex_dtype::{DType, Nullability, StructFields};
312
313 use super::*;
314
315 fn setup_types() -> (DType, DType, DType) {
316 let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
317 let f1_dt = DType::Utf8(Nullability::NonNullable);
318
319 let dtype = DType::Struct(
320 StructFields::new(["a", "b"].into(), vec![f0_dt.clone(), f1_dt.clone()]),
321 Nullability::Nullable,
322 );
323
324 (f0_dt, f1_dt, dtype)
325 }
326
327 #[test]
328 #[should_panic]
329 fn test_struct_scalar_null() {
330 let (_, _, dtype) = setup_types();
331
332 let scalar = Scalar::null(dtype);
333
334 scalar.as_struct().field_by_idx(0).unwrap();
335 }
336
337 #[test]
338 fn test_struct_scalar_non_null() {
339 let (f0_dt, f1_dt, dtype) = setup_types();
340
341 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
342 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
343
344 let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
345 let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
346
347 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
348
349 let scalar_f0 = scalar.as_struct().field_by_idx(0);
350 assert!(scalar_f0.is_some());
351 let scalar_f0 = scalar_f0.unwrap();
352 assert_eq!(scalar_f0, f0_val_null);
353 assert_eq!(scalar_f0.dtype(), &f0_dt);
354
355 let scalar_f1 = scalar.as_struct().field_by_idx(1);
356 assert!(scalar_f1.is_some());
357 let scalar_f1 = scalar_f1.unwrap();
358 assert_eq!(scalar_f1, f1_val_null);
359 assert_eq!(scalar_f1.dtype(), &f1_dt);
360 }
361
362 #[test]
363 #[should_panic(expected = "Expected struct dtype")]
364 fn test_struct_scalar_wrong_dtype() {
365 let dtype = DType::Primitive(I32, Nullability::NonNullable);
366 let scalar = Scalar::primitive::<i32>(1, Nullability::NonNullable);
367
368 Scalar::struct_(dtype, vec![scalar]);
369 }
370
371 #[test]
372 #[should_panic(expected = "Struct has 2 fields but 1 children were provided")]
373 fn test_struct_scalar_wrong_child_count() {
374 let (_, _, dtype) = setup_types();
375 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
376
377 Scalar::struct_(dtype, vec![f0_val]);
378 }
379
380 #[test]
381 #[should_panic(expected = "Field 0 expected dtype i32 but got utf8")]
382 fn test_struct_scalar_wrong_child_dtype() {
383 let (_, _, dtype) = setup_types();
384 let f0_val = Scalar::utf8("wrong", Nullability::NonNullable);
385 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
386
387 Scalar::struct_(dtype, vec![f0_val, f1_val]);
388 }
389
390 #[test]
391 fn test_struct_field_by_name() {
392 let (_, _, dtype) = setup_types();
393 let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
394 let f1_val = Scalar::utf8("world", Nullability::NonNullable);
395
396 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
397
398 let field_a = scalar.as_struct().field("a");
400 assert!(field_a.is_some());
401 assert_eq!(
402 field_a
403 .unwrap()
404 .as_primitive()
405 .typed_value::<i32>()
406 .unwrap(),
407 42
408 );
409
410 let field_b = scalar.as_struct().field("b");
411 assert!(field_b.is_some());
412 assert_eq!(field_b.unwrap().as_utf8().value().unwrap(), "world".into());
413
414 let field_c = scalar.as_struct().field("c");
416 assert!(field_c.is_none());
417 }
418
419 #[test]
420 fn test_struct_fields() {
421 let (_, _, dtype) = setup_types();
422 let f0_val = Scalar::primitive::<i32>(100, Nullability::NonNullable);
423 let f1_val = Scalar::utf8("test", Nullability::NonNullable);
424
425 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
426
427 let fields = scalar.as_struct().fields().unwrap().collect::<Vec<_>>();
428 assert_eq!(fields.len(), 2);
429 assert_eq!(fields[0].as_primitive().typed_value::<i32>().unwrap(), 100);
430 assert_eq!(fields[1].as_utf8().value().unwrap(), "test".into());
431 }
432
433 #[test]
434 fn test_struct_null_fields() {
435 let (_, _, dtype) = setup_types();
436 let null_scalar = Scalar::null(dtype);
437
438 assert!(null_scalar.as_struct().is_null());
439 assert!(null_scalar.as_struct().fields().is_none());
440 assert!(null_scalar.as_struct().field_values().is_none());
441 }
442
443 #[test]
444 fn test_struct_cast_to_struct() {
445 let source_fields = StructFields::new(
447 ["x", "y"].into(),
448 vec![
449 DType::Primitive(I32, Nullability::NonNullable),
450 DType::Primitive(I32, Nullability::NonNullable),
451 ],
452 );
453 let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
454
455 let target_fields = StructFields::new(
457 ["x", "y"].into(),
458 vec![
459 DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
460 DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
461 ],
462 );
463 let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
464
465 let f0 = Scalar::primitive::<i32>(42, Nullability::NonNullable);
466 let f1 = Scalar::primitive::<i32>(123, Nullability::NonNullable);
467 let source_scalar = Scalar::struct_(source_dtype, vec![f0, f1]);
468
469 let result = source_scalar.as_struct().cast(&target_dtype).unwrap();
471 assert_eq!(result.dtype(), &target_dtype);
472
473 let fields = result.as_struct().fields().unwrap().collect::<Vec<_>>();
474 assert_eq!(fields[0].as_primitive().typed_value::<i64>().unwrap(), 42);
475 assert_eq!(fields[1].as_primitive().typed_value::<i64>().unwrap(), 123);
476 }
477
478 #[test]
479 fn test_struct_cast_mismatched_fields() {
480 let source_fields = StructFields::new(
481 ["a"].into(),
482 vec![DType::Primitive(I32, Nullability::NonNullable)],
483 );
484 let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
485
486 let target_fields = StructFields::new(
487 ["a", "b"].into(),
488 vec![
489 DType::Primitive(I32, Nullability::NonNullable),
490 DType::Primitive(I32, Nullability::NonNullable),
491 ],
492 );
493 let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
494
495 let scalar = Scalar::struct_(
496 source_dtype,
497 vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
498 );
499
500 let result = scalar.as_struct().cast(&target_dtype);
501 assert!(result.is_err());
502 }
503
504 #[test]
505 fn test_struct_cast_to_non_struct() {
506 let (_, _, dtype) = setup_types();
507 let scalar = Scalar::struct_(
508 dtype,
509 vec![
510 Scalar::primitive::<i32>(1, Nullability::NonNullable),
511 Scalar::utf8("test", Nullability::NonNullable),
512 ],
513 );
514
515 let result = scalar
516 .as_struct()
517 .cast(&DType::Primitive(I32, Nullability::NonNullable));
518 assert!(result.is_err());
519 }
520
521 #[test]
522 fn test_struct_project() {
523 let (_, _, dtype) = setup_types();
524 let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
525 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
526
527 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
528
529 let projected = scalar.as_struct().project(&["b".into()]).unwrap();
531 let projected_struct = projected.as_struct();
532
533 assert_eq!(projected_struct.names().len(), 1);
534 assert_eq!(projected_struct.names()[0].as_ref(), "b");
535
536 let fields = projected_struct.fields().unwrap().collect::<Vec<_>>();
537 assert_eq!(fields.len(), 1);
538 assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello");
539 }
540
541 #[test]
542 fn test_struct_project_null() {
543 let (_, _, dtype) = setup_types();
544 let null_scalar = Scalar::null(dtype);
545
546 let projected = null_scalar.as_struct().project(&["a".into()]).unwrap();
547 assert!(projected.as_struct().is_null());
548 }
549
550 #[test]
551 fn test_struct_equality() {
552 let (_, _, dtype) = setup_types();
553
554 let scalar1 = Scalar::struct_(
555 dtype.clone(),
556 vec![
557 Scalar::primitive::<i32>(1, Nullability::NonNullable),
558 Scalar::utf8("test", Nullability::NonNullable),
559 ],
560 );
561
562 let scalar2 = Scalar::struct_(
563 dtype.clone(),
564 vec![
565 Scalar::primitive::<i32>(1, Nullability::NonNullable),
566 Scalar::utf8("test", Nullability::NonNullable),
567 ],
568 );
569
570 let scalar3 = Scalar::struct_(
571 dtype,
572 vec![
573 Scalar::primitive::<i32>(2, Nullability::NonNullable),
574 Scalar::utf8("test", Nullability::NonNullable),
575 ],
576 );
577
578 assert_eq!(scalar1.as_struct(), scalar2.as_struct());
579 assert_ne!(scalar1.as_struct(), scalar3.as_struct());
580 }
581
582 #[test]
583 fn test_struct_partial_ord() {
584 let (_, _, dtype) = setup_types();
585
586 let scalar1 = Scalar::struct_(
587 dtype.clone(),
588 vec![
589 Scalar::primitive::<i32>(1, Nullability::NonNullable),
590 Scalar::utf8("a", Nullability::NonNullable),
591 ],
592 );
593
594 let scalar2 = Scalar::struct_(
595 dtype,
596 vec![
597 Scalar::primitive::<i32>(2, Nullability::NonNullable),
598 Scalar::utf8("b", Nullability::NonNullable),
599 ],
600 );
601
602 assert!(scalar1.as_struct() < scalar2.as_struct());
604
605 let other_dtype = DType::Struct(
607 StructFields::new(
608 ["c"].into(),
609 vec![DType::Primitive(I32, Nullability::NonNullable)],
610 ),
611 Nullability::NonNullable,
612 );
613 let scalar3 = Scalar::struct_(
614 other_dtype,
615 vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
616 );
617
618 assert_eq!(scalar1.as_struct().partial_cmp(&scalar3.as_struct()), None);
619 }
620
621 #[test]
622 fn test_struct_hash() {
623 use std::collections::hash_map::DefaultHasher;
624 use std::hash::{Hash, Hasher};
625
626 let (_, _, dtype) = setup_types();
627
628 let scalar1 = Scalar::struct_(
629 dtype.clone(),
630 vec![
631 Scalar::primitive::<i32>(1, Nullability::NonNullable),
632 Scalar::utf8("test", Nullability::NonNullable),
633 ],
634 );
635
636 let scalar2 = Scalar::struct_(
637 dtype,
638 vec![
639 Scalar::primitive::<i32>(1, Nullability::NonNullable),
640 Scalar::utf8("test", Nullability::NonNullable),
641 ],
642 );
643
644 let mut hasher1 = DefaultHasher::new();
645 scalar1.as_struct().hash(&mut hasher1);
646 let hash1 = hasher1.finish();
647
648 let mut hasher2 = DefaultHasher::new();
649 scalar2.as_struct().hash(&mut hasher2);
650 let hash2 = hasher2.finish();
651
652 assert_eq!(hash1, hash2);
653 }
654
655 #[test]
656 fn test_struct_try_new_non_struct_dtype() {
657 let dtype = DType::Primitive(I32, Nullability::NonNullable);
658 let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
659
660 let result = StructScalar::try_new(&dtype, &value);
661 assert!(result.is_err());
662 }
663
664 #[test]
665 fn test_struct_field_out_of_bounds() {
666 let (_, _, dtype) = setup_types();
667 let scalar = Scalar::struct_(
668 dtype,
669 vec![
670 Scalar::primitive::<i32>(1, Nullability::NonNullable),
671 Scalar::utf8("test", Nullability::NonNullable),
672 ],
673 );
674
675 let field = scalar.as_struct().field_by_idx(10);
677 assert!(field.is_none());
678 }
679}