vortex_scalar/
extension.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_dtype::datetime::{TemporalMetadata, is_temporal_ext_type};
9use vortex_dtype::{DType, ExtDType};
10use vortex_error::{VortexError, VortexResult, vortex_bail};
11
12use crate::{Scalar, ScalarValue};
13
14/// A scalar value representing an extension type.
15///
16/// Extension types allow wrapping a storage type with custom semantics.
17#[derive(Debug)]
18pub struct ExtScalar<'a> {
19    ext_dtype: &'a ExtDType,
20    value: &'a ScalarValue,
21}
22
23impl Display for ExtScalar<'_> {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        // Specialized handling for date/time/timestamp builtin extension types.
26        if is_temporal_ext_type(self.ext_dtype().id()) {
27            let metadata =
28                TemporalMetadata::try_from(self.ext_dtype()).map_err(|_| std::fmt::Error)?;
29
30            let maybe_timestamp = self
31                .storage()
32                .as_primitive()
33                .as_::<i64>()
34                .and_then(|maybe_timestamp| {
35                    maybe_timestamp.map(|v| metadata.to_jiff(v)).transpose()
36                })
37                .map_err(|_| std::fmt::Error)?;
38
39            match maybe_timestamp {
40                None => write!(f, "null"),
41                Some(v) => write!(f, "{v}"),
42            }
43        } else {
44            write!(f, "{}({})", self.ext_dtype().id(), self.storage())
45        }
46    }
47}
48
49impl PartialEq for ExtScalar<'_> {
50    fn eq(&self, other: &Self) -> bool {
51        self.ext_dtype.eq_ignore_nullability(other.ext_dtype) && self.storage() == other.storage()
52    }
53}
54
55impl Eq for ExtScalar<'_> {}
56
57// Ord is not implemented since it's undefined for different Extension DTypes
58impl PartialOrd for ExtScalar<'_> {
59    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
60        if !self.ext_dtype.eq_ignore_nullability(other.ext_dtype) {
61            return None;
62        }
63        self.storage().partial_cmp(&other.storage())
64    }
65}
66
67impl Hash for ExtScalar<'_> {
68    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
69        self.ext_dtype.hash(state);
70        self.storage().hash(state);
71    }
72}
73
74impl<'a> ExtScalar<'a> {
75    /// Creates a new extension scalar from a data type and scalar value.
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if the data type is not an extension type.
80    pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
81        let DType::Extension(ext_dtype) = dtype else {
82            vortex_bail!("Expected extension scalar, found {}", dtype)
83        };
84
85        Ok(Self { ext_dtype, value })
86    }
87
88    /// Returns the storage scalar of the extension scalar.
89    pub fn storage(&self) -> Scalar {
90        Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone())
91    }
92
93    /// Returns the extension data type.
94    pub fn ext_dtype(&self) -> &'a ExtDType {
95        self.ext_dtype
96    }
97
98    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
99        if self.value.is_null() && !dtype.is_nullable() {
100            vortex_bail!(
101                "cannot cast extension dtype with id {} and storage type {} to {}",
102                self.ext_dtype.id(),
103                self.ext_dtype.storage_dtype(),
104                dtype
105            );
106        }
107
108        if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) {
109            // Casting from an extension type to the underlying storage type is OK.
110            return Ok(Scalar::new(dtype.clone(), self.value.clone()));
111        }
112
113        if let DType::Extension(ext_dtype) = dtype
114            && self.ext_dtype.eq_ignore_nullability(ext_dtype)
115        {
116            return Ok(Scalar::new(dtype.clone(), self.value.clone()));
117        }
118
119        vortex_bail!(
120            "cannot cast extension dtype with id {} and storage type {} to {}",
121            self.ext_dtype.id(),
122            self.ext_dtype.storage_dtype(),
123            dtype
124        );
125    }
126}
127
128impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> {
129    type Error = VortexError;
130
131    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
132        ExtScalar::try_new(value.dtype(), &value.value)
133    }
134}
135
136impl Scalar {
137    /// Creates a new extension scalar wrapping the given storage value.
138    pub fn extension(ext_dtype: Arc<ExtDType>, value: Scalar) -> Self {
139        // TODO(joe): enable once we use rust duckdb
140        // assert_eq!(ext_dtype.storage_dtype(), value.dtype());
141        Self {
142            dtype: DType::Extension(ext_dtype),
143            value: value.value().clone(),
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::sync::Arc;
151
152    use vortex_dtype::{DType, ExtDType, ExtID, ExtMetadata, Nullability, PType};
153
154    use crate::{ExtScalar, InnerScalarValue, Scalar, ScalarValue};
155
156    #[test]
157    fn test_ext_scalar_equality() {
158        let ext_dtype = Arc::new(ExtDType::new(
159            ExtID::new("test_ext".into()),
160            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
161            None,
162        ));
163
164        let scalar1 = Scalar::extension(
165            ext_dtype.clone(),
166            Scalar::primitive(42i32, Nullability::NonNullable),
167        );
168        let scalar2 = Scalar::extension(
169            ext_dtype.clone(),
170            Scalar::primitive(42i32, Nullability::NonNullable),
171        );
172        let scalar3 = Scalar::extension(
173            ext_dtype,
174            Scalar::primitive(43i32, Nullability::NonNullable),
175        );
176
177        let ext1 = ExtScalar::try_from(&scalar1).unwrap();
178        let ext2 = ExtScalar::try_from(&scalar2).unwrap();
179        let ext3 = ExtScalar::try_from(&scalar3).unwrap();
180
181        assert_eq!(ext1, ext2);
182        assert_ne!(ext1, ext3);
183    }
184
185    #[test]
186    fn test_ext_scalar_partial_ord() {
187        let ext_dtype = Arc::new(ExtDType::new(
188            ExtID::new("test_ext".into()),
189            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
190            None,
191        ));
192
193        let scalar1 = Scalar::extension(
194            ext_dtype.clone(),
195            Scalar::primitive(10i32, Nullability::NonNullable),
196        );
197        let scalar2 = Scalar::extension(
198            ext_dtype,
199            Scalar::primitive(20i32, Nullability::NonNullable),
200        );
201
202        let ext1 = ExtScalar::try_from(&scalar1).unwrap();
203        let ext2 = ExtScalar::try_from(&scalar2).unwrap();
204
205        assert!(ext1 < ext2);
206        assert!(ext2 > ext1);
207    }
208
209    #[test]
210    fn test_ext_scalar_partial_ord_different_types() {
211        let ext_dtype1 = Arc::new(ExtDType::new(
212            ExtID::new("type1".into()),
213            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
214            None,
215        ));
216        let ext_dtype2 = Arc::new(ExtDType::new(
217            ExtID::new("type2".into()),
218            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
219            None,
220        ));
221
222        let scalar1 = Scalar::extension(
223            ext_dtype1,
224            Scalar::primitive(10i32, Nullability::NonNullable),
225        );
226        let scalar2 = Scalar::extension(
227            ext_dtype2,
228            Scalar::primitive(20i32, Nullability::NonNullable),
229        );
230
231        let ext1 = ExtScalar::try_from(&scalar1).unwrap();
232        let ext2 = ExtScalar::try_from(&scalar2).unwrap();
233
234        // Different extension types should not be comparable
235        assert_eq!(ext1.partial_cmp(&ext2), None);
236    }
237
238    #[test]
239    fn test_ext_scalar_hash() {
240        use vortex_utils::aliases::hash_set::HashSet;
241
242        let ext_dtype = Arc::new(ExtDType::new(
243            ExtID::new("test_ext".into()),
244            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
245            None,
246        ));
247
248        let scalar1 = Scalar::extension(
249            ext_dtype.clone(),
250            Scalar::primitive(42i32, Nullability::NonNullable),
251        );
252        let scalar2 = Scalar::extension(
253            ext_dtype,
254            Scalar::primitive(42i32, Nullability::NonNullable),
255        );
256
257        let mut set = HashSet::new();
258        set.insert(scalar2);
259        set.insert(scalar1);
260
261        // Same value should hash the same
262        assert_eq!(set.len(), 1);
263
264        // Different value should hash differently
265        let ext_dtype2 = Arc::new(ExtDType::new(
266            ExtID::new("test_ext".into()),
267            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
268            None,
269        ));
270        let scalar3 = Scalar::extension(
271            ext_dtype2,
272            Scalar::primitive(43i32, Nullability::NonNullable),
273        );
274        set.insert(scalar3);
275        assert_eq!(set.len(), 2);
276    }
277
278    #[test]
279    fn test_ext_scalar_storage() {
280        let ext_dtype = Arc::new(ExtDType::new(
281            ExtID::new("test_ext".into()),
282            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
283            None,
284        ));
285
286        let storage_scalar = Scalar::primitive(42i32, Nullability::NonNullable);
287        let ext_scalar = Scalar::extension(ext_dtype, storage_scalar.clone());
288
289        let ext = ExtScalar::try_from(&ext_scalar).unwrap();
290        assert_eq!(ext.storage(), storage_scalar);
291    }
292
293    #[test]
294    fn test_ext_scalar_ext_dtype() {
295        let ext_id = ExtID::new("test_ext".into());
296        let ext_dtype = Arc::new(ExtDType::new(
297            ext_id.clone(),
298            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
299            None,
300        ));
301
302        let scalar = Scalar::extension(
303            ext_dtype.clone(),
304            Scalar::primitive(42i32, Nullability::NonNullable),
305        );
306
307        let ext = ExtScalar::try_from(&scalar).unwrap();
308        assert_eq!(ext.ext_dtype().id(), &ext_id);
309        assert_eq!(ext.ext_dtype(), ext_dtype.as_ref());
310    }
311
312    #[test]
313    fn test_ext_scalar_cast_to_storage() {
314        let ext_dtype = Arc::new(ExtDType::new(
315            ExtID::new("test_ext".into()),
316            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
317            None,
318        ));
319
320        let scalar = Scalar::extension(
321            ext_dtype,
322            Scalar::primitive(42i32, Nullability::NonNullable),
323        );
324
325        let ext = ExtScalar::try_from(&scalar).unwrap();
326
327        // Cast to storage type
328        let casted = ext
329            .cast(&DType::Primitive(PType::I32, Nullability::NonNullable))
330            .unwrap();
331        assert_eq!(
332            casted.dtype(),
333            &DType::Primitive(PType::I32, Nullability::NonNullable)
334        );
335        assert_eq!(casted.as_primitive().typed_value::<i32>(), Some(42));
336
337        // Cast to nullable storage type
338        let casted_nullable = ext
339            .cast(&DType::Primitive(PType::I32, Nullability::Nullable))
340            .unwrap();
341        assert_eq!(
342            casted_nullable.dtype(),
343            &DType::Primitive(PType::I32, Nullability::Nullable)
344        );
345        assert_eq!(
346            casted_nullable.as_primitive().typed_value::<i32>(),
347            Some(42)
348        );
349    }
350
351    #[test]
352    fn test_ext_scalar_cast_to_self() {
353        let ext_dtype = Arc::new(ExtDType::new(
354            ExtID::new("test_ext".into()),
355            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
356            None,
357        ));
358
359        let scalar = Scalar::extension(
360            ext_dtype.clone(),
361            Scalar::primitive(42i32, Nullability::NonNullable),
362        );
363
364        let ext = ExtScalar::try_from(&scalar).unwrap();
365
366        // Cast to same extension type
367        let casted = ext.cast(&DType::Extension(ext_dtype.clone())).unwrap();
368        assert_eq!(casted.dtype(), &DType::Extension(ext_dtype.clone()));
369
370        // Cast to nullable version of same extension type
371        let nullable_ext = DType::Extension(ext_dtype).as_nullable();
372        let casted_nullable = ext.cast(&nullable_ext).unwrap();
373        assert_eq!(casted_nullable.dtype(), &nullable_ext);
374    }
375
376    #[test]
377    fn test_ext_scalar_cast_incompatible() {
378        let ext_dtype = Arc::new(ExtDType::new(
379            ExtID::new("test_ext".into()),
380            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
381            None,
382        ));
383
384        let scalar = Scalar::extension(
385            ext_dtype,
386            Scalar::primitive(42i32, Nullability::NonNullable),
387        );
388
389        let ext = ExtScalar::try_from(&scalar).unwrap();
390
391        // Cast to incompatible type should fail
392        let result = ext.cast(&DType::Utf8(Nullability::NonNullable));
393        assert!(result.is_err());
394    }
395
396    #[test]
397    fn test_ext_scalar_cast_null_to_non_nullable() {
398        let ext_dtype = Arc::new(ExtDType::new(
399            ExtID::new("test_ext".into()),
400            Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
401            None,
402        ));
403
404        let scalar = Scalar::extension(
405            ext_dtype,
406            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
407        );
408
409        let ext = ExtScalar::try_from(&scalar).unwrap();
410
411        // Cast null to non-nullable should fail
412        let result = ext.cast(&DType::Primitive(PType::I32, Nullability::NonNullable));
413        assert!(result.is_err());
414    }
415
416    #[test]
417    fn test_ext_scalar_try_new_non_extension() {
418        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
419        let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42)));
420
421        let result = ExtScalar::try_new(&dtype, &value);
422        assert!(result.is_err());
423    }
424
425    #[test]
426    fn test_ext_scalar_with_metadata() {
427        let metadata = ExtMetadata::new(vec![1u8, 2, 3].into());
428        let ext_dtype = Arc::new(ExtDType::new(
429            ExtID::new("test_ext_with_meta".into()),
430            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
431            Some(metadata),
432        ));
433
434        let scalar = Scalar::extension(
435            ext_dtype.clone(),
436            Scalar::primitive(42i32, Nullability::NonNullable),
437        );
438
439        let ext = ExtScalar::try_from(&scalar).unwrap();
440        assert_eq!(ext.ext_dtype(), ext_dtype.as_ref());
441        assert!(ext.ext_dtype().metadata().is_some());
442    }
443
444    #[test]
445    fn test_ext_scalar_equality_ignores_nullability() {
446        let ext_dtype_non_null = Arc::new(ExtDType::new(
447            ExtID::new("test_ext".into()),
448            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
449            None,
450        ));
451        let ext_dtype_nullable = Arc::new(ExtDType::new(
452            ExtID::new("test_ext".into()),
453            Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
454            None,
455        ));
456
457        let scalar1 = Scalar::extension(
458            ext_dtype_non_null,
459            Scalar::primitive(42i32, Nullability::NonNullable),
460        );
461        let scalar2 = Scalar::extension(
462            ext_dtype_nullable,
463            Scalar::primitive(42i32, Nullability::Nullable),
464        );
465
466        let ext1 = ExtScalar::try_from(&scalar1).unwrap();
467        let ext2 = ExtScalar::try_from(&scalar2).unwrap();
468
469        // Equality should ignore nullability differences
470        assert_eq!(ext1, ext2);
471    }
472}