1use std::fmt::{Display, Formatter};
2use std::hash::{Hash, Hasher};
3use std::ops::Deref;
4use std::sync::Arc;
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
8use vortex_error::{
9 VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
10};
11
12use crate::{InnerScalarValue, Scalar, ScalarValue};
13
14pub struct StructScalar<'a> {
15 dtype: &'a DType,
16 fields: Option<&'a Arc<[ScalarValue]>>,
17}
18
19impl Display for StructScalar<'_> {
20 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21 match &self.fields {
22 None => write!(f, "null"),
23 Some(fields) => {
24 write!(f, "{{")?;
25 let formatted_fields = self
26 .names()
27 .iter()
28 .zip_eq(self.struct_fields().fields())
29 .zip_eq(fields.iter())
30 .map(|((name, dtype), value)| {
31 let val = Scalar::new(dtype, value.clone());
32 format!("{name}: {val}")
33 })
34 .format(", ");
35 write!(f, "{formatted_fields}")?;
36 write!(f, "}}")
37 }
38 }
39 }
40}
41
42impl PartialEq for StructScalar<'_> {
43 fn eq(&self, other: &Self) -> bool {
44 if !self.dtype.eq_ignore_nullability(other.dtype) {
45 return false;
46 }
47 self.fields() == other.fields()
48 }
49}
50
51impl Eq for StructScalar<'_> {}
52
53impl PartialOrd for StructScalar<'_> {
55 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
56 if !self.dtype.eq_ignore_nullability(other.dtype) {
57 return None;
58 }
59 self.fields().partial_cmp(&other.fields())
60 }
61}
62
63impl Hash for StructScalar<'_> {
64 fn hash<H: Hasher>(&self, state: &mut H) {
65 self.dtype.hash(state);
66 self.fields().hash(state);
67 }
68}
69
70impl<'a> StructScalar<'a> {
71 pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
72 if !matches!(dtype, DType::Struct(..)) {
73 vortex_bail!("Expected struct scalar, found {}", dtype)
74 }
75 Ok(Self {
76 dtype,
77 fields: value.as_list()?,
78 })
79 }
80
81 #[inline]
82 pub fn dtype(&self) -> &'a DType {
83 self.dtype
84 }
85
86 #[inline]
87 pub fn struct_fields(&self) -> &Arc<StructFields> {
88 let DType::Struct(sdtype, ..) = self.dtype else {
89 vortex_panic!("StructScalar always has struct dtype");
90 };
91 sdtype
92 }
93
94 pub fn names(&self) -> &FieldNames {
95 self.struct_fields().names()
96 }
97
98 pub fn is_null(&self) -> bool {
99 self.fields.is_none()
100 }
101
102 pub fn field(&self, name: impl AsRef<str>) -> VortexResult<Scalar> {
103 let DType::Struct(st, _) = self.dtype() else {
104 unreachable!()
105 };
106 let idx = st.find(name)?;
107 self.field_by_idx(idx)
108 }
109
110 pub fn field_by_idx(&self, idx: usize) -> VortexResult<Scalar> {
111 let fields = self
112 .fields
113 .vortex_expect("Can't take field out of null struct scalar");
114 let DType::Struct(st, _) = self.dtype() else {
115 unreachable!()
116 };
117
118 Ok(Scalar {
119 dtype: st.field_by_index(idx)?,
120 value: fields[idx].clone(),
121 })
122 }
123
124 pub fn fields(&self) -> Option<Vec<Scalar>> {
126 let fields = self.fields?;
127 Some(
128 (0..fields.len())
129 .map(|index| {
130 self.field_by_idx(index)
131 .vortex_expect("never out of bounds")
132 })
133 .collect::<Vec<_>>(),
134 )
135 }
136
137 pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
138 self.fields.map(Arc::deref)
139 }
140
141 pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
142 let DType::Struct(st, _) = dtype else {
143 vortex_bail!("Can only cast struct to another struct")
144 };
145 let DType::Struct(own_st, _) = self.dtype() else {
146 unreachable!()
147 };
148
149 if st.fields().len() != own_st.fields().len() {
150 vortex_bail!(
151 "Cannot cast between structs with different number of fields: {} and {}",
152 own_st.fields().len(),
153 st.fields().len()
154 );
155 }
156
157 if let Some(fs) = self.field_values() {
158 let fields = fs
159 .iter()
160 .enumerate()
161 .map(|(i, f)| {
162 Scalar {
163 dtype: own_st.field_by_index(i)?,
164 value: f.clone(),
165 }
166 .cast(&st.field_by_index(i)?)
167 .map(|s| s.value)
168 })
169 .collect::<VortexResult<Vec<_>>>()?;
170 Ok(Scalar {
171 dtype: dtype.clone(),
172 value: ScalarValue(InnerScalarValue::List(fields.into())),
173 })
174 } else {
175 Ok(Scalar::null(dtype.clone()))
176 }
177 }
178
179 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
180 let struct_dtype = self
181 .dtype
182 .as_struct()
183 .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
184 let projected_dtype = struct_dtype.project(projection)?;
185 let new_fields = if let Some(fs) = self.field_values() {
186 ScalarValue(InnerScalarValue::List(
187 projection
188 .iter()
189 .map(|name| {
190 struct_dtype
191 .find(name)
192 .vortex_expect("DType has been successfully projected already")
193 })
194 .map(|i| fs[i].clone())
195 .collect(),
196 ))
197 } else {
198 ScalarValue(InnerScalarValue::Null)
199 };
200 Ok(Scalar::new(
201 DType::Struct(Arc::new(projected_dtype), self.dtype().nullability()),
202 new_fields,
203 ))
204 }
205}
206
207impl Scalar {
208 pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
209 Self {
210 dtype,
211 value: ScalarValue(InnerScalarValue::List(
212 children
213 .into_iter()
214 .map(|x| x.into_value())
215 .collect_vec()
216 .into(),
217 )),
218 }
219 }
220}
221
222impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
223 type Error = VortexError;
224
225 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
226 Self::try_new(value.dtype(), &value.value)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use vortex_dtype::PType::I32;
233 use vortex_dtype::{DType, Nullability, StructFields};
234
235 use super::*;
236
237 fn setup_types() -> (DType, DType, DType) {
238 let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
239 let f1_dt = DType::Utf8(Nullability::NonNullable);
240
241 let dtype = DType::Struct(
242 Arc::new(StructFields::new(
243 vec!["a".into(), "b".into()].into(),
244 vec![f0_dt.clone(), f1_dt.clone()],
245 )),
246 Nullability::Nullable,
247 );
248
249 (f0_dt, f1_dt, dtype)
250 }
251
252 #[test]
253 #[should_panic]
254 fn test_struct_scalar_null() {
255 let (_, _, dtype) = setup_types();
256
257 let scalar = Scalar::null(dtype);
258
259 scalar.as_struct().field_by_idx(0).unwrap();
260 }
261
262 #[test]
263 fn test_struct_scalar_non_null() {
264 let (f0_dt, f1_dt, dtype) = setup_types();
265
266 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
267 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
268
269 let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
270 let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
271
272 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
273
274 let scalar_f0 = scalar.as_struct().field_by_idx(0);
275 assert!(scalar_f0.is_ok());
276 let scalar_f0 = scalar_f0.unwrap();
277 assert_eq!(scalar_f0, f0_val_null);
278 assert_eq!(scalar_f0.dtype(), &f0_dt);
279
280 let scalar_f1 = scalar.as_struct().field_by_idx(1);
281 assert!(scalar_f1.is_ok());
282 let scalar_f1 = scalar_f1.unwrap();
283 assert_eq!(scalar_f1, f1_val_null);
284 assert_eq!(scalar_f1.dtype(), &f1_dt);
285 }
286}