1use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_dtype::datetime::{TemporalMetadata, is_temporal_ext_type};
9use vortex_dtype::{DType, ExtDType};
10use vortex_error::{VortexError, VortexResult, vortex_bail};
11
12use crate::{Scalar, ScalarValue};
13
14#[derive(Debug)]
18pub struct ExtScalar<'a> {
19 ext_dtype: &'a ExtDType,
20 value: &'a ScalarValue,
21}
22
23impl Display for ExtScalar<'_> {
24 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25 if is_temporal_ext_type(self.ext_dtype().id()) {
27 let metadata =
28 TemporalMetadata::try_from(self.ext_dtype()).map_err(|_| std::fmt::Error)?;
29
30 let maybe_timestamp = self
31 .storage()
32 .as_primitive()
33 .as_::<i64>()
34 .and_then(|maybe_timestamp| {
35 maybe_timestamp.map(|v| metadata.to_jiff(v)).transpose()
36 })
37 .map_err(|_| std::fmt::Error)?;
38
39 match maybe_timestamp {
40 None => write!(f, "null"),
41 Some(v) => write!(f, "{v}"),
42 }
43 } else {
44 write!(f, "{}({})", self.ext_dtype().id(), self.storage())
45 }
46 }
47}
48
49impl PartialEq for ExtScalar<'_> {
50 fn eq(&self, other: &Self) -> bool {
51 self.ext_dtype.eq_ignore_nullability(other.ext_dtype) && self.storage() == other.storage()
52 }
53}
54
55impl Eq for ExtScalar<'_> {}
56
57impl PartialOrd for ExtScalar<'_> {
59 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
60 if !self.ext_dtype.eq_ignore_nullability(other.ext_dtype) {
61 return None;
62 }
63 self.storage().partial_cmp(&other.storage())
64 }
65}
66
67impl Hash for ExtScalar<'_> {
68 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
69 self.ext_dtype.hash(state);
70 self.storage().hash(state);
71 }
72}
73
74impl<'a> ExtScalar<'a> {
75 pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
81 let DType::Extension(ext_dtype) = dtype else {
82 vortex_bail!("Expected extension scalar, found {}", dtype)
83 };
84
85 Ok(Self { ext_dtype, value })
86 }
87
88 pub fn storage(&self) -> Scalar {
90 Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone())
91 }
92
93 pub fn ext_dtype(&self) -> &'a ExtDType {
95 self.ext_dtype
96 }
97
98 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
99 if self.value.is_null() && !dtype.is_nullable() {
100 vortex_bail!(
101 "cannot cast extension dtype with id {} and storage type {} to {}",
102 self.ext_dtype.id(),
103 self.ext_dtype.storage_dtype(),
104 dtype
105 );
106 }
107
108 if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) {
109 return Ok(Scalar::new(dtype.clone(), self.value.clone()));
111 }
112
113 if let DType::Extension(ext_dtype) = dtype
114 && self.ext_dtype.eq_ignore_nullability(ext_dtype)
115 {
116 return Ok(Scalar::new(dtype.clone(), self.value.clone()));
117 }
118
119 vortex_bail!(
120 "cannot cast extension dtype with id {} and storage type {} to {}",
121 self.ext_dtype.id(),
122 self.ext_dtype.storage_dtype(),
123 dtype
124 );
125 }
126}
127
128impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> {
129 type Error = VortexError;
130
131 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
132 ExtScalar::try_new(value.dtype(), &value.value)
133 }
134}
135
136impl Scalar {
137 pub fn extension(ext_dtype: Arc<ExtDType>, value: Scalar) -> Self {
139 Self {
142 dtype: DType::Extension(ext_dtype),
143 value: value.value().clone(),
144 }
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use std::sync::Arc;
151
152 use vortex_dtype::{DType, ExtDType, ExtID, ExtMetadata, Nullability, PType};
153
154 use crate::{ExtScalar, InnerScalarValue, Scalar, ScalarValue};
155
156 #[test]
157 fn test_ext_scalar_equality() {
158 let ext_dtype = Arc::new(ExtDType::new(
159 ExtID::new("test_ext".into()),
160 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
161 None,
162 ));
163
164 let scalar1 = Scalar::extension(
165 ext_dtype.clone(),
166 Scalar::primitive(42i32, Nullability::NonNullable),
167 );
168 let scalar2 = Scalar::extension(
169 ext_dtype.clone(),
170 Scalar::primitive(42i32, Nullability::NonNullable),
171 );
172 let scalar3 = Scalar::extension(
173 ext_dtype,
174 Scalar::primitive(43i32, Nullability::NonNullable),
175 );
176
177 let ext1 = ExtScalar::try_from(&scalar1).unwrap();
178 let ext2 = ExtScalar::try_from(&scalar2).unwrap();
179 let ext3 = ExtScalar::try_from(&scalar3).unwrap();
180
181 assert_eq!(ext1, ext2);
182 assert_ne!(ext1, ext3);
183 }
184
185 #[test]
186 fn test_ext_scalar_partial_ord() {
187 let ext_dtype = Arc::new(ExtDType::new(
188 ExtID::new("test_ext".into()),
189 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
190 None,
191 ));
192
193 let scalar1 = Scalar::extension(
194 ext_dtype.clone(),
195 Scalar::primitive(10i32, Nullability::NonNullable),
196 );
197 let scalar2 = Scalar::extension(
198 ext_dtype,
199 Scalar::primitive(20i32, Nullability::NonNullable),
200 );
201
202 let ext1 = ExtScalar::try_from(&scalar1).unwrap();
203 let ext2 = ExtScalar::try_from(&scalar2).unwrap();
204
205 assert!(ext1 < ext2);
206 assert!(ext2 > ext1);
207 }
208
209 #[test]
210 fn test_ext_scalar_partial_ord_different_types() {
211 let ext_dtype1 = Arc::new(ExtDType::new(
212 ExtID::new("type1".into()),
213 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
214 None,
215 ));
216 let ext_dtype2 = Arc::new(ExtDType::new(
217 ExtID::new("type2".into()),
218 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
219 None,
220 ));
221
222 let scalar1 = Scalar::extension(
223 ext_dtype1,
224 Scalar::primitive(10i32, Nullability::NonNullable),
225 );
226 let scalar2 = Scalar::extension(
227 ext_dtype2,
228 Scalar::primitive(20i32, Nullability::NonNullable),
229 );
230
231 let ext1 = ExtScalar::try_from(&scalar1).unwrap();
232 let ext2 = ExtScalar::try_from(&scalar2).unwrap();
233
234 assert_eq!(ext1.partial_cmp(&ext2), None);
236 }
237
238 #[test]
239 fn test_ext_scalar_hash() {
240 use vortex_utils::aliases::hash_set::HashSet;
241
242 let ext_dtype = Arc::new(ExtDType::new(
243 ExtID::new("test_ext".into()),
244 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
245 None,
246 ));
247
248 let scalar1 = Scalar::extension(
249 ext_dtype.clone(),
250 Scalar::primitive(42i32, Nullability::NonNullable),
251 );
252 let scalar2 = Scalar::extension(
253 ext_dtype,
254 Scalar::primitive(42i32, Nullability::NonNullable),
255 );
256
257 let mut set = HashSet::new();
258 set.insert(scalar2);
259 set.insert(scalar1);
260
261 assert_eq!(set.len(), 1);
263
264 let ext_dtype2 = Arc::new(ExtDType::new(
266 ExtID::new("test_ext".into()),
267 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
268 None,
269 ));
270 let scalar3 = Scalar::extension(
271 ext_dtype2,
272 Scalar::primitive(43i32, Nullability::NonNullable),
273 );
274 set.insert(scalar3);
275 assert_eq!(set.len(), 2);
276 }
277
278 #[test]
279 fn test_ext_scalar_storage() {
280 let ext_dtype = Arc::new(ExtDType::new(
281 ExtID::new("test_ext".into()),
282 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
283 None,
284 ));
285
286 let storage_scalar = Scalar::primitive(42i32, Nullability::NonNullable);
287 let ext_scalar = Scalar::extension(ext_dtype, storage_scalar.clone());
288
289 let ext = ExtScalar::try_from(&ext_scalar).unwrap();
290 assert_eq!(ext.storage(), storage_scalar);
291 }
292
293 #[test]
294 fn test_ext_scalar_ext_dtype() {
295 let ext_id = ExtID::new("test_ext".into());
296 let ext_dtype = Arc::new(ExtDType::new(
297 ext_id.clone(),
298 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
299 None,
300 ));
301
302 let scalar = Scalar::extension(
303 ext_dtype.clone(),
304 Scalar::primitive(42i32, Nullability::NonNullable),
305 );
306
307 let ext = ExtScalar::try_from(&scalar).unwrap();
308 assert_eq!(ext.ext_dtype().id(), &ext_id);
309 assert_eq!(ext.ext_dtype(), ext_dtype.as_ref());
310 }
311
312 #[test]
313 fn test_ext_scalar_cast_to_storage() {
314 let ext_dtype = Arc::new(ExtDType::new(
315 ExtID::new("test_ext".into()),
316 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
317 None,
318 ));
319
320 let scalar = Scalar::extension(
321 ext_dtype,
322 Scalar::primitive(42i32, Nullability::NonNullable),
323 );
324
325 let ext = ExtScalar::try_from(&scalar).unwrap();
326
327 let casted = ext
329 .cast(&DType::Primitive(PType::I32, Nullability::NonNullable))
330 .unwrap();
331 assert_eq!(
332 casted.dtype(),
333 &DType::Primitive(PType::I32, Nullability::NonNullable)
334 );
335 assert_eq!(casted.as_primitive().typed_value::<i32>(), Some(42));
336
337 let casted_nullable = ext
339 .cast(&DType::Primitive(PType::I32, Nullability::Nullable))
340 .unwrap();
341 assert_eq!(
342 casted_nullable.dtype(),
343 &DType::Primitive(PType::I32, Nullability::Nullable)
344 );
345 assert_eq!(
346 casted_nullable.as_primitive().typed_value::<i32>(),
347 Some(42)
348 );
349 }
350
351 #[test]
352 fn test_ext_scalar_cast_to_self() {
353 let ext_dtype = Arc::new(ExtDType::new(
354 ExtID::new("test_ext".into()),
355 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
356 None,
357 ));
358
359 let scalar = Scalar::extension(
360 ext_dtype.clone(),
361 Scalar::primitive(42i32, Nullability::NonNullable),
362 );
363
364 let ext = ExtScalar::try_from(&scalar).unwrap();
365
366 let casted = ext.cast(&DType::Extension(ext_dtype.clone())).unwrap();
368 assert_eq!(casted.dtype(), &DType::Extension(ext_dtype.clone()));
369
370 let nullable_ext = DType::Extension(ext_dtype).as_nullable();
372 let casted_nullable = ext.cast(&nullable_ext).unwrap();
373 assert_eq!(casted_nullable.dtype(), &nullable_ext);
374 }
375
376 #[test]
377 fn test_ext_scalar_cast_incompatible() {
378 let ext_dtype = Arc::new(ExtDType::new(
379 ExtID::new("test_ext".into()),
380 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
381 None,
382 ));
383
384 let scalar = Scalar::extension(
385 ext_dtype,
386 Scalar::primitive(42i32, Nullability::NonNullable),
387 );
388
389 let ext = ExtScalar::try_from(&scalar).unwrap();
390
391 let result = ext.cast(&DType::Utf8(Nullability::NonNullable));
393 assert!(result.is_err());
394 }
395
396 #[test]
397 fn test_ext_scalar_cast_null_to_non_nullable() {
398 let ext_dtype = Arc::new(ExtDType::new(
399 ExtID::new("test_ext".into()),
400 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
401 None,
402 ));
403
404 let scalar = Scalar::extension(
405 ext_dtype,
406 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
407 );
408
409 let ext = ExtScalar::try_from(&scalar).unwrap();
410
411 let result = ext.cast(&DType::Primitive(PType::I32, Nullability::NonNullable));
413 assert!(result.is_err());
414 }
415
416 #[test]
417 fn test_ext_scalar_try_new_non_extension() {
418 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
419 let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
420
421 let result = ExtScalar::try_new(&dtype, &value);
422 assert!(result.is_err());
423 }
424
425 #[test]
426 fn test_ext_scalar_with_metadata() {
427 let metadata = ExtMetadata::new(vec![1u8, 2, 3].into());
428 let ext_dtype = Arc::new(ExtDType::new(
429 ExtID::new("test_ext_with_meta".into()),
430 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
431 Some(metadata),
432 ));
433
434 let scalar = Scalar::extension(
435 ext_dtype.clone(),
436 Scalar::primitive(42i32, Nullability::NonNullable),
437 );
438
439 let ext = ExtScalar::try_from(&scalar).unwrap();
440 assert_eq!(ext.ext_dtype(), ext_dtype.as_ref());
441 assert!(ext.ext_dtype().metadata().is_some());
442 }
443
444 #[test]
445 fn test_ext_scalar_equality_ignores_nullability() {
446 let ext_dtype_non_null = Arc::new(ExtDType::new(
447 ExtID::new("test_ext".into()),
448 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
449 None,
450 ));
451 let ext_dtype_nullable = Arc::new(ExtDType::new(
452 ExtID::new("test_ext".into()),
453 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
454 None,
455 ));
456
457 let scalar1 = Scalar::extension(
458 ext_dtype_non_null,
459 Scalar::primitive(42i32, Nullability::NonNullable),
460 );
461 let scalar2 = Scalar::extension(
462 ext_dtype_nullable,
463 Scalar::primitive(42i32, Nullability::Nullable),
464 );
465
466 let ext1 = ExtScalar::try_from(&scalar1).unwrap();
467 let ext2 = ExtScalar::try_from(&scalar2).unwrap();
468
469 assert_eq!(ext1, ext2);
471 }
472}