Skip to main content

vortex_array/scalar/
arrow.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Conversions between [`Scalar`] and Arrow scalar types.
5
6use std::sync::Arc;
7
8use arrow_array::Scalar as ArrowScalar;
9use arrow_array::*;
10use vortex_dtype::DType;
11use vortex_dtype::PType;
12use vortex_dtype::datetime::AnyTemporal;
13use vortex_dtype::datetime::TemporalMetadata;
14use vortex_dtype::datetime::TimeUnit;
15use vortex_error::VortexError;
16use vortex_error::vortex_bail;
17use vortex_error::vortex_err;
18
19use crate::scalar::BinaryScalar;
20use crate::scalar::BoolScalar;
21use crate::scalar::DecimalScalar;
22use crate::scalar::DecimalValue;
23use crate::scalar::ExtScalar;
24use crate::scalar::PrimitiveScalar;
25use crate::scalar::Scalar;
26use crate::scalar::Utf8Scalar;
27
28/// Arrow represents scalars as single-element arrays. This constant is the length of those arrays.
29const SCALAR_ARRAY_LEN: usize = 1;
30
31/// Converts an optional value to an Arrow scalar array.
32macro_rules! value_to_arrow_scalar {
33    ($V:expr, $AR:ty) => {
34        Ok(std::sync::Arc::new(
35            $V.map(<$AR>::new_scalar)
36                .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(SCALAR_ARRAY_LEN))),
37        ))
38    };
39}
40
41/// Converts an optional timestamp value to an Arrow scalar array.
42macro_rules! timestamp_to_arrow_scalar {
43    ($V:expr, $TZ:expr, $AR:ty) => {{
44        let array = match $V {
45            Some(v) => <$AR>::new_scalar(v).into_inner(),
46            None => <$AR>::new_null(SCALAR_ARRAY_LEN),
47        }
48        .with_timezone_opt($TZ);
49        Ok(Arc::new(ArrowScalar::new(array)))
50    }};
51}
52
53impl TryFrom<&Scalar> for Arc<dyn Datum> {
54    type Error = VortexError;
55
56    fn try_from(value: &Scalar) -> Result<Arc<dyn Datum>, Self::Error> {
57        match value.dtype() {
58            DType::Null => Ok(Arc::new(NullArray::new(SCALAR_ARRAY_LEN))),
59            DType::Bool(_) => bool_to_arrow(value.as_bool()),
60            DType::Primitive(..) => primitive_to_arrow(value.as_primitive()),
61            DType::Decimal(..) => decimal_to_arrow(value.as_decimal()),
62            DType::Utf8(_) => utf8_to_arrow(value.as_utf8()),
63            DType::Binary(_) => binary_to_arrow(value.as_binary()),
64            DType::Struct(..) => unimplemented!("struct scalar conversion"),
65            DType::List(..) => unimplemented!("list scalar conversion"),
66            DType::FixedSizeList(..) => unimplemented!("fixed-size list scalar conversion"),
67            DType::Extension(..) => extension_to_arrow(value.as_extension()),
68        }
69    }
70}
71
72/// Convert a [`BoolScalar`] to an Arrow [`Datum`].
73fn bool_to_arrow(scalar: BoolScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
74    value_to_arrow_scalar!(scalar.value(), BooleanArray)
75}
76
77/// Convert a [`PrimitiveScalar`] to an Arrow [`Datum`].
78fn primitive_to_arrow(scalar: PrimitiveScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
79    match scalar.ptype() {
80        PType::U8 => value_to_arrow_scalar!(scalar.typed_value(), UInt8Array),
81        PType::U16 => value_to_arrow_scalar!(scalar.typed_value(), UInt16Array),
82        PType::U32 => value_to_arrow_scalar!(scalar.typed_value(), UInt32Array),
83        PType::U64 => value_to_arrow_scalar!(scalar.typed_value(), UInt64Array),
84        PType::I8 => value_to_arrow_scalar!(scalar.typed_value(), Int8Array),
85        PType::I16 => value_to_arrow_scalar!(scalar.typed_value(), Int16Array),
86        PType::I32 => value_to_arrow_scalar!(scalar.typed_value(), Int32Array),
87        PType::I64 => value_to_arrow_scalar!(scalar.typed_value(), Int64Array),
88        PType::F16 => value_to_arrow_scalar!(scalar.typed_value(), Float16Array),
89        PType::F32 => value_to_arrow_scalar!(scalar.typed_value(), Float32Array),
90        PType::F64 => value_to_arrow_scalar!(scalar.typed_value(), Float64Array),
91    }
92}
93
94/// Convert a [`DecimalScalar`] to an Arrow [`Datum`].
95fn decimal_to_arrow(scalar: DecimalScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
96    // TODO(joe): Replace with decimal32, etc. once Arrow supports them.
97    match scalar.decimal_value() {
98        Some(DecimalValue::I8(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
99        Some(DecimalValue::I16(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
100        Some(DecimalValue::I32(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
101        Some(DecimalValue::I64(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
102        Some(DecimalValue::I128(v128)) => Ok(Arc::new(Decimal128Array::new_scalar(v128))),
103        Some(DecimalValue::I256(v256)) => Ok(Arc::new(Decimal256Array::new_scalar(v256.into()))),
104        None => Ok(Arc::new(arrow_array::Scalar::new(
105            Decimal128Array::new_null(SCALAR_ARRAY_LEN),
106        ))),
107    }
108}
109
110/// Convert a [`Utf8Scalar`] to an Arrow [`Datum`].
111fn utf8_to_arrow(scalar: Utf8Scalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
112    value_to_arrow_scalar!(scalar.value(), StringViewArray)
113}
114
115/// Convert a [`BinaryScalar`] to an Arrow [`Datum`].
116fn binary_to_arrow(scalar: BinaryScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
117    value_to_arrow_scalar!(scalar.value(), BinaryViewArray)
118}
119
120/// Convert an [`ExtScalar`] to an Arrow [`Datum`].
121///
122/// Currently only temporal extension types (timestamps, dates, and times) are supported.
123fn extension_to_arrow(scalar: ExtScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
124    let ext_dtype = scalar.ext_dtype();
125    let Some(temporal) = ext_dtype.metadata_opt::<AnyTemporal>() else {
126        vortex_bail!(
127            "Cannot convert extension scalar {} to Arrow",
128            ext_dtype.id()
129        )
130    };
131
132    let storage_scalar = scalar.to_storage_scalar();
133    let primitive = storage_scalar
134        .as_primitive_opt()
135        .ok_or_else(|| vortex_err!("Expected primitive scalar"))?;
136
137    match temporal {
138        TemporalMetadata::Timestamp(unit, tz) => {
139            let value = primitive.as_::<i64>();
140            match unit {
141                TimeUnit::Nanoseconds => {
142                    timestamp_to_arrow_scalar!(value, tz.clone(), TimestampNanosecondArray)
143                }
144                TimeUnit::Microseconds => {
145                    timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMicrosecondArray)
146                }
147                TimeUnit::Milliseconds => {
148                    timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMillisecondArray)
149                }
150                TimeUnit::Seconds => {
151                    timestamp_to_arrow_scalar!(value, tz.clone(), TimestampSecondArray)
152                }
153                TimeUnit::Days => {
154                    vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
155                }
156            }
157        }
158        TemporalMetadata::Date(unit) => match unit {
159            TimeUnit::Milliseconds => {
160                value_to_arrow_scalar!(primitive.as_::<i64>(), Date64Array)
161            }
162            TimeUnit::Days => {
163                value_to_arrow_scalar!(primitive.as_::<i32>(), Date32Array)
164            }
165            TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => {
166                vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
167            }
168        },
169        TemporalMetadata::Time(unit) => match unit {
170            TimeUnit::Nanoseconds => {
171                value_to_arrow_scalar!(primitive.as_::<i64>(), Time64NanosecondArray)
172            }
173            TimeUnit::Microseconds => {
174                value_to_arrow_scalar!(primitive.as_::<i64>(), Time64MicrosecondArray)
175            }
176            TimeUnit::Milliseconds => {
177                value_to_arrow_scalar!(primitive.as_::<i32>(), Time32MillisecondArray)
178            }
179            TimeUnit::Seconds => {
180                value_to_arrow_scalar!(primitive.as_::<i32>(), Time32SecondArray)
181            }
182            TimeUnit::Days => {
183                vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
184            }
185        },
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use std::sync::Arc;
192
193    use arrow_array::Datum;
194    use rstest::rstest;
195    use vortex_dtype::DType;
196    use vortex_dtype::DecimalDType;
197    use vortex_dtype::NativeDType;
198    use vortex_dtype::Nullability;
199    use vortex_dtype::PType;
200    use vortex_dtype::datetime::Date;
201    use vortex_dtype::datetime::Time;
202    use vortex_dtype::datetime::TimeUnit;
203    use vortex_dtype::datetime::Timestamp;
204    use vortex_dtype::datetime::TimestampOptions;
205    use vortex_dtype::extension::ExtDTypeVTable;
206    use vortex_error::VortexResult;
207    use vortex_error::vortex_bail;
208
209    use crate::scalar::DecimalValue;
210    use crate::scalar::Scalar;
211
212    #[test]
213    fn test_null_scalar_to_arrow() {
214        let scalar = Scalar::null(DType::Null);
215        let result = Arc::<dyn Datum>::try_from(&scalar);
216        assert!(result.is_ok());
217    }
218
219    #[test]
220    fn test_bool_scalar_to_arrow() {
221        let scalar = Scalar::bool(true, Nullability::NonNullable);
222        let result = Arc::<dyn Datum>::try_from(&scalar);
223        assert!(result.is_ok());
224    }
225
226    #[test]
227    fn test_null_bool_scalar_to_arrow() {
228        let scalar = Scalar::null(bool::dtype().as_nullable());
229        let result = Arc::<dyn Datum>::try_from(&scalar);
230        assert!(result.is_ok());
231    }
232
233    #[test]
234    fn test_primitive_u8_to_arrow() {
235        let scalar = Scalar::primitive(42u8, Nullability::NonNullable);
236        let result = Arc::<dyn Datum>::try_from(&scalar);
237        assert!(result.is_ok());
238    }
239
240    #[test]
241    fn test_primitive_u16_to_arrow() {
242        let scalar = Scalar::primitive(1000u16, Nullability::NonNullable);
243        let result = Arc::<dyn Datum>::try_from(&scalar);
244        assert!(result.is_ok());
245    }
246
247    #[test]
248    fn test_primitive_u32_to_arrow() {
249        let scalar = Scalar::primitive(100000u32, Nullability::NonNullable);
250        let result = Arc::<dyn Datum>::try_from(&scalar);
251        assert!(result.is_ok());
252    }
253
254    #[test]
255    fn test_primitive_u64_to_arrow() {
256        let scalar = Scalar::primitive(10000000000u64, Nullability::NonNullable);
257        let result = Arc::<dyn Datum>::try_from(&scalar);
258        assert!(result.is_ok());
259    }
260
261    #[test]
262    fn test_primitive_i8_to_arrow() {
263        let scalar = Scalar::primitive(-42i8, Nullability::NonNullable);
264        let result = Arc::<dyn Datum>::try_from(&scalar);
265        assert!(result.is_ok());
266    }
267
268    #[test]
269    fn test_primitive_i16_to_arrow() {
270        let scalar = Scalar::primitive(-1000i16, Nullability::NonNullable);
271        let result = Arc::<dyn Datum>::try_from(&scalar);
272        assert!(result.is_ok());
273    }
274
275    #[test]
276    fn test_primitive_i32_to_arrow() {
277        let scalar = Scalar::primitive(-100000i32, Nullability::NonNullable);
278        let result = Arc::<dyn Datum>::try_from(&scalar);
279        assert!(result.is_ok());
280    }
281
282    #[test]
283    fn test_primitive_i64_to_arrow() {
284        let scalar = Scalar::primitive(-10000000000i64, Nullability::NonNullable);
285        let result = Arc::<dyn Datum>::try_from(&scalar);
286        assert!(result.is_ok());
287    }
288
289    #[test]
290    fn test_primitive_f16_to_arrow() {
291        use vortex_dtype::half::f16;
292
293        let scalar = Scalar::primitive(f16::from_f32(1.234), Nullability::NonNullable);
294        let result = Arc::<dyn Datum>::try_from(&scalar);
295        assert!(result.is_ok());
296    }
297
298    #[test]
299    fn test_primitive_f32_to_arrow() {
300        let scalar = Scalar::primitive(1.234f32, Nullability::NonNullable);
301        let result = Arc::<dyn Datum>::try_from(&scalar);
302        assert!(result.is_ok());
303    }
304
305    #[test]
306    fn test_primitive_f64_to_arrow() {
307        let scalar = Scalar::primitive(1.234567890123f64, Nullability::NonNullable);
308        let result = Arc::<dyn Datum>::try_from(&scalar);
309        assert!(result.is_ok());
310    }
311
312    #[test]
313    fn test_null_primitive_to_arrow() {
314        let scalar = Scalar::null(i32::dtype().as_nullable());
315        let result = Arc::<dyn Datum>::try_from(&scalar);
316        assert!(result.is_ok());
317    }
318
319    #[test]
320    fn test_utf8_scalar_to_arrow() {
321        let scalar = Scalar::utf8("hello world".to_string(), Nullability::NonNullable);
322        let result = Arc::<dyn Datum>::try_from(&scalar);
323        assert!(result.is_ok());
324    }
325
326    #[test]
327    fn test_null_utf8_scalar_to_arrow() {
328        let scalar = Scalar::null(String::dtype().as_nullable());
329        let result = Arc::<dyn Datum>::try_from(&scalar);
330        assert!(result.is_ok());
331    }
332
333    #[test]
334    fn test_binary_scalar_to_arrow() {
335        let data = vec![1u8, 2, 3, 4, 5];
336        let scalar = Scalar::binary(data, Nullability::NonNullable);
337        let result = Arc::<dyn Datum>::try_from(&scalar);
338        assert!(result.is_ok());
339    }
340
341    #[test]
342    fn test_null_binary_scalar_to_arrow() {
343        let scalar = Scalar::null(DType::Binary(Nullability::Nullable));
344        let result = Arc::<dyn Datum>::try_from(&scalar);
345        assert!(result.is_ok());
346    }
347
348    #[test]
349    fn test_decimal_scalars_to_arrow() {
350        // Test various decimal value types
351        let decimal_dtype = DecimalDType::new(5, 2);
352
353        let scalar_i8 = Scalar::decimal(
354            DecimalValue::I8(100),
355            decimal_dtype,
356            Nullability::NonNullable,
357        );
358        assert!(Arc::<dyn Datum>::try_from(&scalar_i8).is_ok());
359
360        let scalar_i16 = Scalar::decimal(
361            DecimalValue::I16(10000),
362            decimal_dtype,
363            Nullability::NonNullable,
364        );
365        assert!(Arc::<dyn Datum>::try_from(&scalar_i16).is_ok());
366
367        let scalar_i32 = Scalar::decimal(
368            DecimalValue::I32(99999),
369            decimal_dtype,
370            Nullability::NonNullable,
371        );
372        assert!(Arc::<dyn Datum>::try_from(&scalar_i32).is_ok());
373
374        let scalar_i64 = Scalar::decimal(
375            DecimalValue::I64(99999),
376            decimal_dtype,
377            Nullability::NonNullable,
378        );
379        assert!(Arc::<dyn Datum>::try_from(&scalar_i64).is_ok());
380
381        let scalar_i128 = Scalar::decimal(
382            DecimalValue::I128(99999),
383            decimal_dtype,
384            Nullability::NonNullable,
385        );
386        assert!(Arc::<dyn Datum>::try_from(&scalar_i128).is_ok());
387
388        // Test i256
389        use vortex_dtype::i256;
390        let value_i256 = i256::from_i128(99999);
391        let scalar_i256 = Scalar::decimal(
392            DecimalValue::I256(value_i256),
393            decimal_dtype,
394            Nullability::NonNullable,
395        );
396        assert!(Arc::<dyn Datum>::try_from(&scalar_i256).is_ok());
397    }
398
399    #[test]
400    fn test_null_decimal_to_arrow() {
401        use vortex_dtype::DecimalDType;
402
403        let decimal_dtype = DecimalDType::new(10, 2);
404        let scalar = Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable));
405        let result = Arc::<dyn Datum>::try_from(&scalar);
406        assert!(result.is_ok());
407    }
408
409    #[test]
410    #[should_panic(expected = "struct scalar conversion")]
411    fn test_struct_scalar_to_arrow_todo() {
412        use vortex_dtype::FieldDType;
413        use vortex_dtype::StructFields;
414
415        let struct_dtype = DType::Struct(
416            StructFields::from_iter([(
417                "field1",
418                FieldDType::from(DType::Primitive(PType::I32, Nullability::NonNullable)),
419            )]),
420            Nullability::NonNullable,
421        );
422
423        let struct_scalar = Scalar::struct_(
424            struct_dtype,
425            vec![Scalar::primitive(42i32, Nullability::NonNullable)],
426        );
427        Arc::<dyn Datum>::try_from(&struct_scalar).unwrap();
428    }
429
430    #[test]
431    #[should_panic(expected = "list scalar conversion")]
432    fn test_list_scalar_to_arrow_todo() {
433        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
434        let list_scalar = Scalar::list(
435            element_dtype,
436            vec![
437                Scalar::primitive(1i32, Nullability::NonNullable),
438                Scalar::primitive(2i32, Nullability::NonNullable),
439            ],
440            Nullability::NonNullable,
441        );
442
443        Arc::<dyn Datum>::try_from(&list_scalar).unwrap();
444    }
445
446    #[test]
447    #[should_panic(expected = "Cannot convert extension scalar")]
448    fn test_non_temporal_extension_to_arrow_todo() {
449        use vortex_dtype::ExtID;
450
451        #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
452        struct SomeExt;
453        impl ExtDTypeVTable for SomeExt {
454            type Metadata = String;
455
456            fn id(&self) -> ExtID {
457                ExtID::new_ref("some_ext")
458            }
459
460            fn serialize(&self, _options: &Self::Metadata) -> VortexResult<Vec<u8>> {
461                vortex_bail!("not implemented")
462            }
463
464            fn deserialize(&self, _data: &[u8]) -> VortexResult<Self::Metadata> {
465                vortex_bail!("not implemented")
466            }
467
468            fn validate_dtype(
469                &self,
470                _options: &Self::Metadata,
471                _storage_dtype: &DType,
472            ) -> VortexResult<()> {
473                Ok(())
474            }
475        }
476
477        let scalar = Scalar::extension::<SomeExt>(
478            "".into(),
479            Scalar::primitive(42i32, Nullability::NonNullable),
480        );
481
482        Arc::<dyn Datum>::try_from(&scalar).unwrap();
483    }
484
485    #[rstest]
486    #[case(TimeUnit::Nanoseconds, PType::I64, 123456789i64)]
487    #[case(TimeUnit::Microseconds, PType::I64, 123456789i64)]
488    #[case(TimeUnit::Milliseconds, PType::I32, 123456i64)]
489    #[case(TimeUnit::Seconds, PType::I32, 1234i64)]
490    fn test_temporal_time_to_arrow(
491        #[case] time_unit: TimeUnit,
492        #[case] ptype: PType,
493        #[case] value: i64,
494    ) {
495        let scalar = Scalar::extension::<Time>(
496            time_unit,
497            match ptype {
498                PType::I32 => {
499                    let i32_value = i32::try_from(value).expect("test value should fit in i32");
500                    Scalar::primitive(i32_value, Nullability::NonNullable)
501                }
502                PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
503                _ => unreachable!(),
504            },
505        );
506
507        let result = Arc::<dyn Datum>::try_from(&scalar);
508        assert!(result.is_ok());
509    }
510
511    #[rstest]
512    #[case(TimeUnit::Milliseconds, PType::I64, 1234567890000i64)]
513    #[case(TimeUnit::Days, PType::I32, 19000i64)]
514    fn test_temporal_date_to_arrow(
515        #[case] time_unit: TimeUnit,
516        #[case] ptype: PType,
517        #[case] value: i64,
518    ) {
519        let scalar = Scalar::extension::<Date>(
520            time_unit,
521            match ptype {
522                PType::I32 => {
523                    let i32_value = i32::try_from(value).expect("test value should fit in i32");
524                    Scalar::primitive(i32_value, Nullability::NonNullable)
525                }
526                PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
527                _ => unreachable!(),
528            },
529        );
530
531        let result = Arc::<dyn Datum>::try_from(&scalar);
532        assert!(result.is_ok());
533    }
534
535    #[rstest]
536    #[case(TimeUnit::Nanoseconds, 1234567890000000000i64)]
537    #[case(TimeUnit::Microseconds, 1234567890000000i64)]
538    #[case(TimeUnit::Milliseconds, 1234567890000i64)]
539    #[case(TimeUnit::Seconds, 1234567890i64)]
540    fn test_temporal_timestamp_to_arrow(#[case] time_unit: TimeUnit, #[case] value: i64) {
541        let scalar = Scalar::extension::<Timestamp>(
542            TimestampOptions {
543                unit: time_unit,
544                tz: None,
545            },
546            Scalar::primitive(value, Nullability::NonNullable),
547        );
548
549        let result = Arc::<dyn Datum>::try_from(&scalar);
550        assert!(result.is_ok());
551    }
552
553    #[rstest]
554    #[case(TimeUnit::Nanoseconds, "UTC", 1234567890000000000i64)]
555    #[case(TimeUnit::Microseconds, "EST", 1234567890000000i64)]
556    #[case(TimeUnit::Milliseconds, "ABC", 1234567890000i64)]
557    #[case(TimeUnit::Seconds, "UTC", 1234567890i64)]
558    fn test_temporal_timestamp_tz_to_arrow(
559        #[case] time_unit: TimeUnit,
560        #[case] tz: &str,
561        #[case] value: i64,
562    ) {
563        let scalar = Scalar::extension::<Timestamp>(
564            TimestampOptions {
565                unit: time_unit,
566                tz: Some(tz.into()),
567            },
568            Scalar::primitive(value, Nullability::NonNullable),
569        );
570
571        let result = Arc::<dyn Datum>::try_from(&scalar);
572        assert!(result.is_ok());
573    }
574
575    #[test]
576    fn test_temporal_with_null_value() {
577        let scalar = Scalar::extension::<Time>(
578            TimeUnit::Milliseconds,
579            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
580        );
581
582        let _result = Arc::<dyn Datum>::try_from(&scalar).unwrap();
583    }
584
585    #[test]
586    #[should_panic(expected = "DType utf8 is not a primitive type")]
587    fn test_temporal_non_primitive_storage_error() {
588        let _scalar = Scalar::extension::<Time>(
589            TimeUnit::Nanoseconds,
590            Scalar::utf8("not a timestamp", Nullability::NonNullable),
591        );
592    }
593}