1use std::sync::Arc;
7
8use arrow_array::Scalar as ArrowScalar;
9use arrow_array::*;
10use vortex_dtype::DType;
11use vortex_dtype::PType;
12use vortex_dtype::datetime::AnyTemporal;
13use vortex_dtype::datetime::TemporalMetadata;
14use vortex_dtype::datetime::TimeUnit;
15use vortex_error::VortexError;
16use vortex_error::vortex_bail;
17use vortex_error::vortex_err;
18
19use crate::scalar::BinaryScalar;
20use crate::scalar::BoolScalar;
21use crate::scalar::DecimalScalar;
22use crate::scalar::DecimalValue;
23use crate::scalar::ExtScalar;
24use crate::scalar::PrimitiveScalar;
25use crate::scalar::Scalar;
26use crate::scalar::Utf8Scalar;
27
28const SCALAR_ARRAY_LEN: usize = 1;
30
31macro_rules! value_to_arrow_scalar {
33 ($V:expr, $AR:ty) => {
34 Ok(std::sync::Arc::new(
35 $V.map(<$AR>::new_scalar)
36 .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(SCALAR_ARRAY_LEN))),
37 ))
38 };
39}
40
41macro_rules! timestamp_to_arrow_scalar {
43 ($V:expr, $TZ:expr, $AR:ty) => {{
44 let array = match $V {
45 Some(v) => <$AR>::new_scalar(v).into_inner(),
46 None => <$AR>::new_null(SCALAR_ARRAY_LEN),
47 }
48 .with_timezone_opt($TZ);
49 Ok(Arc::new(ArrowScalar::new(array)))
50 }};
51}
52
53impl TryFrom<&Scalar> for Arc<dyn Datum> {
54 type Error = VortexError;
55
56 fn try_from(value: &Scalar) -> Result<Arc<dyn Datum>, Self::Error> {
57 match value.dtype() {
58 DType::Null => Ok(Arc::new(NullArray::new(SCALAR_ARRAY_LEN))),
59 DType::Bool(_) => bool_to_arrow(value.as_bool()),
60 DType::Primitive(..) => primitive_to_arrow(value.as_primitive()),
61 DType::Decimal(..) => decimal_to_arrow(value.as_decimal()),
62 DType::Utf8(_) => utf8_to_arrow(value.as_utf8()),
63 DType::Binary(_) => binary_to_arrow(value.as_binary()),
64 DType::Struct(..) => unimplemented!("struct scalar conversion"),
65 DType::List(..) => unimplemented!("list scalar conversion"),
66 DType::FixedSizeList(..) => unimplemented!("fixed-size list scalar conversion"),
67 DType::Extension(..) => extension_to_arrow(value.as_extension()),
68 }
69 }
70}
71
72fn bool_to_arrow(scalar: BoolScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
74 value_to_arrow_scalar!(scalar.value(), BooleanArray)
75}
76
77fn primitive_to_arrow(scalar: PrimitiveScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
79 match scalar.ptype() {
80 PType::U8 => value_to_arrow_scalar!(scalar.typed_value(), UInt8Array),
81 PType::U16 => value_to_arrow_scalar!(scalar.typed_value(), UInt16Array),
82 PType::U32 => value_to_arrow_scalar!(scalar.typed_value(), UInt32Array),
83 PType::U64 => value_to_arrow_scalar!(scalar.typed_value(), UInt64Array),
84 PType::I8 => value_to_arrow_scalar!(scalar.typed_value(), Int8Array),
85 PType::I16 => value_to_arrow_scalar!(scalar.typed_value(), Int16Array),
86 PType::I32 => value_to_arrow_scalar!(scalar.typed_value(), Int32Array),
87 PType::I64 => value_to_arrow_scalar!(scalar.typed_value(), Int64Array),
88 PType::F16 => value_to_arrow_scalar!(scalar.typed_value(), Float16Array),
89 PType::F32 => value_to_arrow_scalar!(scalar.typed_value(), Float32Array),
90 PType::F64 => value_to_arrow_scalar!(scalar.typed_value(), Float64Array),
91 }
92}
93
94fn decimal_to_arrow(scalar: DecimalScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
96 match scalar.decimal_value() {
98 Some(DecimalValue::I8(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
99 Some(DecimalValue::I16(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
100 Some(DecimalValue::I32(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
101 Some(DecimalValue::I64(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
102 Some(DecimalValue::I128(v128)) => Ok(Arc::new(Decimal128Array::new_scalar(v128))),
103 Some(DecimalValue::I256(v256)) => Ok(Arc::new(Decimal256Array::new_scalar(v256.into()))),
104 None => Ok(Arc::new(arrow_array::Scalar::new(
105 Decimal128Array::new_null(SCALAR_ARRAY_LEN),
106 ))),
107 }
108}
109
110fn utf8_to_arrow(scalar: Utf8Scalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
112 value_to_arrow_scalar!(scalar.value(), StringViewArray)
113}
114
115fn binary_to_arrow(scalar: BinaryScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
117 value_to_arrow_scalar!(scalar.value(), BinaryViewArray)
118}
119
120fn extension_to_arrow(scalar: ExtScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
124 let ext_dtype = scalar.ext_dtype();
125 let Some(temporal) = ext_dtype.metadata_opt::<AnyTemporal>() else {
126 vortex_bail!(
127 "Cannot convert extension scalar {} to Arrow",
128 ext_dtype.id()
129 )
130 };
131
132 let storage_scalar = scalar.to_storage_scalar();
133 let primitive = storage_scalar
134 .as_primitive_opt()
135 .ok_or_else(|| vortex_err!("Expected primitive scalar"))?;
136
137 match temporal {
138 TemporalMetadata::Timestamp(unit, tz) => {
139 let value = primitive.as_::<i64>();
140 match unit {
141 TimeUnit::Nanoseconds => {
142 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampNanosecondArray)
143 }
144 TimeUnit::Microseconds => {
145 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMicrosecondArray)
146 }
147 TimeUnit::Milliseconds => {
148 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMillisecondArray)
149 }
150 TimeUnit::Seconds => {
151 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampSecondArray)
152 }
153 TimeUnit::Days => {
154 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
155 }
156 }
157 }
158 TemporalMetadata::Date(unit) => match unit {
159 TimeUnit::Milliseconds => {
160 value_to_arrow_scalar!(primitive.as_::<i64>(), Date64Array)
161 }
162 TimeUnit::Days => {
163 value_to_arrow_scalar!(primitive.as_::<i32>(), Date32Array)
164 }
165 TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => {
166 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
167 }
168 },
169 TemporalMetadata::Time(unit) => match unit {
170 TimeUnit::Nanoseconds => {
171 value_to_arrow_scalar!(primitive.as_::<i64>(), Time64NanosecondArray)
172 }
173 TimeUnit::Microseconds => {
174 value_to_arrow_scalar!(primitive.as_::<i64>(), Time64MicrosecondArray)
175 }
176 TimeUnit::Milliseconds => {
177 value_to_arrow_scalar!(primitive.as_::<i32>(), Time32MillisecondArray)
178 }
179 TimeUnit::Seconds => {
180 value_to_arrow_scalar!(primitive.as_::<i32>(), Time32SecondArray)
181 }
182 TimeUnit::Days => {
183 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
184 }
185 },
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use std::sync::Arc;
192
193 use arrow_array::Datum;
194 use rstest::rstest;
195 use vortex_dtype::DType;
196 use vortex_dtype::DecimalDType;
197 use vortex_dtype::NativeDType;
198 use vortex_dtype::Nullability;
199 use vortex_dtype::PType;
200 use vortex_dtype::datetime::Date;
201 use vortex_dtype::datetime::Time;
202 use vortex_dtype::datetime::TimeUnit;
203 use vortex_dtype::datetime::Timestamp;
204 use vortex_dtype::datetime::TimestampOptions;
205 use vortex_dtype::extension::ExtDTypeVTable;
206 use vortex_error::VortexResult;
207 use vortex_error::vortex_bail;
208
209 use crate::scalar::DecimalValue;
210 use crate::scalar::Scalar;
211
212 #[test]
213 fn test_null_scalar_to_arrow() {
214 let scalar = Scalar::null(DType::Null);
215 let result = Arc::<dyn Datum>::try_from(&scalar);
216 assert!(result.is_ok());
217 }
218
219 #[test]
220 fn test_bool_scalar_to_arrow() {
221 let scalar = Scalar::bool(true, Nullability::NonNullable);
222 let result = Arc::<dyn Datum>::try_from(&scalar);
223 assert!(result.is_ok());
224 }
225
226 #[test]
227 fn test_null_bool_scalar_to_arrow() {
228 let scalar = Scalar::null(bool::dtype().as_nullable());
229 let result = Arc::<dyn Datum>::try_from(&scalar);
230 assert!(result.is_ok());
231 }
232
233 #[test]
234 fn test_primitive_u8_to_arrow() {
235 let scalar = Scalar::primitive(42u8, Nullability::NonNullable);
236 let result = Arc::<dyn Datum>::try_from(&scalar);
237 assert!(result.is_ok());
238 }
239
240 #[test]
241 fn test_primitive_u16_to_arrow() {
242 let scalar = Scalar::primitive(1000u16, Nullability::NonNullable);
243 let result = Arc::<dyn Datum>::try_from(&scalar);
244 assert!(result.is_ok());
245 }
246
247 #[test]
248 fn test_primitive_u32_to_arrow() {
249 let scalar = Scalar::primitive(100000u32, Nullability::NonNullable);
250 let result = Arc::<dyn Datum>::try_from(&scalar);
251 assert!(result.is_ok());
252 }
253
254 #[test]
255 fn test_primitive_u64_to_arrow() {
256 let scalar = Scalar::primitive(10000000000u64, Nullability::NonNullable);
257 let result = Arc::<dyn Datum>::try_from(&scalar);
258 assert!(result.is_ok());
259 }
260
261 #[test]
262 fn test_primitive_i8_to_arrow() {
263 let scalar = Scalar::primitive(-42i8, Nullability::NonNullable);
264 let result = Arc::<dyn Datum>::try_from(&scalar);
265 assert!(result.is_ok());
266 }
267
268 #[test]
269 fn test_primitive_i16_to_arrow() {
270 let scalar = Scalar::primitive(-1000i16, Nullability::NonNullable);
271 let result = Arc::<dyn Datum>::try_from(&scalar);
272 assert!(result.is_ok());
273 }
274
275 #[test]
276 fn test_primitive_i32_to_arrow() {
277 let scalar = Scalar::primitive(-100000i32, Nullability::NonNullable);
278 let result = Arc::<dyn Datum>::try_from(&scalar);
279 assert!(result.is_ok());
280 }
281
282 #[test]
283 fn test_primitive_i64_to_arrow() {
284 let scalar = Scalar::primitive(-10000000000i64, Nullability::NonNullable);
285 let result = Arc::<dyn Datum>::try_from(&scalar);
286 assert!(result.is_ok());
287 }
288
289 #[test]
290 fn test_primitive_f16_to_arrow() {
291 use vortex_dtype::half::f16;
292
293 let scalar = Scalar::primitive(f16::from_f32(1.234), Nullability::NonNullable);
294 let result = Arc::<dyn Datum>::try_from(&scalar);
295 assert!(result.is_ok());
296 }
297
298 #[test]
299 fn test_primitive_f32_to_arrow() {
300 let scalar = Scalar::primitive(1.234f32, Nullability::NonNullable);
301 let result = Arc::<dyn Datum>::try_from(&scalar);
302 assert!(result.is_ok());
303 }
304
305 #[test]
306 fn test_primitive_f64_to_arrow() {
307 let scalar = Scalar::primitive(1.234567890123f64, Nullability::NonNullable);
308 let result = Arc::<dyn Datum>::try_from(&scalar);
309 assert!(result.is_ok());
310 }
311
312 #[test]
313 fn test_null_primitive_to_arrow() {
314 let scalar = Scalar::null(i32::dtype().as_nullable());
315 let result = Arc::<dyn Datum>::try_from(&scalar);
316 assert!(result.is_ok());
317 }
318
319 #[test]
320 fn test_utf8_scalar_to_arrow() {
321 let scalar = Scalar::utf8("hello world".to_string(), Nullability::NonNullable);
322 let result = Arc::<dyn Datum>::try_from(&scalar);
323 assert!(result.is_ok());
324 }
325
326 #[test]
327 fn test_null_utf8_scalar_to_arrow() {
328 let scalar = Scalar::null(String::dtype().as_nullable());
329 let result = Arc::<dyn Datum>::try_from(&scalar);
330 assert!(result.is_ok());
331 }
332
333 #[test]
334 fn test_binary_scalar_to_arrow() {
335 let data = vec![1u8, 2, 3, 4, 5];
336 let scalar = Scalar::binary(data, Nullability::NonNullable);
337 let result = Arc::<dyn Datum>::try_from(&scalar);
338 assert!(result.is_ok());
339 }
340
341 #[test]
342 fn test_null_binary_scalar_to_arrow() {
343 let scalar = Scalar::null(DType::Binary(Nullability::Nullable));
344 let result = Arc::<dyn Datum>::try_from(&scalar);
345 assert!(result.is_ok());
346 }
347
348 #[test]
349 fn test_decimal_scalars_to_arrow() {
350 let decimal_dtype = DecimalDType::new(5, 2);
352
353 let scalar_i8 = Scalar::decimal(
354 DecimalValue::I8(100),
355 decimal_dtype,
356 Nullability::NonNullable,
357 );
358 assert!(Arc::<dyn Datum>::try_from(&scalar_i8).is_ok());
359
360 let scalar_i16 = Scalar::decimal(
361 DecimalValue::I16(10000),
362 decimal_dtype,
363 Nullability::NonNullable,
364 );
365 assert!(Arc::<dyn Datum>::try_from(&scalar_i16).is_ok());
366
367 let scalar_i32 = Scalar::decimal(
368 DecimalValue::I32(99999),
369 decimal_dtype,
370 Nullability::NonNullable,
371 );
372 assert!(Arc::<dyn Datum>::try_from(&scalar_i32).is_ok());
373
374 let scalar_i64 = Scalar::decimal(
375 DecimalValue::I64(99999),
376 decimal_dtype,
377 Nullability::NonNullable,
378 );
379 assert!(Arc::<dyn Datum>::try_from(&scalar_i64).is_ok());
380
381 let scalar_i128 = Scalar::decimal(
382 DecimalValue::I128(99999),
383 decimal_dtype,
384 Nullability::NonNullable,
385 );
386 assert!(Arc::<dyn Datum>::try_from(&scalar_i128).is_ok());
387
388 use vortex_dtype::i256;
390 let value_i256 = i256::from_i128(99999);
391 let scalar_i256 = Scalar::decimal(
392 DecimalValue::I256(value_i256),
393 decimal_dtype,
394 Nullability::NonNullable,
395 );
396 assert!(Arc::<dyn Datum>::try_from(&scalar_i256).is_ok());
397 }
398
399 #[test]
400 fn test_null_decimal_to_arrow() {
401 use vortex_dtype::DecimalDType;
402
403 let decimal_dtype = DecimalDType::new(10, 2);
404 let scalar = Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable));
405 let result = Arc::<dyn Datum>::try_from(&scalar);
406 assert!(result.is_ok());
407 }
408
409 #[test]
410 #[should_panic(expected = "struct scalar conversion")]
411 fn test_struct_scalar_to_arrow_todo() {
412 use vortex_dtype::FieldDType;
413 use vortex_dtype::StructFields;
414
415 let struct_dtype = DType::Struct(
416 StructFields::from_iter([(
417 "field1",
418 FieldDType::from(DType::Primitive(PType::I32, Nullability::NonNullable)),
419 )]),
420 Nullability::NonNullable,
421 );
422
423 let struct_scalar = Scalar::struct_(
424 struct_dtype,
425 vec![Scalar::primitive(42i32, Nullability::NonNullable)],
426 );
427 Arc::<dyn Datum>::try_from(&struct_scalar).unwrap();
428 }
429
430 #[test]
431 #[should_panic(expected = "list scalar conversion")]
432 fn test_list_scalar_to_arrow_todo() {
433 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
434 let list_scalar = Scalar::list(
435 element_dtype,
436 vec![
437 Scalar::primitive(1i32, Nullability::NonNullable),
438 Scalar::primitive(2i32, Nullability::NonNullable),
439 ],
440 Nullability::NonNullable,
441 );
442
443 Arc::<dyn Datum>::try_from(&list_scalar).unwrap();
444 }
445
446 #[test]
447 #[should_panic(expected = "Cannot convert extension scalar")]
448 fn test_non_temporal_extension_to_arrow_todo() {
449 use vortex_dtype::ExtID;
450
451 #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
452 struct SomeExt;
453 impl ExtDTypeVTable for SomeExt {
454 type Metadata = String;
455
456 fn id(&self) -> ExtID {
457 ExtID::new_ref("some_ext")
458 }
459
460 fn serialize(&self, _options: &Self::Metadata) -> VortexResult<Vec<u8>> {
461 vortex_bail!("not implemented")
462 }
463
464 fn deserialize(&self, _data: &[u8]) -> VortexResult<Self::Metadata> {
465 vortex_bail!("not implemented")
466 }
467
468 fn validate_dtype(
469 &self,
470 _options: &Self::Metadata,
471 _storage_dtype: &DType,
472 ) -> VortexResult<()> {
473 Ok(())
474 }
475 }
476
477 let scalar = Scalar::extension::<SomeExt>(
478 "".into(),
479 Scalar::primitive(42i32, Nullability::NonNullable),
480 );
481
482 Arc::<dyn Datum>::try_from(&scalar).unwrap();
483 }
484
485 #[rstest]
486 #[case(TimeUnit::Nanoseconds, PType::I64, 123456789i64)]
487 #[case(TimeUnit::Microseconds, PType::I64, 123456789i64)]
488 #[case(TimeUnit::Milliseconds, PType::I32, 123456i64)]
489 #[case(TimeUnit::Seconds, PType::I32, 1234i64)]
490 fn test_temporal_time_to_arrow(
491 #[case] time_unit: TimeUnit,
492 #[case] ptype: PType,
493 #[case] value: i64,
494 ) {
495 let scalar = Scalar::extension::<Time>(
496 time_unit,
497 match ptype {
498 PType::I32 => {
499 let i32_value = i32::try_from(value).expect("test value should fit in i32");
500 Scalar::primitive(i32_value, Nullability::NonNullable)
501 }
502 PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
503 _ => unreachable!(),
504 },
505 );
506
507 let result = Arc::<dyn Datum>::try_from(&scalar);
508 assert!(result.is_ok());
509 }
510
511 #[rstest]
512 #[case(TimeUnit::Milliseconds, PType::I64, 1234567890000i64)]
513 #[case(TimeUnit::Days, PType::I32, 19000i64)]
514 fn test_temporal_date_to_arrow(
515 #[case] time_unit: TimeUnit,
516 #[case] ptype: PType,
517 #[case] value: i64,
518 ) {
519 let scalar = Scalar::extension::<Date>(
520 time_unit,
521 match ptype {
522 PType::I32 => {
523 let i32_value = i32::try_from(value).expect("test value should fit in i32");
524 Scalar::primitive(i32_value, Nullability::NonNullable)
525 }
526 PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
527 _ => unreachable!(),
528 },
529 );
530
531 let result = Arc::<dyn Datum>::try_from(&scalar);
532 assert!(result.is_ok());
533 }
534
535 #[rstest]
536 #[case(TimeUnit::Nanoseconds, 1234567890000000000i64)]
537 #[case(TimeUnit::Microseconds, 1234567890000000i64)]
538 #[case(TimeUnit::Milliseconds, 1234567890000i64)]
539 #[case(TimeUnit::Seconds, 1234567890i64)]
540 fn test_temporal_timestamp_to_arrow(#[case] time_unit: TimeUnit, #[case] value: i64) {
541 let scalar = Scalar::extension::<Timestamp>(
542 TimestampOptions {
543 unit: time_unit,
544 tz: None,
545 },
546 Scalar::primitive(value, Nullability::NonNullable),
547 );
548
549 let result = Arc::<dyn Datum>::try_from(&scalar);
550 assert!(result.is_ok());
551 }
552
553 #[rstest]
554 #[case(TimeUnit::Nanoseconds, "UTC", 1234567890000000000i64)]
555 #[case(TimeUnit::Microseconds, "EST", 1234567890000000i64)]
556 #[case(TimeUnit::Milliseconds, "ABC", 1234567890000i64)]
557 #[case(TimeUnit::Seconds, "UTC", 1234567890i64)]
558 fn test_temporal_timestamp_tz_to_arrow(
559 #[case] time_unit: TimeUnit,
560 #[case] tz: &str,
561 #[case] value: i64,
562 ) {
563 let scalar = Scalar::extension::<Timestamp>(
564 TimestampOptions {
565 unit: time_unit,
566 tz: Some(tz.into()),
567 },
568 Scalar::primitive(value, Nullability::NonNullable),
569 );
570
571 let result = Arc::<dyn Datum>::try_from(&scalar);
572 assert!(result.is_ok());
573 }
574
575 #[test]
576 fn test_temporal_with_null_value() {
577 let scalar = Scalar::extension::<Time>(
578 TimeUnit::Milliseconds,
579 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
580 );
581
582 let _result = Arc::<dyn Datum>::try_from(&scalar).unwrap();
583 }
584
585 #[test]
586 #[should_panic(expected = "DType utf8 is not a primitive type")]
587 fn test_temporal_non_primitive_storage_error() {
588 let _scalar = Scalar::extension::<Time>(
589 TimeUnit::Nanoseconds,
590 Scalar::utf8("not a timestamp", Nullability::NonNullable),
591 );
592 }
593}