1use std::fmt::{Display, Formatter};
2use std::hash::{Hash, Hasher};
3use std::ops::Deref;
4use std::sync::Arc;
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, StructDType};
8use vortex_error::{
9 VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
10};
11
12use crate::{InnerScalarValue, Scalar, ScalarValue};
13
14pub struct StructScalar<'a> {
15 dtype: &'a DType,
16 fields: Option<&'a Arc<[ScalarValue]>>,
17}
18
19impl Display for StructScalar<'_> {
20 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21 match &self.fields {
22 None => write!(f, "null"),
23 Some(fields) => {
24 write!(f, "{{")?;
25 let formatted_fields = self
26 .struct_dtype()
27 .names()
28 .iter()
29 .zip_eq(self.struct_dtype().fields())
30 .zip_eq(fields.iter())
31 .map(|((name, dtype), value)| {
32 let val = Scalar::new(dtype, value.clone());
33 format!("{name}: {val}")
34 })
35 .format(", ");
36 write!(f, "{}", formatted_fields)?;
37 write!(f, "}}")
38 }
39 }
40 }
41}
42
43impl PartialEq for StructScalar<'_> {
44 fn eq(&self, other: &Self) -> bool {
45 if !self.dtype.eq_ignore_nullability(other.dtype) {
46 return false;
47 }
48 self.fields() == other.fields()
49 }
50}
51
52impl Eq for StructScalar<'_> {}
53
54impl PartialOrd for StructScalar<'_> {
56 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
57 if !self.dtype.eq_ignore_nullability(other.dtype) {
58 return None;
59 }
60 self.fields().partial_cmp(&other.fields())
61 }
62}
63
64impl Hash for StructScalar<'_> {
65 fn hash<H: Hasher>(&self, state: &mut H) {
66 self.dtype.hash(state);
67 self.fields().hash(state);
68 }
69}
70
71impl<'a> StructScalar<'a> {
72 pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
73 if !matches!(dtype, DType::Struct(..)) {
74 vortex_bail!("Expected struct scalar, found {}", dtype)
75 }
76 Ok(Self {
77 dtype,
78 fields: value.as_list()?,
79 })
80 }
81
82 #[inline]
83 pub fn dtype(&self) -> &'a DType {
84 self.dtype
85 }
86
87 #[inline]
88 pub fn struct_dtype(&self) -> &Arc<StructDType> {
89 let DType::Struct(sdtype, ..) = self.dtype else {
90 vortex_panic!("StructScalar always has struct dtype");
91 };
92 sdtype
93 }
94
95 pub fn is_null(&self) -> bool {
96 self.fields.is_none()
97 }
98
99 pub fn field(&self, name: impl AsRef<str>) -> VortexResult<Scalar> {
100 let DType::Struct(st, _) = self.dtype() else {
101 unreachable!()
102 };
103 let idx = st.find(name)?;
104 self.field_by_idx(idx)
105 }
106
107 pub fn field_by_idx(&self, idx: usize) -> VortexResult<Scalar> {
108 let fields = self
109 .fields
110 .vortex_expect("Can't take field out of null struct scalar");
111 let DType::Struct(st, _) = self.dtype() else {
112 unreachable!()
113 };
114
115 Ok(Scalar {
116 dtype: st.field_by_index(idx)?,
117 value: fields[idx].clone(),
118 })
119 }
120
121 pub fn fields(&self) -> Option<Vec<Scalar>> {
123 let fields = self.fields?;
124 Some(
125 (0..fields.len())
126 .map(|index| {
127 self.field_by_idx(index)
128 .vortex_expect("never out of bounds")
129 })
130 .collect::<Vec<_>>(),
131 )
132 }
133
134 pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
135 self.fields.map(Arc::deref)
136 }
137
138 pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
139 let DType::Struct(st, _) = dtype else {
140 vortex_bail!("Can only cast struct to another struct")
141 };
142 let DType::Struct(own_st, _) = self.dtype() else {
143 unreachable!()
144 };
145
146 if st.fields().len() != own_st.fields().len() {
147 vortex_bail!(
148 "Cannot cast between structs with different number of fields: {} and {}",
149 own_st.fields().len(),
150 st.fields().len()
151 );
152 }
153
154 if let Some(fs) = self.field_values() {
155 let fields = fs
156 .iter()
157 .enumerate()
158 .map(|(i, f)| {
159 Scalar {
160 dtype: own_st.field_by_index(i)?,
161 value: f.clone(),
162 }
163 .cast(&st.field_by_index(i)?)
164 .map(|s| s.value)
165 })
166 .collect::<VortexResult<Vec<_>>>()?;
167 Ok(Scalar {
168 dtype: dtype.clone(),
169 value: ScalarValue(InnerScalarValue::List(fields.into())),
170 })
171 } else {
172 Ok(Scalar::null(dtype.clone()))
173 }
174 }
175
176 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
177 let struct_dtype = self
178 .dtype
179 .as_struct()
180 .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
181 let projected_dtype = struct_dtype.project(projection)?;
182 let new_fields = if let Some(fs) = self.field_values() {
183 ScalarValue(InnerScalarValue::List(
184 projection
185 .iter()
186 .map(|name| {
187 struct_dtype
188 .find(name)
189 .vortex_expect("DType has been successfully projected already")
190 })
191 .map(|i| fs[i].clone())
192 .collect(),
193 ))
194 } else {
195 ScalarValue(InnerScalarValue::Null)
196 };
197 Ok(Scalar::new(
198 DType::Struct(Arc::new(projected_dtype), self.dtype().nullability()),
199 new_fields,
200 ))
201 }
202}
203
204impl Scalar {
205 pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
206 Self {
207 dtype,
208 value: ScalarValue(InnerScalarValue::List(
209 children
210 .into_iter()
211 .map(|x| x.into_value())
212 .collect_vec()
213 .into(),
214 )),
215 }
216 }
217}
218
219impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
220 type Error = VortexError;
221
222 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
223 Self::try_new(value.dtype(), &value.value)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use vortex_dtype::PType::I32;
230 use vortex_dtype::{DType, Nullability, StructDType};
231
232 use super::*;
233
234 fn setup_types() -> (DType, DType, DType) {
235 let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
236 let f1_dt = DType::Utf8(Nullability::NonNullable);
237
238 let dtype = DType::Struct(
239 Arc::new(StructDType::new(
240 vec!["a".into(), "b".into()].into(),
241 vec![f0_dt.clone(), f1_dt.clone()],
242 )),
243 Nullability::Nullable,
244 );
245
246 (f0_dt, f1_dt, dtype)
247 }
248
249 #[test]
250 #[should_panic]
251 fn test_struct_scalar_null() {
252 let (_, _, dtype) = setup_types();
253
254 let scalar = Scalar::null(dtype);
255
256 scalar.as_struct().field_by_idx(0).unwrap();
257 }
258
259 #[test]
260 fn test_struct_scalar_non_null() {
261 let (f0_dt, f1_dt, dtype) = setup_types();
262
263 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
264 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
265
266 let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
267 let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
268
269 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
270
271 let scalar_f0 = scalar.as_struct().field_by_idx(0);
272 assert!(scalar_f0.is_ok());
273 let scalar_f0 = scalar_f0.unwrap();
274 assert_eq!(scalar_f0, f0_val_null);
275 assert_eq!(scalar_f0.dtype(), &f0_dt);
276
277 let scalar_f1 = scalar.as_struct().field_by_idx(1);
278 assert!(scalar_f1.is_ok());
279 let scalar_f1 = scalar_f1.unwrap();
280 assert_eq!(scalar_f1, f1_val_null);
281 assert_eq!(scalar_f1.dtype(), &f1_dt);
282 }
283}