1use std::sync::Arc;
7
8use arrow_array::Scalar as ArrowScalar;
9use arrow_array::*;
10use vortex_error::VortexError;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13
14use crate::dtype::DType;
15use crate::dtype::PType;
16use crate::extension::datetime::AnyTemporal;
17use crate::extension::datetime::TemporalMetadata;
18use crate::extension::datetime::TimeUnit;
19use crate::scalar::BinaryScalar;
20use crate::scalar::BoolScalar;
21use crate::scalar::DecimalScalar;
22use crate::scalar::DecimalValue;
23use crate::scalar::ExtScalar;
24use crate::scalar::PrimitiveScalar;
25use crate::scalar::Scalar;
26use crate::scalar::Utf8Scalar;
27
28const SCALAR_ARRAY_LEN: usize = 1;
30
31macro_rules! value_to_arrow_scalar {
33 ($V:expr, $AR:ty) => {
34 Ok(std::sync::Arc::new(
35 $V.map(<$AR>::new_scalar)
36 .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(SCALAR_ARRAY_LEN))),
37 ))
38 };
39}
40
41macro_rules! timestamp_to_arrow_scalar {
43 ($V:expr, $TZ:expr, $AR:ty) => {{
44 let array = match $V {
45 Some(v) => <$AR>::new_scalar(v).into_inner(),
46 None => <$AR>::new_null(SCALAR_ARRAY_LEN),
47 }
48 .with_timezone_opt($TZ);
49 Ok(Arc::new(ArrowScalar::new(array)))
50 }};
51}
52
53impl TryFrom<&Scalar> for Arc<dyn Datum> {
54 type Error = VortexError;
55
56 fn try_from(value: &Scalar) -> Result<Arc<dyn Datum>, Self::Error> {
57 match value.dtype() {
58 DType::Null => Ok(Arc::new(NullArray::new(SCALAR_ARRAY_LEN))),
59 DType::Bool(_) => bool_to_arrow(value.as_bool()),
60 DType::Primitive(..) => primitive_to_arrow(value.as_primitive()),
61 DType::Decimal(..) => decimal_to_arrow(value.as_decimal()),
62 DType::Utf8(_) => utf8_to_arrow(value.as_utf8()),
63 DType::Binary(_) => binary_to_arrow(value.as_binary()),
64 DType::List(..) => unimplemented!("list scalar conversion"),
65 DType::FixedSizeList(..) => unimplemented!("fixed-size list scalar conversion"),
66 DType::Struct(..) => unimplemented!("struct scalar conversion"),
67 DType::Union(..) => unimplemented!("union scalar conversion"),
68 DType::Variant(_) => unimplemented!("Variant scalar conversion"),
69 DType::Extension(..) => extension_to_arrow(value.as_extension()),
70 }
71 }
72}
73
74fn bool_to_arrow(scalar: BoolScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
76 value_to_arrow_scalar!(scalar.value(), BooleanArray)
77}
78
79fn primitive_to_arrow(scalar: PrimitiveScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
81 match scalar.ptype() {
82 PType::U8 => value_to_arrow_scalar!(scalar.typed_value(), UInt8Array),
83 PType::U16 => value_to_arrow_scalar!(scalar.typed_value(), UInt16Array),
84 PType::U32 => value_to_arrow_scalar!(scalar.typed_value(), UInt32Array),
85 PType::U64 => value_to_arrow_scalar!(scalar.typed_value(), UInt64Array),
86 PType::I8 => value_to_arrow_scalar!(scalar.typed_value(), Int8Array),
87 PType::I16 => value_to_arrow_scalar!(scalar.typed_value(), Int16Array),
88 PType::I32 => value_to_arrow_scalar!(scalar.typed_value(), Int32Array),
89 PType::I64 => value_to_arrow_scalar!(scalar.typed_value(), Int64Array),
90 PType::F16 => value_to_arrow_scalar!(scalar.typed_value(), Float16Array),
91 PType::F32 => value_to_arrow_scalar!(scalar.typed_value(), Float32Array),
92 PType::F64 => value_to_arrow_scalar!(scalar.typed_value(), Float64Array),
93 }
94}
95
96fn decimal_to_arrow(scalar: DecimalScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
98 match scalar.decimal_value() {
100 Some(DecimalValue::I8(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
101 Some(DecimalValue::I16(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
102 Some(DecimalValue::I32(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
103 Some(DecimalValue::I64(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
104 Some(DecimalValue::I128(v128)) => Ok(Arc::new(Decimal128Array::new_scalar(v128))),
105 Some(DecimalValue::I256(v256)) => Ok(Arc::new(Decimal256Array::new_scalar(v256.into()))),
106 None => Ok(Arc::new(arrow_array::Scalar::new(
107 Decimal128Array::new_null(SCALAR_ARRAY_LEN),
108 ))),
109 }
110}
111
112fn utf8_to_arrow(scalar: Utf8Scalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
114 value_to_arrow_scalar!(scalar.value(), StringViewArray)
115}
116
117fn binary_to_arrow(scalar: BinaryScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
119 value_to_arrow_scalar!(scalar.value(), BinaryViewArray)
120}
121
122fn extension_to_arrow(scalar: ExtScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
126 let ext_dtype = scalar.ext_dtype();
127 let Some(temporal) = ext_dtype.metadata_opt::<AnyTemporal>() else {
128 vortex_bail!(
129 "Cannot convert extension scalar {} to Arrow",
130 ext_dtype.id()
131 )
132 };
133
134 let storage_scalar = scalar.to_storage_scalar();
135 let primitive = storage_scalar
136 .as_primitive_opt()
137 .ok_or_else(|| vortex_err!("Expected primitive scalar"))?;
138
139 match temporal {
140 TemporalMetadata::Timestamp(unit, tz) => {
141 let value = primitive.as_::<i64>();
142 match unit {
143 TimeUnit::Nanoseconds => {
144 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampNanosecondArray)
145 }
146 TimeUnit::Microseconds => {
147 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMicrosecondArray)
148 }
149 TimeUnit::Milliseconds => {
150 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMillisecondArray)
151 }
152 TimeUnit::Seconds => {
153 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampSecondArray)
154 }
155 TimeUnit::Days => {
156 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
157 }
158 }
159 }
160 TemporalMetadata::Date(unit) => match unit {
161 TimeUnit::Milliseconds => {
162 value_to_arrow_scalar!(primitive.as_::<i64>(), Date64Array)
163 }
164 TimeUnit::Days => {
165 value_to_arrow_scalar!(primitive.as_::<i32>(), Date32Array)
166 }
167 TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => {
168 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
169 }
170 },
171 TemporalMetadata::Time(unit) => match unit {
172 TimeUnit::Nanoseconds => {
173 value_to_arrow_scalar!(primitive.as_::<i64>(), Time64NanosecondArray)
174 }
175 TimeUnit::Microseconds => {
176 value_to_arrow_scalar!(primitive.as_::<i64>(), Time64MicrosecondArray)
177 }
178 TimeUnit::Milliseconds => {
179 value_to_arrow_scalar!(primitive.as_::<i32>(), Time32MillisecondArray)
180 }
181 TimeUnit::Seconds => {
182 value_to_arrow_scalar!(primitive.as_::<i32>(), Time32SecondArray)
183 }
184 TimeUnit::Days => {
185 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
186 }
187 },
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::sync::Arc;
194
195 use arrow_array::Datum;
196 use rstest::rstest;
197 use vortex_error::VortexResult;
198 use vortex_error::vortex_bail;
199
200 use crate::dtype::DType;
201 use crate::dtype::DecimalDType;
202 use crate::dtype::FieldDType;
203 use crate::dtype::NativeDType;
204 use crate::dtype::Nullability;
205 use crate::dtype::PType;
206 use crate::dtype::StructFields;
207 use crate::dtype::extension::ExtDType;
208 use crate::dtype::extension::ExtId;
209 use crate::dtype::extension::ExtVTable;
210 use crate::dtype::i256;
211 use crate::extension::datetime::Date;
212 use crate::extension::datetime::Time;
213 use crate::extension::datetime::TimeUnit;
214 use crate::extension::datetime::Timestamp;
215 use crate::extension::datetime::TimestampOptions;
216 use crate::scalar::DecimalValue;
217 use crate::scalar::Scalar;
218 use crate::scalar::ScalarValue;
219
220 #[test]
221 fn test_null_scalar_to_arrow() {
222 let scalar = Scalar::null(DType::Null);
223 let result = Arc::<dyn Datum>::try_from(&scalar);
224 assert!(result.is_ok());
225 }
226
227 #[test]
228 fn test_bool_scalar_to_arrow() {
229 let scalar = Scalar::bool(true, Nullability::NonNullable);
230 let result = Arc::<dyn Datum>::try_from(&scalar);
231 assert!(result.is_ok());
232 }
233
234 #[test]
235 fn test_null_bool_scalar_to_arrow() {
236 let scalar = Scalar::null(bool::dtype().as_nullable());
237 let result = Arc::<dyn Datum>::try_from(&scalar);
238 assert!(result.is_ok());
239 }
240
241 #[test]
242 fn test_primitive_u8_to_arrow() {
243 let scalar = Scalar::primitive(42u8, Nullability::NonNullable);
244 let result = Arc::<dyn Datum>::try_from(&scalar);
245 assert!(result.is_ok());
246 }
247
248 #[test]
249 fn test_primitive_u16_to_arrow() {
250 let scalar = Scalar::primitive(1000u16, Nullability::NonNullable);
251 let result = Arc::<dyn Datum>::try_from(&scalar);
252 assert!(result.is_ok());
253 }
254
255 #[test]
256 fn test_primitive_u32_to_arrow() {
257 let scalar = Scalar::primitive(100000u32, Nullability::NonNullable);
258 let result = Arc::<dyn Datum>::try_from(&scalar);
259 assert!(result.is_ok());
260 }
261
262 #[test]
263 fn test_primitive_u64_to_arrow() {
264 let scalar = Scalar::primitive(10000000000u64, Nullability::NonNullable);
265 let result = Arc::<dyn Datum>::try_from(&scalar);
266 assert!(result.is_ok());
267 }
268
269 #[test]
270 fn test_primitive_i8_to_arrow() {
271 let scalar = Scalar::primitive(-42i8, Nullability::NonNullable);
272 let result = Arc::<dyn Datum>::try_from(&scalar);
273 assert!(result.is_ok());
274 }
275
276 #[test]
277 fn test_primitive_i16_to_arrow() {
278 let scalar = Scalar::primitive(-1000i16, Nullability::NonNullable);
279 let result = Arc::<dyn Datum>::try_from(&scalar);
280 assert!(result.is_ok());
281 }
282
283 #[test]
284 fn test_primitive_i32_to_arrow() {
285 let scalar = Scalar::primitive(-100000i32, Nullability::NonNullable);
286 let result = Arc::<dyn Datum>::try_from(&scalar);
287 assert!(result.is_ok());
288 }
289
290 #[test]
291 fn test_primitive_i64_to_arrow() {
292 let scalar = Scalar::primitive(-10000000000i64, Nullability::NonNullable);
293 let result = Arc::<dyn Datum>::try_from(&scalar);
294 assert!(result.is_ok());
295 }
296
297 #[test]
298 fn test_primitive_f16_to_arrow() {
299 use crate::dtype::half::f16;
300
301 let scalar = Scalar::primitive(f16::from_f32(1.234), Nullability::NonNullable);
302 let result = Arc::<dyn Datum>::try_from(&scalar);
303 assert!(result.is_ok());
304 }
305
306 #[test]
307 fn test_primitive_f32_to_arrow() {
308 let scalar = Scalar::primitive(1.234f32, Nullability::NonNullable);
309 let result = Arc::<dyn Datum>::try_from(&scalar);
310 assert!(result.is_ok());
311 }
312
313 #[test]
314 fn test_primitive_f64_to_arrow() {
315 let scalar = Scalar::primitive(1.234567890123f64, Nullability::NonNullable);
316 let result = Arc::<dyn Datum>::try_from(&scalar);
317 assert!(result.is_ok());
318 }
319
320 #[test]
321 fn test_null_primitive_to_arrow() {
322 let scalar = Scalar::null(i32::dtype().as_nullable());
323 let result = Arc::<dyn Datum>::try_from(&scalar);
324 assert!(result.is_ok());
325 }
326
327 #[test]
328 fn test_utf8_scalar_to_arrow() {
329 let scalar = Scalar::utf8("hello world".to_string(), Nullability::NonNullable);
330 let result = Arc::<dyn Datum>::try_from(&scalar);
331 assert!(result.is_ok());
332 }
333
334 #[test]
335 fn test_null_utf8_scalar_to_arrow() {
336 let scalar = Scalar::null(String::dtype().as_nullable());
337 let result = Arc::<dyn Datum>::try_from(&scalar);
338 assert!(result.is_ok());
339 }
340
341 #[test]
342 fn test_binary_scalar_to_arrow() {
343 let data = vec![1u8, 2, 3, 4, 5];
344 let scalar = Scalar::binary(data, Nullability::NonNullable);
345 let result = Arc::<dyn Datum>::try_from(&scalar);
346 assert!(result.is_ok());
347 }
348
349 #[test]
350 fn test_null_binary_scalar_to_arrow() {
351 let scalar = Scalar::null(DType::Binary(Nullability::Nullable));
352 let result = Arc::<dyn Datum>::try_from(&scalar);
353 assert!(result.is_ok());
354 }
355
356 #[test]
357 fn test_decimal_scalars_to_arrow() {
358 let decimal_dtype = DecimalDType::new(5, 2);
360
361 let scalar_i8 = Scalar::decimal(
362 DecimalValue::I8(100),
363 decimal_dtype,
364 Nullability::NonNullable,
365 );
366 assert!(Arc::<dyn Datum>::try_from(&scalar_i8).is_ok());
367
368 let scalar_i16 = Scalar::decimal(
369 DecimalValue::I16(10000),
370 decimal_dtype,
371 Nullability::NonNullable,
372 );
373 assert!(Arc::<dyn Datum>::try_from(&scalar_i16).is_ok());
374
375 let scalar_i32 = Scalar::decimal(
376 DecimalValue::I32(99999),
377 decimal_dtype,
378 Nullability::NonNullable,
379 );
380 assert!(Arc::<dyn Datum>::try_from(&scalar_i32).is_ok());
381
382 let scalar_i64 = Scalar::decimal(
383 DecimalValue::I64(99999),
384 decimal_dtype,
385 Nullability::NonNullable,
386 );
387 assert!(Arc::<dyn Datum>::try_from(&scalar_i64).is_ok());
388
389 let scalar_i128 = Scalar::decimal(
390 DecimalValue::I128(99999),
391 decimal_dtype,
392 Nullability::NonNullable,
393 );
394 assert!(Arc::<dyn Datum>::try_from(&scalar_i128).is_ok());
395
396 let value_i256 = i256::from_i128(99999);
399 let scalar_i256 = Scalar::decimal(
400 DecimalValue::I256(value_i256),
401 decimal_dtype,
402 Nullability::NonNullable,
403 );
404 assert!(Arc::<dyn Datum>::try_from(&scalar_i256).is_ok());
405 }
406
407 #[test]
408 fn test_null_decimal_to_arrow() {
409 let decimal_dtype = DecimalDType::new(10, 2);
410 let scalar = Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable));
411 let result = Arc::<dyn Datum>::try_from(&scalar);
412 assert!(result.is_ok());
413 }
414
415 #[test]
416 #[should_panic(expected = "struct scalar conversion")]
417 fn test_struct_scalar_to_arrow_todo() {
418 let struct_dtype = DType::Struct(
419 StructFields::from_iter([(
420 "field1",
421 FieldDType::from(DType::Primitive(PType::I32, Nullability::NonNullable)),
422 )]),
423 Nullability::NonNullable,
424 );
425
426 let struct_scalar = Scalar::struct_(
427 struct_dtype,
428 vec![Scalar::primitive(42i32, Nullability::NonNullable)],
429 );
430 Arc::<dyn Datum>::try_from(&struct_scalar).unwrap();
431 }
432
433 #[test]
434 #[should_panic(expected = "list scalar conversion")]
435 fn test_list_scalar_to_arrow_todo() {
436 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
437 let list_scalar = Scalar::list(
438 element_dtype,
439 vec![
440 Scalar::primitive(1i32, Nullability::NonNullable),
441 Scalar::primitive(2i32, Nullability::NonNullable),
442 ],
443 Nullability::NonNullable,
444 );
445
446 Arc::<dyn Datum>::try_from(&list_scalar).unwrap();
447 }
448
449 #[test]
450 #[should_panic(expected = "Cannot convert extension scalar")]
451 fn test_non_temporal_extension_to_arrow_todo() {
452 #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
453 struct SomeExt;
454 impl ExtVTable for SomeExt {
455 type Metadata = String;
456 type NativeValue<'a> = &'a str;
457
458 fn id(&self) -> ExtId {
459 ExtId::new("some_ext")
460 }
461
462 fn serialize_metadata(&self, _options: &Self::Metadata) -> VortexResult<Vec<u8>> {
463 vortex_bail!("not implemented")
464 }
465
466 fn deserialize_metadata(&self, _data: &[u8]) -> VortexResult<Self::Metadata> {
467 vortex_bail!("not implemented")
468 }
469
470 fn validate_dtype(_ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
471 Ok(())
472 }
473
474 fn unpack_native<'a>(
475 _ext_dtype: &'a ExtDType<Self>,
476 _storage_value: &'a ScalarValue,
477 ) -> VortexResult<Self::NativeValue<'a>> {
478 Ok("")
479 }
480 }
481
482 let scalar = Scalar::extension::<SomeExt>(
483 "".into(),
484 Scalar::primitive(42i32, Nullability::NonNullable),
485 );
486
487 Arc::<dyn Datum>::try_from(&scalar).unwrap();
488 }
489
490 #[rstest]
491 #[case(TimeUnit::Nanoseconds, PType::I64, 123456789i64)]
492 #[case(TimeUnit::Microseconds, PType::I64, 123456789i64)]
493 #[case(TimeUnit::Milliseconds, PType::I32, 123456i64)]
494 #[case(TimeUnit::Seconds, PType::I32, 1234i64)]
495 fn test_temporal_time_to_arrow(
496 #[case] time_unit: TimeUnit,
497 #[case] ptype: PType,
498 #[case] value: i64,
499 ) {
500 let scalar = Scalar::extension::<Time>(
501 time_unit,
502 match ptype {
503 PType::I32 => {
504 let i32_value = i32::try_from(value).expect("test value should fit in i32");
505 Scalar::primitive(i32_value, Nullability::NonNullable)
506 }
507 PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
508 _ => unreachable!(),
509 },
510 );
511
512 let result = Arc::<dyn Datum>::try_from(&scalar);
513 assert!(result.is_ok());
514 }
515
516 #[rstest]
517 #[case(TimeUnit::Milliseconds, PType::I64, 1234567890000i64)]
518 #[case(TimeUnit::Days, PType::I32, 19000i64)]
519 fn test_temporal_date_to_arrow(
520 #[case] time_unit: TimeUnit,
521 #[case] ptype: PType,
522 #[case] value: i64,
523 ) {
524 let scalar = Scalar::extension::<Date>(
525 time_unit,
526 match ptype {
527 PType::I32 => {
528 let i32_value = i32::try_from(value).expect("test value should fit in i32");
529 Scalar::primitive(i32_value, Nullability::NonNullable)
530 }
531 PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
532 _ => unreachable!(),
533 },
534 );
535
536 let result = Arc::<dyn Datum>::try_from(&scalar);
537 assert!(result.is_ok());
538 }
539
540 #[rstest]
541 #[case(TimeUnit::Nanoseconds, 1234567890000000000i64)]
542 #[case(TimeUnit::Microseconds, 1234567890000000i64)]
543 #[case(TimeUnit::Milliseconds, 1234567890000i64)]
544 #[case(TimeUnit::Seconds, 1234567890i64)]
545 fn test_temporal_timestamp_to_arrow(#[case] time_unit: TimeUnit, #[case] value: i64) {
546 let scalar = Scalar::extension::<Timestamp>(
547 TimestampOptions {
548 unit: time_unit,
549 tz: None,
550 },
551 Scalar::primitive(value, Nullability::NonNullable),
552 );
553
554 let result = Arc::<dyn Datum>::try_from(&scalar);
555 assert!(result.is_ok());
556 }
557
558 #[rstest]
559 #[case(TimeUnit::Nanoseconds, "UTC", 1234567890000000000i64)]
560 #[case(TimeUnit::Microseconds, "EST", 1234567890000000i64)]
561 #[case(TimeUnit::Microseconds, "Asia/Qatar", 1234567890000000i64)]
562 #[case(TimeUnit::Microseconds, "Australia/Sydney", 1234567890000000i64)]
563 #[case(TimeUnit::Milliseconds, "HST", 1234567890000i64)]
564 #[case(TimeUnit::Seconds, "GMT", 1234567890i64)]
565 fn test_temporal_timestamp_tz_to_arrow(
566 #[case] time_unit: TimeUnit,
567 #[case] tz: &str,
568 #[case] value: i64,
569 ) {
570 let scalar = Scalar::extension::<Timestamp>(
571 TimestampOptions {
572 unit: time_unit,
573 tz: Some(tz.into()),
574 },
575 Scalar::primitive(value, Nullability::NonNullable),
576 );
577
578 let result = Arc::<dyn Datum>::try_from(&scalar);
579 assert!(result.is_ok());
580 }
581
582 #[test]
583 fn test_temporal_with_null_value() {
584 let scalar = Scalar::extension::<Time>(
585 TimeUnit::Milliseconds,
586 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
587 );
588
589 let _result = Arc::<dyn Datum>::try_from(&scalar).unwrap();
590 }
591
592 #[test]
593 #[should_panic(expected = "DType utf8 is not a primitive type")]
594 fn test_temporal_non_primitive_storage_error() {
595 let _scalar = Scalar::extension::<Time>(
596 TimeUnit::Nanoseconds,
597 Scalar::utf8("not a timestamp", Nullability::NonNullable),
598 );
599 }
600}