vortex_scalar/
extension.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_dtype::datetime::{TemporalMetadata, is_temporal_ext_type};
9use vortex_dtype::{DType, ExtDType};
10use vortex_error::{VortexError, VortexResult, vortex_bail};
11
12use crate::{Scalar, ScalarValue};
13
14/// A scalar value representing an extension type.
15///
16/// Extension types allow wrapping a storage type with custom semantics.
17#[derive(Debug)]
18pub struct ExtScalar<'a> {
19    ext_dtype: &'a ExtDType,
20    value: &'a ScalarValue,
21}
22
23impl Display for ExtScalar<'_> {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        // Specialized handling for date/time/timestamp builtin extension types.
26        if is_temporal_ext_type(self.ext_dtype().id()) {
27            let metadata =
28                TemporalMetadata::try_from(self.ext_dtype()).map_err(|_| std::fmt::Error)?;
29
30            let maybe_timestamp = self
31                .storage()
32                .as_primitive()
33                .as_::<i64>()
34                .map(|maybe_timestamp| metadata.to_jiff(maybe_timestamp))
35                .transpose()
36                .map_err(|_| std::fmt::Error)?;
37
38            match maybe_timestamp {
39                None => write!(f, "null"),
40                Some(v) => write!(f, "{v}"),
41            }
42        } else {
43            write!(f, "{}({})", self.ext_dtype().id(), self.storage())
44        }
45    }
46}
47
48impl PartialEq for ExtScalar<'_> {
49    fn eq(&self, other: &Self) -> bool {
50        self.ext_dtype.eq_ignore_nullability(other.ext_dtype) && self.storage() == other.storage()
51    }
52}
53
54impl Eq for ExtScalar<'_> {}
55
56// Ord is not implemented since it's undefined for different Extension DTypes
57impl PartialOrd for ExtScalar<'_> {
58    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
59        if !self.ext_dtype.eq_ignore_nullability(other.ext_dtype) {
60            return None;
61        }
62        self.storage().partial_cmp(&other.storage())
63    }
64}
65
66impl Hash for ExtScalar<'_> {
67    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
68        self.ext_dtype.hash(state);
69        self.storage().hash(state);
70    }
71}
72
73impl<'a> ExtScalar<'a> {
74    /// Creates a new extension scalar from a data type and scalar value.
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if the data type is not an extension type.
79    pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
80        let DType::Extension(ext_dtype) = dtype else {
81            vortex_bail!("Expected extension scalar, found {}", dtype)
82        };
83
84        Ok(Self { ext_dtype, value })
85    }
86
87    /// Returns the storage scalar of the extension scalar.
88    pub fn storage(&self) -> Scalar {
89        Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone())
90    }
91
92    /// Returns the extension data type.
93    pub fn ext_dtype(&self) -> &'a ExtDType {
94        self.ext_dtype
95    }
96
97    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
98        if self.value.is_null() && !dtype.is_nullable() {
99            vortex_bail!(
100                "cannot cast extension dtype with id {} and storage type {} to {}",
101                self.ext_dtype.id(),
102                self.ext_dtype.storage_dtype(),
103                dtype
104            );
105        }
106
107        if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) {
108            // Casting from an extension type to the underlying storage type is OK.
109            return Ok(Scalar::new(dtype.clone(), self.value.clone()));
110        }
111
112        if let DType::Extension(ext_dtype) = dtype
113            && self.ext_dtype.eq_ignore_nullability(ext_dtype)
114        {
115            return Ok(Scalar::new(dtype.clone(), self.value.clone()));
116        }
117
118        vortex_bail!(
119            "cannot cast extension dtype with id {} and storage type {} to {}",
120            self.ext_dtype.id(),
121            self.ext_dtype.storage_dtype(),
122            dtype
123        );
124    }
125}
126
127impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> {
128    type Error = VortexError;
129
130    fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
131        ExtScalar::try_new(scalar.dtype(), scalar.value())
132    }
133}
134
135impl Scalar {
136    /// Creates a new extension scalar wrapping the given storage value.
137    pub fn extension(ext_dtype: Arc<ExtDType>, value: Scalar) -> Self {
138        // TODO(joe): enable once we use rust duckdb
139        // assert_eq!(ext_dtype.storage_dtype(), value.dtype());
140        Self::new(DType::Extension(ext_dtype), value.value().clone())
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use std::sync::Arc;
147
148    use vortex_dtype::{DType, ExtDType, ExtID, ExtMetadata, Nullability, PType};
149
150    use crate::{ExtScalar, InnerScalarValue, Scalar, ScalarValue};
151
152    #[test]
153    fn test_ext_scalar_equality() {
154        let ext_dtype = Arc::new(ExtDType::new(
155            ExtID::new("test_ext".into()),
156            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
157            None,
158        ));
159
160        let scalar1 = Scalar::extension(
161            ext_dtype.clone(),
162            Scalar::primitive(42i32, Nullability::NonNullable),
163        );
164        let scalar2 = Scalar::extension(
165            ext_dtype.clone(),
166            Scalar::primitive(42i32, Nullability::NonNullable),
167        );
168        let scalar3 = Scalar::extension(
169            ext_dtype,
170            Scalar::primitive(43i32, Nullability::NonNullable),
171        );
172
173        let ext1 = ExtScalar::try_from(&scalar1).unwrap();
174        let ext2 = ExtScalar::try_from(&scalar2).unwrap();
175        let ext3 = ExtScalar::try_from(&scalar3).unwrap();
176
177        assert_eq!(ext1, ext2);
178        assert_ne!(ext1, ext3);
179    }
180
181    #[test]
182    fn test_ext_scalar_partial_ord() {
183        let ext_dtype = Arc::new(ExtDType::new(
184            ExtID::new("test_ext".into()),
185            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
186            None,
187        ));
188
189        let scalar1 = Scalar::extension(
190            ext_dtype.clone(),
191            Scalar::primitive(10i32, Nullability::NonNullable),
192        );
193        let scalar2 = Scalar::extension(
194            ext_dtype,
195            Scalar::primitive(20i32, Nullability::NonNullable),
196        );
197
198        let ext1 = ExtScalar::try_from(&scalar1).unwrap();
199        let ext2 = ExtScalar::try_from(&scalar2).unwrap();
200
201        assert!(ext1 < ext2);
202        assert!(ext2 > ext1);
203    }
204
205    #[test]
206    fn test_ext_scalar_partial_ord_different_types() {
207        let ext_dtype1 = Arc::new(ExtDType::new(
208            ExtID::new("type1".into()),
209            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
210            None,
211        ));
212        let ext_dtype2 = Arc::new(ExtDType::new(
213            ExtID::new("type2".into()),
214            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
215            None,
216        ));
217
218        let scalar1 = Scalar::extension(
219            ext_dtype1,
220            Scalar::primitive(10i32, Nullability::NonNullable),
221        );
222        let scalar2 = Scalar::extension(
223            ext_dtype2,
224            Scalar::primitive(20i32, Nullability::NonNullable),
225        );
226
227        let ext1 = ExtScalar::try_from(&scalar1).unwrap();
228        let ext2 = ExtScalar::try_from(&scalar2).unwrap();
229
230        // Different extension types should not be comparable
231        assert_eq!(ext1.partial_cmp(&ext2), None);
232    }
233
234    #[test]
235    fn test_ext_scalar_hash() {
236        use vortex_utils::aliases::hash_set::HashSet;
237
238        let ext_dtype = Arc::new(ExtDType::new(
239            ExtID::new("test_ext".into()),
240            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
241            None,
242        ));
243
244        let scalar1 = Scalar::extension(
245            ext_dtype.clone(),
246            Scalar::primitive(42i32, Nullability::NonNullable),
247        );
248        let scalar2 = Scalar::extension(
249            ext_dtype,
250            Scalar::primitive(42i32, Nullability::NonNullable),
251        );
252
253        let mut set = HashSet::new();
254        set.insert(scalar2);
255        set.insert(scalar1);
256
257        // Same value should hash the same
258        assert_eq!(set.len(), 1);
259
260        // Different value should hash differently
261        let ext_dtype2 = Arc::new(ExtDType::new(
262            ExtID::new("test_ext".into()),
263            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
264            None,
265        ));
266        let scalar3 = Scalar::extension(
267            ext_dtype2,
268            Scalar::primitive(43i32, Nullability::NonNullable),
269        );
270        set.insert(scalar3);
271        assert_eq!(set.len(), 2);
272    }
273
274    #[test]
275    fn test_ext_scalar_storage() {
276        let ext_dtype = Arc::new(ExtDType::new(
277            ExtID::new("test_ext".into()),
278            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
279            None,
280        ));
281
282        let storage_scalar = Scalar::primitive(42i32, Nullability::NonNullable);
283        let ext_scalar = Scalar::extension(ext_dtype, storage_scalar.clone());
284
285        let ext = ExtScalar::try_from(&ext_scalar).unwrap();
286        assert_eq!(ext.storage(), storage_scalar);
287    }
288
289    #[test]
290    fn test_ext_scalar_ext_dtype() {
291        let ext_id = ExtID::new("test_ext".into());
292        let ext_dtype = Arc::new(ExtDType::new(
293            ext_id.clone(),
294            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
295            None,
296        ));
297
298        let scalar = Scalar::extension(
299            ext_dtype.clone(),
300            Scalar::primitive(42i32, Nullability::NonNullable),
301        );
302
303        let ext = ExtScalar::try_from(&scalar).unwrap();
304        assert_eq!(ext.ext_dtype().id(), &ext_id);
305        assert_eq!(ext.ext_dtype(), ext_dtype.as_ref());
306    }
307
308    #[test]
309    fn test_ext_scalar_cast_to_storage() {
310        let ext_dtype = Arc::new(ExtDType::new(
311            ExtID::new("test_ext".into()),
312            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
313            None,
314        ));
315
316        let scalar = Scalar::extension(
317            ext_dtype,
318            Scalar::primitive(42i32, Nullability::NonNullable),
319        );
320
321        let ext = ExtScalar::try_from(&scalar).unwrap();
322
323        // Cast to storage type
324        let casted = ext
325            .cast(&DType::Primitive(PType::I32, Nullability::NonNullable))
326            .unwrap();
327        assert_eq!(
328            casted.dtype(),
329            &DType::Primitive(PType::I32, Nullability::NonNullable)
330        );
331        assert_eq!(casted.as_primitive().typed_value::<i32>(), Some(42));
332
333        // Cast to nullable storage type
334        let casted_nullable = ext
335            .cast(&DType::Primitive(PType::I32, Nullability::Nullable))
336            .unwrap();
337        assert_eq!(
338            casted_nullable.dtype(),
339            &DType::Primitive(PType::I32, Nullability::Nullable)
340        );
341        assert_eq!(
342            casted_nullable.as_primitive().typed_value::<i32>(),
343            Some(42)
344        );
345    }
346
347    #[test]
348    fn test_ext_scalar_cast_to_self() {
349        let ext_dtype = Arc::new(ExtDType::new(
350            ExtID::new("test_ext".into()),
351            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
352            None,
353        ));
354
355        let scalar = Scalar::extension(
356            ext_dtype.clone(),
357            Scalar::primitive(42i32, Nullability::NonNullable),
358        );
359
360        let ext = ExtScalar::try_from(&scalar).unwrap();
361
362        // Cast to same extension type
363        let casted = ext.cast(&DType::Extension(ext_dtype.clone())).unwrap();
364        assert_eq!(casted.dtype(), &DType::Extension(ext_dtype.clone()));
365
366        // Cast to nullable version of same extension type
367        let nullable_ext = DType::Extension(ext_dtype).as_nullable();
368        let casted_nullable = ext.cast(&nullable_ext).unwrap();
369        assert_eq!(casted_nullable.dtype(), &nullable_ext);
370    }
371
372    #[test]
373    fn test_ext_scalar_cast_incompatible() {
374        let ext_dtype = Arc::new(ExtDType::new(
375            ExtID::new("test_ext".into()),
376            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
377            None,
378        ));
379
380        let scalar = Scalar::extension(
381            ext_dtype,
382            Scalar::primitive(42i32, Nullability::NonNullable),
383        );
384
385        let ext = ExtScalar::try_from(&scalar).unwrap();
386
387        // Cast to incompatible type should fail
388        let result = ext.cast(&DType::Utf8(Nullability::NonNullable));
389        assert!(result.is_err());
390    }
391
392    #[test]
393    fn test_ext_scalar_cast_null_to_non_nullable() {
394        let ext_dtype = Arc::new(ExtDType::new(
395            ExtID::new("test_ext".into()),
396            Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
397            None,
398        ));
399
400        let scalar = Scalar::extension(
401            ext_dtype,
402            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
403        );
404
405        let ext = ExtScalar::try_from(&scalar).unwrap();
406
407        // Cast null to non-nullable should fail
408        let result = ext.cast(&DType::Primitive(PType::I32, Nullability::NonNullable));
409        assert!(result.is_err());
410    }
411
412    #[test]
413    fn test_ext_scalar_try_new_non_extension() {
414        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
415        let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
416
417        let result = ExtScalar::try_new(&dtype, &value);
418        assert!(result.is_err());
419    }
420
421    #[test]
422    fn test_ext_scalar_with_metadata() {
423        let metadata = ExtMetadata::new(vec![1u8, 2, 3].into());
424        let ext_dtype = Arc::new(ExtDType::new(
425            ExtID::new("test_ext_with_meta".into()),
426            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
427            Some(metadata),
428        ));
429
430        let scalar = Scalar::extension(
431            ext_dtype.clone(),
432            Scalar::primitive(42i32, Nullability::NonNullable),
433        );
434
435        let ext = ExtScalar::try_from(&scalar).unwrap();
436        assert_eq!(ext.ext_dtype(), ext_dtype.as_ref());
437        assert!(ext.ext_dtype().metadata().is_some());
438    }
439
440    #[test]
441    fn test_ext_scalar_equality_ignores_nullability() {
442        let ext_dtype_non_null = Arc::new(ExtDType::new(
443            ExtID::new("test_ext".into()),
444            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
445            None,
446        ));
447        let ext_dtype_nullable = Arc::new(ExtDType::new(
448            ExtID::new("test_ext".into()),
449            Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
450            None,
451        ));
452
453        let scalar1 = Scalar::extension(
454            ext_dtype_non_null,
455            Scalar::primitive(42i32, Nullability::NonNullable),
456        );
457        let scalar2 = Scalar::extension(
458            ext_dtype_nullable,
459            Scalar::primitive(42i32, Nullability::Nullable),
460        );
461
462        let ext1 = ExtScalar::try_from(&scalar1).unwrap();
463        let ext2 = ExtScalar::try_from(&scalar2).unwrap();
464
465        // Equality should ignore nullability differences
466        assert_eq!(ext1, ext2);
467    }
468}