vortex_scalar/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Scalar values and types for the Vortex system.
5//!
6//! This crate provides scalar types and values that can be used to represent individual
7//! data elements in the Vortex array system. Scalars are composed of a logical data type
8//! ([`DType`]) and a value ([`ScalarValue`]).
9
10#![deny(missing_docs)]
11
12use std::cmp::Ordering;
13use std::hash::Hash;
14use std::sync::Arc;
15
16pub use scalar_type::ScalarType;
17use vortex_buffer::{Buffer, BufferString, ByteBuffer};
18use vortex_dtype::half::f16;
19use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
20#[cfg(feature = "arbitrary")]
21pub mod arbitrary;
22mod arrow;
23mod bigint;
24mod binary;
25mod bool;
26mod decimal;
27mod display;
28mod extension;
29mod list;
30mod null;
31mod primitive;
32mod proto;
33mod pvalue;
34mod scalar_type;
35mod scalar_value;
36mod struct_;
37mod utf8;
38
39pub use bigint::*;
40pub use binary::*;
41pub use bool::*;
42pub use decimal::*;
43pub use extension::*;
44pub use list::*;
45pub use primitive::*;
46pub use pvalue::*;
47pub use scalar_value::*;
48pub use struct_::*;
49pub use utf8::*;
50use vortex_error::{VortexExpect, VortexResult, vortex_bail};
51
52/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`].
53///
54/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers
55/// for example [`BoolScalar`], [`PrimitiveScalar`], etc.
56///
57/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype,
58/// including nullability. When the DType does match, ordering is nulls first (lowest), then the
59/// natural ordering of the scalar value.
60#[derive(Debug, Clone)]
61pub struct Scalar {
62    dtype: DType,
63    value: ScalarValue,
64}
65
66impl Scalar {
67    /// Creates a new scalar with the given data type and value.
68    ///
69    /// FIXME(ngates): this is unchecked... we don't know that the scalar value is compatible
70    ///  with the data type.
71    pub fn new(dtype: DType, value: ScalarValue) -> Self {
72        Self { dtype, value }
73    }
74
75    /// Returns a reference to the scalar's data type.
76    #[inline]
77    pub fn dtype(&self) -> &DType {
78        &self.dtype
79    }
80
81    /// Returns a reference to the scalar's underlying value.
82    #[inline]
83    pub fn value(&self) -> &ScalarValue {
84        &self.value
85    }
86
87    /// Consumes the scalar and returns its data type and value as a tuple.
88    #[inline]
89    pub fn into_parts(self) -> (DType, ScalarValue) {
90        (self.dtype, self.value)
91    }
92
93    /// Consumes the scalar and returns its underlying value.
94    #[inline]
95    pub fn into_value(self) -> ScalarValue {
96        self.value
97    }
98
99    /// Returns true if the scalar is not null.
100    pub fn is_valid(&self) -> bool {
101        !self.value.is_null()
102    }
103
104    /// Returns true if the scalar is null.
105    pub fn is_null(&self) -> bool {
106        self.value.is_null()
107    }
108
109    /// Creates a null scalar with the given nullable data type.
110    ///
111    /// # Panics
112    ///
113    /// Panics if the data type is not nullable.
114    pub fn null(dtype: DType) -> Self {
115        assert!(
116            dtype.is_nullable(),
117            "Creating null scalar for non-nullable DType {dtype}"
118        );
119        Self {
120            dtype,
121            value: ScalarValue(InnerScalarValue::Null),
122        }
123    }
124
125    /// Creates a null scalar for the given scalar type.
126    ///
127    /// The resulting scalar will have a nullable version of the type's data type.
128    pub fn null_typed<T: ScalarType>() -> Self {
129        Self {
130            dtype: T::dtype().as_nullable(),
131            value: ScalarValue(InnerScalarValue::Null),
132        }
133    }
134
135    /// Casts the scalar to the target data type.
136    ///
137    /// Returns an error if the cast is not supported or if the value cannot be represented
138    /// in the target type.
139    pub fn cast(&self, target: &DType) -> VortexResult<Self> {
140        if let DType::Extension(ext_dtype) = target {
141            let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
142            Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
143        } else {
144            self.cast_to_non_extension(target)
145        }
146    }
147
148    fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
149        assert!(!matches!(target, DType::Extension(..)));
150        if self.is_null() {
151            if target.is_nullable() {
152                return Ok(Scalar::new(target.clone(), self.value.clone()));
153            } else {
154                vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
155            }
156        }
157
158        if self.dtype().eq_ignore_nullability(target) {
159            return Ok(Scalar::new(target.clone(), self.value.clone()));
160        }
161
162        match &self.dtype {
163            DType::Null => unreachable!(), // handled by if is_null case
164            DType::Bool(_) => self.as_bool().cast(target),
165            DType::Primitive(..) => self.as_primitive().cast(target),
166            DType::Decimal(..) => todo!("(aduffy): implement DecimalScalar casting"),
167            DType::Utf8(_) => self.as_utf8().cast(target),
168            DType::Binary(_) => self.as_binary().cast(target),
169            DType::Struct(..) => self.as_struct().cast(target),
170            DType::List(..) => self.as_list().cast(target),
171            DType::Extension(..) => self.as_extension().cast(target),
172        }
173    }
174
175    /// Converts the scalar to have a nullable version of its data type.
176    pub fn into_nullable(self) -> Self {
177        Self {
178            dtype: self.dtype.as_nullable(),
179            value: self.value,
180        }
181    }
182
183    /// Returns the size of the scalar in bytes, uncompressed.
184    pub fn nbytes(&self) -> usize {
185        match self.dtype() {
186            DType::Null => 0,
187            DType::Bool(_) => 1,
188            DType::Primitive(ptype, _) => ptype.byte_width(),
189            DType::Decimal(dt, _) => {
190                if dt.precision() >= DECIMAL128_MAX_PRECISION {
191                    size_of::<i128>()
192                } else {
193                    size_of::<i256>()
194                }
195            }
196            DType::Binary(_) | DType::Utf8(_) => self
197                .value()
198                .as_buffer()
199                .ok()
200                .flatten()
201                .map_or(0, |s| s.len()),
202            DType::Struct(_dtype, _) => self
203                .as_struct()
204                .fields()
205                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
206                .unwrap_or_default(),
207            DType::List(_dtype, _) => self
208                .as_list()
209                .elements()
210                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
211                .unwrap_or_default(),
212            DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
213        }
214    }
215
216    /// Creates a "default" scalar value for the given data type.
217    ///
218    /// For nullable types, returns null. For non-nullable types, returns
219    /// an appropriate zero/empty value.
220    pub fn default_value(dtype: DType) -> Self {
221        if dtype.is_nullable() {
222            return Self::null(dtype);
223        }
224
225        match dtype {
226            DType::Null => Self::null(dtype),
227            DType::Bool(nullability) => Self::bool(false, nullability),
228            DType::Primitive(pt, nullability) => {
229                Self::primitive_value(PValue::zero(pt), pt, nullability)
230            }
231            DType::Decimal(dt, nullability) => {
232                Self::decimal(DecimalValue::from(0), dt, nullability)
233            }
234            DType::Utf8(nullability) => Self::utf8("", nullability),
235            DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
236            DType::Struct(sf, nullability) => {
237                let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
238                Self::struct_(DType::Struct(sf, nullability), fields)
239            }
240            DType::List(dt, nullability) => Self::list(dt, vec![], nullability),
241            DType::Extension(dt) => {
242                let scalar = Self::default_value(dt.storage_dtype().clone());
243                Self::extension(dt, scalar)
244            }
245        }
246    }
247}
248
249impl Scalar {
250    /// Returns a view of the scalar as a boolean scalar.
251    ///
252    /// # Panics
253    ///
254    /// Panics if the scalar is not a boolean type.
255    pub fn as_bool(&self) -> BoolScalar<'_> {
256        BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
257    }
258
259    /// Returns a view of the scalar as a boolean scalar if it has a boolean type.
260    pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
261        matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
262    }
263
264    /// Returns a view of the scalar as a primitive scalar.
265    ///
266    /// # Panics
267    ///
268    /// Panics if the scalar is not a primitive type.
269    pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
270        PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
271    }
272
273    /// Returns a view of the scalar as a primitive scalar if it has a primitive type.
274    pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
275        matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
276    }
277
278    /// Returns a view of the scalar as a decimal scalar.
279    ///
280    /// # Panics
281    ///
282    /// Panics if the scalar is not a decimal type.
283    pub fn as_decimal(&self) -> DecimalScalar<'_> {
284        DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
285    }
286
287    /// Returns a view of the scalar as a decimal scalar if it has a decimal type.
288    pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
289        matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
290    }
291
292    /// Returns a view of the scalar as a UTF-8 string scalar.
293    ///
294    /// # Panics
295    ///
296    /// Panics if the scalar is not a UTF-8 type.
297    pub fn as_utf8(&self) -> Utf8Scalar<'_> {
298        Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
299    }
300
301    /// Returns a view of the scalar as a UTF-8 string scalar if it has a UTF-8 type.
302    pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
303        matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
304    }
305
306    /// Returns a view of the scalar as a binary scalar.
307    ///
308    /// # Panics
309    ///
310    /// Panics if the scalar is not a binary type.
311    pub fn as_binary(&self) -> BinaryScalar<'_> {
312        BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
313    }
314
315    /// Returns a view of the scalar as a binary scalar if it has a binary type.
316    pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
317        matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
318    }
319
320    /// Returns a view of the scalar as a struct scalar.
321    ///
322    /// # Panics
323    ///
324    /// Panics if the scalar is not a struct type.
325    pub fn as_struct(&self) -> StructScalar<'_> {
326        StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
327    }
328
329    /// Returns a view of the scalar as a struct scalar if it has a struct type.
330    pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
331        matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
332    }
333
334    /// Returns a view of the scalar as a list scalar.
335    ///
336    /// # Panics
337    ///
338    /// Panics if the scalar is not a list type.
339    pub fn as_list(&self) -> ListScalar<'_> {
340        ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
341    }
342
343    /// Returns a view of the scalar as a list scalar if it has a list type.
344    pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
345        matches!(self.dtype, DType::List(..)).then(|| self.as_list())
346    }
347
348    /// Returns a view of the scalar as an extension scalar.
349    ///
350    /// # Panics
351    ///
352    /// Panics if the scalar is not an extension type.
353    pub fn as_extension(&self) -> ExtScalar<'_> {
354        ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
355    }
356
357    /// Returns a view of the scalar as an extension scalar if it has an extension type.
358    pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
359        matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
360    }
361}
362
363impl PartialEq for Scalar {
364    fn eq(&self, other: &Self) -> bool {
365        if !self.dtype.eq_ignore_nullability(&other.dtype) {
366            return false;
367        }
368
369        match self.dtype() {
370            DType::Null => true,
371            DType::Bool(_) => self.as_bool() == other.as_bool(),
372            DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
373            DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
374            DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
375            DType::Binary(_) => self.as_binary() == other.as_binary(),
376            DType::Struct(..) => self.as_struct() == other.as_struct(),
377            DType::List(..) => self.as_list() == other.as_list(),
378            DType::Extension(_) => self.as_extension() == other.as_extension(),
379        }
380    }
381}
382
383impl Eq for Scalar {}
384
385impl PartialOrd for Scalar {
386    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
387        if !self.dtype().eq_ignore_nullability(other.dtype()) {
388            return None;
389        }
390        match self.dtype() {
391            DType::Null => Some(Ordering::Equal),
392            DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
393            DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
394            DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
395            DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
396            DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
397            DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
398            DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
399            DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
400        }
401    }
402}
403
404impl Hash for Scalar {
405    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
406        match self.dtype() {
407            DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value
408            DType::Bool(_) => self.as_bool().hash(state),
409            DType::Primitive(..) => self.as_primitive().hash(state),
410            DType::Decimal(..) => self.as_decimal().hash(state),
411            DType::Utf8(_) => self.as_utf8().hash(state),
412            DType::Binary(_) => self.as_binary().hash(state),
413            DType::Struct(..) => self.as_struct().hash(state),
414            DType::List(..) => self.as_list().hash(state),
415            DType::Extension(_) => self.as_extension().hash(state),
416        }
417    }
418}
419
420impl AsRef<Self> for Scalar {
421    fn as_ref(&self) -> &Self {
422        self
423    }
424}
425
426impl<T> From<Option<T>> for Scalar
427where
428    T: ScalarType,
429    Scalar: From<T>,
430{
431    fn from(value: Option<T>) -> Self {
432        value
433            .map(Scalar::from)
434            .map(|x| x.into_nullable())
435            .unwrap_or_else(|| Scalar {
436                dtype: T::dtype().as_nullable(),
437                value: ScalarValue(InnerScalarValue::Null),
438            })
439    }
440}
441
442impl From<PrimitiveScalar<'_>> for Scalar {
443    fn from(pscalar: PrimitiveScalar<'_>) -> Self {
444        let dtype = pscalar.dtype().clone();
445        let value = pscalar
446            .pvalue()
447            .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
448            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
449        Self::new(dtype, value)
450    }
451}
452
453impl From<DecimalScalar<'_>> for Scalar {
454    fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
455        let dtype = decimal_scalar.dtype().clone();
456        let value = decimal_scalar
457            .decimal_value()
458            .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
459            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
460        Self::new(dtype, value)
461    }
462}
463
464macro_rules! from_vec_for_scalar {
465    ($T:ty) => {
466        impl From<Vec<$T>> for Scalar {
467            fn from(value: Vec<$T>) -> Self {
468                Scalar {
469                    dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
470                    value: ScalarValue(InnerScalarValue::List(
471                        value
472                            .into_iter()
473                            .map(Scalar::from)
474                            .map(|s| s.into_value())
475                            .collect::<Arc<[_]>>(),
476                    )),
477                }
478            }
479        }
480    };
481}
482
483// no From<Vec<u8>> because it could either be a List or a Buffer
484from_vec_for_scalar!(u16);
485from_vec_for_scalar!(u32);
486from_vec_for_scalar!(u64);
487from_vec_for_scalar!(usize); // For usize only, we implicitly cast for better ergonomics.
488from_vec_for_scalar!(i8);
489from_vec_for_scalar!(i16);
490from_vec_for_scalar!(i32);
491from_vec_for_scalar!(i64);
492from_vec_for_scalar!(f16);
493from_vec_for_scalar!(f32);
494from_vec_for_scalar!(f64);
495from_vec_for_scalar!(String);
496from_vec_for_scalar!(BufferString);
497from_vec_for_scalar!(ByteBuffer);
498
499#[cfg(test)]
500mod test {
501    use std::sync::Arc;
502
503    use rstest::rstest;
504    use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
505
506    use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
507
508    #[rstest]
509    fn null_can_cast_to_anything_nullable(
510        #[values(
511            DType::Null,
512            DType::Bool(Nullability::Nullable),
513            DType::Primitive(PType::I32, Nullability::Nullable),
514            DType::Extension(Arc::from(ExtDType::new(
515                ExtID::from("a"),
516                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
517                None,
518            ))),
519            DType::Extension(Arc::from(ExtDType::new(
520                ExtID::from("b"),
521                Arc::from(DType::Utf8(Nullability::Nullable)),
522                None,
523            )))
524        )]
525        source_dtype: DType,
526        #[values(
527            DType::Null,
528            DType::Bool(Nullability::Nullable),
529            DType::Primitive(PType::I32, Nullability::Nullable),
530            DType::Extension(Arc::from(ExtDType::new(
531                ExtID::from("a"),
532                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
533                None,
534            ))),
535            DType::Extension(Arc::from(ExtDType::new(
536                ExtID::from("b"),
537                Arc::from(DType::Utf8(Nullability::Nullable)),
538                None,
539            )))
540        )]
541        target_dtype: DType,
542    ) {
543        assert_eq!(
544            Scalar::null(source_dtype)
545                .cast(&target_dtype)
546                .unwrap()
547                .dtype(),
548            &target_dtype
549        );
550    }
551
552    #[test]
553    fn list_casts() {
554        let list = Scalar::new(
555            DType::List(
556                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
557                Nullability::Nullable,
558            ),
559            ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
560                InnerScalarValue::Primitive(PValue::U16(6)),
561            )]))),
562        );
563
564        let target_u32 = DType::List(
565            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
566            Nullability::Nullable,
567        );
568        assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
569
570        let target_u32_nonnull = DType::List(
571            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
572            Nullability::Nullable,
573        );
574        assert_eq!(
575            list.cast(&target_u32_nonnull).unwrap().dtype(),
576            &target_u32_nonnull
577        );
578
579        let target_nonnull = DType::List(
580            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
581            Nullability::NonNullable,
582        );
583        assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
584
585        let target_u8 = DType::List(
586            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
587            Nullability::Nullable,
588        );
589        assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
590
591        let list_with_null = Scalar::new(
592            DType::List(
593                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
594                Nullability::Nullable,
595            ),
596            ScalarValue(InnerScalarValue::List(Arc::from([
597                ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
598                ScalarValue(InnerScalarValue::Null),
599            ]))),
600        );
601        let target_u8 = DType::List(
602            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
603            Nullability::Nullable,
604        );
605        assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
606
607        let target_u32_nonnull = DType::List(
608            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
609            Nullability::Nullable,
610        );
611        assert!(list_with_null.cast(&target_u32_nonnull).is_err());
612    }
613
614    #[test]
615    fn cast_to_from_extension_types() {
616        let apples = ExtDType::new(
617            ExtID::new(Arc::from("apples")),
618            Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
619            None,
620        );
621        let ext_dtype = DType::Extension(Arc::from(apples.clone()));
622        let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
623        let storage_scalar = Scalar::new(
624            DType::clone(apples.storage_dtype()),
625            ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
626        );
627
628        // to self
629        let expected_dtype = &ext_dtype;
630        let actual = ext_scalar.cast(expected_dtype).unwrap();
631        assert_eq!(actual.dtype(), expected_dtype);
632
633        // to nullable self
634        let expected_dtype = &ext_dtype.as_nullable();
635        let actual = ext_scalar.cast(expected_dtype).unwrap();
636        assert_eq!(actual.dtype(), expected_dtype);
637
638        // cast to the storage type
639        let expected_dtype = apples.storage_dtype();
640        let actual = ext_scalar.cast(expected_dtype).unwrap();
641        assert_eq!(actual.dtype(), expected_dtype);
642
643        // cast to the storage type, nullable
644        let expected_dtype = &apples.storage_dtype().as_nullable();
645        let actual = ext_scalar.cast(expected_dtype).unwrap();
646        assert_eq!(actual.dtype(), expected_dtype);
647
648        // cast from storage type to extension
649        let expected_dtype = &ext_dtype;
650        let actual = storage_scalar.cast(expected_dtype).unwrap();
651        assert_eq!(actual.dtype(), expected_dtype);
652
653        // cast from storage type to extension, nullable
654        let expected_dtype = &ext_dtype.as_nullable();
655        let actual = storage_scalar.cast(expected_dtype).unwrap();
656        assert_eq!(actual.dtype(), expected_dtype);
657
658        // cast from *compatible* storage type to extension
659        let storage_scalar_u64 = Scalar::new(
660            DType::clone(apples.storage_dtype()),
661            ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
662        );
663        let expected_dtype = &ext_dtype;
664        let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
665        assert_eq!(actual.dtype(), expected_dtype);
666
667        // cast from *incompatible* storage type to extension
668        let apples_u8 = ExtDType::new(
669            ExtID::new(Arc::from("apples")),
670            Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
671            None,
672        );
673        let expected_dtype = &DType::Extension(Arc::from(apples_u8));
674        let result = storage_scalar.cast(expected_dtype);
675        assert!(
676            result.as_ref().is_err_and(|err| {
677                err
678                    .to_string()
679                    .contains("Can't cast u16 scalar 1000u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
680            }),
681            "{result:?}"
682        );
683    }
684
685    #[test]
686    fn default_value_for_complex_dtype() {
687        let struct_dtype = DType::struct_(
688            [
689                ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
690                (
691                    "b",
692                    DType::list(
693                        DType::Primitive(PType::I8, Nullability::Nullable),
694                        Nullability::NonNullable,
695                    ),
696                ),
697                ("c", DType::Primitive(PType::I32, Nullability::Nullable)),
698            ],
699            Nullability::NonNullable,
700        );
701
702        let scalar = Scalar::default_value(struct_dtype.clone());
703        assert_eq!(scalar.dtype(), &struct_dtype);
704
705        let scalar = scalar.as_struct();
706
707        let a_field = scalar.field("a").unwrap();
708        assert_eq!(a_field.as_primitive().pvalue().unwrap(), PValue::I32(0));
709
710        let b_field = scalar.field("b").unwrap();
711        assert!(b_field.as_list().is_empty());
712
713        let c_field = scalar.field("c").unwrap();
714        assert!(c_field.is_null());
715    }
716}