1use num_traits::ToBytes;
7use num_traits::ToPrimitive;
8use prost::Message;
9use vortex_buffer::BufferString;
10use vortex_buffer::ByteBuffer;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_ensure;
15use vortex_error::vortex_err;
16use vortex_proto::scalar as pb;
17use vortex_proto::scalar::ListValue;
18use vortex_proto::scalar::scalar_value::Kind;
19use vortex_session::VortexSession;
20
21use crate::dtype::DType;
22use crate::dtype::PType;
23use crate::dtype::half::f16;
24use crate::dtype::i256;
25use crate::scalar::DecimalValue;
26use crate::scalar::PValue;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30impl From<&Scalar> for pb::Scalar {
35 fn from(value: &Scalar) -> Self {
36 pb::Scalar {
37 dtype: Some(
38 (value.dtype())
39 .try_into()
40 .vortex_expect("Failed to convert DType to proto"),
41 ),
42 value: Some(Box::new(ScalarValue::to_proto(value.value()))),
43 }
44 }
45}
46
47impl ScalarValue {
48 pub fn to_proto(this: Option<&Self>) -> pb::ScalarValue {
58 match this {
59 None => pb::ScalarValue {
60 kind: Some(Kind::NullValue(0)),
61 },
62 Some(this) => pb::ScalarValue::from(this),
63 }
64 }
65
66 pub fn to_proto_bytes<B: Default + bytes::BufMut>(value: Option<&ScalarValue>) -> B {
68 let proto = Self::to_proto(value);
69 let mut buf = B::default();
70 proto
71 .encode(&mut buf)
72 .vortex_expect("Failed to encode scalar value");
73 buf
74 }
75}
76
77impl From<&ScalarValue> for pb::ScalarValue {
78 fn from(value: &ScalarValue) -> Self {
79 match value {
80 ScalarValue::Bool(v) => pb::ScalarValue {
81 kind: Some(Kind::BoolValue(*v)),
82 },
83 ScalarValue::Primitive(v) => pb::ScalarValue::from(v),
84 ScalarValue::Decimal(v) => {
85 let inner_value = match v {
86 DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
87 DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
88 DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
89 DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
90 DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
91 DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
92 };
93
94 pb::ScalarValue {
95 kind: Some(Kind::BytesValue(inner_value)),
96 }
97 }
98 ScalarValue::Utf8(v) => pb::ScalarValue {
99 kind: Some(Kind::StringValue(v.to_string())),
100 },
101 ScalarValue::Binary(v) => pb::ScalarValue {
102 kind: Some(Kind::BytesValue(v.to_vec())),
103 },
104 ScalarValue::List(v) => {
105 let mut values = Vec::with_capacity(v.len());
106 for elem in v.iter() {
107 values.push(ScalarValue::to_proto(elem.as_ref()));
108 }
109 pb::ScalarValue {
110 kind: Some(Kind::ListValue(ListValue { values })),
111 }
112 }
113 ScalarValue::Variant(v) => pb::ScalarValue {
114 kind: Some(Kind::VariantValue(Box::new(pb::Scalar::from(v.as_ref())))),
115 },
116 }
117 }
118}
119
120impl From<&PValue> for pb::ScalarValue {
121 fn from(value: &PValue) -> Self {
122 match value {
123 PValue::I8(v) => pb::ScalarValue {
124 kind: Some(Kind::Int64Value(*v as i64)),
125 },
126 PValue::I16(v) => pb::ScalarValue {
127 kind: Some(Kind::Int64Value(*v as i64)),
128 },
129 PValue::I32(v) => pb::ScalarValue {
130 kind: Some(Kind::Int64Value(*v as i64)),
131 },
132 PValue::I64(v) => pb::ScalarValue {
133 kind: Some(Kind::Int64Value(*v)),
134 },
135 PValue::U8(v) => pb::ScalarValue {
136 kind: Some(Kind::Uint64Value(*v as u64)),
137 },
138 PValue::U16(v) => pb::ScalarValue {
139 kind: Some(Kind::Uint64Value(*v as u64)),
140 },
141 PValue::U32(v) => pb::ScalarValue {
142 kind: Some(Kind::Uint64Value(*v as u64)),
143 },
144 PValue::U64(v) => pb::ScalarValue {
145 kind: Some(Kind::Uint64Value(*v)),
146 },
147 PValue::F16(v) => pb::ScalarValue {
148 kind: Some(Kind::F16Value(v.to_bits() as u64)),
149 },
150 PValue::F32(v) => pb::ScalarValue {
151 kind: Some(Kind::F32Value(*v)),
152 },
153 PValue::F64(v) => pb::ScalarValue {
154 kind: Some(Kind::F64Value(*v)),
155 },
156 }
157 }
158}
159
160impl Scalar {
165 pub fn from_proto_value(
174 value: &pb::ScalarValue,
175 dtype: &DType,
176 session: &VortexSession,
177 ) -> VortexResult<Self> {
178 let scalar_value = ScalarValue::from_proto(value, dtype, session)?;
179
180 Scalar::try_new(dtype.clone(), scalar_value)
181 }
182
183 pub fn from_proto(value: &pb::Scalar, session: &VortexSession) -> VortexResult<Self> {
189 let dtype = DType::from_proto(
190 value
191 .dtype
192 .as_ref()
193 .ok_or_else(|| vortex_err!(Serde: "Scalar missing dtype"))?,
194 session,
195 )?;
196
197 let pb_scalar_value: &pb::ScalarValue = value
198 .value
199 .as_ref()
200 .ok_or_else(|| vortex_err!(Serde: "Scalar missing value"))?;
201
202 let value: Option<ScalarValue> = ScalarValue::from_proto(pb_scalar_value, &dtype, session)?;
203
204 Scalar::try_new(dtype, value)
205 }
206}
207
208impl ScalarValue {
209 pub fn from_proto_bytes(
218 bytes: &[u8],
219 dtype: &DType,
220 session: &VortexSession,
221 ) -> VortexResult<Option<Self>> {
222 let proto = pb::ScalarValue::decode(bytes)?;
223 Self::from_proto(&proto, dtype, session)
224 }
225
226 pub fn from_proto(
235 value: &pb::ScalarValue,
236 dtype: &DType,
237 session: &VortexSession,
238 ) -> VortexResult<Option<Self>> {
239 let kind = value
240 .kind
241 .as_ref()
242 .ok_or_else(|| vortex_err!(Serde: "Scalar value missing kind"))?;
243
244 let dtype = match dtype {
246 DType::Extension(ext) => ext.storage_dtype(),
247 _ => dtype,
248 };
249
250 Ok(match kind {
251 Kind::NullValue(_) => None,
252 Kind::BoolValue(v) => Some(bool_from_proto(*v, dtype)?),
253 Kind::Int64Value(v) => Some(int64_from_proto(*v, dtype)?),
254 Kind::Uint64Value(v) => Some(uint64_from_proto(*v, dtype)?),
255 Kind::F16Value(v) => Some(f16_from_proto(*v, dtype)?),
256 Kind::F32Value(v) => Some(f32_from_proto(*v, dtype)?),
257 Kind::F64Value(v) => Some(f64_from_proto(*v, dtype)?),
258 Kind::StringValue(s) => Some(string_from_proto(s, dtype)?),
259 Kind::BytesValue(b) => Some(bytes_from_proto(b, dtype)?),
260 Kind::ListValue(v) => Some(list_from_proto(v, dtype, session)?),
261 Kind::VariantValue(v) => match dtype {
262 DType::Variant(_) => Some(ScalarValue::Variant(Box::new(Scalar::from_proto(
263 v, session,
264 )?))),
265 _ => vortex_bail!(Serde: "expected non-Variant scalar proto for dtype {dtype}"),
266 },
267 })
268 }
269}
270
271fn bool_from_proto(v: bool, dtype: &DType) -> VortexResult<ScalarValue> {
273 vortex_ensure!(
274 dtype.is_boolean(),
275 Serde: "expected Bool dtype for BoolValue, got {dtype}"
276 );
277
278 Ok(ScalarValue::Bool(v))
279}
280
281fn int64_from_proto(v: i64, dtype: &DType) -> VortexResult<ScalarValue> {
286 vortex_ensure!(
287 dtype.is_primitive(),
288 Serde: "expected Primitive dtype for Int64Value, got {dtype}"
289 );
290
291 let pvalue = match dtype.as_ptype() {
292 PType::I8 => v.to_i8().map(PValue::I8),
293 PType::I16 => v.to_i16().map(PValue::I16),
294 PType::I32 => v.to_i32().map(PValue::I32),
295 PType::I64 => Some(PValue::I64(v)),
296 PType::U8 => v.to_u8().map(PValue::U8),
299 PType::U16 => v.to_u16().map(PValue::U16),
300 PType::U32 => v.to_u32().map(PValue::U32),
301 PType::U64 => v.to_u64().map(PValue::U64),
302 ftype @ (PType::F16 | PType::F32 | PType::F64) => vortex_bail!(
303 Serde: "expected signed integer ptype for serialized Int64Value, got float {ftype}"
304 ),
305 }
306 .ok_or_else(|| vortex_err!(Serde: "Int64 value {v} out of range for dtype {dtype}"))?;
307
308 Ok(ScalarValue::Primitive(pvalue))
309}
310
311fn uint64_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
317 vortex_ensure!(
318 dtype.is_primitive(),
319 Serde: "expected Primitive dtype for Uint64Value, got {dtype}"
320 );
321
322 let pvalue = match dtype.as_ptype() {
323 PType::U8 => v.to_u8().map(PValue::U8),
324 PType::U16 => v.to_u16().map(PValue::U16),
325 PType::U32 => v.to_u32().map(PValue::U32),
326 PType::U64 => Some(PValue::U64(v)),
327 PType::I8 => v.to_i8().map(PValue::I8),
330 PType::I16 => v.to_i16().map(PValue::I16),
331 PType::I32 => v.to_i32().map(PValue::I32),
332 PType::I64 => v.to_i64().map(PValue::I64),
333 PType::F16 => v.to_u16().map(f16::from_bits).map(PValue::F16),
335 ftype @ (PType::F32 | PType::F64) => vortex_bail!(
336 Serde: "expected unsigned integer ptype for serialized Uint64Value, got {ftype}"
337 ),
338 }
339 .ok_or_else(|| vortex_err!(Serde: "Uint64 value {v} out of range for dtype {dtype}"))?;
340
341 Ok(ScalarValue::Primitive(pvalue))
342}
343
344fn f16_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
346 vortex_ensure!(
347 matches!(dtype, DType::Primitive(PType::F16, _)),
348 Serde: "expected F16 dtype for F16Value, got {dtype}"
349 );
350
351 let bits = u16::try_from(v)
352 .map_err(|_| vortex_err!(Serde: "f16 bitwise representation has more than 16 bits: {v}"))?;
353
354 Ok(ScalarValue::Primitive(PValue::F16(f16::from_bits(bits))))
355}
356
357fn f32_from_proto(v: f32, dtype: &DType) -> VortexResult<ScalarValue> {
359 vortex_ensure!(
360 matches!(dtype, DType::Primitive(PType::F32, _)),
361 Serde: "expected F32 dtype for F32Value, got {dtype}"
362 );
363
364 Ok(ScalarValue::Primitive(PValue::F32(v)))
365}
366
367fn f64_from_proto(v: f64, dtype: &DType) -> VortexResult<ScalarValue> {
369 vortex_ensure!(
370 matches!(dtype, DType::Primitive(PType::F64, _)),
371 Serde: "expected F64 dtype for F64Value, got {dtype}"
372 );
373
374 Ok(ScalarValue::Primitive(PValue::F64(v)))
375}
376
377fn string_from_proto(s: &str, dtype: &DType) -> VortexResult<ScalarValue> {
380 match dtype {
381 DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::from(s))),
382 DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(s.as_bytes()))),
383 _ => vortex_bail!(
384 Serde: "expected Utf8 or Binary dtype for StringValue, got {dtype}"
385 ),
386 }
387}
388
389fn bytes_from_proto(bytes: &[u8], dtype: &DType) -> VortexResult<ScalarValue> {
394 match dtype {
395 DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::try_from(bytes)?)),
396 DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(bytes))),
397 DType::Decimal(..) => Ok(ScalarValue::Decimal(match bytes.len() {
399 1 => DecimalValue::I8(bytes[0] as i8),
400 2 => DecimalValue::I16(i16::from_le_bytes(
401 bytes
402 .try_into()
403 .ok()
404 .vortex_expect("Buffer has invalid number of bytes"),
405 )),
406 4 => DecimalValue::I32(i32::from_le_bytes(
407 bytes
408 .try_into()
409 .ok()
410 .vortex_expect("Buffer has invalid number of bytes"),
411 )),
412 8 => DecimalValue::I64(i64::from_le_bytes(
413 bytes
414 .try_into()
415 .ok()
416 .vortex_expect("Buffer has invalid number of bytes"),
417 )),
418 16 => DecimalValue::I128(i128::from_le_bytes(
419 bytes
420 .try_into()
421 .ok()
422 .vortex_expect("Buffer has invalid number of bytes"),
423 )),
424 32 => DecimalValue::I256(i256::from_le_bytes(
425 bytes
426 .try_into()
427 .ok()
428 .vortex_expect("Buffer has invalid number of bytes"),
429 )),
430 l => vortex_bail!(Serde: "invalid decimal byte length: {l}"),
431 })),
432 _ => vortex_bail!(
433 Serde: "expected Utf8, Binary, or Decimal dtype for BytesValue, got {dtype}"
434 ),
435 }
436}
437
438fn list_from_proto(
440 v: &ListValue,
441 dtype: &DType,
442 session: &VortexSession,
443) -> VortexResult<ScalarValue> {
444 let element_dtype = dtype
445 .as_list_element_opt()
446 .ok_or_else(|| vortex_err!(Serde: "expected List dtype for ListValue, got {dtype}"))?;
447
448 let mut values = Vec::with_capacity(v.values.len());
449 for elem in v.values.iter() {
450 values.push(ScalarValue::from_proto(
451 elem,
452 element_dtype.as_ref(),
453 session,
454 )?);
455 }
456
457 Ok(ScalarValue::List(values))
458}
459
460#[cfg(test)]
461mod tests {
462 use std::sync::Arc;
463
464 use vortex_buffer::BufferString;
465 use vortex_error::vortex_panic;
466 use vortex_proto::scalar as pb;
467 use vortex_session::VortexSession;
468
469 use super::*;
470 use crate::dtype::DType;
471 use crate::dtype::DecimalDType;
472 use crate::dtype::Nullability;
473 use crate::dtype::PType;
474 use crate::dtype::half::f16;
475 use crate::scalar::DecimalValue;
476 use crate::scalar::Scalar;
477 use crate::scalar::ScalarValue;
478
479 fn session() -> VortexSession {
480 VortexSession::empty()
481 }
482
483 fn round_trip(scalar: Scalar) {
484 assert_eq!(
485 scalar,
486 Scalar::from_proto(&pb::Scalar::from(&scalar), &session()).unwrap(),
487 );
488 }
489
490 #[test]
491 fn test_null() {
492 round_trip(Scalar::null(DType::Null));
493 }
494
495 #[test]
496 fn test_bool() {
497 round_trip(Scalar::new(
498 DType::Bool(Nullability::Nullable),
499 Some(ScalarValue::Bool(true)),
500 ));
501 }
502
503 #[test]
504 fn test_primitive() {
505 round_trip(Scalar::new(
506 DType::Primitive(PType::I32, Nullability::Nullable),
507 Some(ScalarValue::Primitive(42i32.into())),
508 ));
509 }
510
511 #[test]
512 fn test_buffer() {
513 round_trip(Scalar::new(
514 DType::Binary(Nullability::Nullable),
515 Some(ScalarValue::Binary(vec![1, 2, 3].into())),
516 ));
517 }
518
519 #[test]
520 fn test_buffer_string() {
521 round_trip(Scalar::new(
522 DType::Utf8(Nullability::Nullable),
523 Some(ScalarValue::Utf8(BufferString::from("hello".to_string()))),
524 ));
525 }
526
527 #[test]
528 fn test_list() {
529 round_trip(Scalar::new(
530 DType::List(
531 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
532 Nullability::Nullable,
533 ),
534 Some(ScalarValue::List(vec![
535 Some(ScalarValue::Primitive(42i32.into())),
536 Some(ScalarValue::Primitive(43i32.into())),
537 ])),
538 ));
539 }
540
541 #[test]
542 fn test_f16() {
543 round_trip(Scalar::primitive(
544 f16::from_f32(0.42),
545 Nullability::Nullable,
546 ));
547 }
548
549 #[test]
550 fn test_i8() {
551 round_trip(Scalar::new(
552 DType::Primitive(PType::I8, Nullability::Nullable),
553 Some(ScalarValue::Primitive(i8::MIN.into())),
554 ));
555
556 round_trip(Scalar::new(
557 DType::Primitive(PType::I8, Nullability::Nullable),
558 Some(ScalarValue::Primitive(0i8.into())),
559 ));
560
561 round_trip(Scalar::new(
562 DType::Primitive(PType::I8, Nullability::Nullable),
563 Some(ScalarValue::Primitive(i8::MAX.into())),
564 ));
565 }
566
567 #[test]
568 fn test_decimal_i32_roundtrip() {
569 round_trip(Scalar::decimal(
571 DecimalValue::I32(123_456),
572 DecimalDType::new(10, 2),
573 Nullability::NonNullable,
574 ));
575 }
576
577 #[test]
578 fn test_decimal_i128_roundtrip() {
579 round_trip(Scalar::decimal(
581 DecimalValue::I128(99_999_999_999_999_999_999),
582 DecimalDType::new(38, 6),
583 Nullability::Nullable,
584 ));
585 }
586
587 #[test]
588 fn test_decimal_null_roundtrip() {
589 round_trip(Scalar::null(DType::Decimal(
590 DecimalDType::new(10, 2),
591 Nullability::Nullable,
592 )));
593 }
594
595 #[test]
596 fn test_scalar_value_serde_roundtrip_binary() {
597 round_trip(Scalar::binary(
598 ByteBuffer::copy_from(b"hello"),
599 Nullability::NonNullable,
600 ));
601 }
602
603 #[test]
604 fn test_scalar_value_serde_roundtrip_utf8() {
605 round_trip(Scalar::utf8("hello", Nullability::NonNullable));
606 }
607
608 #[test]
609 fn test_variant_scalar_roundtrip() {
610 let nums = Scalar::list(
611 Arc::new(DType::Variant(Nullability::NonNullable)),
612 vec![
613 Scalar::variant(Scalar::primitive(-7_i16, Nullability::NonNullable)),
614 Scalar::variant(Scalar::primitive(42_u32, Nullability::NonNullable)),
615 Scalar::variant(Scalar::decimal(
616 DecimalValue::I128(123_456_789),
617 DecimalDType::new(18, 0),
618 Nullability::NonNullable,
619 )),
620 ],
621 Nullability::NonNullable,
622 );
623
624 let nested = Scalar::list(
625 Arc::new(DType::Variant(Nullability::NonNullable)),
626 vec![
627 Scalar::variant(Scalar::from(true)),
628 Scalar::variant(nums),
629 Scalar::variant(Scalar::binary(
630 ByteBuffer::copy_from(b"abc"),
631 Nullability::NonNullable,
632 )),
633 Scalar::variant(Scalar::null(DType::Null)),
634 ],
635 Nullability::NonNullable,
636 );
637
638 round_trip(Scalar::variant(nested));
639 }
640
641 #[test]
642 fn test_variant_scalar_proto_preserves_scalar_null_vs_variant_null() {
643 let scalar_null = Scalar::null(DType::Variant(Nullability::Nullable));
644 let variant_null = Scalar::variant(Scalar::null(DType::Null));
645
646 let scalar_null_pb = pb::Scalar::from(&scalar_null);
647 let variant_null_pb = pb::Scalar::from(&variant_null);
648
649 assert_ne!(scalar_null_pb, variant_null_pb);
650 assert_eq!(
651 Scalar::from_proto(&scalar_null_pb, &session()).unwrap(),
652 scalar_null,
653 );
654 assert_eq!(
655 Scalar::from_proto(&variant_null_pb, &session()).unwrap(),
656 variant_null,
657 );
658 }
659
660 #[test]
661 fn test_backcompat_f16_serialized_as_u64() {
662 let f16_value = f16::from_f32(0.42);
678 let f16_bits_as_u64 = f16_value.to_bits() as u64; let pb_scalar_value = pb::ScalarValue {
681 kind: Some(Kind::Uint64Value(f16_bits_as_u64)),
682 };
683
684 let scalar_value = ScalarValue::from_proto(
686 &pb_scalar_value,
687 &DType::Primitive(PType::U64, Nullability::NonNullable),
688 &session(),
689 )
690 .unwrap();
691 assert_eq!(
692 scalar_value.as_ref().map(|v| v.as_primitive()),
693 Some(&PValue::U64(14008u64)),
694 );
695
696 let scalar_value_f16 = ScalarValue::from_proto(
698 &pb_scalar_value,
699 &DType::Primitive(PType::F16, Nullability::Nullable),
700 &session(),
701 )
702 .unwrap();
703
704 let scalar = Scalar::new(
705 DType::Primitive(PType::F16, Nullability::Nullable),
706 scalar_value_f16,
707 );
708
709 assert_eq!(
710 scalar.as_primitive().pvalue().unwrap(),
711 PValue::F16(f16::from_f32(0.42)),
712 "Uint64Value should be correctly interpreted as f16 when dtype is F16"
713 );
714 }
715
716 #[test]
717 fn test_scalar_value_direct_roundtrip_f16() {
718 let f16_values = vec![
720 f16::from_f32(0.0),
721 f16::from_f32(1.0),
722 f16::from_f32(-1.0),
723 f16::from_f32(0.42),
724 f16::from_f32(5.722046e-6),
725 f16::from_f32(std::f32::consts::PI),
726 f16::INFINITY,
727 f16::NEG_INFINITY,
728 f16::NAN,
729 ];
730
731 for f16_val in f16_values {
732 let scalar_value = ScalarValue::Primitive(PValue::F16(f16_val));
733 let pb_value = ScalarValue::to_proto(Some(&scalar_value));
734 let read_back = ScalarValue::from_proto(
735 &pb_value,
736 &DType::Primitive(PType::F16, Nullability::NonNullable),
737 &session(),
738 )
739 .unwrap();
740
741 match (&scalar_value, read_back.as_ref()) {
742 (
743 ScalarValue::Primitive(PValue::F16(original)),
744 Some(ScalarValue::Primitive(PValue::F16(roundtripped))),
745 ) => {
746 if original.is_nan() && roundtripped.is_nan() {
747 continue;
749 }
750 assert_eq!(
751 original, roundtripped,
752 "F16 value {original:?} did not roundtrip correctly"
753 );
754 }
755 _ => {
756 vortex_panic!(
757 "Expected f16 primitive values, got {scalar_value:?} and {read_back:?}"
758 )
759 }
760 }
761 }
762 }
763
764 #[test]
765 fn test_scalar_value_direct_roundtrip_preserves_values() {
766 let exact_roundtrip_cases: Vec<(&str, Option<ScalarValue>, DType)> = vec![
771 ("null", None, DType::Null),
772 (
773 "bool_true",
774 Some(ScalarValue::Bool(true)),
775 DType::Bool(Nullability::Nullable),
776 ),
777 (
778 "bool_false",
779 Some(ScalarValue::Bool(false)),
780 DType::Bool(Nullability::Nullable),
781 ),
782 (
783 "u64",
784 Some(ScalarValue::Primitive(PValue::U64(18446744073709551615))),
785 DType::Primitive(PType::U64, Nullability::Nullable),
786 ),
787 (
788 "i64",
789 Some(ScalarValue::Primitive(PValue::I64(-9223372036854775808))),
790 DType::Primitive(PType::I64, Nullability::Nullable),
791 ),
792 (
793 "f32",
794 Some(ScalarValue::Primitive(PValue::F32(std::f32::consts::E))),
795 DType::Primitive(PType::F32, Nullability::Nullable),
796 ),
797 (
798 "f64",
799 Some(ScalarValue::Primitive(PValue::F64(std::f64::consts::PI))),
800 DType::Primitive(PType::F64, Nullability::Nullable),
801 ),
802 (
803 "string",
804 Some(ScalarValue::Utf8(BufferString::from("test"))),
805 DType::Utf8(Nullability::Nullable),
806 ),
807 (
808 "bytes",
809 Some(ScalarValue::Binary(vec![1, 2, 3, 4, 5].into())),
810 DType::Binary(Nullability::Nullable),
811 ),
812 ];
813
814 for (name, value, dtype) in exact_roundtrip_cases {
815 let pb_value = ScalarValue::to_proto(value.as_ref());
816 let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
817
818 let original_debug = format!("{value:?}");
819 let roundtrip_debug = format!("{read_back:?}");
820 assert_eq!(
821 original_debug, roundtrip_debug,
822 "ScalarValue {name} did not roundtrip exactly"
823 );
824 }
825
826 let unsigned_cases = vec![
829 (
830 "u8",
831 ScalarValue::Primitive(PValue::U8(255)),
832 DType::Primitive(PType::U8, Nullability::Nullable),
833 255u64,
834 ),
835 (
836 "u16",
837 ScalarValue::Primitive(PValue::U16(65535)),
838 DType::Primitive(PType::U16, Nullability::Nullable),
839 65535u64,
840 ),
841 (
842 "u32",
843 ScalarValue::Primitive(PValue::U32(4294967295)),
844 DType::Primitive(PType::U32, Nullability::Nullable),
845 4294967295u64,
846 ),
847 ];
848
849 for (name, value, dtype, expected) in unsigned_cases {
850 let pb_value = ScalarValue::to_proto(Some(&value));
851 let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
852
853 match read_back.as_ref() {
854 Some(ScalarValue::Primitive(pv)) => {
855 let v = match pv {
856 PValue::U8(v) => *v as u64,
857 PValue::U16(v) => *v as u64,
858 PValue::U32(v) => *v as u64,
859 PValue::U64(v) => *v,
860 _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
861 };
862 assert_eq!(
863 v, expected,
864 "ScalarValue {name} value not preserved: expected {expected}, got {v}"
865 );
866 }
867 _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
868 }
869 }
870
871 let signed_cases = vec![
873 (
874 "i8",
875 ScalarValue::Primitive(PValue::I8(-128)),
876 DType::Primitive(PType::I8, Nullability::Nullable),
877 -128i64,
878 ),
879 (
880 "i16",
881 ScalarValue::Primitive(PValue::I16(-32768)),
882 DType::Primitive(PType::I16, Nullability::Nullable),
883 -32768i64,
884 ),
885 (
886 "i32",
887 ScalarValue::Primitive(PValue::I32(-2147483648)),
888 DType::Primitive(PType::I32, Nullability::Nullable),
889 -2147483648i64,
890 ),
891 ];
892
893 for (name, value, dtype, expected) in signed_cases {
894 let pb_value = ScalarValue::to_proto(Some(&value));
895 let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
896
897 match read_back.as_ref() {
898 Some(ScalarValue::Primitive(pv)) => {
899 let v = match pv {
900 PValue::I8(v) => *v as i64,
901 PValue::I16(v) => *v as i64,
902 PValue::I32(v) => *v as i64,
903 PValue::I64(v) => *v,
904 _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
905 };
906 assert_eq!(
907 v, expected,
908 "ScalarValue {name} value not preserved: expected {expected}, got {v}"
909 );
910 }
911 _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
912 }
913 }
914 }
915
916 #[test]
919 fn test_backcompat_signed_integer_deserialized_as_unsigned() {
920 let v = ScalarValue::Primitive(PValue::I64(0));
921 assert_eq!(
922 Scalar::from_proto_value(
923 &pb::ScalarValue::from(&v),
924 &DType::Primitive(PType::U64, Nullability::Nullable),
925 &session()
926 )
927 .unwrap(),
928 Scalar::primitive(0u64, Nullability::Nullable)
929 );
930 }
931
932 #[test]
935 fn test_backcompat_unsigned_integer_deserialized_as_signed() {
936 let v = ScalarValue::Primitive(PValue::U64(0));
937 assert_eq!(
938 Scalar::from_proto_value(
939 &pb::ScalarValue::from(&v),
940 &DType::Primitive(PType::I64, Nullability::Nullable),
941 &session()
942 )
943 .unwrap(),
944 Scalar::primitive(0i64, Nullability::Nullable)
945 );
946 }
947}