1use std::fmt::{Display, Formatter};
5use std::hash::{Hash, Hasher};
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
11use vortex_error::{
12 VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
13};
14
15use crate::{InnerScalarValue, Scalar, ScalarValue};
16
17#[derive(Debug)]
22pub struct StructScalar<'a> {
23 dtype: &'a DType,
24 fields: Option<&'a Arc<[ScalarValue]>>,
25}
26
27impl Display for StructScalar<'_> {
28 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
29 match &self.fields {
30 None => write!(f, "null"),
31 Some(fields) => {
32 write!(f, "{{")?;
33 let formatted_fields = self
34 .names()
35 .iter()
36 .zip_eq(self.struct_fields().fields())
37 .zip_eq(fields.iter())
38 .map(|((name, dtype), value)| {
39 let val = Scalar::new(dtype, value.clone());
40 format!("{name}: {val}")
41 })
42 .format(", ");
43 write!(f, "{formatted_fields}")?;
44 write!(f, "}}")
45 }
46 }
47 }
48}
49
50impl PartialEq for StructScalar<'_> {
51 fn eq(&self, other: &Self) -> bool {
52 if !self.dtype.eq_ignore_nullability(other.dtype) {
53 return false;
54 }
55 self.fields() == other.fields()
56 }
57}
58
59impl Eq for StructScalar<'_> {}
60
61impl PartialOrd for StructScalar<'_> {
63 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
64 if !self.dtype.eq_ignore_nullability(other.dtype) {
65 return None;
66 }
67 self.fields().partial_cmp(&other.fields())
68 }
69}
70
71impl Hash for StructScalar<'_> {
72 fn hash<H: Hasher>(&self, state: &mut H) {
73 self.dtype.hash(state);
74 self.fields().hash(state);
75 }
76}
77
78impl<'a> StructScalar<'a> {
79 pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
80 if !matches!(dtype, DType::Struct(..)) {
81 vortex_bail!("Expected struct scalar, found {}", dtype)
82 }
83 Ok(Self {
84 dtype,
85 fields: value.as_list()?,
86 })
87 }
88
89 #[inline]
91 pub fn dtype(&self) -> &'a DType {
92 self.dtype
93 }
94
95 #[inline]
97 pub fn struct_fields(&self) -> &StructFields {
98 self.dtype
99 .as_struct()
100 .vortex_expect("StructScalar always has struct dtype")
101 }
102
103 pub fn names(&self) -> &FieldNames {
105 self.struct_fields().names()
106 }
107
108 pub fn is_null(&self) -> bool {
110 self.fields.is_none()
111 }
112
113 pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
117 let idx = self.struct_fields().find(name)?;
118 self.field_by_idx(idx)
119 }
120
121 pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
129 let fields = self
130 .fields
131 .vortex_expect("Can't take field out of null struct scalar");
132 Some(Scalar::new(
133 self.struct_fields().field_by_index(idx)?,
134 fields[idx].clone(),
135 ))
136 }
137
138 pub fn fields(&self) -> Option<Vec<Scalar>> {
140 let fields = self.fields?;
141 Some(
142 (0..fields.len())
143 .map(|index| {
144 self.field_by_idx(index)
145 .vortex_expect("never out of bounds")
146 })
147 .collect::<Vec<_>>(),
148 )
149 }
150
151 pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
152 self.fields.map(Arc::deref)
153 }
154
155 pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
161 let DType::Struct(st, _) = dtype else {
162 vortex_bail!(
163 "Cannot cast struct to {}: struct can only be cast to struct",
164 dtype
165 )
166 };
167 let own_st = self.struct_fields();
168
169 if st.fields().len() != own_st.fields().len() {
170 vortex_bail!(
171 "Cannot cast between structs with different number of fields: {} and {}",
172 own_st.fields().len(),
173 st.fields().len()
174 );
175 }
176
177 if let Some(fs) = self.field_values() {
178 let fields = fs
179 .iter()
180 .enumerate()
181 .map(|(i, f)| {
182 Scalar::new(
183 own_st
184 .field_by_index(i)
185 .vortex_expect("Iterating over scalar fields"),
186 f.clone(),
187 )
188 .cast(
189 &st.field_by_index(i)
190 .vortex_expect("Iterating over scalar fields"),
191 )
192 .map(|s| s.value)
193 })
194 .collect::<VortexResult<Vec<_>>>()?;
195 Ok(Scalar::new(
196 dtype.clone(),
197 ScalarValue(InnerScalarValue::List(fields.into())),
198 ))
199 } else {
200 Ok(Scalar::null(dtype.clone()))
201 }
202 }
203
204 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
210 let struct_dtype = self
211 .dtype
212 .as_struct()
213 .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
214 let projected_dtype = struct_dtype.project(projection)?;
215 let new_fields = if let Some(fs) = self.field_values() {
216 ScalarValue(InnerScalarValue::List(
217 projection
218 .iter()
219 .map(|name| {
220 struct_dtype
221 .find(name)
222 .vortex_expect("DType has been successfully projected already")
223 })
224 .map(|i| fs[i].clone())
225 .collect(),
226 ))
227 } else {
228 ScalarValue(InnerScalarValue::Null)
229 };
230 Ok(Scalar::new(
231 DType::Struct(projected_dtype, self.dtype().nullability()),
232 new_fields,
233 ))
234 }
235}
236
237impl Scalar {
238 pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
240 let DType::Struct(struct_fields, _) = &dtype else {
241 vortex_panic!("Expected struct dtype, found {}", dtype);
242 };
243
244 let field_dtypes = struct_fields.fields();
245 if children.len() != field_dtypes.len() {
246 vortex_panic!(
247 "Struct has {} fields but {} children were provided",
248 field_dtypes.len(),
249 children.len()
250 );
251 }
252
253 for (idx, (child, expected_dtype)) in children.iter().zip(field_dtypes).enumerate() {
254 if child.dtype() != &expected_dtype {
255 vortex_panic!(
256 "Field {} expected dtype {} but got {}",
257 idx,
258 expected_dtype,
259 child.dtype()
260 );
261 }
262 }
263
264 Self::new(
265 dtype,
266 ScalarValue(InnerScalarValue::List(
267 children
268 .into_iter()
269 .map(|x| x.into_value())
270 .collect_vec()
271 .into(),
272 )),
273 )
274 }
275}
276
277impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
278 type Error = VortexError;
279
280 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
281 Self::try_new(value.dtype(), &value.value)
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use vortex_dtype::PType::I32;
288 use vortex_dtype::{DType, Nullability, StructFields};
289
290 use super::*;
291
292 fn setup_types() -> (DType, DType, DType) {
293 let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
294 let f1_dt = DType::Utf8(Nullability::NonNullable);
295
296 let dtype = DType::Struct(
297 StructFields::new(
298 vec!["a".into(), "b".into()].into(),
299 vec![f0_dt.clone(), f1_dt.clone()],
300 ),
301 Nullability::Nullable,
302 );
303
304 (f0_dt, f1_dt, dtype)
305 }
306
307 #[test]
308 #[should_panic]
309 fn test_struct_scalar_null() {
310 let (_, _, dtype) = setup_types();
311
312 let scalar = Scalar::null(dtype);
313
314 scalar.as_struct().field_by_idx(0).unwrap();
315 }
316
317 #[test]
318 fn test_struct_scalar_non_null() {
319 let (f0_dt, f1_dt, dtype) = setup_types();
320
321 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
322 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
323
324 let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
325 let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
326
327 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
328
329 let scalar_f0 = scalar.as_struct().field_by_idx(0);
330 assert!(scalar_f0.is_some());
331 let scalar_f0 = scalar_f0.unwrap();
332 assert_eq!(scalar_f0, f0_val_null);
333 assert_eq!(scalar_f0.dtype(), &f0_dt);
334
335 let scalar_f1 = scalar.as_struct().field_by_idx(1);
336 assert!(scalar_f1.is_some());
337 let scalar_f1 = scalar_f1.unwrap();
338 assert_eq!(scalar_f1, f1_val_null);
339 assert_eq!(scalar_f1.dtype(), &f1_dt);
340 }
341
342 #[test]
343 #[should_panic(expected = "Expected struct dtype")]
344 fn test_struct_scalar_wrong_dtype() {
345 let dtype = DType::Primitive(I32, Nullability::NonNullable);
346 let scalar = Scalar::primitive::<i32>(1, Nullability::NonNullable);
347
348 Scalar::struct_(dtype, vec![scalar]);
349 }
350
351 #[test]
352 #[should_panic(expected = "Struct has 2 fields but 1 children were provided")]
353 fn test_struct_scalar_wrong_child_count() {
354 let (_, _, dtype) = setup_types();
355 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
356
357 Scalar::struct_(dtype, vec![f0_val]);
358 }
359
360 #[test]
361 #[should_panic(expected = "Field 0 expected dtype i32 but got utf8")]
362 fn test_struct_scalar_wrong_child_dtype() {
363 let (_, _, dtype) = setup_types();
364 let f0_val = Scalar::utf8("wrong", Nullability::NonNullable);
365 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
366
367 Scalar::struct_(dtype, vec![f0_val, f1_val]);
368 }
369
370 #[test]
371 fn test_struct_field_by_name() {
372 let (_, _, dtype) = setup_types();
373 let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
374 let f1_val = Scalar::utf8("world", Nullability::NonNullable);
375
376 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
377
378 let field_a = scalar.as_struct().field("a");
380 assert!(field_a.is_some());
381 assert_eq!(
382 field_a
383 .unwrap()
384 .as_primitive()
385 .typed_value::<i32>()
386 .unwrap(),
387 42
388 );
389
390 let field_b = scalar.as_struct().field("b");
391 assert!(field_b.is_some());
392 assert_eq!(field_b.unwrap().as_utf8().value().unwrap(), "world".into());
393
394 let field_c = scalar.as_struct().field("c");
396 assert!(field_c.is_none());
397 }
398
399 #[test]
400 fn test_struct_fields() {
401 let (_, _, dtype) = setup_types();
402 let f0_val = Scalar::primitive::<i32>(100, Nullability::NonNullable);
403 let f1_val = Scalar::utf8("test", Nullability::NonNullable);
404
405 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
406
407 let fields = scalar.as_struct().fields().unwrap();
408 assert_eq!(fields.len(), 2);
409 assert_eq!(fields[0].as_primitive().typed_value::<i32>().unwrap(), 100);
410 assert_eq!(fields[1].as_utf8().value().unwrap(), "test".into());
411 }
412
413 #[test]
414 fn test_struct_null_fields() {
415 let (_, _, dtype) = setup_types();
416 let null_scalar = Scalar::null(dtype);
417
418 assert!(null_scalar.as_struct().is_null());
419 assert!(null_scalar.as_struct().fields().is_none());
420 assert!(null_scalar.as_struct().field_values().is_none());
421 }
422
423 #[test]
424 fn test_struct_cast_to_struct() {
425 let source_fields = StructFields::new(
427 vec!["x".into(), "y".into()].into(),
428 vec![
429 DType::Primitive(I32, Nullability::NonNullable),
430 DType::Primitive(I32, Nullability::NonNullable),
431 ],
432 );
433 let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
434
435 let target_fields = StructFields::new(
437 vec!["x".into(), "y".into()].into(),
438 vec![
439 DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
440 DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
441 ],
442 );
443 let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
444
445 let f0 = Scalar::primitive::<i32>(42, Nullability::NonNullable);
446 let f1 = Scalar::primitive::<i32>(123, Nullability::NonNullable);
447 let source_scalar = Scalar::struct_(source_dtype, vec![f0, f1]);
448
449 let result = source_scalar.as_struct().cast(&target_dtype).unwrap();
451 assert_eq!(result.dtype(), &target_dtype);
452
453 let fields = result.as_struct().fields().unwrap();
454 assert_eq!(fields[0].as_primitive().typed_value::<i64>().unwrap(), 42);
455 assert_eq!(fields[1].as_primitive().typed_value::<i64>().unwrap(), 123);
456 }
457
458 #[test]
459 fn test_struct_cast_mismatched_fields() {
460 let source_fields = StructFields::new(
461 vec!["a".into()].into(),
462 vec![DType::Primitive(I32, Nullability::NonNullable)],
463 );
464 let source_dtype = DType::Struct(source_fields, Nullability::NonNullable);
465
466 let target_fields = StructFields::new(
467 vec!["a".into(), "b".into()].into(),
468 vec![
469 DType::Primitive(I32, Nullability::NonNullable),
470 DType::Primitive(I32, Nullability::NonNullable),
471 ],
472 );
473 let target_dtype = DType::Struct(target_fields, Nullability::NonNullable);
474
475 let scalar = Scalar::struct_(
476 source_dtype,
477 vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
478 );
479
480 let result = scalar.as_struct().cast(&target_dtype);
481 assert!(result.is_err());
482 }
483
484 #[test]
485 fn test_struct_cast_to_non_struct() {
486 let (_, _, dtype) = setup_types();
487 let scalar = Scalar::struct_(
488 dtype,
489 vec![
490 Scalar::primitive::<i32>(1, Nullability::NonNullable),
491 Scalar::utf8("test", Nullability::NonNullable),
492 ],
493 );
494
495 let result = scalar
496 .as_struct()
497 .cast(&DType::Primitive(I32, Nullability::NonNullable));
498 assert!(result.is_err());
499 }
500
501 #[test]
502 fn test_struct_project() {
503 let (_, _, dtype) = setup_types();
504 let f0_val = Scalar::primitive::<i32>(42, Nullability::NonNullable);
505 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
506
507 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
508
509 let projected = scalar.as_struct().project(&["b".into()]).unwrap();
511 let projected_struct = projected.as_struct();
512
513 assert_eq!(projected_struct.names().len(), 1);
514 assert_eq!(projected_struct.names()[0].as_ref(), "b");
515
516 let fields = projected_struct.fields().unwrap();
517 assert_eq!(fields.len(), 1);
518 assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello");
519 }
520
521 #[test]
522 fn test_struct_project_null() {
523 let (_, _, dtype) = setup_types();
524 let null_scalar = Scalar::null(dtype);
525
526 let projected = null_scalar.as_struct().project(&["a".into()]).unwrap();
527 assert!(projected.as_struct().is_null());
528 }
529
530 #[test]
531 fn test_struct_equality() {
532 let (_, _, dtype) = setup_types();
533
534 let scalar1 = Scalar::struct_(
535 dtype.clone(),
536 vec![
537 Scalar::primitive::<i32>(1, Nullability::NonNullable),
538 Scalar::utf8("test", Nullability::NonNullable),
539 ],
540 );
541
542 let scalar2 = Scalar::struct_(
543 dtype.clone(),
544 vec![
545 Scalar::primitive::<i32>(1, Nullability::NonNullable),
546 Scalar::utf8("test", Nullability::NonNullable),
547 ],
548 );
549
550 let scalar3 = Scalar::struct_(
551 dtype,
552 vec![
553 Scalar::primitive::<i32>(2, Nullability::NonNullable),
554 Scalar::utf8("test", Nullability::NonNullable),
555 ],
556 );
557
558 assert_eq!(scalar1.as_struct(), scalar2.as_struct());
559 assert_ne!(scalar1.as_struct(), scalar3.as_struct());
560 }
561
562 #[test]
563 fn test_struct_partial_ord() {
564 let (_, _, dtype) = setup_types();
565
566 let scalar1 = Scalar::struct_(
567 dtype.clone(),
568 vec![
569 Scalar::primitive::<i32>(1, Nullability::NonNullable),
570 Scalar::utf8("a", Nullability::NonNullable),
571 ],
572 );
573
574 let scalar2 = Scalar::struct_(
575 dtype,
576 vec![
577 Scalar::primitive::<i32>(2, Nullability::NonNullable),
578 Scalar::utf8("b", Nullability::NonNullable),
579 ],
580 );
581
582 assert!(scalar1.as_struct() < scalar2.as_struct());
584
585 let other_dtype = DType::Struct(
587 StructFields::new(
588 vec!["c".into()].into(),
589 vec![DType::Primitive(I32, Nullability::NonNullable)],
590 ),
591 Nullability::NonNullable,
592 );
593 let scalar3 = Scalar::struct_(
594 other_dtype,
595 vec![Scalar::primitive::<i32>(1, Nullability::NonNullable)],
596 );
597
598 assert_eq!(scalar1.as_struct().partial_cmp(&scalar3.as_struct()), None);
599 }
600
601 #[test]
602 fn test_struct_hash() {
603 use std::collections::hash_map::DefaultHasher;
604 use std::hash::{Hash, Hasher};
605
606 let (_, _, dtype) = setup_types();
607
608 let scalar1 = Scalar::struct_(
609 dtype.clone(),
610 vec![
611 Scalar::primitive::<i32>(1, Nullability::NonNullable),
612 Scalar::utf8("test", Nullability::NonNullable),
613 ],
614 );
615
616 let scalar2 = Scalar::struct_(
617 dtype,
618 vec![
619 Scalar::primitive::<i32>(1, Nullability::NonNullable),
620 Scalar::utf8("test", Nullability::NonNullable),
621 ],
622 );
623
624 let mut hasher1 = DefaultHasher::new();
625 scalar1.as_struct().hash(&mut hasher1);
626 let hash1 = hasher1.finish();
627
628 let mut hasher2 = DefaultHasher::new();
629 scalar2.as_struct().hash(&mut hasher2);
630 let hash2 = hasher2.finish();
631
632 assert_eq!(hash1, hash2);
633 }
634
635 #[test]
636 fn test_struct_try_new_non_struct_dtype() {
637 let dtype = DType::Primitive(I32, Nullability::NonNullable);
638 let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
639
640 let result = StructScalar::try_new(&dtype, &value);
641 assert!(result.is_err());
642 }
643
644 #[test]
645 fn test_struct_field_out_of_bounds() {
646 let (_, _, dtype) = setup_types();
647 let scalar = Scalar::struct_(
648 dtype,
649 vec![
650 Scalar::primitive::<i32>(1, Nullability::NonNullable),
651 Scalar::utf8("test", Nullability::NonNullable),
652 ],
653 );
654
655 let field = scalar.as_struct().field_by_idx(10);
657 assert!(field.is_none());
658 }
659}