1use std::fmt::{Display, Formatter};
2use std::hash::{Hash, Hasher};
3use std::ops::Deref;
4use std::sync::Arc;
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
8use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
9
10use crate::{InnerScalarValue, Scalar, ScalarValue};
11
12pub struct StructScalar<'a> {
13 dtype: &'a DType,
14 fields: Option<&'a Arc<[ScalarValue]>>,
15}
16
17impl Display for StructScalar<'_> {
18 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
19 match &self.fields {
20 None => write!(f, "null"),
21 Some(fields) => {
22 write!(f, "{{")?;
23 let formatted_fields = self
24 .names()
25 .iter()
26 .zip_eq(self.struct_fields().fields())
27 .zip_eq(fields.iter())
28 .map(|((name, dtype), value)| {
29 let val = Scalar::new(dtype, value.clone());
30 format!("{name}: {val}")
31 })
32 .format(", ");
33 write!(f, "{formatted_fields}")?;
34 write!(f, "}}")
35 }
36 }
37 }
38}
39
40impl PartialEq for StructScalar<'_> {
41 fn eq(&self, other: &Self) -> bool {
42 if !self.dtype.eq_ignore_nullability(other.dtype) {
43 return false;
44 }
45 self.fields() == other.fields()
46 }
47}
48
49impl Eq for StructScalar<'_> {}
50
51impl PartialOrd for StructScalar<'_> {
53 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
54 if !self.dtype.eq_ignore_nullability(other.dtype) {
55 return None;
56 }
57 self.fields().partial_cmp(&other.fields())
58 }
59}
60
61impl Hash for StructScalar<'_> {
62 fn hash<H: Hasher>(&self, state: &mut H) {
63 self.dtype.hash(state);
64 self.fields().hash(state);
65 }
66}
67
68impl<'a> StructScalar<'a> {
69 pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
70 if !matches!(dtype, DType::Struct(..)) {
71 vortex_bail!("Expected struct scalar, found {}", dtype)
72 }
73 Ok(Self {
74 dtype,
75 fields: value.as_list()?,
76 })
77 }
78
79 #[inline]
80 pub fn dtype(&self) -> &'a DType {
81 self.dtype
82 }
83
84 #[inline]
85 pub fn struct_fields(&self) -> &StructFields {
86 self.dtype
87 .as_struct()
88 .vortex_expect("StructScalar always has struct dtype")
89 }
90
91 pub fn names(&self) -> &FieldNames {
92 self.struct_fields().names()
93 }
94
95 pub fn is_null(&self) -> bool {
96 self.fields.is_none()
97 }
98
99 pub fn field(&self, name: impl AsRef<str>) -> Option<Scalar> {
100 let idx = self.struct_fields().find(name)?;
101 self.field_by_idx(idx)
102 }
103
104 pub fn field_by_idx(&self, idx: usize) -> Option<Scalar> {
105 let fields = self
106 .fields
107 .vortex_expect("Can't take field out of null struct scalar");
108 Some(Scalar {
109 dtype: self.struct_fields().field_by_index(idx)?,
110 value: fields[idx].clone(),
111 })
112 }
113
114 pub fn fields(&self) -> Option<Vec<Scalar>> {
116 let fields = self.fields?;
117 Some(
118 (0..fields.len())
119 .map(|index| {
120 self.field_by_idx(index)
121 .vortex_expect("never out of bounds")
122 })
123 .collect::<Vec<_>>(),
124 )
125 }
126
127 pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> {
128 self.fields.map(Arc::deref)
129 }
130
131 pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
132 let DType::Struct(st, _) = dtype else {
133 vortex_bail!("Can only cast struct to another struct")
134 };
135 let own_st = self.struct_fields();
136
137 if st.fields().len() != own_st.fields().len() {
138 vortex_bail!(
139 "Cannot cast between structs with different number of fields: {} and {}",
140 own_st.fields().len(),
141 st.fields().len()
142 );
143 }
144
145 if let Some(fs) = self.field_values() {
146 let fields = fs
147 .iter()
148 .enumerate()
149 .map(|(i, f)| {
150 Scalar {
151 dtype: own_st
152 .field_by_index(i)
153 .vortex_expect("Iterating over scalar fields"),
154 value: f.clone(),
155 }
156 .cast(
157 &st.field_by_index(i)
158 .vortex_expect("Iterating over scalar fields"),
159 )
160 .map(|s| s.value)
161 })
162 .collect::<VortexResult<Vec<_>>>()?;
163 Ok(Scalar {
164 dtype: dtype.clone(),
165 value: ScalarValue(InnerScalarValue::List(fields.into())),
166 })
167 } else {
168 Ok(Scalar::null(dtype.clone()))
169 }
170 }
171
172 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Scalar> {
173 let struct_dtype = self
174 .dtype
175 .as_struct()
176 .ok_or_else(|| vortex_err!("Not a struct dtype"))?;
177 let projected_dtype = struct_dtype.project(projection)?;
178 let new_fields = if let Some(fs) = self.field_values() {
179 ScalarValue(InnerScalarValue::List(
180 projection
181 .iter()
182 .map(|name| {
183 struct_dtype
184 .find(name)
185 .vortex_expect("DType has been successfully projected already")
186 })
187 .map(|i| fs[i].clone())
188 .collect(),
189 ))
190 } else {
191 ScalarValue(InnerScalarValue::Null)
192 };
193 Ok(Scalar::new(
194 DType::Struct(projected_dtype, self.dtype().nullability()),
195 new_fields,
196 ))
197 }
198}
199
200impl Scalar {
201 pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
202 Self {
203 dtype,
204 value: ScalarValue(InnerScalarValue::List(
205 children
206 .into_iter()
207 .map(|x| x.into_value())
208 .collect_vec()
209 .into(),
210 )),
211 }
212 }
213}
214
215impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> {
216 type Error = VortexError;
217
218 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
219 Self::try_new(value.dtype(), &value.value)
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use vortex_dtype::PType::I32;
226 use vortex_dtype::{DType, Nullability, StructFields};
227
228 use super::*;
229
230 fn setup_types() -> (DType, DType, DType) {
231 let f0_dt = DType::Primitive(I32, Nullability::NonNullable);
232 let f1_dt = DType::Utf8(Nullability::NonNullable);
233
234 let dtype = DType::Struct(
235 StructFields::new(
236 vec!["a".into(), "b".into()].into(),
237 vec![f0_dt.clone(), f1_dt.clone()],
238 ),
239 Nullability::Nullable,
240 );
241
242 (f0_dt, f1_dt, dtype)
243 }
244
245 #[test]
246 #[should_panic]
247 fn test_struct_scalar_null() {
248 let (_, _, dtype) = setup_types();
249
250 let scalar = Scalar::null(dtype);
251
252 scalar.as_struct().field_by_idx(0).unwrap();
253 }
254
255 #[test]
256 fn test_struct_scalar_non_null() {
257 let (f0_dt, f1_dt, dtype) = setup_types();
258
259 let f0_val = Scalar::primitive::<i32>(1, Nullability::NonNullable);
260 let f1_val = Scalar::utf8("hello", Nullability::NonNullable);
261
262 let f0_val_null = Scalar::primitive::<i32>(1, Nullability::Nullable);
263 let f1_val_null = Scalar::utf8("hello", Nullability::Nullable);
264
265 let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]);
266
267 let scalar_f0 = scalar.as_struct().field_by_idx(0);
268 assert!(scalar_f0.is_some());
269 let scalar_f0 = scalar_f0.unwrap();
270 assert_eq!(scalar_f0, f0_val_null);
271 assert_eq!(scalar_f0.dtype(), &f0_dt);
272
273 let scalar_f1 = scalar.as_struct().field_by_idx(1);
274 assert!(scalar_f1.is_some());
275 let scalar_f1 = scalar_f1.unwrap();
276 assert_eq!(scalar_f1, f1_val_null);
277 assert_eq!(scalar_f1.dtype(), &f1_dt);
278 }
279}