1use std::sync::Arc;
7
8use arrow_array::Scalar as ArrowScalar;
9use arrow_array::*;
10use vortex_error::VortexError;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13
14use crate::dtype::DType;
15use crate::dtype::PType;
16use crate::extension::datetime::AnyTemporal;
17use crate::extension::datetime::TemporalMetadata;
18use crate::extension::datetime::TimeUnit;
19use crate::scalar::BinaryScalar;
20use crate::scalar::BoolScalar;
21use crate::scalar::DecimalScalar;
22use crate::scalar::DecimalValue;
23use crate::scalar::ExtScalar;
24use crate::scalar::PrimitiveScalar;
25use crate::scalar::Scalar;
26use crate::scalar::Utf8Scalar;
27
28const SCALAR_ARRAY_LEN: usize = 1;
30
31macro_rules! value_to_arrow_scalar {
33 ($V:expr, $AR:ty) => {
34 Ok(std::sync::Arc::new(
35 $V.map(<$AR>::new_scalar)
36 .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(SCALAR_ARRAY_LEN))),
37 ))
38 };
39}
40
41macro_rules! timestamp_to_arrow_scalar {
43 ($V:expr, $TZ:expr, $AR:ty) => {{
44 let array = match $V {
45 Some(v) => <$AR>::new_scalar(v).into_inner(),
46 None => <$AR>::new_null(SCALAR_ARRAY_LEN),
47 }
48 .with_timezone_opt($TZ);
49 Ok(Arc::new(ArrowScalar::new(array)))
50 }};
51}
52
53impl TryFrom<&Scalar> for Arc<dyn Datum> {
54 type Error = VortexError;
55
56 fn try_from(value: &Scalar) -> Result<Arc<dyn Datum>, Self::Error> {
57 match value.dtype() {
58 DType::Null => Ok(Arc::new(NullArray::new(SCALAR_ARRAY_LEN))),
59 DType::Bool(_) => bool_to_arrow(value.as_bool()),
60 DType::Primitive(..) => primitive_to_arrow(value.as_primitive()),
61 DType::Decimal(..) => decimal_to_arrow(value.as_decimal()),
62 DType::Utf8(_) => utf8_to_arrow(value.as_utf8()),
63 DType::Binary(_) => binary_to_arrow(value.as_binary()),
64 DType::Struct(..) => unimplemented!("struct scalar conversion"),
65 DType::List(..) => unimplemented!("list scalar conversion"),
66 DType::FixedSizeList(..) => unimplemented!("fixed-size list scalar conversion"),
67 DType::Extension(..) => extension_to_arrow(value.as_extension()),
68 }
69 }
70}
71
72fn bool_to_arrow(scalar: BoolScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
74 value_to_arrow_scalar!(scalar.value(), BooleanArray)
75}
76
77fn primitive_to_arrow(scalar: PrimitiveScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
79 match scalar.ptype() {
80 PType::U8 => value_to_arrow_scalar!(scalar.typed_value(), UInt8Array),
81 PType::U16 => value_to_arrow_scalar!(scalar.typed_value(), UInt16Array),
82 PType::U32 => value_to_arrow_scalar!(scalar.typed_value(), UInt32Array),
83 PType::U64 => value_to_arrow_scalar!(scalar.typed_value(), UInt64Array),
84 PType::I8 => value_to_arrow_scalar!(scalar.typed_value(), Int8Array),
85 PType::I16 => value_to_arrow_scalar!(scalar.typed_value(), Int16Array),
86 PType::I32 => value_to_arrow_scalar!(scalar.typed_value(), Int32Array),
87 PType::I64 => value_to_arrow_scalar!(scalar.typed_value(), Int64Array),
88 PType::F16 => value_to_arrow_scalar!(scalar.typed_value(), Float16Array),
89 PType::F32 => value_to_arrow_scalar!(scalar.typed_value(), Float32Array),
90 PType::F64 => value_to_arrow_scalar!(scalar.typed_value(), Float64Array),
91 }
92}
93
94fn decimal_to_arrow(scalar: DecimalScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
96 match scalar.decimal_value() {
98 Some(DecimalValue::I8(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
99 Some(DecimalValue::I16(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
100 Some(DecimalValue::I32(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
101 Some(DecimalValue::I64(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
102 Some(DecimalValue::I128(v128)) => Ok(Arc::new(Decimal128Array::new_scalar(v128))),
103 Some(DecimalValue::I256(v256)) => Ok(Arc::new(Decimal256Array::new_scalar(v256.into()))),
104 None => Ok(Arc::new(arrow_array::Scalar::new(
105 Decimal128Array::new_null(SCALAR_ARRAY_LEN),
106 ))),
107 }
108}
109
110fn utf8_to_arrow(scalar: Utf8Scalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
112 value_to_arrow_scalar!(scalar.value(), StringViewArray)
113}
114
115fn binary_to_arrow(scalar: BinaryScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
117 value_to_arrow_scalar!(scalar.value(), BinaryViewArray)
118}
119
120fn extension_to_arrow(scalar: ExtScalar<'_>) -> Result<Arc<dyn Datum>, VortexError> {
124 let ext_dtype = scalar.ext_dtype();
125 let Some(temporal) = ext_dtype.metadata_opt::<AnyTemporal>() else {
126 vortex_bail!(
127 "Cannot convert extension scalar {} to Arrow",
128 ext_dtype.id()
129 )
130 };
131
132 let storage_scalar = scalar.to_storage_scalar();
133 let primitive = storage_scalar
134 .as_primitive_opt()
135 .ok_or_else(|| vortex_err!("Expected primitive scalar"))?;
136
137 match temporal {
138 TemporalMetadata::Timestamp(unit, tz) => {
139 let value = primitive.as_::<i64>();
140 match unit {
141 TimeUnit::Nanoseconds => {
142 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampNanosecondArray)
143 }
144 TimeUnit::Microseconds => {
145 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMicrosecondArray)
146 }
147 TimeUnit::Milliseconds => {
148 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampMillisecondArray)
149 }
150 TimeUnit::Seconds => {
151 timestamp_to_arrow_scalar!(value, tz.clone(), TimestampSecondArray)
152 }
153 TimeUnit::Days => {
154 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
155 }
156 }
157 }
158 TemporalMetadata::Date(unit) => match unit {
159 TimeUnit::Milliseconds => {
160 value_to_arrow_scalar!(primitive.as_::<i64>(), Date64Array)
161 }
162 TimeUnit::Days => {
163 value_to_arrow_scalar!(primitive.as_::<i32>(), Date32Array)
164 }
165 TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => {
166 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
167 }
168 },
169 TemporalMetadata::Time(unit) => match unit {
170 TimeUnit::Nanoseconds => {
171 value_to_arrow_scalar!(primitive.as_::<i64>(), Time64NanosecondArray)
172 }
173 TimeUnit::Microseconds => {
174 value_to_arrow_scalar!(primitive.as_::<i64>(), Time64MicrosecondArray)
175 }
176 TimeUnit::Milliseconds => {
177 value_to_arrow_scalar!(primitive.as_::<i32>(), Time32MillisecondArray)
178 }
179 TimeUnit::Seconds => {
180 value_to_arrow_scalar!(primitive.as_::<i32>(), Time32SecondArray)
181 }
182 TimeUnit::Days => {
183 vortex_bail!("Unsupported TimeUnit {unit} for {}", ext_dtype.id())
184 }
185 },
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use std::sync::Arc;
192
193 use arrow_array::Datum;
194 use rstest::rstest;
195 use vortex_error::VortexResult;
196 use vortex_error::vortex_bail;
197
198 use crate::dtype::DType;
199 use crate::dtype::DecimalDType;
200 use crate::dtype::FieldDType;
201 use crate::dtype::NativeDType;
202 use crate::dtype::Nullability;
203 use crate::dtype::PType;
204 use crate::dtype::StructFields;
205 use crate::dtype::extension::ExtId;
206 use crate::dtype::extension::ExtVTable;
207 use crate::dtype::i256;
208 use crate::extension::datetime::Date;
209 use crate::extension::datetime::Time;
210 use crate::extension::datetime::TimeUnit;
211 use crate::extension::datetime::Timestamp;
212 use crate::extension::datetime::TimestampOptions;
213 use crate::scalar::DecimalValue;
214 use crate::scalar::Scalar;
215 use crate::scalar::ScalarValue;
216
217 #[test]
218 fn test_null_scalar_to_arrow() {
219 let scalar = Scalar::null(DType::Null);
220 let result = Arc::<dyn Datum>::try_from(&scalar);
221 assert!(result.is_ok());
222 }
223
224 #[test]
225 fn test_bool_scalar_to_arrow() {
226 let scalar = Scalar::bool(true, Nullability::NonNullable);
227 let result = Arc::<dyn Datum>::try_from(&scalar);
228 assert!(result.is_ok());
229 }
230
231 #[test]
232 fn test_null_bool_scalar_to_arrow() {
233 let scalar = Scalar::null(bool::dtype().as_nullable());
234 let result = Arc::<dyn Datum>::try_from(&scalar);
235 assert!(result.is_ok());
236 }
237
238 #[test]
239 fn test_primitive_u8_to_arrow() {
240 let scalar = Scalar::primitive(42u8, Nullability::NonNullable);
241 let result = Arc::<dyn Datum>::try_from(&scalar);
242 assert!(result.is_ok());
243 }
244
245 #[test]
246 fn test_primitive_u16_to_arrow() {
247 let scalar = Scalar::primitive(1000u16, Nullability::NonNullable);
248 let result = Arc::<dyn Datum>::try_from(&scalar);
249 assert!(result.is_ok());
250 }
251
252 #[test]
253 fn test_primitive_u32_to_arrow() {
254 let scalar = Scalar::primitive(100000u32, Nullability::NonNullable);
255 let result = Arc::<dyn Datum>::try_from(&scalar);
256 assert!(result.is_ok());
257 }
258
259 #[test]
260 fn test_primitive_u64_to_arrow() {
261 let scalar = Scalar::primitive(10000000000u64, Nullability::NonNullable);
262 let result = Arc::<dyn Datum>::try_from(&scalar);
263 assert!(result.is_ok());
264 }
265
266 #[test]
267 fn test_primitive_i8_to_arrow() {
268 let scalar = Scalar::primitive(-42i8, Nullability::NonNullable);
269 let result = Arc::<dyn Datum>::try_from(&scalar);
270 assert!(result.is_ok());
271 }
272
273 #[test]
274 fn test_primitive_i16_to_arrow() {
275 let scalar = Scalar::primitive(-1000i16, Nullability::NonNullable);
276 let result = Arc::<dyn Datum>::try_from(&scalar);
277 assert!(result.is_ok());
278 }
279
280 #[test]
281 fn test_primitive_i32_to_arrow() {
282 let scalar = Scalar::primitive(-100000i32, Nullability::NonNullable);
283 let result = Arc::<dyn Datum>::try_from(&scalar);
284 assert!(result.is_ok());
285 }
286
287 #[test]
288 fn test_primitive_i64_to_arrow() {
289 let scalar = Scalar::primitive(-10000000000i64, Nullability::NonNullable);
290 let result = Arc::<dyn Datum>::try_from(&scalar);
291 assert!(result.is_ok());
292 }
293
294 #[test]
295 fn test_primitive_f16_to_arrow() {
296 use crate::dtype::half::f16;
297
298 let scalar = Scalar::primitive(f16::from_f32(1.234), Nullability::NonNullable);
299 let result = Arc::<dyn Datum>::try_from(&scalar);
300 assert!(result.is_ok());
301 }
302
303 #[test]
304 fn test_primitive_f32_to_arrow() {
305 let scalar = Scalar::primitive(1.234f32, Nullability::NonNullable);
306 let result = Arc::<dyn Datum>::try_from(&scalar);
307 assert!(result.is_ok());
308 }
309
310 #[test]
311 fn test_primitive_f64_to_arrow() {
312 let scalar = Scalar::primitive(1.234567890123f64, Nullability::NonNullable);
313 let result = Arc::<dyn Datum>::try_from(&scalar);
314 assert!(result.is_ok());
315 }
316
317 #[test]
318 fn test_null_primitive_to_arrow() {
319 let scalar = Scalar::null(i32::dtype().as_nullable());
320 let result = Arc::<dyn Datum>::try_from(&scalar);
321 assert!(result.is_ok());
322 }
323
324 #[test]
325 fn test_utf8_scalar_to_arrow() {
326 let scalar = Scalar::utf8("hello world".to_string(), Nullability::NonNullable);
327 let result = Arc::<dyn Datum>::try_from(&scalar);
328 assert!(result.is_ok());
329 }
330
331 #[test]
332 fn test_null_utf8_scalar_to_arrow() {
333 let scalar = Scalar::null(String::dtype().as_nullable());
334 let result = Arc::<dyn Datum>::try_from(&scalar);
335 assert!(result.is_ok());
336 }
337
338 #[test]
339 fn test_binary_scalar_to_arrow() {
340 let data = vec![1u8, 2, 3, 4, 5];
341 let scalar = Scalar::binary(data, Nullability::NonNullable);
342 let result = Arc::<dyn Datum>::try_from(&scalar);
343 assert!(result.is_ok());
344 }
345
346 #[test]
347 fn test_null_binary_scalar_to_arrow() {
348 let scalar = Scalar::null(DType::Binary(Nullability::Nullable));
349 let result = Arc::<dyn Datum>::try_from(&scalar);
350 assert!(result.is_ok());
351 }
352
353 #[test]
354 fn test_decimal_scalars_to_arrow() {
355 let decimal_dtype = DecimalDType::new(5, 2);
357
358 let scalar_i8 = Scalar::decimal(
359 DecimalValue::I8(100),
360 decimal_dtype,
361 Nullability::NonNullable,
362 );
363 assert!(Arc::<dyn Datum>::try_from(&scalar_i8).is_ok());
364
365 let scalar_i16 = Scalar::decimal(
366 DecimalValue::I16(10000),
367 decimal_dtype,
368 Nullability::NonNullable,
369 );
370 assert!(Arc::<dyn Datum>::try_from(&scalar_i16).is_ok());
371
372 let scalar_i32 = Scalar::decimal(
373 DecimalValue::I32(99999),
374 decimal_dtype,
375 Nullability::NonNullable,
376 );
377 assert!(Arc::<dyn Datum>::try_from(&scalar_i32).is_ok());
378
379 let scalar_i64 = Scalar::decimal(
380 DecimalValue::I64(99999),
381 decimal_dtype,
382 Nullability::NonNullable,
383 );
384 assert!(Arc::<dyn Datum>::try_from(&scalar_i64).is_ok());
385
386 let scalar_i128 = Scalar::decimal(
387 DecimalValue::I128(99999),
388 decimal_dtype,
389 Nullability::NonNullable,
390 );
391 assert!(Arc::<dyn Datum>::try_from(&scalar_i128).is_ok());
392
393 let value_i256 = i256::from_i128(99999);
396 let scalar_i256 = Scalar::decimal(
397 DecimalValue::I256(value_i256),
398 decimal_dtype,
399 Nullability::NonNullable,
400 );
401 assert!(Arc::<dyn Datum>::try_from(&scalar_i256).is_ok());
402 }
403
404 #[test]
405 fn test_null_decimal_to_arrow() {
406 let decimal_dtype = DecimalDType::new(10, 2);
407 let scalar = Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable));
408 let result = Arc::<dyn Datum>::try_from(&scalar);
409 assert!(result.is_ok());
410 }
411
412 #[test]
413 #[should_panic(expected = "struct scalar conversion")]
414 fn test_struct_scalar_to_arrow_todo() {
415 let struct_dtype = DType::Struct(
416 StructFields::from_iter([(
417 "field1",
418 FieldDType::from(DType::Primitive(PType::I32, Nullability::NonNullable)),
419 )]),
420 Nullability::NonNullable,
421 );
422
423 let struct_scalar = Scalar::struct_(
424 struct_dtype,
425 vec![Scalar::primitive(42i32, Nullability::NonNullable)],
426 );
427 Arc::<dyn Datum>::try_from(&struct_scalar).unwrap();
428 }
429
430 #[test]
431 #[should_panic(expected = "list scalar conversion")]
432 fn test_list_scalar_to_arrow_todo() {
433 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
434 let list_scalar = Scalar::list(
435 element_dtype,
436 vec![
437 Scalar::primitive(1i32, Nullability::NonNullable),
438 Scalar::primitive(2i32, Nullability::NonNullable),
439 ],
440 Nullability::NonNullable,
441 );
442
443 Arc::<dyn Datum>::try_from(&list_scalar).unwrap();
444 }
445
446 #[test]
447 #[should_panic(expected = "Cannot convert extension scalar")]
448 fn test_non_temporal_extension_to_arrow_todo() {
449 #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
450 struct SomeExt;
451 impl ExtVTable for SomeExt {
452 type Metadata = String;
453 type NativeValue<'a> = &'a str;
454
455 fn id(&self) -> ExtId {
456 ExtId::new_ref("some_ext")
457 }
458
459 fn serialize_metadata(&self, _options: &Self::Metadata) -> VortexResult<Vec<u8>> {
460 vortex_bail!("not implemented")
461 }
462
463 fn deserialize_metadata(&self, _data: &[u8]) -> VortexResult<Self::Metadata> {
464 vortex_bail!("not implemented")
465 }
466
467 fn validate_dtype(
468 &self,
469 _options: &Self::Metadata,
470 _storage_dtype: &DType,
471 ) -> VortexResult<()> {
472 Ok(())
473 }
474
475 fn unpack_native<'a>(
476 &self,
477 _metadata: &'a Self::Metadata,
478 _storage_dtype: &'a DType,
479 _storage_value: &'a ScalarValue,
480 ) -> VortexResult<Self::NativeValue<'a>> {
481 Ok("")
482 }
483 }
484
485 let scalar = Scalar::extension::<SomeExt>(
486 "".into(),
487 Scalar::primitive(42i32, Nullability::NonNullable),
488 );
489
490 Arc::<dyn Datum>::try_from(&scalar).unwrap();
491 }
492
493 #[rstest]
494 #[case(TimeUnit::Nanoseconds, PType::I64, 123456789i64)]
495 #[case(TimeUnit::Microseconds, PType::I64, 123456789i64)]
496 #[case(TimeUnit::Milliseconds, PType::I32, 123456i64)]
497 #[case(TimeUnit::Seconds, PType::I32, 1234i64)]
498 fn test_temporal_time_to_arrow(
499 #[case] time_unit: TimeUnit,
500 #[case] ptype: PType,
501 #[case] value: i64,
502 ) {
503 let scalar = Scalar::extension::<Time>(
504 time_unit,
505 match ptype {
506 PType::I32 => {
507 let i32_value = i32::try_from(value).expect("test value should fit in i32");
508 Scalar::primitive(i32_value, Nullability::NonNullable)
509 }
510 PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
511 _ => unreachable!(),
512 },
513 );
514
515 let result = Arc::<dyn Datum>::try_from(&scalar);
516 assert!(result.is_ok());
517 }
518
519 #[rstest]
520 #[case(TimeUnit::Milliseconds, PType::I64, 1234567890000i64)]
521 #[case(TimeUnit::Days, PType::I32, 19000i64)]
522 fn test_temporal_date_to_arrow(
523 #[case] time_unit: TimeUnit,
524 #[case] ptype: PType,
525 #[case] value: i64,
526 ) {
527 let scalar = Scalar::extension::<Date>(
528 time_unit,
529 match ptype {
530 PType::I32 => {
531 let i32_value = i32::try_from(value).expect("test value should fit in i32");
532 Scalar::primitive(i32_value, Nullability::NonNullable)
533 }
534 PType::I64 => Scalar::primitive(value, Nullability::NonNullable),
535 _ => unreachable!(),
536 },
537 );
538
539 let result = Arc::<dyn Datum>::try_from(&scalar);
540 assert!(result.is_ok());
541 }
542
543 #[rstest]
544 #[case(TimeUnit::Nanoseconds, 1234567890000000000i64)]
545 #[case(TimeUnit::Microseconds, 1234567890000000i64)]
546 #[case(TimeUnit::Milliseconds, 1234567890000i64)]
547 #[case(TimeUnit::Seconds, 1234567890i64)]
548 fn test_temporal_timestamp_to_arrow(#[case] time_unit: TimeUnit, #[case] value: i64) {
549 let scalar = Scalar::extension::<Timestamp>(
550 TimestampOptions {
551 unit: time_unit,
552 tz: None,
553 },
554 Scalar::primitive(value, Nullability::NonNullable),
555 );
556
557 let result = Arc::<dyn Datum>::try_from(&scalar);
558 assert!(result.is_ok());
559 }
560
561 #[rstest]
562 #[case(TimeUnit::Nanoseconds, "UTC", 1234567890000000000i64)]
563 #[case(TimeUnit::Microseconds, "EST", 1234567890000000i64)]
564 #[case(TimeUnit::Milliseconds, "ABC", 1234567890000i64)]
565 #[case(TimeUnit::Seconds, "UTC", 1234567890i64)]
566 fn test_temporal_timestamp_tz_to_arrow(
567 #[case] time_unit: TimeUnit,
568 #[case] tz: &str,
569 #[case] value: i64,
570 ) {
571 let scalar = Scalar::extension::<Timestamp>(
572 TimestampOptions {
573 unit: time_unit,
574 tz: Some(tz.into()),
575 },
576 Scalar::primitive(value, Nullability::NonNullable),
577 );
578
579 let result = Arc::<dyn Datum>::try_from(&scalar);
580 assert!(result.is_ok());
581 }
582
583 #[test]
584 fn test_temporal_with_null_value() {
585 let scalar = Scalar::extension::<Time>(
586 TimeUnit::Milliseconds,
587 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
588 );
589
590 let _result = Arc::<dyn Datum>::try_from(&scalar).unwrap();
591 }
592
593 #[test]
594 #[should_panic(expected = "DType utf8 is not a primitive type")]
595 fn test_temporal_non_primitive_storage_error() {
596 let _scalar = Scalar::extension::<Time>(
597 TimeUnit::Nanoseconds,
598 Scalar::utf8("not a timestamp", Nullability::NonNullable),
599 );
600 }
601}