1use std::cmp::Ordering;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::ops::Deref;
10use std::sync::Arc;
11
12use itertools::Itertools;
13use vortex_dtype::DType;
14use vortex_dtype::FieldName;
15use vortex_dtype::FieldNames;
16use vortex_dtype::StructFields;
17use vortex_error::VortexError;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_error::vortex_err;
22use vortex_error::vortex_panic;
23
24use crate::InnerScalarValue;
25use crate::Scalar;
26use crate::ScalarValue;
27
28#[derive(Debug, Clone)]
33pub struct StructScalar<'a> {
34 dtype: &'a DType,
35 fields: Option<&'a Arc<[ScalarValue]>>,
36}
37
38impl Display for StructScalar<'_> {
39 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40 match &self.fields {
41 None => write!(f, "null"),
42 Some(fields) => {
43 write!(f, "{{")?;
44 let formatted_fields = self
45 .names()
46 .iter()
47 .zip_eq(self.struct_fields().fields())
48 .zip_eq(fields.iter())
49 .map(|((name, dtype), value)| {
50 let val = Scalar::new(dtype, value.clone());
51 format!("{name}: {val}")
52 })
53 .format(", ");
54 write!(f, "{formatted_fields}")?;
55 write!(f, "}}")
56 }
57 }
58 }
59}
60
61impl PartialEq for StructScalar<'_> {
62 fn eq(&self, other: &Self) -> bool {
63 if !self.dtype.eq_ignore_nullability(other.dtype) {
64 return false;
65 }
66
67 match (self.fields(), other.fields()) {
68 (Some(lhs), Some(rhs)) => lhs.zip(rhs).all(|(l_s, r_s)| l_s == r_s),
69 (None, None) => true,
70 (Some(_), None) | (None, Some(_)) => false,
71 }
72 }
73}
74
75impl Eq for StructScalar<'_> {}
76
77impl PartialOrd for StructScalar<'_> {
79 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
80 if !self.dtype.eq_ignore_nullability(other.dtype) {
81 return None;
82 }
83
84 match (self.fields(), other.fields()) {
85 (Some(lhs), Some(rhs)) => {
86 for (l_s, r_s) in lhs.zip(rhs) {
87 match l_s.partial_cmp(&r_s)? {
88 Ordering::Equal => continue,
89 Ordering::Less => return Some(Ordering::Less),
90 Ordering::Greater => return Some(Ordering::Greater),
91 }
92 }
93 }
94 (None, None) => return Some(Ordering::Equal),
95 (Some(_), None) => return Some(Ordering::Greater),
96 (None, Some(_)) => return Some(Ordering::Less),
97 }
98
99 Some(Ordering::Equal)
100 }
101}
102
103impl Hash for StructScalar<'_> {
104 fn hash<H: Hasher>(&self, state: &mut H) {
105 self.dtype.hash(state);
106 if let Some(fields) = self.fields() {
107 for f in fields {
108 f.hash(state);
109 }
110 }
111 }
112}
113
114impl<'a> StructScalar<'a> {
115 #[inline]
116 pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
117 if !matches!(dtype, DType::Struct(..)) {
118 vortex_bail!("Expected struct scalar, found {}", dtype)
119 }
120
121 Ok(Self {
122 dtype,
123 fields: value.as_list()?,
124 })
125 }
126
127 #[inline]
129 pub fn dtype(&self) -> &'a DType {
130 self.dtype
131 }
132
133 #[inline]
135 pub fn struct_fields(&self) -> &StructFields {
136 self.dtype
137 .as_struct_fields_opt()
138 .vortex_expect("StructScalar always has struct dtype")
139 }
140
141 pub fn names(&self) -> &FieldNames {
143 self.struct_fields().names()
144 }
145
146 pub fn is_null(&self) -> bool {
148 self.fields.is_none()
149 }
150
151 pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
155 let idx = self.struct_fields().find(name)?;
156 self.field_by_idx(idx)
157 }
158
159 pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
167 let fields = self
168 .fields
169 .vortex_expect("Can't take field out of null struct scalar");
170 Some(Scalar::new(
171 self.struct_fields().field_by_index(idx)?,
172 fields[idx].clone(),
173 ))
174 }
175
176 pub fn fields(&self) -> Option<impl ExactSizeIterator<Item = Scalar>> {
178 let fields = self.fields?;
179 Some(
180 fields
181 .iter()
182 .zip(self.struct_fields().fields())
183 .map(|(v, dtype)| Scalar::new(dtype, v.clone())),
184 )
185 }
186
187 pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
188 self.fields.map(Arc::deref)
189 }
190
191 pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
197 let DType::Struct(st, _) = dtype else {
198 vortex_bail!(
199 "Cannot cast struct to {}: struct can only be cast to struct",
200 dtype
201 )
202 };
203 let own_st = self.struct_fields();
204
205 if st.fields().len() != own_st.fields().len() {
206 vortex_bail!(
207 "Cannot cast between structs with different number of fields: {} and {}",
208 own_st.fields().len(),
209 st.fields().len()
210 );
211 }
212
213 if let Some(fs) = self.field_values() {
214 let fields = fs
215 .iter()
216 .enumerate()
217 .map(|(i, f)| {
218 Scalar::new(
219 own_st
220 .field_by_index(i)
221 .vortex_expect("Iterating over scalar fields"),
222 f.clone(),
223 )
224 .cast(
225 &st.field_by_index(i)
226 .vortex_expect("Iterating over scalar fields"),
227 )
228 .map(|s| s.into_value())
229 })
230 .collect::<VortexResult<Vec<_>>>()?;
231 Ok(Scalar::new(
232 dtype.clone(),
233 ScalarValue(InnerScalarValue::List(fields.into())),
234 ))
235 } else {
236 Ok(Scalar::null(dtype.clone()))
237 }
238 }
239
240 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
246 let struct_dtype = self
247 .dtype
248 .as_struct_fields_opt()
249 .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
250 let projected_dtype = struct_dtype.project(projection)?;
251 let new_fields = if let Some(fs) = self.field_values() {
252 ScalarValue(InnerScalarValue::List(
253 projection
254 .iter()
255 .map(|name| {
256 struct_dtype
257 .find(name)
258 .vortex_expect("DType has been successfully projected already")
259 })
260 .map(|i| fs[i].clone())
261 .collect(),
262 ))
263 } else {
264 ScalarValue(InnerScalarValue::Null)
265 };
266 Ok(Scalar::new(
267 DType::Struct(projected_dtype, self.dtype().nullability()),
268 new_fields,
269 ))
270 }
271}
272
273impl Scalar {
274 pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
276 let DType::Struct(struct_fields, _) = &dtype else {
277 vortex_panic!("Expected struct dtype, found {}", dtype);
278 };
279
280 let field_dtypes = struct_fields.fields();
281 if children.len() != field_dtypes.len() {
282 vortex_panic!(
283 "Struct has {} fields but {} children were provided",
284 field_dtypes.len(),
285 children.len()
286 );
287 }
288
289 for (idx, (child, expected_dtype)) in children.iter().zip(field_dtypes).enumerate() {
290 if child.dtype() != &expected_dtype {
291 vortex_panic!(
292 "Field {} expected dtype {} but got {}",
293 idx,
294 expected_dtype,
295 child.dtype()
296 );
297 }
298 }
299
300 let mut value_children = Vec::with_capacity(children.len());
301 value_children.extend(children.into_iter().map(|x| x.into_value()));
302
303 Self::new(
304 dtype,
305 ScalarValue(InnerScalarValue::List(value_children.into())),
306 )
307 }
308}
309
310impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
311 type Error = VortexError;
312
313 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
314 Self::try_new(value.dtype(), value.value())
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use vortex_dtype::DType;
321 use vortex_dtype::Nullability;
322 use vortex_dtype::PType::I32;
323 use vortex_dtype::StructFields;
324
325 use super::*;
326
327 fn setup_types() -> (DType, DType, DType) {
328 let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
329 let f1_dt = DType::Utf8(Nullability::NonNullable);
330
331 let dtype = DType::Struct(
332 StructFields::new(["a", "b"].into(), vec![f0_dt.clone(), f1_dt.clone()]),
333 Nullability::Nullable,
334 );
335
336 (f0_dt, f1_dt, dtype)
337 }
338
339 #[test]
340 #[should_panic]
341 fn test_struct_scalar_null() {
342 let (_, _, dtype) = setup_types();
343
344 let scalar = Scalar::null(dtype);
345
346 scalar.as_struct().field_by_idx(0).unwrap();
347 }
348
349 #[test]
350 fn test_struct_scalar_non_null() {
351 let (f0_dt, f1_dt, dtype) = setup_types();
352
353 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
354 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
355
356 let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
357 let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
358
359 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
360
361 let scalar_f0 = scalar.as_struct().field_by_idx(0);
362 assert!(scalar_f0.is_some());
363 let scalar_f0 = scalar_f0.unwrap();
364 assert_eq!(scalar_f0, f0_val_null);
365 assert_eq!(scalar_f0.dtype(), &f0_dt);
366
367 let scalar_f1 = scalar.as_struct().field_by_idx(1);
368 assert!(scalar_f1.is_some());
369 let scalar_f1 = scalar_f1.unwrap();
370 assert_eq!(scalar_f1, f1_val_null);
371 assert_eq!(scalar_f1.dtype(), &f1_dt);
372 }
373
374 #[test]
375 #[should_panic(expected = "Expected struct dtype")]
376 fn test_struct_scalar_wrong_dtype() {
377 let dtype = DType::Primitive(I32, Nullability::NonNullable);
378 let scalar = Scalar::primitive::<i32>(1, Nullability::NonNullable);
379
380 Scalar::struct_(dtype, vec![scalar]);
381 }
382
383 #[test]
384 #[should_panic(expected = "Struct has 2 fields but 1 children were provided")]
385 fn test_struct_scalar_wrong_child_count() {
386 let (_, _, dtype) = setup_types();
387 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
388
389 Scalar::struct_(dtype, vec![f0_val]);
390 }
391
392 #[test]
393 #[should_panic(expected = "Field 0 expected dtype i32 but got utf8")]
394 fn test_struct_scalar_wrong_child_dtype() {
395 let (_, _, dtype) = setup_types();
396 let f0_val = Scalar::utf8("wrong", Nullability::NonNullable);
397 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
398
399 Scalar::struct_(dtype, vec![f0_val, f1_val]);
400 }
401
402 #[test]
403 fn test_struct_field_by_name() {
404 let (_, _, dtype) = setup_types();
405 let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
406 let f1_val = Scalar::utf8("world", Nullability::NonNullable);
407
408 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
409
410 let field_a = scalar.as_struct().field("a");
412 assert!(field_a.is_some());
413 assert_eq!(
414 field_a
415 .unwrap()
416 .as_primitive()
417 .typed_value::<i32>()
418 .unwrap(),
419 42
420 );
421
422 let field_b = scalar.as_struct().field("b");
423 assert!(field_b.is_some());
424 assert_eq!(field_b.unwrap().as_utf8().value().unwrap(), "world".into());
425
426 let field_c = scalar.as_struct().field("c");
428 assert!(field_c.is_none());
429 }
430
431 #[test]
432 fn test_struct_fields() {
433 let (_, _, dtype) = setup_types();
434 let f0_val = Scalar::primitive::<i32>(100, Nullability::NonNullable);
435 let f1_val = Scalar::utf8("test", Nullability::NonNullable);
436
437 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
438
439 let fields = scalar.as_struct().fields().unwrap().collect::<Vec<_>>();
440 assert_eq!(fields.len(), 2);
441 assert_eq!(fields[0].as_primitive().typed_value::<i32>().unwrap(), 100);
442 assert_eq!(fields[1].as_utf8().value().unwrap(), "test".into());
443 }
444
445 #[test]
446 fn test_struct_null_fields() {
447 let (_, _, dtype) = setup_types();
448 let null_scalar = Scalar::null(dtype);
449
450 assert!(null_scalar.as_struct().is_null());
451 assert!(null_scalar.as_struct().fields().is_none());
452 assert!(null_scalar.as_struct().field_values().is_none());
453 }
454
455 #[test]
456 fn test_struct_cast_to_struct() {
457 let source_fields = StructFields::new(
459 ["x", "y"].into(),
460 vec![
461 DType::Primitive(I32, Nullability::NonNullable),
462 DType::Primitive(I32, Nullability::NonNullable),
463 ],
464 );
465 let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
466
467 let target_fields = StructFields::new(
469 ["x", "y"].into(),
470 vec![
471 DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
472 DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
473 ],
474 );
475 let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
476
477 let f0 = Scalar::primitive::<i32>(42, Nullability::NonNullable);
478 let f1 = Scalar::primitive::<i32>(123, Nullability::NonNullable);
479 let source_scalar = Scalar::struct_(source_dtype, vec![f0, f1]);
480
481 let result = source_scalar.as_struct().cast(&target_dtype).unwrap();
483 assert_eq!(result.dtype(), &target_dtype);
484
485 let fields = result.as_struct().fields().unwrap().collect::<Vec<_>>();
486 assert_eq!(fields[0].as_primitive().typed_value::<i64>().unwrap(), 42);
487 assert_eq!(fields[1].as_primitive().typed_value::<i64>().unwrap(), 123);
488 }
489
490 #[test]
491 fn test_struct_cast_mismatched_fields() {
492 let source_fields = StructFields::new(
493 ["a"].into(),
494 vec![DType::Primitive(I32, Nullability::NonNullable)],
495 );
496 let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
497
498 let target_fields = StructFields::new(
499 ["a", "b"].into(),
500 vec![
501 DType::Primitive(I32, Nullability::NonNullable),
502 DType::Primitive(I32, Nullability::NonNullable),
503 ],
504 );
505 let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
506
507 let scalar = Scalar::struct_(
508 source_dtype,
509 vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
510 );
511
512 let result = scalar.as_struct().cast(&target_dtype);
513 assert!(result.is_err());
514 }
515
516 #[test]
517 fn test_struct_cast_to_non_struct() {
518 let (_, _, dtype) = setup_types();
519 let scalar = Scalar::struct_(
520 dtype,
521 vec![
522 Scalar::primitive::<i32>(1, Nullability::NonNullable),
523 Scalar::utf8("test", Nullability::NonNullable),
524 ],
525 );
526
527 let result = scalar
528 .as_struct()
529 .cast(&DType::Primitive(I32, Nullability::NonNullable));
530 assert!(result.is_err());
531 }
532
533 #[test]
534 fn test_struct_project() {
535 let (_, _, dtype) = setup_types();
536 let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
537 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
538
539 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
540
541 let projected = scalar.as_struct().project(&["b".into()]).unwrap();
543 let projected_struct = projected.as_struct();
544
545 assert_eq!(projected_struct.names().len(), 1);
546 assert_eq!(projected_struct.names()[0].as_ref(), "b");
547
548 let fields = projected_struct.fields().unwrap().collect::<Vec<_>>();
549 assert_eq!(fields.len(), 1);
550 assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello");
551 }
552
553 #[test]
554 fn test_struct_project_null() {
555 let (_, _, dtype) = setup_types();
556 let null_scalar = Scalar::null(dtype);
557
558 let projected = null_scalar.as_struct().project(&["a".into()]).unwrap();
559 assert!(projected.as_struct().is_null());
560 }
561
562 #[test]
563 fn test_struct_equality() {
564 let (_, _, dtype) = setup_types();
565
566 let scalar1 = Scalar::struct_(
567 dtype.clone(),
568 vec![
569 Scalar::primitive::<i32>(1, Nullability::NonNullable),
570 Scalar::utf8("test", Nullability::NonNullable),
571 ],
572 );
573
574 let scalar2 = Scalar::struct_(
575 dtype.clone(),
576 vec![
577 Scalar::primitive::<i32>(1, Nullability::NonNullable),
578 Scalar::utf8("test", Nullability::NonNullable),
579 ],
580 );
581
582 let scalar3 = Scalar::struct_(
583 dtype,
584 vec![
585 Scalar::primitive::<i32>(2, Nullability::NonNullable),
586 Scalar::utf8("test", Nullability::NonNullable),
587 ],
588 );
589
590 assert_eq!(scalar1.as_struct(), scalar2.as_struct());
591 assert_ne!(scalar1.as_struct(), scalar3.as_struct());
592 }
593
594 #[test]
595 fn test_struct_partial_ord() {
596 let (_, _, dtype) = setup_types();
597
598 let scalar1 = Scalar::struct_(
599 dtype.clone(),
600 vec![
601 Scalar::primitive::<i32>(1, Nullability::NonNullable),
602 Scalar::utf8("a", Nullability::NonNullable),
603 ],
604 );
605
606 let scalar2 = Scalar::struct_(
607 dtype,
608 vec![
609 Scalar::primitive::<i32>(2, Nullability::NonNullable),
610 Scalar::utf8("b", Nullability::NonNullable),
611 ],
612 );
613
614 assert!(scalar1.as_struct() < scalar2.as_struct());
616
617 let other_dtype = DType::Struct(
619 StructFields::new(
620 ["c"].into(),
621 vec![DType::Primitive(I32, Nullability::NonNullable)],
622 ),
623 Nullability::NonNullable,
624 );
625 let scalar3 = Scalar::struct_(
626 other_dtype,
627 vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
628 );
629
630 assert_eq!(scalar1.as_struct().partial_cmp(&scalar3.as_struct()), None);
631 }
632
633 #[test]
634 fn test_struct_hash() {
635 use std::collections::hash_map::DefaultHasher;
636 use std::hash::Hash;
637 use std::hash::Hasher;
638
639 let (_, _, dtype) = setup_types();
640
641 let scalar1 = Scalar::struct_(
642 dtype.clone(),
643 vec![
644 Scalar::primitive::<i32>(1, Nullability::NonNullable),
645 Scalar::utf8("test", Nullability::NonNullable),
646 ],
647 );
648
649 let scalar2 = Scalar::struct_(
650 dtype,
651 vec![
652 Scalar::primitive::<i32>(1, Nullability::NonNullable),
653 Scalar::utf8("test", Nullability::NonNullable),
654 ],
655 );
656
657 let mut hasher1 = DefaultHasher::new();
658 scalar1.as_struct().hash(&mut hasher1);
659 let hash1 = hasher1.finish();
660
661 let mut hasher2 = DefaultHasher::new();
662 scalar2.as_struct().hash(&mut hasher2);
663 let hash2 = hasher2.finish();
664
665 assert_eq!(hash1, hash2);
666 }
667
668 #[test]
669 fn test_struct_try_new_non_struct_dtype() {
670 let dtype = DType::Primitive(I32, Nullability::NonNullable);
671 let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
672
673 let result = StructScalar::try_new(&dtype, &value);
674 assert!(result.is_err());
675 }
676
677 #[test]
678 fn test_struct_field_out_of_bounds() {
679 let (_, _, dtype) = setup_types();
680 let scalar = Scalar::struct_(
681 dtype,
682 vec![
683 Scalar::primitive::<i32>(1, Nullability::NonNullable),
684 Scalar::utf8("test", Nullability::NonNullable),
685 ],
686 );
687
688 let field = scalar.as_struct().field_by_idx(10);
690 assert!(field.is_none());
691 }
692}