1use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_dtype::datetime::{TemporalMetadata, is_temporal_ext_type};
9use vortex_dtype::{DType, ExtDType};
10use vortex_error::{VortexError, VortexResult, vortex_bail};
11
12use crate::{Scalar, ScalarValue};
13
14#[derive(Debug)]
18pub struct ExtScalar<'a> {
19 ext_dtype: &'a ExtDType,
20 value: &'a ScalarValue,
21}
22
23impl Display for ExtScalar<'_> {
24 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25 if is_temporal_ext_type(self.ext_dtype().id()) {
27 let metadata =
28 TemporalMetadata::try_from(self.ext_dtype()).map_err(|_| std::fmt::Error)?;
29
30 let maybe_timestamp = self
31 .storage()
32 .as_primitive()
33 .as_::<i64>()
34 .map(|maybe_timestamp| metadata.to_jiff(maybe_timestamp))
35 .transpose()
36 .map_err(|_| std::fmt::Error)?;
37
38 match maybe_timestamp {
39 None => write!(f, "null"),
40 Some(v) => write!(f, "{v}"),
41 }
42 } else {
43 write!(f, "{}({})", self.ext_dtype().id(), self.storage())
44 }
45 }
46}
47
48impl PartialEq for ExtScalar<'_> {
49 fn eq(&self, other: &Self) -> bool {
50 self.ext_dtype.eq_ignore_nullability(other.ext_dtype) && self.storage() == other.storage()
51 }
52}
53
54impl Eq for ExtScalar<'_> {}
55
56impl PartialOrd for ExtScalar<'_> {
58 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
59 if !self.ext_dtype.eq_ignore_nullability(other.ext_dtype) {
60 return None;
61 }
62 self.storage().partial_cmp(&other.storage())
63 }
64}
65
66impl Hash for ExtScalar<'_> {
67 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
68 self.ext_dtype.hash(state);
69 self.storage().hash(state);
70 }
71}
72
73impl<'a> ExtScalar<'a> {
74 pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
80 let DType::Extension(ext_dtype) = dtype else {
81 vortex_bail!("Expected extension scalar, found {}", dtype)
82 };
83
84 Ok(Self { ext_dtype, value })
85 }
86
87 pub fn storage(&self) -> Scalar {
89 Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone())
90 }
91
92 pub fn ext_dtype(&self) -> &'a ExtDType {
94 self.ext_dtype
95 }
96
97 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
98 if self.value.is_null() && !dtype.is_nullable() {
99 vortex_bail!(
100 "cannot cast extension dtype with id {} and storage type {} to {}",
101 self.ext_dtype.id(),
102 self.ext_dtype.storage_dtype(),
103 dtype
104 );
105 }
106
107 if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) {
108 return Ok(Scalar::new(dtype.clone(), self.value.clone()));
110 }
111
112 if let DType::Extension(ext_dtype) = dtype
113 && self.ext_dtype.eq_ignore_nullability(ext_dtype)
114 {
115 return Ok(Scalar::new(dtype.clone(), self.value.clone()));
116 }
117
118 vortex_bail!(
119 "cannot cast extension dtype with id {} and storage type {} to {}",
120 self.ext_dtype.id(),
121 self.ext_dtype.storage_dtype(),
122 dtype
123 );
124 }
125}
126
127impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> {
128 type Error = VortexError;
129
130 fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
131 ExtScalar::try_new(scalar.dtype(), scalar.value())
132 }
133}
134
135impl Scalar {
136 pub fn extension(ext_dtype: Arc<ExtDType>, value: Scalar) -> Self {
138 Self::new(DType::Extension(ext_dtype), value.value().clone())
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use std::sync::Arc;
147
148 use vortex_dtype::{DType, ExtDType, ExtID, ExtMetadata, Nullability, PType};
149
150 use crate::{ExtScalar, InnerScalarValue, Scalar, ScalarValue};
151
152 #[test]
153 fn test_ext_scalar_equality() {
154 let ext_dtype = Arc::new(ExtDType::new(
155 ExtID::new("test_ext".into()),
156 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
157 None,
158 ));
159
160 let scalar1 = Scalar::extension(
161 ext_dtype.clone(),
162 Scalar::primitive(42i32, Nullability::NonNullable),
163 );
164 let scalar2 = Scalar::extension(
165 ext_dtype.clone(),
166 Scalar::primitive(42i32, Nullability::NonNullable),
167 );
168 let scalar3 = Scalar::extension(
169 ext_dtype,
170 Scalar::primitive(43i32, Nullability::NonNullable),
171 );
172
173 let ext1 = ExtScalar::try_from(&scalar1).unwrap();
174 let ext2 = ExtScalar::try_from(&scalar2).unwrap();
175 let ext3 = ExtScalar::try_from(&scalar3).unwrap();
176
177 assert_eq!(ext1, ext2);
178 assert_ne!(ext1, ext3);
179 }
180
181 #[test]
182 fn test_ext_scalar_partial_ord() {
183 let ext_dtype = Arc::new(ExtDType::new(
184 ExtID::new("test_ext".into()),
185 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
186 None,
187 ));
188
189 let scalar1 = Scalar::extension(
190 ext_dtype.clone(),
191 Scalar::primitive(10i32, Nullability::NonNullable),
192 );
193 let scalar2 = Scalar::extension(
194 ext_dtype,
195 Scalar::primitive(20i32, Nullability::NonNullable),
196 );
197
198 let ext1 = ExtScalar::try_from(&scalar1).unwrap();
199 let ext2 = ExtScalar::try_from(&scalar2).unwrap();
200
201 assert!(ext1 < ext2);
202 assert!(ext2 > ext1);
203 }
204
205 #[test]
206 fn test_ext_scalar_partial_ord_different_types() {
207 let ext_dtype1 = Arc::new(ExtDType::new(
208 ExtID::new("type1".into()),
209 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
210 None,
211 ));
212 let ext_dtype2 = Arc::new(ExtDType::new(
213 ExtID::new("type2".into()),
214 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
215 None,
216 ));
217
218 let scalar1 = Scalar::extension(
219 ext_dtype1,
220 Scalar::primitive(10i32, Nullability::NonNullable),
221 );
222 let scalar2 = Scalar::extension(
223 ext_dtype2,
224 Scalar::primitive(20i32, Nullability::NonNullable),
225 );
226
227 let ext1 = ExtScalar::try_from(&scalar1).unwrap();
228 let ext2 = ExtScalar::try_from(&scalar2).unwrap();
229
230 assert_eq!(ext1.partial_cmp(&ext2), None);
232 }
233
234 #[test]
235 fn test_ext_scalar_hash() {
236 use vortex_utils::aliases::hash_set::HashSet;
237
238 let ext_dtype = Arc::new(ExtDType::new(
239 ExtID::new("test_ext".into()),
240 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
241 None,
242 ));
243
244 let scalar1 = Scalar::extension(
245 ext_dtype.clone(),
246 Scalar::primitive(42i32, Nullability::NonNullable),
247 );
248 let scalar2 = Scalar::extension(
249 ext_dtype,
250 Scalar::primitive(42i32, Nullability::NonNullable),
251 );
252
253 let mut set = HashSet::new();
254 set.insert(scalar2);
255 set.insert(scalar1);
256
257 assert_eq!(set.len(), 1);
259
260 let ext_dtype2 = Arc::new(ExtDType::new(
262 ExtID::new("test_ext".into()),
263 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
264 None,
265 ));
266 let scalar3 = Scalar::extension(
267 ext_dtype2,
268 Scalar::primitive(43i32, Nullability::NonNullable),
269 );
270 set.insert(scalar3);
271 assert_eq!(set.len(), 2);
272 }
273
274 #[test]
275 fn test_ext_scalar_storage() {
276 let ext_dtype = Arc::new(ExtDType::new(
277 ExtID::new("test_ext".into()),
278 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
279 None,
280 ));
281
282 let storage_scalar = Scalar::primitive(42i32, Nullability::NonNullable);
283 let ext_scalar = Scalar::extension(ext_dtype, storage_scalar.clone());
284
285 let ext = ExtScalar::try_from(&ext_scalar).unwrap();
286 assert_eq!(ext.storage(), storage_scalar);
287 }
288
289 #[test]
290 fn test_ext_scalar_ext_dtype() {
291 let ext_id = ExtID::new("test_ext".into());
292 let ext_dtype = Arc::new(ExtDType::new(
293 ext_id.clone(),
294 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
295 None,
296 ));
297
298 let scalar = Scalar::extension(
299 ext_dtype.clone(),
300 Scalar::primitive(42i32, Nullability::NonNullable),
301 );
302
303 let ext = ExtScalar::try_from(&scalar).unwrap();
304 assert_eq!(ext.ext_dtype().id(), &ext_id);
305 assert_eq!(ext.ext_dtype(), ext_dtype.as_ref());
306 }
307
308 #[test]
309 fn test_ext_scalar_cast_to_storage() {
310 let ext_dtype = Arc::new(ExtDType::new(
311 ExtID::new("test_ext".into()),
312 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
313 None,
314 ));
315
316 let scalar = Scalar::extension(
317 ext_dtype,
318 Scalar::primitive(42i32, Nullability::NonNullable),
319 );
320
321 let ext = ExtScalar::try_from(&scalar).unwrap();
322
323 let casted = ext
325 .cast(&DType::Primitive(PType::I32, Nullability::NonNullable))
326 .unwrap();
327 assert_eq!(
328 casted.dtype(),
329 &DType::Primitive(PType::I32, Nullability::NonNullable)
330 );
331 assert_eq!(casted.as_primitive().typed_value::<i32>(), Some(42));
332
333 let casted_nullable = ext
335 .cast(&DType::Primitive(PType::I32, Nullability::Nullable))
336 .unwrap();
337 assert_eq!(
338 casted_nullable.dtype(),
339 &DType::Primitive(PType::I32, Nullability::Nullable)
340 );
341 assert_eq!(
342 casted_nullable.as_primitive().typed_value::<i32>(),
343 Some(42)
344 );
345 }
346
347 #[test]
348 fn test_ext_scalar_cast_to_self() {
349 let ext_dtype = Arc::new(ExtDType::new(
350 ExtID::new("test_ext".into()),
351 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
352 None,
353 ));
354
355 let scalar = Scalar::extension(
356 ext_dtype.clone(),
357 Scalar::primitive(42i32, Nullability::NonNullable),
358 );
359
360 let ext = ExtScalar::try_from(&scalar).unwrap();
361
362 let casted = ext.cast(&DType::Extension(ext_dtype.clone())).unwrap();
364 assert_eq!(casted.dtype(), &DType::Extension(ext_dtype.clone()));
365
366 let nullable_ext = DType::Extension(ext_dtype).as_nullable();
368 let casted_nullable = ext.cast(&nullable_ext).unwrap();
369 assert_eq!(casted_nullable.dtype(), &nullable_ext);
370 }
371
372 #[test]
373 fn test_ext_scalar_cast_incompatible() {
374 let ext_dtype = Arc::new(ExtDType::new(
375 ExtID::new("test_ext".into()),
376 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
377 None,
378 ));
379
380 let scalar = Scalar::extension(
381 ext_dtype,
382 Scalar::primitive(42i32, Nullability::NonNullable),
383 );
384
385 let ext = ExtScalar::try_from(&scalar).unwrap();
386
387 let result = ext.cast(&DType::Utf8(Nullability::NonNullable));
389 assert!(result.is_err());
390 }
391
392 #[test]
393 fn test_ext_scalar_cast_null_to_non_nullable() {
394 let ext_dtype = Arc::new(ExtDType::new(
395 ExtID::new("test_ext".into()),
396 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
397 None,
398 ));
399
400 let scalar = Scalar::extension(
401 ext_dtype,
402 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
403 );
404
405 let ext = ExtScalar::try_from(&scalar).unwrap();
406
407 let result = ext.cast(&DType::Primitive(PType::I32, Nullability::NonNullable));
409 assert!(result.is_err());
410 }
411
412 #[test]
413 fn test_ext_scalar_try_new_non_extension() {
414 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
415 let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
416
417 let result = ExtScalar::try_new(&dtype, &value);
418 assert!(result.is_err());
419 }
420
421 #[test]
422 fn test_ext_scalar_with_metadata() {
423 let metadata = ExtMetadata::new(vec![1u8, 2, 3].into());
424 let ext_dtype = Arc::new(ExtDType::new(
425 ExtID::new("test_ext_with_meta".into()),
426 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
427 Some(metadata),
428 ));
429
430 let scalar = Scalar::extension(
431 ext_dtype.clone(),
432 Scalar::primitive(42i32, Nullability::NonNullable),
433 );
434
435 let ext = ExtScalar::try_from(&scalar).unwrap();
436 assert_eq!(ext.ext_dtype(), ext_dtype.as_ref());
437 assert!(ext.ext_dtype().metadata().is_some());
438 }
439
440 #[test]
441 fn test_ext_scalar_equality_ignores_nullability() {
442 let ext_dtype_non_null = Arc::new(ExtDType::new(
443 ExtID::new("test_ext".into()),
444 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
445 None,
446 ));
447 let ext_dtype_nullable = Arc::new(ExtDType::new(
448 ExtID::new("test_ext".into()),
449 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
450 None,
451 ));
452
453 let scalar1 = Scalar::extension(
454 ext_dtype_non_null,
455 Scalar::primitive(42i32, Nullability::NonNullable),
456 );
457 let scalar2 = Scalar::extension(
458 ext_dtype_nullable,
459 Scalar::primitive(42i32, Nullability::Nullable),
460 );
461
462 let ext1 = ExtScalar::try_from(&scalar1).unwrap();
463 let ext2 = ExtScalar::try_from(&scalar2).unwrap();
464
465 assert_eq!(ext1, ext2);
467 }
468}