1use std::hash::Hash;
2use std::ops::Deref;
3use std::sync::Arc;
4
5use itertools::Itertools as _;
6use vortex_dtype::{DType, Nullability};
7use vortex_error::{
8 VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic,
9};
10
11use crate::{InnerScalarValue, Scalar, ScalarValue};
12
13pub struct ListScalar<'a> {
14 dtype: &'a DType,
15 element_dtype: &'a Arc<DType>,
16 elements: Option<Arc<[ScalarValue]>>,
17}
18
19impl PartialEq for ListScalar<'_> {
20 fn eq(&self, other: &Self) -> bool {
21 self.dtype.eq_ignore_nullability(other.dtype) && self.elements() == other.elements()
22 }
23}
24
25impl Eq for ListScalar<'_> {}
26
27impl PartialOrd for ListScalar<'_> {
29 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
30 if !self
31 .element_dtype
32 .eq_ignore_nullability(other.element_dtype)
33 {
34 return None;
35 }
36 self.elements().partial_cmp(&other.elements())
37 }
38}
39
40impl Hash for ListScalar<'_> {
41 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42 self.dtype.hash(state);
43 self.elements().hash(state);
44 }
45}
46
47impl<'a> ListScalar<'a> {
48 #[inline]
49 pub fn dtype(&self) -> &'a DType {
50 self.dtype
51 }
52
53 #[inline]
54 pub fn len(&self) -> usize {
55 self.elements.as_ref().map(|e| e.len()).unwrap_or(0)
56 }
57
58 #[inline]
59 pub fn is_empty(&self) -> bool {
60 match self.elements.as_ref() {
61 None => true,
62 Some(l) => l.is_empty(),
63 }
64 }
65
66 #[inline]
67 pub fn is_null(&self) -> bool {
68 self.elements.is_none()
69 }
70
71 pub fn element_dtype(&self) -> &DType {
72 let DType::List(element_type, _) = self.dtype() else {
73 unreachable!();
74 };
75 (*element_type).deref()
76 }
77
78 pub fn element(&self, idx: usize) -> Option<Scalar> {
79 self.elements
80 .as_ref()
81 .and_then(|l| l.get(idx))
82 .map(|value| Scalar {
83 dtype: self.element_dtype().clone(),
84 value: value.clone(),
85 })
86 }
87
88 pub fn elements(&self) -> Option<Vec<Scalar>> {
89 self.elements.as_ref().map(|elems| {
90 elems
91 .iter()
92 .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
93 .collect_vec()
94 })
95 }
96
97 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
98 let DType::List(element_dtype, ..) = dtype else {
99 vortex_bail!("Can't cast {:?} to {}", self.dtype(), dtype)
100 };
101
102 Ok(Scalar::new(
103 dtype.clone(),
104 ScalarValue(InnerScalarValue::List(
105 self.elements
106 .as_ref()
107 .vortex_expect("nullness handled in Scalar::cast")
108 .iter()
109 .map(|element| {
110 Scalar::new(DType::clone(self.element_dtype), element.clone())
111 .cast(element_dtype)
112 .map(|x| x.value().clone())
113 })
114 .process_results(|iter| iter.collect())?,
115 )),
116 ))
117 }
118}
119
120impl Scalar {
121 pub fn list(
122 element_dtype: Arc<DType>,
123 children: Vec<Scalar>,
124 nullability: Nullability,
125 ) -> Self {
126 for child in &children {
127 if child.dtype() != &*element_dtype {
128 vortex_panic!(
129 "tried to create list of {} with values of type {}",
130 element_dtype,
131 child.dtype()
132 );
133 }
134 }
135 Self {
136 dtype: DType::List(element_dtype, nullability),
137 value: ScalarValue(InnerScalarValue::List(
138 children.into_iter().map(|x| x.value).collect::<Arc<[_]>>(),
139 )),
140 }
141 }
142
143 pub fn list_empty(element_dtype: Arc<DType>, nullability: Nullability) -> Self {
144 Self {
145 dtype: DType::List(element_dtype, nullability),
146 value: ScalarValue(InnerScalarValue::Null),
147 }
148 }
149}
150
151impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> {
152 type Error = VortexError;
153
154 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
155 let DType::List(element_dtype, ..) = value.dtype() else {
156 vortex_bail!("Expected list scalar, found {}", value.dtype())
157 };
158
159 Ok(Self {
160 dtype: value.dtype(),
161 element_dtype,
162 elements: value.value.as_list()?.cloned(),
163 })
164 }
165}
166
167impl<'a, T: for<'b> TryFrom<&'b Scalar, Error = VortexError>> TryFrom<&'a Scalar> for Vec<T> {
168 type Error = VortexError;
169
170 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
171 let value = ListScalar::try_from(value)?;
172 let mut elems = vec![];
173 for e in value
174 .elements()
175 .ok_or_else(|| vortex_err!("Expected non-null list"))?
176 {
177 elems.push(T::try_from(&e)?);
178 }
179 Ok(elems)
180 }
181}