vortex_scalar/
list.rs

1use std::fmt::{Display, Formatter};
2use std::hash::Hash;
3use std::ops::Deref;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{
9    VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic,
10};
11
12use crate::{InnerScalarValue, Scalar, ScalarValue};
13
14pub struct ListScalar<'a> {
15    dtype: &'a DType,
16    element_dtype: &'a Arc<DType>,
17    elements: Option<Arc<[ScalarValue]>>,
18}
19
20impl Display for ListScalar<'_> {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        match &self.elements {
23            None => write!(f, "null"),
24            Some(elems) => {
25                write!(
26                    f,
27                    "[{}]",
28                    elems
29                        .iter()
30                        .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
31                        .format(", ")
32                )
33            }
34        }
35    }
36}
37
38impl PartialEq for ListScalar<'_> {
39    fn eq(&self, other: &Self) -> bool {
40        self.dtype.eq_ignore_nullability(other.dtype) && self.elements() == other.elements()
41    }
42}
43
44impl Eq for ListScalar<'_> {}
45
46/// Ord is not implemented since it's undefined for different element DTypes
47impl PartialOrd for ListScalar<'_> {
48    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
49        if !self
50            .element_dtype
51            .eq_ignore_nullability(other.element_dtype)
52        {
53            return None;
54        }
55        self.elements().partial_cmp(&other.elements())
56    }
57}
58
59impl Hash for ListScalar<'_> {
60    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
61        self.dtype.hash(state);
62        self.elements().hash(state);
63    }
64}
65
66impl<'a> ListScalar<'a> {
67    #[inline]
68    pub fn dtype(&self) -> &'a DType {
69        self.dtype
70    }
71
72    #[inline]
73    pub fn len(&self) -> usize {
74        self.elements.as_ref().map(|e| e.len()).unwrap_or(0)
75    }
76
77    #[inline]
78    pub fn is_empty(&self) -> bool {
79        match self.elements.as_ref() {
80            None => true,
81            Some(l) => l.is_empty(),
82        }
83    }
84
85    #[inline]
86    pub fn is_null(&self) -> bool {
87        self.elements.is_none()
88    }
89
90    pub fn element_dtype(&self) -> &DType {
91        let DType::List(element_type, _) = self.dtype() else {
92            unreachable!();
93        };
94        (*element_type).deref()
95    }
96
97    pub fn element(&self, idx: usize) -> Option<Scalar> {
98        self.elements
99            .as_ref()
100            .and_then(|l| l.get(idx))
101            .map(|value| Scalar {
102                dtype: self.element_dtype().clone(),
103                value: value.clone(),
104            })
105    }
106
107    pub fn elements(&self) -> Option<Vec<Scalar>> {
108        self.elements.as_ref().map(|elems| {
109            elems
110                .iter()
111                .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
112                .collect_vec()
113        })
114    }
115
116    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
117        let DType::List(element_dtype, ..) = dtype else {
118            vortex_bail!("Can't cast {:?} to {}", self.dtype(), dtype)
119        };
120
121        Ok(Scalar::new(
122            dtype.clone(),
123            ScalarValue(InnerScalarValue::List(
124                self.elements
125                    .as_ref()
126                    .vortex_expect("nullness handled in Scalar::cast")
127                    .iter()
128                    .map(|element| {
129                        Scalar::new(DType::clone(self.element_dtype), element.clone())
130                            .cast(element_dtype)
131                            .map(|x| x.value().clone())
132                    })
133                    .process_results(|iter| iter.collect())?,
134            )),
135        ))
136    }
137}
138
139impl Scalar {
140    pub fn list(
141        element_dtype: Arc<DType>,
142        children: Vec<Scalar>,
143        nullability: Nullability,
144    ) -> Self {
145        for child in &children {
146            if child.dtype() != &*element_dtype {
147                vortex_panic!(
148                    "tried to create list of {} with values of type {}",
149                    element_dtype,
150                    child.dtype()
151                );
152            }
153        }
154        Self {
155            dtype: DType::List(element_dtype, nullability),
156            value: ScalarValue(InnerScalarValue::List(
157                children.into_iter().map(|x| x.value).collect::<Arc<[_]>>(),
158            )),
159        }
160    }
161
162    pub fn list_empty(element_dtype: Arc<DType>, nullability: Nullability) -> Self {
163        Self {
164            dtype: DType::List(element_dtype, nullability),
165            value: ScalarValue(InnerScalarValue::Null),
166        }
167    }
168}
169
170impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> {
171    type Error = VortexError;
172
173    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
174        let DType::List(element_dtype, ..) = value.dtype() else {
175            vortex_bail!("Expected list scalar, found {}", value.dtype())
176        };
177
178        Ok(Self {
179            dtype: value.dtype(),
180            element_dtype,
181            elements: value.value.as_list()?.cloned(),
182        })
183    }
184}
185
186impl<'a, T: for<'b> TryFrom<&'b Scalar, Error = VortexError>> TryFrom<&'a Scalar> for Vec<T> {
187    type Error = VortexError;
188
189    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
190        let value = ListScalar::try_from(value)?;
191        let mut elems = vec![];
192        for e in value
193            .elements()
194            .ok_or_else(|| vortex_err!("Expected non-null list"))?
195        {
196            elems.push(T::try_from(&e)?);
197        }
198        Ok(elems)
199    }
200}