1use std::sync::Arc;
7
8use arrow_array::Scalar as ArrowScalar;
9use arrow_array::*;
10use vortex_error::VortexError;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13
14use crate::dtype::DType;
15use crate::dtype::PType;
16use crate::extension::datetime::AnyTemporal;
17use crate::extension::datetime::TemporalMetadata;
18use crate::extension::datetime::TimeUnit;
19use crate::scalar::BinaryScalar;
20use crate::scalar::BoolScalar;
21use crate::scalar::DecimalScalar;
22use crate::scalar::DecimalValue;
23use crate::scalar::ExtScalar;
24use crate::scalar::PrimitiveScalar;
25use crate::scalar::Scalar;
26use crate::scalar::Utf8Scalar;
27
28const SCALAR_ARRAY_LEN: usize = 1;
30
31macro_rules! value_to_arrow_scalar {
33 ($V:expr, $AR:ty) => {
34 Ok(std::sync::Arc::new(
35 $V.map(<$AR>::new_scalar)
36 .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(SCALAR_ARRAY_LEN))),
37 ))
38 };
39}
40
41macro_rules! timestamp_to_arrow_scalar {
43 ($V:expr, $TZ:expr, $AR:ty) => {{
44 let array = match $V {
45 Some(v) => <$AR>::new_scalar(v).into_inner(),
46 None => <$AR>::new_null(SCALAR_ARRAY_LEN),
47 }
48 .with_timezone_opt($TZ);
49 Ok(Arc::new(ArrowScalar::new(array)))
50 }};
51}
52
53impl TryFrom<&Scalar> for Arc<dyn Datum> {
54 type Error = VortexError;
55
56 fn try_from(value: &Scalar) -> Result<Arc<dyn Datum>, Self::Error> {
57 match value.dtype() {
58 DType::Null => Ok(Arc::new(NullArray::new(SCALAR_ARRAY_LEN))),
59 DType::Bool(_) => bool_to_arrow(value.as_bool()),
60 DType::Primitive(..) => primitive_to_arrow(value.as_primitive()),
61 DType::Decimal(..) => decimal_to_arrow(value.as_decimal()),
62 DType::Utf8(_) => utf8_to_arrow(value.as_utf8()),
63 DType::Binary(_) => binary_to_arrow(value.as_binary()),
64 DType::Struct(..) => unimplemented!("struct scalar conversion"),
65 DType::List(..) => unimplemented!("list scalar conversion"),
66 DType::FixedSizeList(..) => unimplemented!("fixed-size list scalar conversion"),
67 DType::Extension(..) => extension_to_arrow(value.as_extension()),
68 DType::Variant(_) => unimplemented!("Variant scalar conversion"),
69 }
70 }
71}
72
73fn bool_to_arrow(scalar: BoolScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
75 value_to_arrow_scalar!(scalar.value(), BooleanArray)
76}
77
78fn primitive_to_arrow(scalar: PrimitiveScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
80 match scalar.ptype() {
81 PType::U8 => value_to_arrow_scalar!(scalar.typed_value(), UInt8Array),
82 PType::U16 => value_to_arrow_scalar!(scalar.typed_value(), UInt16Array),
83 PType::U32 => value_to_arrow_scalar!(scalar.typed_value(), UInt32Array),
84 PType::U64 => value_to_arrow_scalar!(scalar.typed_value(), UInt64Array),
85 PType::I8 => value_to_arrow_scalar!(scalar.typed_value(), Int8Array),
86 PType::I16 => value_to_arrow_scalar!(scalar.typed_value(), Int16Array),
87 PType::I32 => value_to_arrow_scalar!(scalar.typed_value(), Int32Array),
88 PType::I64 => value_to_arrow_scalar!(scalar.typed_value(), Int64Array),
89 PType::F16 => value_to_arrow_scalar!(scalar.typed_value(), Float16Array),
90 PType::F32 => value_to_arrow_scalar!(scalar.typed_value(), Float32Array),
91 PType::F64 => value_to_arrow_scalar!(scalar.typed_value(), Float64Array),
92 }
93}
94
95fn decimal_to_arrow(scalar: DecimalScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
97 match scalar.decimal_value() {
99 Some(DecimalValue::I8(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
100 Some(DecimalValue::I16(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
101 Some(DecimalValue::I32(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
102 Some(DecimalValue::I64(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
103 Some(DecimalValue::I128(v128)) => Ok(Arc::new(Decimal128Array::new_scalar(v128))),
104 Some(DecimalValue::I256(v256)) => Ok(Arc::new(Decimal256Array::new_scalar(v256.into()))),
105 None => Ok(Arc::new(arrow_array::Scalar::new(
106 Decimal128Array::new_null(SCALAR_ARRAY_LEN),
107 ))),
108 }
109}
110
111fn utf8_to_arrow(scalar: Utf8Scalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
113 value_to_arrow_scalar!(scalar.value(), StringViewArray)
114}
115
116fn binary_to_arrow(scalar: BinaryScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
118 value_to_arrow_scalar!(scalar.value(), BinaryViewArray)
119}
120
121fn extension_to_arrow(scalar: ExtScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
125 let ext_dtype = scalar.ext_dtype();
126 let Some(temporal) = ext_dtype.metadata_opt::<AnyTemporal>() else {
127 vortex_bail!(
128 "Cannot convert extension scalar {} to Arrow",
129 ext_dtype.id()
130 )
131 };
132
133 let storage_scalar = scalar.to_storage_scalar();
134 let primitive = storage_scalar
135 .as_primitive_opt()
136 .ok_or_else(|| vortex_err!("Expected primitive scalar"))?;
137
138 match temporal {
139 TemporalMetadata::Timestamp(unit, tz) => {
140 let value = primitive.as_::<i64>();
141 match unit {
142 TimeUnit::Nanoseconds => {
143 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampNanosecondArray)
144 }
145 TimeUnit::Microseconds => {
146 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMicrosecondArray)
147 }
148 TimeUnit::Milliseconds => {
149 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMillisecondArray)
150 }
151 TimeUnit::Seconds => {
152 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampSecondArray)
153 }
154 TimeUnit::Days => {
155 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
156 }
157 }
158 }
159 TemporalMetadata::Date(unit) => match unit {
160 TimeUnit::Milliseconds => {
161 value_to_arrow_scalar!(primitive.as_::<i64>(), Date64Array)
162 }
163 TimeUnit::Days => {
164 value_to_arrow_scalar!(primitive.as_::<i32>(), Date32Array)
165 }
166 TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => {
167 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
168 }
169 },
170 TemporalMetadata::Time(unit) => match unit {
171 TimeUnit::Nanoseconds => {
172 value_to_arrow_scalar!(primitive.as_::<i64>(), Time64NanosecondArray)
173 }
174 TimeUnit::Microseconds => {
175 value_to_arrow_scalar!(primitive.as_::<i64>(), Time64MicrosecondArray)
176 }
177 TimeUnit::Milliseconds => {
178 value_to_arrow_scalar!(primitive.as_::<i32>(), Time32MillisecondArray)
179 }
180 TimeUnit::Seconds => {
181 value_to_arrow_scalar!(primitive.as_::<i32>(), Time32SecondArray)
182 }
183 TimeUnit::Days => {
184 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
185 }
186 },
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use std::sync::Arc;
193
194 use arrow_array::Datum;
195 use rstest::rstest;
196 use vortex_error::VortexResult;
197 use vortex_error::vortex_bail;
198
199 use crate::dtype::DType;
200 use crate::dtype::DecimalDType;
201 use crate::dtype::FieldDType;
202 use crate::dtype::NativeDType;
203 use crate::dtype::Nullability;
204 use crate::dtype::PType;
205 use crate::dtype::StructFields;
206 use crate::dtype::extension::ExtDType;
207 use crate::dtype::extension::ExtId;
208 use crate::dtype::extension::ExtVTable;
209 use crate::dtype::i256;
210 use crate::extension::datetime::Date;
211 use crate::extension::datetime::Time;
212 use crate::extension::datetime::TimeUnit;
213 use crate::extension::datetime::Timestamp;
214 use crate::extension::datetime::TimestampOptions;
215 use crate::scalar::DecimalValue;
216 use crate::scalar::Scalar;
217 use crate::scalar::ScalarValue;
218
219 #[test]
220 fn test_null_scalar_to_arrow() {
221 let scalar = Scalar::null(DType::Null);
222 let result = Arc::<dyn Datum>::try_from(&scalar);
223 assert!(result.is_ok());
224 }
225
226 #[test]
227 fn test_bool_scalar_to_arrow() {
228 let scalar = Scalar::bool(true, Nullability::NonNullable);
229 let result = Arc::<dyn Datum>::try_from(&scalar);
230 assert!(result.is_ok());
231 }
232
233 #[test]
234 fn test_null_bool_scalar_to_arrow() {
235 let scalar = Scalar::null(bool::dtype().as_nullable());
236 let result = Arc::<dyn Datum>::try_from(&scalar);
237 assert!(result.is_ok());
238 }
239
240 #[test]
241 fn test_primitive_u8_to_arrow() {
242 let scalar = Scalar::primitive(42u8, Nullability::NonNullable);
243 let result = Arc::<dyn Datum>::try_from(&scalar);
244 assert!(result.is_ok());
245 }
246
247 #[test]
248 fn test_primitive_u16_to_arrow() {
249 let scalar = Scalar::primitive(1000u16, Nullability::NonNullable);
250 let result = Arc::<dyn Datum>::try_from(&scalar);
251 assert!(result.is_ok());
252 }
253
254 #[test]
255 fn test_primitive_u32_to_arrow() {
256 let scalar = Scalar::primitive(100000u32, Nullability::NonNullable);
257 let result = Arc::<dyn Datum>::try_from(&scalar);
258 assert!(result.is_ok());
259 }
260
261 #[test]
262 fn test_primitive_u64_to_arrow() {
263 let scalar = Scalar::primitive(10000000000u64, Nullability::NonNullable);
264 let result = Arc::<dyn Datum>::try_from(&scalar);
265 assert!(result.is_ok());
266 }
267
268 #[test]
269 fn test_primitive_i8_to_arrow() {
270 let scalar = Scalar::primitive(-42i8, Nullability::NonNullable);
271 let result = Arc::<dyn Datum>::try_from(&scalar);
272 assert!(result.is_ok());
273 }
274
275 #[test]
276 fn test_primitive_i16_to_arrow() {
277 let scalar = Scalar::primitive(-1000i16, Nullability::NonNullable);
278 let result = Arc::<dyn Datum>::try_from(&scalar);
279 assert!(result.is_ok());
280 }
281
282 #[test]
283 fn test_primitive_i32_to_arrow() {
284 let scalar = Scalar::primitive(-100000i32, Nullability::NonNullable);
285 let result = Arc::<dyn Datum>::try_from(&scalar);
286 assert!(result.is_ok());
287 }
288
289 #[test]
290 fn test_primitive_i64_to_arrow() {
291 let scalar = Scalar::primitive(-10000000000i64, Nullability::NonNullable);
292 let result = Arc::<dyn Datum>::try_from(&scalar);
293 assert!(result.is_ok());
294 }
295
296 #[test]
297 fn test_primitive_f16_to_arrow() {
298 use crate::dtype::half::f16;
299
300 let scalar = Scalar::primitive(f16::from_f32(1.234), Nullability::NonNullable);
301 let result = Arc::<dyn Datum>::try_from(&scalar);
302 assert!(result.is_ok());
303 }
304
305 #[test]
306 fn test_primitive_f32_to_arrow() {
307 let scalar = Scalar::primitive(1.234f32, Nullability::NonNullable);
308 let result = Arc::<dyn Datum>::try_from(&scalar);
309 assert!(result.is_ok());
310 }
311
312 #[test]
313 fn test_primitive_f64_to_arrow() {
314 let scalar = Scalar::primitive(1.234567890123f64, Nullability::NonNullable);
315 let result = Arc::<dyn Datum>::try_from(&scalar);
316 assert!(result.is_ok());
317 }
318
319 #[test]
320 fn test_null_primitive_to_arrow() {
321 let scalar = Scalar::null(i32::dtype().as_nullable());
322 let result = Arc::<dyn Datum>::try_from(&scalar);
323 assert!(result.is_ok());
324 }
325
326 #[test]
327 fn test_utf8_scalar_to_arrow() {
328 let scalar = Scalar::utf8("hello world".to_string(), Nullability::NonNullable);
329 let result = Arc::<dyn Datum>::try_from(&scalar);
330 assert!(result.is_ok());
331 }
332
333 #[test]
334 fn test_null_utf8_scalar_to_arrow() {
335 let scalar = Scalar::null(String::dtype().as_nullable());
336 let result = Arc::<dyn Datum>::try_from(&scalar);
337 assert!(result.is_ok());
338 }
339
340 #[test]
341 fn test_binary_scalar_to_arrow() {
342 let data = vec![1u8, 2, 3, 4, 5];
343 let scalar = Scalar::binary(data, Nullability::NonNullable);
344 let result = Arc::<dyn Datum>::try_from(&scalar);
345 assert!(result.is_ok());
346 }
347
348 #[test]
349 fn test_null_binary_scalar_to_arrow() {
350 let scalar = Scalar::null(DType::Binary(Nullability::Nullable));
351 let result = Arc::<dyn Datum>::try_from(&scalar);
352 assert!(result.is_ok());
353 }
354
355 #[test]
356 fn test_decimal_scalars_to_arrow() {
357 let decimal_dtype = DecimalDType::new(5, 2);
359
360 let scalar_i8 = Scalar::decimal(
361 DecimalValue::I8(100),
362 decimal_dtype,
363 Nullability::NonNullable,
364 );
365 assert!(Arc::<dyn Datum>::try_from(&scalar_i8).is_ok());
366
367 let scalar_i16 = Scalar::decimal(
368 DecimalValue::I16(10000),
369 decimal_dtype,
370 Nullability::NonNullable,
371 );
372 assert!(Arc::<dyn Datum>::try_from(&scalar_i16).is_ok());
373
374 let scalar_i32 = Scalar::decimal(
375 DecimalValue::I32(99999),
376 decimal_dtype,
377 Nullability::NonNullable,
378 );
379 assert!(Arc::<dyn Datum>::try_from(&scalar_i32).is_ok());
380
381 let scalar_i64 = Scalar::decimal(
382 DecimalValue::I64(99999),
383 decimal_dtype,
384 Nullability::NonNullable,
385 );
386 assert!(Arc::<dyn Datum>::try_from(&scalar_i64).is_ok());
387
388 let scalar_i128 = Scalar::decimal(
389 DecimalValue::I128(99999),
390 decimal_dtype,
391 Nullability::NonNullable,
392 );
393 assert!(Arc::<dyn Datum>::try_from(&scalar_i128).is_ok());
394
395 let value_i256 = i256::from_i128(99999);
398 let scalar_i256 = Scalar::decimal(
399 DecimalValue::I256(value_i256),
400 decimal_dtype,
401 Nullability::NonNullable,
402 );
403 assert!(Arc::<dyn Datum>::try_from(&scalar_i256).is_ok());
404 }
405
406 #[test]
407 fn test_null_decimal_to_arrow() {
408 let decimal_dtype = DecimalDType::new(10, 2);
409 let scalar = Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable));
410 let result = Arc::<dyn Datum>::try_from(&scalar);
411 assert!(result.is_ok());
412 }
413
414 #[test]
415 #[should_panic(expected = "struct scalar conversion")]
416 fn test_struct_scalar_to_arrow_todo() {
417 let struct_dtype = DType::Struct(
418 StructFields::from_iter([(
419 "field1",
420 FieldDType::from(DType::Primitive(PType::I32, Nullability::NonNullable)),
421 )]),
422 Nullability::NonNullable,
423 );
424
425 let struct_scalar = Scalar::struct_(
426 struct_dtype,
427 vec![Scalar::primitive(42i32, Nullability::NonNullable)],
428 );
429 Arc::<dyn Datum>::try_from(&struct_scalar).unwrap();
430 }
431
432 #[test]
433 #[should_panic(expected = "list scalar conversion")]
434 fn test_list_scalar_to_arrow_todo() {
435 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
436 let list_scalar = Scalar::list(
437 element_dtype,
438 vec![
439 Scalar::primitive(1i32, Nullability::NonNullable),
440 Scalar::primitive(2i32, Nullability::NonNullable),
441 ],
442 Nullability::NonNullable,
443 );
444
445 Arc::<dyn Datum>::try_from(&list_scalar).unwrap();
446 }
447
448 #[test]
449 #[should_panic(expected = "Cannot convert extension scalar")]
450 fn test_non_temporal_extension_to_arrow_todo() {
451 #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
452 struct SomeExt;
453 impl ExtVTable for SomeExt {
454 type Metadata = String;
455 type NativeValue<'a> = &'a str;
456
457 fn id(&self) -> ExtId {
458 ExtId::new_ref("some_ext")
459 }
460
461 fn serialize_metadata(&self, _options: &Self::Metadata) -> VortexResult<Vec<u8>> {
462 vortex_bail!("not implemented")
463 }
464
465 fn deserialize_metadata(&self, _data: &[u8]) -> VortexResult<Self::Metadata> {
466 vortex_bail!("not implemented")
467 }
468
469 fn validate_dtype(_ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
470 Ok(())
471 }
472
473 fn unpack_native<'a>(
474 _ext_dtype: &'a ExtDType<Self>,
475 _storage_value: &'a ScalarValue,
476 ) -> VortexResult<Self::NativeValue<'a>> {
477 Ok("")
478 }
479 }
480
481 let scalar = Scalar::extension::<SomeExt>(
482 "".into(),
483 Scalar::primitive(42i32, Nullability::NonNullable),
484 );
485
486 Arc::<dyn Datum>::try_from(&scalar).unwrap();
487 }
488
489 #[rstest]
490 #[case(TimeUnit::Nanoseconds, PType::I64, 123456789i64)]
491 #[case(TimeUnit::Microseconds, PType::I64, 123456789i64)]
492 #[case(TimeUnit::Milliseconds, PType::I32, 123456i64)]
493 #[case(TimeUnit::Seconds, PType::I32, 1234i64)]
494 fn test_temporal_time_to_arrow(
495 #[case] time_unit: TimeUnit,
496 #[case] ptype: PType,
497 #[case] value: i64,
498 ) {
499 let scalar = Scalar::extension::<Time>(
500 time_unit,
501 match ptype {
502 PType::I32 => {
503 let i32_value = i32::try_from(value).expect("test value should fit in i32");
504 Scalar::primitive(i32_value, Nullability::NonNullable)
505 }
506 PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
507 _ => unreachable!(),
508 },
509 );
510
511 let result = Arc::<dyn Datum>::try_from(&scalar);
512 assert!(result.is_ok());
513 }
514
515 #[rstest]
516 #[case(TimeUnit::Milliseconds, PType::I64, 1234567890000i64)]
517 #[case(TimeUnit::Days, PType::I32, 19000i64)]
518 fn test_temporal_date_to_arrow(
519 #[case] time_unit: TimeUnit,
520 #[case] ptype: PType,
521 #[case] value: i64,
522 ) {
523 let scalar = Scalar::extension::<Date>(
524 time_unit,
525 match ptype {
526 PType::I32 => {
527 let i32_value = i32::try_from(value).expect("test value should fit in i32");
528 Scalar::primitive(i32_value, Nullability::NonNullable)
529 }
530 PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
531 _ => unreachable!(),
532 },
533 );
534
535 let result = Arc::<dyn Datum>::try_from(&scalar);
536 assert!(result.is_ok());
537 }
538
539 #[rstest]
540 #[case(TimeUnit::Nanoseconds, 1234567890000000000i64)]
541 #[case(TimeUnit::Microseconds, 1234567890000000i64)]
542 #[case(TimeUnit::Milliseconds, 1234567890000i64)]
543 #[case(TimeUnit::Seconds, 1234567890i64)]
544 fn test_temporal_timestamp_to_arrow(#[case] time_unit: TimeUnit, #[case] value: i64) {
545 let scalar = Scalar::extension::<Timestamp>(
546 TimestampOptions {
547 unit: time_unit,
548 tz: None,
549 },
550 Scalar::primitive(value, Nullability::NonNullable),
551 );
552
553 let result = Arc::<dyn Datum>::try_from(&scalar);
554 assert!(result.is_ok());
555 }
556
557 #[rstest]
558 #[case(TimeUnit::Nanoseconds, "UTC", 1234567890000000000i64)]
559 #[case(TimeUnit::Microseconds, "EST", 1234567890000000i64)]
560 #[case(TimeUnit::Microseconds, "Asia/Qatar", 1234567890000000i64)]
561 #[case(TimeUnit::Microseconds, "Australia/Sydney", 1234567890000000i64)]
562 #[case(TimeUnit::Milliseconds, "HST", 1234567890000i64)]
563 #[case(TimeUnit::Seconds, "GMT", 1234567890i64)]
564 fn test_temporal_timestamp_tz_to_arrow(
565 #[case] time_unit: TimeUnit,
566 #[case] tz: &str,
567 #[case] value: i64,
568 ) {
569 let scalar = Scalar::extension::<Timestamp>(
570 TimestampOptions {
571 unit: time_unit,
572 tz: Some(tz.into()),
573 },
574 Scalar::primitive(value, Nullability::NonNullable),
575 );
576
577 let result = Arc::<dyn Datum>::try_from(&scalar);
578 assert!(result.is_ok());
579 }
580
581 #[test]
582 fn test_temporal_with_null_value() {
583 let scalar = Scalar::extension::<Time>(
584 TimeUnit::Milliseconds,
585 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
586 );
587
588 let _result = Arc::<dyn Datum>::try_from(&scalar).unwrap();
589 }
590
591 #[test]
592 #[should_panic(expected = "DType utf8 is not a primitive type")]
593 fn test_temporal_non_primitive_storage_error() {
594 let _scalar = Scalar::extension::<Time>(
595 TimeUnit::Nanoseconds,
596 Scalar::utf8("not a timestamp", Nullability::NonNullable),
597 );
598 }
599}