vortex_scalar/
list.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use vortex_dtype::{DType, Nullability};
10use vortex_error::{
11    VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic,
12};
13
14use crate::{InnerScalarValue, Scalar, ScalarValue};
15
16/// A scalar value representing a list or fixed-size list (array) of elements.
17///
18/// We use the same [`ListScalar`] to represent both variants since a single list scalar's data is
19/// identical to a single fixed-size list scalar.
20///
21/// This type provides a view into a list or fixed-size list scalar value which can contain zero or
22/// more elements of the same type, or be null. If the `dtype` is a [`FixedSizeList`], then the
23/// number of `elements` is equal to the `size` field of the [`FixedSizeList`].
24///
25/// [`FixedSizeList`]: DType::FixedSizeList
26#[derive(Debug, Clone)]
27pub struct ListScalar<'a> {
28    dtype: &'a DType,
29    element_dtype: &'a Arc<DType>,
30    elements: Option<Arc<[ScalarValue]>>,
31}
32
33impl Display for ListScalar<'_> {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        match &self.elements {
36            None => write!(f, "null"),
37            Some(elems) => {
38                let fixed_size_list_str: &dyn Display =
39                    if let DType::FixedSizeList(_, size, _) = self.dtype {
40                        &format!("fixed_size<{size}>")
41                    } else {
42                        &""
43                    };
44
45                write!(
46                    f,
47                    "{fixed_size_list_str}[{}]",
48                    elems
49                        .iter()
50                        .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
51                        .format(", ")
52                )
53            }
54        }
55    }
56}
57
58impl PartialEq for ListScalar<'_> {
59    fn eq(&self, other: &Self) -> bool {
60        self.dtype.eq_ignore_nullability(other.dtype) && self.elements() == other.elements()
61    }
62}
63
64impl Eq for ListScalar<'_> {}
65
66/// Ord is not implemented since it's undefined for different element DTypes
67impl PartialOrd for ListScalar<'_> {
68    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
69        if !self
70            .element_dtype
71            .eq_ignore_nullability(other.element_dtype)
72        {
73            return None;
74        }
75        self.elements().partial_cmp(&other.elements())
76    }
77}
78
79impl Hash for ListScalar<'_> {
80    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
81        self.dtype.hash(state);
82        self.elements().hash(state);
83    }
84}
85
86impl<'a> ListScalar<'a> {
87    /// Returns the data type of this list scalar.
88    #[inline]
89    pub fn dtype(&self) -> &'a DType {
90        self.dtype
91    }
92
93    /// Returns the number of elements in the list.
94    ///
95    /// Returns 0 if the list is null.
96    #[inline]
97    pub fn len(&self) -> usize {
98        self.elements.as_ref().map(|e| e.len()).unwrap_or(0)
99    }
100
101    /// Returns true if the list has no elements or is null.
102    #[inline]
103    pub fn is_empty(&self) -> bool {
104        match self.elements.as_ref() {
105            None => true,
106            Some(l) => l.is_empty(),
107        }
108    }
109
110    /// Returns true if the list is null.
111    #[inline]
112    pub fn is_null(&self) -> bool {
113        self.elements.is_none()
114    }
115
116    /// Returns the data type of the list's elements.
117    pub fn element_dtype(&self) -> &DType {
118        self.dtype
119            .as_any_size_list_element_opt()
120            .unwrap_or_else(|| vortex_panic!("`ListScalar` somehow had dtype {}", self.dtype))
121            .as_ref()
122    }
123
124    /// Returns the element at the given index as a scalar.
125    ///
126    /// Returns None if the list is null or the index is out of bounds.
127    pub fn element(&self, idx: usize) -> Option<Scalar> {
128        self.elements
129            .as_ref()
130            .and_then(|l| l.get(idx))
131            .map(|value| Scalar::new(self.element_dtype().clone(), value.clone()))
132    }
133
134    /// Returns all elements in the list as a vector of scalars.
135    ///
136    /// Returns None if the list is null.
137    pub fn elements(&self) -> Option<Vec<Scalar>> {
138        self.elements.as_ref().map(|elems| {
139            elems
140                .iter()
141                .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
142                .collect_vec()
143        })
144    }
145
146    /// Casts the list to the target [`DType`].
147    ///
148    /// # Panics
149    ///
150    /// Panics if the target [`DType`] is not a [`List`]: or [`FixedSizeList`], or if trying to cast
151    /// to a [`FixedSizeList`] with the incorrect number of elements.
152    ///
153    /// [`List`]: DType::List
154    /// [`FixedSizeList`]: DType::FixedSizeList
155    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
156        let target_element_dtype = dtype
157            .as_any_size_list_element_opt()
158            .ok_or_else(|| {
159                vortex_err!(
160                    "Cannot cast {} to {}: list can only be cast to a list or fixed-size list",
161                    self.dtype(),
162                    dtype
163                )
164            })?
165            .as_ref();
166
167        if let DType::FixedSizeList(_, size, _) = dtype
168            && *size as usize != self.len()
169        {
170            vortex_bail!(
171                "tried to cast to a `FixedSizeList[{size}]` but had {} elements",
172                self.len()
173            )
174        }
175
176        Ok(Scalar::new(
177            dtype.clone(),
178            ScalarValue(InnerScalarValue::List(
179                self.elements
180                    .as_ref()
181                    .vortex_expect("nullness handled in Scalar::cast")
182                    .iter()
183                    .map(|element| {
184                        // Recursively cast the elements of the list.
185                        Scalar::new(DType::clone(self.element_dtype), element.clone())
186                            .cast(target_element_dtype)
187                            .map(|x| x.into_value())
188                    })
189                    .collect::<VortexResult<Arc<[ScalarValue]>>>()?,
190            )),
191        ))
192    }
193}
194
195/// A helper enum for creating a [`ListScalar`].
196enum ListKind {
197    Variable,
198    FixedSize,
199}
200
201/// Helper functions to create a [`ListScalar`] as a [`Scalar`].
202impl Scalar {
203    fn create_list(
204        element_dtype: impl Into<Arc<DType>>,
205        children: Vec<Scalar>,
206        nullability: Nullability,
207        list_kind: ListKind,
208    ) -> Self {
209        let element_dtype = element_dtype.into();
210
211        let children: Arc<[ScalarValue]> = children
212            .into_iter()
213            .map(|child| {
214                if child.dtype() != &*element_dtype {
215                    vortex_panic!(
216                        "tried to create list of {} with values of type {}",
217                        element_dtype,
218                        child.dtype()
219                    );
220                }
221                child.into_value()
222            })
223            .collect();
224        let size: u32 = children
225            .len()
226            .try_into()
227            .vortex_expect("tried to create a list that was too large");
228
229        let dtype = match list_kind {
230            ListKind::Variable => DType::List(element_dtype, nullability),
231            ListKind::FixedSize => DType::FixedSizeList(element_dtype, size, nullability),
232        };
233
234        Self::new(dtype, ScalarValue(InnerScalarValue::List(children)))
235    }
236
237    /// Creates a new list scalar with the given element type and children.
238    ///
239    /// # Panics
240    ///
241    /// Panics if any child scalar has a different type than the element type, or if there are too
242    /// many children.
243    pub fn list(
244        element_dtype: impl Into<Arc<DType>>,
245        children: Vec<Scalar>,
246        nullability: Nullability,
247    ) -> Self {
248        Self::create_list(element_dtype, children, nullability, ListKind::Variable)
249    }
250
251    /// Creates a new empty list scalar with the given element type.
252    pub fn list_empty(element_dtype: Arc<DType>, nullability: Nullability) -> Self {
253        Self::create_list(element_dtype, vec![], nullability, ListKind::Variable)
254    }
255
256    /// Creates a new fixed-size list scalar with the given element type and children.
257    ///
258    /// # Panics
259    ///
260    /// Panics if any child scalar has a different type than the element type, or if there are too
261    /// many children.
262    pub fn fixed_size_list(
263        element_dtype: impl Into<Arc<DType>>,
264        children: Vec<Scalar>,
265        nullability: Nullability,
266    ) -> Self {
267        Self::create_list(element_dtype, children, nullability, ListKind::FixedSize)
268    }
269}
270
271impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> {
272    type Error = VortexError;
273
274    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
275        let element_dtype = value
276            .dtype()
277            .as_any_size_list_element_opt()
278            .ok_or_else(|| vortex_err!("Expected list scalar, found {}", value.dtype()))?;
279
280        Ok(Self {
281            dtype: value.dtype(),
282            element_dtype,
283            elements: value.value().as_list()?.cloned(),
284        })
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use std::sync::Arc;
291
292    use vortex_dtype::{DType, Nullability, PType};
293
294    use super::*;
295
296    #[test]
297    fn test_list_scalar_creation() {
298        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
299        let children = vec![
300            Scalar::primitive(1i32, Nullability::NonNullable),
301            Scalar::primitive(2i32, Nullability::NonNullable),
302            Scalar::primitive(3i32, Nullability::NonNullable),
303        ];
304        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
305
306        let list = ListScalar::try_from(&list_scalar).unwrap();
307        assert_eq!(list.len(), 3);
308        assert!(!list.is_empty());
309        assert!(!list.is_null());
310    }
311
312    #[test]
313    fn test_empty_list() {
314        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
315        let list_scalar = Scalar::list_empty(element_dtype, Nullability::NonNullable);
316
317        let list = ListScalar::try_from(&list_scalar).unwrap();
318        assert_eq!(list.len(), 0);
319        assert!(list.is_empty());
320        assert!(!list.is_null());
321    }
322
323    #[test]
324    fn test_null_list() {
325        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
326        let null = Scalar::null(DType::List(element_dtype, Nullability::Nullable));
327
328        let list = ListScalar::try_from(&null).unwrap();
329        assert!(list.is_empty());
330        assert!(list.is_null());
331    }
332
333    #[test]
334    fn test_list_element_access() {
335        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
336        let children = vec![
337            Scalar::primitive(10i32, Nullability::NonNullable),
338            Scalar::primitive(20i32, Nullability::NonNullable),
339            Scalar::primitive(30i32, Nullability::NonNullable),
340        ];
341        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
342
343        let list = ListScalar::try_from(&list_scalar).unwrap();
344
345        // Test element access
346        let elem0 = list.element(0).unwrap();
347        assert_eq!(elem0.as_primitive().typed_value::<i32>().unwrap(), 10);
348
349        let elem1 = list.element(1).unwrap();
350        assert_eq!(elem1.as_primitive().typed_value::<i32>().unwrap(), 20);
351
352        let elem2 = list.element(2).unwrap();
353        assert_eq!(elem2.as_primitive().typed_value::<i32>().unwrap(), 30);
354
355        // Test out of bounds
356        assert!(list.element(3).is_none());
357    }
358
359    #[test]
360    fn test_list_elements() {
361        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
362        let children = vec![
363            Scalar::primitive(100i32, Nullability::NonNullable),
364            Scalar::primitive(200i32, Nullability::NonNullable),
365        ];
366        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
367
368        let list = ListScalar::try_from(&list_scalar).unwrap();
369        let elements = list.elements().unwrap();
370
371        assert_eq!(elements.len(), 2);
372        assert_eq!(
373            elements[0].as_primitive().typed_value::<i32>().unwrap(),
374            100
375        );
376        assert_eq!(
377            elements[1].as_primitive().typed_value::<i32>().unwrap(),
378            200
379        );
380    }
381
382    #[test]
383    fn test_list_display() {
384        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
385        let children = vec![
386            Scalar::primitive(1i32, Nullability::NonNullable),
387            Scalar::primitive(2i32, Nullability::NonNullable),
388        ];
389        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
390
391        let list = ListScalar::try_from(&list_scalar).unwrap();
392        let display = format!("{list}");
393        assert!(display.contains("1"));
394        assert!(display.contains("2"));
395    }
396
397    #[test]
398    fn test_list_equality() {
399        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
400        let children1 = vec![
401            Scalar::primitive(1i32, Nullability::NonNullable),
402            Scalar::primitive(2i32, Nullability::NonNullable),
403        ];
404        let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
405
406        let children2 = vec![
407            Scalar::primitive(1i32, Nullability::NonNullable),
408            Scalar::primitive(2i32, Nullability::NonNullable),
409        ];
410        let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
411
412        let list1 = ListScalar::try_from(&list_scalar1).unwrap();
413        let list2 = ListScalar::try_from(&list_scalar2).unwrap();
414
415        assert_eq!(list1, list2);
416    }
417
418    #[test]
419    fn test_list_inequality() {
420        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
421        let children1 = vec![
422            Scalar::primitive(1i32, Nullability::NonNullable),
423            Scalar::primitive(2i32, Nullability::NonNullable),
424        ];
425        let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
426
427        let children2 = vec![
428            Scalar::primitive(1i32, Nullability::NonNullable),
429            Scalar::primitive(3i32, Nullability::NonNullable),
430        ];
431        let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
432
433        let list1 = ListScalar::try_from(&list_scalar1).unwrap();
434        let list2 = ListScalar::try_from(&list_scalar2).unwrap();
435
436        assert_ne!(list1, list2);
437    }
438
439    #[test]
440    fn test_list_partial_ord() {
441        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
442
443        let children1 = vec![Scalar::primitive(1i32, Nullability::NonNullable)];
444        let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
445
446        let children2 = vec![Scalar::primitive(2i32, Nullability::NonNullable)];
447        let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
448
449        let list1 = ListScalar::try_from(&list_scalar1).unwrap();
450        let list2 = ListScalar::try_from(&list_scalar2).unwrap();
451
452        assert!(list1 < list2);
453    }
454
455    #[test]
456    fn test_list_partial_ord_different_types() {
457        let element_dtype1 = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
458        let element_dtype2 = Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable));
459
460        let children1 = vec![Scalar::primitive(1i32, Nullability::NonNullable)];
461        let list_scalar1 = Scalar::list(element_dtype1, children1, Nullability::NonNullable);
462
463        let children2 = vec![Scalar::primitive(1i64, Nullability::NonNullable)];
464        let list_scalar2 = Scalar::list(element_dtype2, children2, Nullability::NonNullable);
465
466        let list1 = ListScalar::try_from(&list_scalar1).unwrap();
467        let list2 = ListScalar::try_from(&list_scalar2).unwrap();
468
469        assert!(list1.partial_cmp(&list2).is_none());
470    }
471
472    #[test]
473    fn test_list_hash() {
474        use std::collections::hash_map::DefaultHasher;
475        use std::hash::{Hash, Hasher};
476
477        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
478        let children = vec![
479            Scalar::primitive(1i32, Nullability::NonNullable),
480            Scalar::primitive(2i32, Nullability::NonNullable),
481        ];
482        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
483
484        let list = ListScalar::try_from(&list_scalar).unwrap();
485
486        let mut hasher1 = DefaultHasher::new();
487        list.hash(&mut hasher1);
488        let hash1 = hasher1.finish();
489
490        let mut hasher2 = DefaultHasher::new();
491        list.hash(&mut hasher2);
492        let hash2 = hasher2.finish();
493
494        assert_eq!(hash1, hash2);
495    }
496
497    #[test]
498    fn test_vec_conversion() {
499        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
500        let children = vec![
501            Scalar::primitive(10i32, Nullability::NonNullable),
502            Scalar::primitive(20i32, Nullability::NonNullable),
503            Scalar::primitive(30i32, Nullability::NonNullable),
504        ];
505        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
506
507        let vec: Vec<i32> = Vec::try_from(&list_scalar).unwrap();
508        assert_eq!(vec, vec![10, 20, 30]);
509    }
510
511    #[test]
512    fn test_vec_conversion_empty_list() {
513        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
514        let list_scalar = Scalar::list_empty(element_dtype, Nullability::Nullable);
515
516        let result: VortexResult<Vec<i32>> = Vec::try_from(&list_scalar);
517        assert!(result.unwrap().is_empty());
518    }
519
520    #[test]
521    fn test_list_cast() {
522        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
523        let children = vec![
524            Scalar::primitive(1i32, Nullability::NonNullable),
525            Scalar::primitive(2i32, Nullability::NonNullable),
526        ];
527        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
528
529        let list = ListScalar::try_from(&list_scalar).unwrap();
530
531        // Cast to list with i64 elements
532        let target_dtype = DType::List(
533            Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
534            Nullability::NonNullable,
535        );
536
537        let casted = list.cast(&target_dtype).unwrap();
538        let casted_list = ListScalar::try_from(&casted).unwrap();
539
540        assert_eq!(casted_list.len(), 2);
541        let elem0 = casted_list.element(0).unwrap();
542        assert_eq!(elem0.as_primitive().typed_value::<i64>().unwrap(), 1);
543    }
544
545    #[test]
546    #[should_panic(expected = "tried to create list of i32 with values of type i64")]
547    fn test_list_wrong_element_type_panic() {
548        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
549        let children = vec![
550            Scalar::primitive(1i64, Nullability::NonNullable), // Wrong type!
551        ];
552        Scalar::list(element_dtype, children, Nullability::NonNullable);
553    }
554
555    #[test]
556    fn test_try_from_wrong_dtype() {
557        let scalar = Scalar::primitive(42i32, Nullability::NonNullable);
558        let result = ListScalar::try_from(&scalar);
559        assert!(result.is_err());
560    }
561
562    #[test]
563    fn test_string_list() {
564        let element_dtype = Arc::new(DType::Utf8(Nullability::NonNullable));
565        let children = vec![
566            Scalar::utf8("hello".to_string(), Nullability::NonNullable),
567            Scalar::utf8("world".to_string(), Nullability::NonNullable),
568        ];
569        let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
570
571        let list = ListScalar::try_from(&list_scalar).unwrap();
572        assert_eq!(list.len(), 2);
573
574        let elem0 = list.element(0).unwrap();
575        assert_eq!(elem0.as_utf8().value().unwrap().as_str(), "hello");
576
577        let elem1 = list.element(1).unwrap();
578        assert_eq!(elem1.as_utf8().value().unwrap().as_str(), "world");
579    }
580
581    #[test]
582    fn test_nested_lists() {
583        let inner_element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
584        let inner_list_dtype = Arc::new(DType::List(
585            inner_element_dtype.clone(),
586            Nullability::NonNullable,
587        ));
588
589        let inner_list1 = Scalar::list(
590            inner_element_dtype.clone(),
591            vec![
592                Scalar::primitive(1i32, Nullability::NonNullable),
593                Scalar::primitive(2i32, Nullability::NonNullable),
594            ],
595            Nullability::NonNullable,
596        );
597
598        let inner_list2 = Scalar::list(
599            inner_element_dtype,
600            vec![
601                Scalar::primitive(3i32, Nullability::NonNullable),
602                Scalar::primitive(4i32, Nullability::NonNullable),
603            ],
604            Nullability::NonNullable,
605        );
606
607        let outer_list = Scalar::list(
608            inner_list_dtype,
609            vec![inner_list1, inner_list2],
610            Nullability::NonNullable,
611        );
612
613        let list = ListScalar::try_from(&outer_list).unwrap();
614        assert_eq!(list.len(), 2);
615
616        let nested_list = list.element(0).unwrap();
617        let nested = ListScalar::try_from(&nested_list).unwrap();
618        assert_eq!(nested.len(), 2);
619    }
620}