vortex_scalar/
list.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{
12    VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic,
13};
14
15use crate::{InnerScalarValue, Scalar, ScalarValue};
16
17/// A scalar value representing a list (array) of elements.
18///
19/// This type provides a view into a list scalar value, which can contain
20/// zero or more elements of the same type, or be null.
21pub struct ListScalar<'a> {
22    dtype: &'a DType,
23    element_dtype: &'a Arc<DType>,
24    elements: Option<Arc<[ScalarValue]>>,
25}
26
27impl Display for ListScalar<'_> {
28    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
29        match &self.elements {
30            None => write!(f, "null"),
31            Some(elems) => {
32                write!(
33                    f,
34                    "[{}]",
35                    elems
36                        .iter()
37                        .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
38                        .format(", ")
39                )
40            }
41        }
42    }
43}
44
45impl PartialEq for ListScalar<'_> {
46    fn eq(&self, other: &Self) -> bool {
47        self.dtype.eq_ignore_nullability(other.dtype) && self.elements() == other.elements()
48    }
49}
50
51impl Eq for ListScalar<'_> {}
52
53/// Ord is not implemented since it's undefined for different element DTypes
54impl PartialOrd for ListScalar<'_> {
55    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
56        if !self
57            .element_dtype
58            .eq_ignore_nullability(other.element_dtype)
59        {
60            return None;
61        }
62        self.elements().partial_cmp(&other.elements())
63    }
64}
65
66impl Hash for ListScalar<'_> {
67    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
68        self.dtype.hash(state);
69        self.elements().hash(state);
70    }
71}
72
73impl<'a> ListScalar<'a> {
74    /// Returns the data type of this list scalar.
75    #[inline]
76    pub fn dtype(&self) -> &'a DType {
77        self.dtype
78    }
79
80    /// Returns the number of elements in the list.
81    ///
82    /// Returns 0 if the list is null.
83    #[inline]
84    pub fn len(&self) -> usize {
85        self.elements.as_ref().map(|e| e.len()).unwrap_or(0)
86    }
87
88    /// Returns true if the list has no elements or is null.
89    #[inline]
90    pub fn is_empty(&self) -> bool {
91        match self.elements.as_ref() {
92            None => true,
93            Some(l) => l.is_empty(),
94        }
95    }
96
97    /// Returns true if the list is null.
98    #[inline]
99    pub fn is_null(&self) -> bool {
100        self.elements.is_none()
101    }
102
103    /// Returns the data type of the list's elements.
104    pub fn element_dtype(&self) -> &DType {
105        let DType::List(element_type, _) = self.dtype() else {
106            unreachable!();
107        };
108        (*element_type).deref()
109    }
110
111    /// Returns the element at the given index as a scalar.
112    ///
113    /// Returns None if the list is null or the index is out of bounds.
114    pub fn element(&self, idx: usize) -> Option<Scalar> {
115        self.elements
116            .as_ref()
117            .and_then(|l| l.get(idx))
118            .map(|value| Scalar {
119                dtype: self.element_dtype().clone(),
120                value: value.clone(),
121            })
122    }
123
124    /// Returns all elements in the list as a vector of scalars.
125    ///
126    /// Returns None if the list is null.
127    pub fn elements(&self) -> Option<Vec<Scalar>> {
128        self.elements.as_ref().map(|elems| {
129            elems
130                .iter()
131                .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
132                .collect_vec()
133        })
134    }
135
136    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
137        let DType::List(element_dtype, ..) = dtype else {
138            vortex_bail!("Can't cast {:?} to {}", self.dtype(), dtype)
139        };
140
141        Ok(Scalar::new(
142            dtype.clone(),
143            ScalarValue(InnerScalarValue::List(
144                self.elements
145                    .as_ref()
146                    .vortex_expect("nullness handled in Scalar::cast")
147                    .iter()
148                    .map(|element| {
149                        Scalar::new(DType::clone(self.element_dtype), element.clone())
150                            .cast(element_dtype)
151                            .map(|x| x.value().clone())
152                    })
153                    .process_results(|iter| iter.collect())?,
154            )),
155        ))
156    }
157}
158
159impl Scalar {
160    /// Creates a new list scalar with the given element type and children.
161    ///
162    /// # Panics
163    ///
164    /// Panics if any child scalar has a different type than the element type.
165    pub fn list(
166        element_dtype: impl Into<Arc<DType>>,
167        children: Vec<Scalar>,
168        nullability: Nullability,
169    ) -> Self {
170        let element_dtype = element_dtype.into();
171        for child in &children {
172            if child.dtype() != &*element_dtype {
173                vortex_panic!(
174                    "tried to create list of {} with values of type {}",
175                    element_dtype,
176                    child.dtype()
177                );
178            }
179        }
180        Self {
181            dtype: DType::List(element_dtype, nullability),
182            value: ScalarValue(InnerScalarValue::List(
183                children.into_iter().map(|x| x.value).collect(),
184            )),
185        }
186    }
187
188    /// Creates a new empty list scalar with the given element type.
189    pub fn list_empty(element_dtype: Arc<DType>, nullability: Nullability) -> Self {
190        Self {
191            dtype: DType::List(element_dtype, nullability),
192            value: ScalarValue(InnerScalarValue::Null),
193        }
194    }
195}
196
197impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> {
198    type Error = VortexError;
199
200    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
201        let DType::List(element_dtype, ..) = value.dtype() else {
202            vortex_bail!("Expected list scalar, found {}", value.dtype())
203        };
204
205        Ok(Self {
206            dtype: value.dtype(),
207            element_dtype,
208            elements: value.value.as_list()?.cloned(),
209        })
210    }
211}
212
213impl<'a, T: for<'b> TryFrom<&'b Scalar, Error = VortexError>> TryFrom<&'a Scalar> for Vec<T> {
214    type Error = VortexError;
215
216    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
217        let value = ListScalar::try_from(value)?;
218        let mut elems = vec![];
219        for e in value
220            .elements()
221            .ok_or_else(|| vortex_err!("Expected non-null list"))?
222        {
223            elems.push(T::try_from(&e)?);
224        }
225        Ok(elems)
226    }
227}