1use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_buffer::{BufferString, ByteBuffer};
11use vortex_dtype::half::f16;
12use vortex_dtype::{DType, Nullability};
13use vortex_error::{
14 VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic,
15};
16
17use crate::{InnerScalarValue, Scalar, ScalarValue};
18
19#[derive(Debug)]
24pub struct ListScalar<'a> {
25 dtype: &'a DType,
26 element_dtype: &'a Arc<DType>,
27 elements: Option<Arc<[ScalarValue]>>,
28}
29
30impl Display for ListScalar<'_> {
31 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32 match &self.elements {
33 None => write!(f, "null"),
34 Some(elems) => {
35 write!(
36 f,
37 "[{}]",
38 elems
39 .iter()
40 .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
41 .format(", ")
42 )
43 }
44 }
45 }
46}
47
48impl PartialEq for ListScalar<'_> {
49 fn eq(&self, other: &Self) -> bool {
50 self.dtype.eq_ignore_nullability(other.dtype) && self.elements() == other.elements()
51 }
52}
53
54impl Eq for ListScalar<'_> {}
55
56impl PartialOrd for ListScalar<'_> {
58 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
59 if !self
60 .element_dtype
61 .eq_ignore_nullability(other.element_dtype)
62 {
63 return None;
64 }
65 self.elements().partial_cmp(&other.elements())
66 }
67}
68
69impl Hash for ListScalar<'_> {
70 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
71 self.dtype.hash(state);
72 self.elements().hash(state);
73 }
74}
75
76impl<'a> ListScalar<'a> {
77 #[inline]
79 pub fn dtype(&self) -> &'a DType {
80 self.dtype
81 }
82
83 #[inline]
87 pub fn len(&self) -> usize {
88 self.elements.as_ref().map(|e| e.len()).unwrap_or(0)
89 }
90
91 #[inline]
93 pub fn is_empty(&self) -> bool {
94 match self.elements.as_ref() {
95 None => true,
96 Some(l) => l.is_empty(),
97 }
98 }
99
100 #[inline]
102 pub fn is_null(&self) -> bool {
103 self.elements.is_none()
104 }
105
106 pub fn element_dtype(&self) -> &DType {
108 let DType::List(element_type, _) = self.dtype() else {
109 unreachable!();
110 };
111 (*element_type).deref()
112 }
113
114 pub fn element(&self, idx: usize) -> Option<Scalar> {
118 self.elements
119 .as_ref()
120 .and_then(|l| l.get(idx))
121 .map(|value| Scalar::new(self.element_dtype().clone(), value.clone()))
122 }
123
124 pub fn elements(&self) -> Option<Vec<Scalar>> {
128 self.elements.as_ref().map(|elems| {
129 elems
130 .iter()
131 .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
132 .collect_vec()
133 })
134 }
135
136 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
137 let DType::List(element_dtype, ..) = dtype else {
138 vortex_bail!(
139 "Cannot cast {} to {}: list can only be cast to list",
140 self.dtype(),
141 dtype
142 )
143 };
144
145 Ok(Scalar::new(
146 dtype.clone(),
147 ScalarValue(InnerScalarValue::List(
148 self.elements
149 .as_ref()
150 .vortex_expect("nullness handled in Scalar::cast")
151 .iter()
152 .map(|element| {
153 Scalar::new(DType::clone(self.element_dtype), element.clone())
154 .cast(element_dtype)
155 .map(|x| x.value().clone())
156 })
157 .process_results(|iter| iter.collect())?,
158 )),
159 ))
160 }
161}
162
163impl Scalar {
164 pub fn list(
170 element_dtype: impl Into<Arc<DType>>,
171 children: Vec<Scalar>,
172 nullability: Nullability,
173 ) -> Self {
174 let element_dtype = element_dtype.into();
175 for child in &children {
176 if child.dtype() != &*element_dtype {
177 vortex_panic!(
178 "tried to create list of {} with values of type {}",
179 element_dtype,
180 child.dtype()
181 );
182 }
183 }
184 Self::new(
185 DType::List(element_dtype, nullability),
186 ScalarValue(InnerScalarValue::List(
187 children.into_iter().map(|x| x.value).collect(),
188 )),
189 )
190 }
191
192 pub fn list_empty(element_dtype: Arc<DType>, nullability: Nullability) -> Self {
194 Self::new(
195 DType::List(element_dtype, nullability),
196 ScalarValue(InnerScalarValue::Null),
197 )
198 }
199}
200
201impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> {
202 type Error = VortexError;
203
204 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
205 let DType::List(element_dtype, ..) = value.dtype() else {
206 vortex_bail!("Expected list scalar, found {}", value.dtype())
207 };
208
209 Ok(Self {
210 dtype: value.dtype(),
211 element_dtype,
212 elements: value.value.as_list()?.cloned(),
213 })
214 }
215}
216
217impl<'a, T> TryFrom<&'a Scalar> for Vec<T>
218where
219 T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
220{
221 type Error = VortexError;
222
223 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
224 let value = ListScalar::try_from(value)?;
225 let mut elems = vec![];
226 for e in value
227 .elements()
228 .ok_or_else(|| vortex_err!("Expected non-null list"))?
229 {
230 elems.push(T::try_from(&e)?);
231 }
232 Ok(elems)
233 }
234}
235
236impl<T> TryFrom<Scalar> for Vec<T>
237where
238 T: TryFrom<Scalar, Error = VortexError>,
239{
240 type Error = VortexError;
241
242 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
243 let value = ListScalar::try_from(&value)?;
244 let mut elems = vec![];
245 for e in value
246 .elements()
247 .ok_or_else(|| vortex_err!("Expected non-null list"))?
248 {
249 elems.push(T::try_from(e)?);
250 }
251 Ok(elems)
252 }
253}
254
255macro_rules! from_vec_for_scalar_value {
256 ($T:ty) => {
257 impl From<Vec<$T>> for ScalarValue {
258 fn from(value: Vec<$T>) -> Self {
259 ScalarValue(InnerScalarValue::List(
260 value
261 .into_iter()
262 .map(ScalarValue::from)
263 .collect::<Arc<[_]>>(),
264 ))
265 }
266 }
267 };
268}
269
270from_vec_for_scalar_value!(u16);
272from_vec_for_scalar_value!(u32);
273from_vec_for_scalar_value!(u64);
274from_vec_for_scalar_value!(usize); from_vec_for_scalar_value!(i8);
276from_vec_for_scalar_value!(i16);
277from_vec_for_scalar_value!(i32);
278from_vec_for_scalar_value!(i64);
279from_vec_for_scalar_value!(f16);
280from_vec_for_scalar_value!(f32);
281from_vec_for_scalar_value!(f64);
282from_vec_for_scalar_value!(String);
283from_vec_for_scalar_value!(BufferString);
284from_vec_for_scalar_value!(ByteBuffer);
285
286#[cfg(test)]
287mod tests {
288 use std::sync::Arc;
289
290 use vortex_dtype::{DType, Nullability, PType};
291
292 use super::*;
293
294 #[test]
295 fn test_list_scalar_creation() {
296 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
297 let children = vec![
298 Scalar::primitive(1i32, Nullability::NonNullable),
299 Scalar::primitive(2i32, Nullability::NonNullable),
300 Scalar::primitive(3i32, Nullability::NonNullable),
301 ];
302 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
303
304 let list = ListScalar::try_from(&list_scalar).unwrap();
305 assert_eq!(list.len(), 3);
306 assert!(!list.is_empty());
307 assert!(!list.is_null());
308 }
309
310 #[test]
311 fn test_empty_list() {
312 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
313 let list_scalar = Scalar::list(element_dtype, vec![], Nullability::NonNullable);
314
315 let list = ListScalar::try_from(&list_scalar).unwrap();
316 assert_eq!(list.len(), 0);
317 assert!(list.is_empty());
318 assert!(!list.is_null());
319 }
320
321 #[test]
322 fn test_null_list() {
323 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
324 let list_scalar = Scalar::list_empty(element_dtype, Nullability::Nullable);
325
326 let list = ListScalar::try_from(&list_scalar).unwrap();
327 assert_eq!(list.len(), 0);
328 assert!(list.is_empty());
329 assert!(list.is_null());
330 }
331
332 #[test]
333 fn test_list_element_access() {
334 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
335 let children = vec![
336 Scalar::primitive(10i32, Nullability::NonNullable),
337 Scalar::primitive(20i32, Nullability::NonNullable),
338 Scalar::primitive(30i32, Nullability::NonNullable),
339 ];
340 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
341
342 let list = ListScalar::try_from(&list_scalar).unwrap();
343
344 let elem0 = list.element(0).unwrap();
346 assert_eq!(elem0.as_primitive().typed_value::<i32>().unwrap(), 10);
347
348 let elem1 = list.element(1).unwrap();
349 assert_eq!(elem1.as_primitive().typed_value::<i32>().unwrap(), 20);
350
351 let elem2 = list.element(2).unwrap();
352 assert_eq!(elem2.as_primitive().typed_value::<i32>().unwrap(), 30);
353
354 assert!(list.element(3).is_none());
356 }
357
358 #[test]
359 fn test_list_elements() {
360 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
361 let children = vec![
362 Scalar::primitive(100i32, Nullability::NonNullable),
363 Scalar::primitive(200i32, Nullability::NonNullable),
364 ];
365 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
366
367 let list = ListScalar::try_from(&list_scalar).unwrap();
368 let elements = list.elements().unwrap();
369
370 assert_eq!(elements.len(), 2);
371 assert_eq!(
372 elements[0].as_primitive().typed_value::<i32>().unwrap(),
373 100
374 );
375 assert_eq!(
376 elements[1].as_primitive().typed_value::<i32>().unwrap(),
377 200
378 );
379 }
380
381 #[test]
382 fn test_list_display() {
383 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
384 let children = vec![
385 Scalar::primitive(1i32, Nullability::NonNullable),
386 Scalar::primitive(2i32, Nullability::NonNullable),
387 ];
388 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
389
390 let list = ListScalar::try_from(&list_scalar).unwrap();
391 let display = format!("{list}");
392 assert!(display.contains("1"));
393 assert!(display.contains("2"));
394 }
395
396 #[test]
397 fn test_null_list_display() {
398 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
399 let list_scalar = Scalar::list_empty(element_dtype, Nullability::Nullable);
400
401 let list = ListScalar::try_from(&list_scalar).unwrap();
402 let display = format!("{list}");
403 assert_eq!(display, "null");
404 }
405
406 #[test]
407 fn test_list_equality() {
408 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
409 let children1 = vec![
410 Scalar::primitive(1i32, Nullability::NonNullable),
411 Scalar::primitive(2i32, Nullability::NonNullable),
412 ];
413 let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
414
415 let children2 = vec![
416 Scalar::primitive(1i32, Nullability::NonNullable),
417 Scalar::primitive(2i32, Nullability::NonNullable),
418 ];
419 let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
420
421 let list1 = ListScalar::try_from(&list_scalar1).unwrap();
422 let list2 = ListScalar::try_from(&list_scalar2).unwrap();
423
424 assert_eq!(list1, list2);
425 }
426
427 #[test]
428 fn test_list_inequality() {
429 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
430 let children1 = vec![
431 Scalar::primitive(1i32, Nullability::NonNullable),
432 Scalar::primitive(2i32, Nullability::NonNullable),
433 ];
434 let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
435
436 let children2 = vec![
437 Scalar::primitive(1i32, Nullability::NonNullable),
438 Scalar::primitive(3i32, Nullability::NonNullable),
439 ];
440 let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
441
442 let list1 = ListScalar::try_from(&list_scalar1).unwrap();
443 let list2 = ListScalar::try_from(&list_scalar2).unwrap();
444
445 assert_ne!(list1, list2);
446 }
447
448 #[test]
449 fn test_list_partial_ord() {
450 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
451
452 let children1 = vec![Scalar::primitive(1i32, Nullability::NonNullable)];
453 let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
454
455 let children2 = vec![Scalar::primitive(2i32, Nullability::NonNullable)];
456 let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
457
458 let list1 = ListScalar::try_from(&list_scalar1).unwrap();
459 let list2 = ListScalar::try_from(&list_scalar2).unwrap();
460
461 assert!(list1 < list2);
462 }
463
464 #[test]
465 fn test_list_partial_ord_different_types() {
466 let element_dtype1 = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
467 let element_dtype2 = Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable));
468
469 let children1 = vec![Scalar::primitive(1i32, Nullability::NonNullable)];
470 let list_scalar1 = Scalar::list(element_dtype1, children1, Nullability::NonNullable);
471
472 let children2 = vec![Scalar::primitive(1i64, Nullability::NonNullable)];
473 let list_scalar2 = Scalar::list(element_dtype2, children2, Nullability::NonNullable);
474
475 let list1 = ListScalar::try_from(&list_scalar1).unwrap();
476 let list2 = ListScalar::try_from(&list_scalar2).unwrap();
477
478 assert!(list1.partial_cmp(&list2).is_none());
479 }
480
481 #[test]
482 fn test_list_hash() {
483 use std::collections::hash_map::DefaultHasher;
484 use std::hash::{Hash, Hasher};
485
486 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
487 let children = vec![
488 Scalar::primitive(1i32, Nullability::NonNullable),
489 Scalar::primitive(2i32, Nullability::NonNullable),
490 ];
491 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
492
493 let list = ListScalar::try_from(&list_scalar).unwrap();
494
495 let mut hasher1 = DefaultHasher::new();
496 list.hash(&mut hasher1);
497 let hash1 = hasher1.finish();
498
499 let mut hasher2 = DefaultHasher::new();
500 list.hash(&mut hasher2);
501 let hash2 = hasher2.finish();
502
503 assert_eq!(hash1, hash2);
504 }
505
506 #[test]
507 fn test_vec_conversion() {
508 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
509 let children = vec![
510 Scalar::primitive(10i32, Nullability::NonNullable),
511 Scalar::primitive(20i32, Nullability::NonNullable),
512 Scalar::primitive(30i32, Nullability::NonNullable),
513 ];
514 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
515
516 let vec: Vec<i32> = Vec::try_from(&list_scalar).unwrap();
517 assert_eq!(vec, vec![10, 20, 30]);
518 }
519
520 #[test]
521 fn test_vec_conversion_null_list() {
522 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
523 let list_scalar = Scalar::list_empty(element_dtype, Nullability::Nullable);
524
525 let result: Result<Vec<i32>, VortexError> = Vec::try_from(&list_scalar);
526 assert!(result.is_err());
527 }
528
529 #[test]
530 fn test_list_cast() {
531 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
532 let children = vec![
533 Scalar::primitive(1i32, Nullability::NonNullable),
534 Scalar::primitive(2i32, Nullability::NonNullable),
535 ];
536 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
537
538 let list = ListScalar::try_from(&list_scalar).unwrap();
539
540 let target_dtype = DType::List(
542 Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
543 Nullability::NonNullable,
544 );
545
546 let casted = list.cast(&target_dtype).unwrap();
547 let casted_list = ListScalar::try_from(&casted).unwrap();
548
549 assert_eq!(casted_list.len(), 2);
550 let elem0 = casted_list.element(0).unwrap();
551 assert_eq!(elem0.as_primitive().typed_value::<i64>().unwrap(), 1);
552 }
553
554 #[test]
555 #[should_panic(expected = "tried to create list of i32 with values of type i64")]
556 fn test_list_wrong_element_type_panic() {
557 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
558 let children = vec![
559 Scalar::primitive(1i64, Nullability::NonNullable), ];
561 let _ = Scalar::list(element_dtype, children, Nullability::NonNullable);
562 }
563
564 #[test]
565 fn test_try_from_wrong_dtype() {
566 let scalar = Scalar::primitive(42i32, Nullability::NonNullable);
567 let result = ListScalar::try_from(&scalar);
568 assert!(result.is_err());
569 }
570
571 #[test]
572 fn test_string_list() {
573 let element_dtype = Arc::new(DType::Utf8(Nullability::NonNullable));
574 let children = vec![
575 Scalar::utf8("hello".to_string(), Nullability::NonNullable),
576 Scalar::utf8("world".to_string(), Nullability::NonNullable),
577 ];
578 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
579
580 let list = ListScalar::try_from(&list_scalar).unwrap();
581 assert_eq!(list.len(), 2);
582
583 let elem0 = list.element(0).unwrap();
584 assert_eq!(elem0.as_utf8().value().unwrap().as_str(), "hello");
585
586 let elem1 = list.element(1).unwrap();
587 assert_eq!(elem1.as_utf8().value().unwrap().as_str(), "world");
588 }
589
590 #[test]
591 fn test_nested_lists() {
592 let inner_element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
593 let inner_list_dtype = Arc::new(DType::List(
594 inner_element_dtype.clone(),
595 Nullability::NonNullable,
596 ));
597
598 let inner_list1 = Scalar::list(
599 inner_element_dtype.clone(),
600 vec![
601 Scalar::primitive(1i32, Nullability::NonNullable),
602 Scalar::primitive(2i32, Nullability::NonNullable),
603 ],
604 Nullability::NonNullable,
605 );
606
607 let inner_list2 = Scalar::list(
608 inner_element_dtype,
609 vec![
610 Scalar::primitive(3i32, Nullability::NonNullable),
611 Scalar::primitive(4i32, Nullability::NonNullable),
612 ],
613 Nullability::NonNullable,
614 );
615
616 let outer_list = Scalar::list(
617 inner_list_dtype,
618 vec![inner_list1, inner_list2],
619 Nullability::NonNullable,
620 );
621
622 let list = ListScalar::try_from(&outer_list).unwrap();
623 assert_eq!(list.len(), 2);
624
625 let nested_list = list.element(0).unwrap();
626 let nested = ListScalar::try_from(&nested_list).unwrap();
627 assert_eq!(nested.len(), 2);
628 }
629}