vortex_scalar/
list.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_buffer::{BufferString, ByteBuffer};
11use vortex_dtype::half::f16;
12use vortex_dtype::{DType, Nullability};
13use vortex_error::{
14    VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic,
15};
16
17use crate::{InnerScalarValue, Scalar, ScalarValue};
18
19/// A scalar value representing a list (array) of elements.
20///
21/// This type provides a view into a list scalar value, which can contain
22/// zero or more elements of the same type, or be null.
23#[derive(Debug)]
24pub struct ListScalar<'a> {
25    dtype: &'a DType,
26    element_dtype: &'a Arc<DType>,
27    elements: Option<Arc<[ScalarValue]>>,
28}
29
30impl Display for ListScalar<'_> {
31    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32        match &self.elements {
33            None => write!(f, "null"),
34            Some(elems) => {
35                write!(
36                    f,
37                    "[{}]",
38                    elems
39                        .iter()
40                        .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
41                        .format(", ")
42                )
43            }
44        }
45    }
46}
47
48impl PartialEq for ListScalar<'_> {
49    fn eq(&self, other: &Self) -> bool {
50        self.dtype.eq_ignore_nullability(other.dtype) && self.elements() == other.elements()
51    }
52}
53
54impl Eq for ListScalar<'_> {}
55
56/// Ord is not implemented since it's undefined for different element DTypes
57impl PartialOrd for ListScalar<'_> {
58    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
59        if !self
60            .element_dtype
61            .eq_ignore_nullability(other.element_dtype)
62        {
63            return None;
64        }
65        self.elements().partial_cmp(&other.elements())
66    }
67}
68
69impl Hash for ListScalar<'_> {
70    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
71        self.dtype.hash(state);
72        self.elements().hash(state);
73    }
74}
75
76impl<'a> ListScalar<'a> {
77    /// Returns the data type of this list scalar.
78    #[inline]
79    pub fn dtype(&self) -> &'a DType {
80        self.dtype
81    }
82
83    /// Returns the number of elements in the list.
84    ///
85    /// Returns 0 if the list is null.
86    #[inline]
87    pub fn len(&self) -> usize {
88        self.elements.as_ref().map(|e| e.len()).unwrap_or(0)
89    }
90
91    /// Returns true if the list has no elements or is null.
92    #[inline]
93    pub fn is_empty(&self) -> bool {
94        match self.elements.as_ref() {
95            None => true,
96            Some(l) => l.is_empty(),
97        }
98    }
99
100    /// Returns true if the list is null.
101    #[inline]
102    pub fn is_null(&self) -> bool {
103        self.elements.is_none()
104    }
105
106    /// Returns the data type of the list's elements.
107    pub fn element_dtype(&self) -> &DType {
108        let DType::List(element_type, _) = self.dtype() else {
109            unreachable!();
110        };
111        (*element_type).deref()
112    }
113
114    /// Returns the element at the given index as a scalar.
115    ///
116    /// Returns None if the list is null or the index is out of bounds.
117    pub fn element(&self, idx: usize) -> Option<Scalar> {
118        self.elements
119            .as_ref()
120            .and_then(|l| l.get(idx))
121            .map(|value| Scalar::new(self.element_dtype().clone(), value.clone()))
122    }
123
124    /// Returns all elements in the list as a vector of scalars.
125    ///
126    /// Returns None if the list is null.
127    pub fn elements(&self) -> Option<Vec<Scalar>> {
128        self.elements.as_ref().map(|elems| {
129            elems
130                .iter()
131                .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
132                .collect_vec()
133        })
134    }
135
136    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
137        let DType::List(element_dtype, ..) = dtype else {
138            vortex_bail!(
139                "Cannot cast {} to {}: list can only be cast to list",
140                self.dtype(),
141                dtype
142            )
143        };
144
145        Ok(Scalar::new(
146            dtype.clone(),
147            ScalarValue(InnerScalarValue::List(
148                self.elements
149                    .as_ref()
150                    .vortex_expect("nullness handled in Scalar::cast")
151                    .iter()
152                    .map(|element| {
153                        Scalar::new(DType::clone(self.element_dtype), element.clone())
154                            .cast(element_dtype)
155                            .map(|x| x.value().clone())
156                    })
157                    .process_results(|iter| iter.collect())?,
158            )),
159        ))
160    }
161}
162
163impl Scalar {
164    /// Creates a new list scalar with the given element type and children.
165    ///
166    /// # Panics
167    ///
168    /// Panics if any child scalar has a different type than the element type.
169    pub fn list(
170        element_dtype: impl Into<Arc<DType>>,
171        children: Vec<Scalar>,
172        nullability: Nullability,
173    ) -> Self {
174        let element_dtype = element_dtype.into();
175        for child in &children {
176            if child.dtype() != &*element_dtype {
177                vortex_panic!(
178                    "tried to create list of {} with values of type {}",
179                    element_dtype,
180                    child.dtype()
181                );
182            }
183        }
184        Self::new(
185            DType::List(element_dtype, nullability),
186            ScalarValue(InnerScalarValue::List(
187                children.into_iter().map(|x| x.value).collect(),
188            )),
189        )
190    }
191
192    /// Creates a new empty list scalar with the given element type.
193    pub fn list_empty(element_dtype: Arc<DType>, nullability: Nullability) -> Self {
194        Self::new(
195            DType::List(element_dtype, nullability),
196            ScalarValue(InnerScalarValue::Null),
197        )
198    }
199}
200
201impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> {
202    type Error = VortexError;
203
204    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
205        let DType::List(element_dtype, ..) = value.dtype() else {
206            vortex_bail!("Expected list scalar, found {}", value.dtype())
207        };
208
209        Ok(Self {
210            dtype: value.dtype(),
211            element_dtype,
212            elements: value.value.as_list()?.cloned(),
213        })
214    }
215}
216
217impl<'a, T> TryFrom<&'a Scalar> for Vec<T>
218where
219    T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
220{
221    type Error = VortexError;
222
223    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
224        let value = ListScalar::try_from(value)?;
225        let mut elems = vec![];
226        for e in value
227            .elements()
228            .ok_or_else(|| vortex_err!("Expected non-null list"))?
229        {
230            elems.push(T::try_from(&e)?);
231        }
232        Ok(elems)
233    }
234}
235
236impl<T> TryFrom<Scalar> for Vec<T>
237where
238    T: TryFrom<Scalar, Error = VortexError>,
239{
240    type Error = VortexError;
241
242    fn try_from(value: Scalar) -> Result<Self, Self::Error> {
243        let value = ListScalar::try_from(&value)?;
244        let mut elems = vec![];
245        for e in value
246            .elements()
247            .ok_or_else(|| vortex_err!("Expected non-null list"))?
248        {
249            elems.push(T::try_from(e)?);
250        }
251        Ok(elems)
252    }
253}
254
255macro_rules! from_vec_for_scalar_value {
256    ($T:ty) => {
257        impl From<Vec<$T>> for ScalarValue {
258            fn from(value: Vec<$T>) -> Self {
259                ScalarValue(InnerScalarValue::List(
260                    value
261                        .into_iter()
262                        .map(ScalarValue::from)
263                        .collect::<Arc<[_]>>(),
264                ))
265            }
266        }
267    };
268}
269
270// no From<Vec<u8>> because it could either be a List or a Buffer
271from_vec_for_scalar_value!(u16);
272from_vec_for_scalar_value!(u32);
273from_vec_for_scalar_value!(u64);
274from_vec_for_scalar_value!(usize); // For usize only, we implicitly cast for better ergonomics.
275from_vec_for_scalar_value!(i8);
276from_vec_for_scalar_value!(i16);
277from_vec_for_scalar_value!(i32);
278from_vec_for_scalar_value!(i64);
279from_vec_for_scalar_value!(f16);
280from_vec_for_scalar_value!(f32);
281from_vec_for_scalar_value!(f64);
282from_vec_for_scalar_value!(String);
283from_vec_for_scalar_value!(BufferString);
284from_vec_for_scalar_value!(ByteBuffer);
285
286#[cfg(test)]
287mod tests {
288    use std::sync::Arc;
289
290    use vortex_dtype::{DType, Nullability, PType};
291
292    use super::*;
293
294    #[test]
295    fn test_list_scalar_creation() {
296        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
297        let children = vec![
298            Scalar::primitive(1i32, Nullability::NonNullable),
299            Scalar::primitive(2i32, Nullability::NonNullable),
300            Scalar::primitive(3i32, Nullability::NonNullable),
301        ];
302        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
303
304        let list = ListScalar::try_from(&list_scalar).unwrap();
305        assert_eq!(list.len(), 3);
306        assert!(!list.is_empty());
307        assert!(!list.is_null());
308    }
309
310    #[test]
311    fn test_empty_list() {
312        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
313        let list_scalar = Scalar::list(element_dtype, vec![], Nullability::NonNullable);
314
315        let list = ListScalar::try_from(&list_scalar).unwrap();
316        assert_eq!(list.len(), 0);
317        assert!(list.is_empty());
318        assert!(!list.is_null());
319    }
320
321    #[test]
322    fn test_null_list() {
323        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
324        let list_scalar = Scalar::list_empty(element_dtype, Nullability::Nullable);
325
326        let list = ListScalar::try_from(&list_scalar).unwrap();
327        assert_eq!(list.len(), 0);
328        assert!(list.is_empty());
329        assert!(list.is_null());
330    }
331
332    #[test]
333    fn test_list_element_access() {
334        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
335        let children = vec![
336            Scalar::primitive(10i32, Nullability::NonNullable),
337            Scalar::primitive(20i32, Nullability::NonNullable),
338            Scalar::primitive(30i32, Nullability::NonNullable),
339        ];
340        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
341
342        let list = ListScalar::try_from(&list_scalar).unwrap();
343
344        // Test element access
345        let elem0 = list.element(0).unwrap();
346        assert_eq!(elem0.as_primitive().typed_value::<i32>().unwrap(), 10);
347
348        let elem1 = list.element(1).unwrap();
349        assert_eq!(elem1.as_primitive().typed_value::<i32>().unwrap(), 20);
350
351        let elem2 = list.element(2).unwrap();
352        assert_eq!(elem2.as_primitive().typed_value::<i32>().unwrap(), 30);
353
354        // Test out of bounds
355        assert!(list.element(3).is_none());
356    }
357
358    #[test]
359    fn test_list_elements() {
360        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
361        let children = vec![
362            Scalar::primitive(100i32, Nullability::NonNullable),
363            Scalar::primitive(200i32, Nullability::NonNullable),
364        ];
365        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
366
367        let list = ListScalar::try_from(&list_scalar).unwrap();
368        let elements = list.elements().unwrap();
369
370        assert_eq!(elements.len(), 2);
371        assert_eq!(
372            elements[0].as_primitive().typed_value::<i32>().unwrap(),
373            100
374        );
375        assert_eq!(
376            elements[1].as_primitive().typed_value::<i32>().unwrap(),
377            200
378        );
379    }
380
381    #[test]
382    fn test_list_display() {
383        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
384        let children = vec![
385            Scalar::primitive(1i32, Nullability::NonNullable),
386            Scalar::primitive(2i32, Nullability::NonNullable),
387        ];
388        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
389
390        let list = ListScalar::try_from(&list_scalar).unwrap();
391        let display = format!("{list}");
392        assert!(display.contains("1"));
393        assert!(display.contains("2"));
394    }
395
396    #[test]
397    fn test_null_list_display() {
398        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
399        let list_scalar = Scalar::list_empty(element_dtype, Nullability::Nullable);
400
401        let list = ListScalar::try_from(&list_scalar).unwrap();
402        let display = format!("{list}");
403        assert_eq!(display, "null");
404    }
405
406    #[test]
407    fn test_list_equality() {
408        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
409        let children1 = vec![
410            Scalar::primitive(1i32, Nullability::NonNullable),
411            Scalar::primitive(2i32, Nullability::NonNullable),
412        ];
413        let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
414
415        let children2 = vec![
416            Scalar::primitive(1i32, Nullability::NonNullable),
417            Scalar::primitive(2i32, Nullability::NonNullable),
418        ];
419        let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
420
421        let list1 = ListScalar::try_from(&list_scalar1).unwrap();
422        let list2 = ListScalar::try_from(&list_scalar2).unwrap();
423
424        assert_eq!(list1, list2);
425    }
426
427    #[test]
428    fn test_list_inequality() {
429        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
430        let children1 = vec![
431            Scalar::primitive(1i32, Nullability::NonNullable),
432            Scalar::primitive(2i32, Nullability::NonNullable),
433        ];
434        let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
435
436        let children2 = vec![
437            Scalar::primitive(1i32, Nullability::NonNullable),
438            Scalar::primitive(3i32, Nullability::NonNullable),
439        ];
440        let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
441
442        let list1 = ListScalar::try_from(&list_scalar1).unwrap();
443        let list2 = ListScalar::try_from(&list_scalar2).unwrap();
444
445        assert_ne!(list1, list2);
446    }
447
448    #[test]
449    fn test_list_partial_ord() {
450        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
451
452        let children1 = vec![Scalar::primitive(1i32, Nullability::NonNullable)];
453        let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
454
455        let children2 = vec![Scalar::primitive(2i32, Nullability::NonNullable)];
456        let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
457
458        let list1 = ListScalar::try_from(&list_scalar1).unwrap();
459        let list2 = ListScalar::try_from(&list_scalar2).unwrap();
460
461        assert!(list1 < list2);
462    }
463
464    #[test]
465    fn test_list_partial_ord_different_types() {
466        let element_dtype1 = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
467        let element_dtype2 = Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable));
468
469        let children1 = vec![Scalar::primitive(1i32, Nullability::NonNullable)];
470        let list_scalar1 = Scalar::list(element_dtype1, children1, Nullability::NonNullable);
471
472        let children2 = vec![Scalar::primitive(1i64, Nullability::NonNullable)];
473        let list_scalar2 = Scalar::list(element_dtype2, children2, Nullability::NonNullable);
474
475        let list1 = ListScalar::try_from(&list_scalar1).unwrap();
476        let list2 = ListScalar::try_from(&list_scalar2).unwrap();
477
478        assert!(list1.partial_cmp(&list2).is_none());
479    }
480
481    #[test]
482    fn test_list_hash() {
483        use std::collections::hash_map::DefaultHasher;
484        use std::hash::{Hash, Hasher};
485
486        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
487        let children = vec![
488            Scalar::primitive(1i32, Nullability::NonNullable),
489            Scalar::primitive(2i32, Nullability::NonNullable),
490        ];
491        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
492
493        let list = ListScalar::try_from(&list_scalar).unwrap();
494
495        let mut hasher1 = DefaultHasher::new();
496        list.hash(&mut hasher1);
497        let hash1 = hasher1.finish();
498
499        let mut hasher2 = DefaultHasher::new();
500        list.hash(&mut hasher2);
501        let hash2 = hasher2.finish();
502
503        assert_eq!(hash1, hash2);
504    }
505
506    #[test]
507    fn test_vec_conversion() {
508        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
509        let children = vec![
510            Scalar::primitive(10i32, Nullability::NonNullable),
511            Scalar::primitive(20i32, Nullability::NonNullable),
512            Scalar::primitive(30i32, Nullability::NonNullable),
513        ];
514        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
515
516        let vec: Vec<i32> = Vec::try_from(&list_scalar).unwrap();
517        assert_eq!(vec, vec![10, 20, 30]);
518    }
519
520    #[test]
521    fn test_vec_conversion_null_list() {
522        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
523        let list_scalar = Scalar::list_empty(element_dtype, Nullability::Nullable);
524
525        let result: Result<Vec<i32>, VortexError> = Vec::try_from(&list_scalar);
526        assert!(result.is_err());
527    }
528
529    #[test]
530    fn test_list_cast() {
531        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
532        let children = vec![
533            Scalar::primitive(1i32, Nullability::NonNullable),
534            Scalar::primitive(2i32, Nullability::NonNullable),
535        ];
536        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
537
538        let list = ListScalar::try_from(&list_scalar).unwrap();
539
540        // Cast to list with i64 elements
541        let target_dtype = DType::List(
542            Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
543            Nullability::NonNullable,
544        );
545
546        let casted = list.cast(&target_dtype).unwrap();
547        let casted_list = ListScalar::try_from(&casted).unwrap();
548
549        assert_eq!(casted_list.len(), 2);
550        let elem0 = casted_list.element(0).unwrap();
551        assert_eq!(elem0.as_primitive().typed_value::<i64>().unwrap(), 1);
552    }
553
554    #[test]
555    #[should_panic(expected = "tried to create list of i32 with values of type i64")]
556    fn test_list_wrong_element_type_panic() {
557        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
558        let children = vec![
559            Scalar::primitive(1i64, Nullability::NonNullable), // Wrong type!
560        ];
561        let _ = Scalar::list(element_dtype, children, Nullability::NonNullable);
562    }
563
564    #[test]
565    fn test_try_from_wrong_dtype() {
566        let scalar = Scalar::primitive(42i32, Nullability::NonNullable);
567        let result = ListScalar::try_from(&scalar);
568        assert!(result.is_err());
569    }
570
571    #[test]
572    fn test_string_list() {
573        let element_dtype = Arc::new(DType::Utf8(Nullability::NonNullable));
574        let children = vec![
575            Scalar::utf8("hello".to_string(), Nullability::NonNullable),
576            Scalar::utf8("world".to_string(), Nullability::NonNullable),
577        ];
578        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
579
580        let list = ListScalar::try_from(&list_scalar).unwrap();
581        assert_eq!(list.len(), 2);
582
583        let elem0 = list.element(0).unwrap();
584        assert_eq!(elem0.as_utf8().value().unwrap().as_str(), "hello");
585
586        let elem1 = list.element(1).unwrap();
587        assert_eq!(elem1.as_utf8().value().unwrap().as_str(), "world");
588    }
589
590    #[test]
591    fn test_nested_lists() {
592        let inner_element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
593        let inner_list_dtype = Arc::new(DType::List(
594            inner_element_dtype.clone(),
595            Nullability::NonNullable,
596        ));
597
598        let inner_list1 = Scalar::list(
599            inner_element_dtype.clone(),
600            vec![
601                Scalar::primitive(1i32, Nullability::NonNullable),
602                Scalar::primitive(2i32, Nullability::NonNullable),
603            ],
604            Nullability::NonNullable,
605        );
606
607        let inner_list2 = Scalar::list(
608            inner_element_dtype,
609            vec![
610                Scalar::primitive(3i32, Nullability::NonNullable),
611                Scalar::primitive(4i32, Nullability::NonNullable),
612            ],
613            Nullability::NonNullable,
614        );
615
616        let outer_list = Scalar::list(
617            inner_list_dtype,
618            vec![inner_list1, inner_list2],
619            Nullability::NonNullable,
620        );
621
622        let list = ListScalar::try_from(&outer_list).unwrap();
623        assert_eq!(list.len(), 2);
624
625        let nested_list = list.element(0).unwrap();
626        let nested = ListScalar::try_from(&nested_list).unwrap();
627        assert_eq!(nested.len(), 2);
628    }
629}