Skip to main content

vortex_array/scalar/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Protobuf serialization and deserialization for scalars.
5
6use num_traits::ToBytes;
7use num_traits::ToPrimitive;
8use prost::Message;
9use vortex_buffer::BufferString;
10use vortex_buffer::ByteBuffer;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_ensure;
15use vortex_error::vortex_err;
16use vortex_proto::scalar as pb;
17use vortex_proto::scalar::ListValue;
18use vortex_proto::scalar::scalar_value::Kind;
19use vortex_session::VortexSession;
20
21use crate::dtype::DType;
22use crate::dtype::PType;
23use crate::dtype::half::f16;
24use crate::dtype::i256;
25use crate::scalar::DecimalValue;
26use crate::scalar::PValue;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30////////////////////////////////////////////////////////////////////////////////////////////////////
31// Serialize INTO proto.
32////////////////////////////////////////////////////////////////////////////////////////////////////
33
34impl From<&Scalar> for pb::Scalar {
35    fn from(value: &Scalar) -> Self {
36        pb::Scalar {
37            dtype: Some(
38                (value.dtype())
39                    .try_into()
40                    .vortex_expect("Failed to convert DType to proto"),
41            ),
42            value: Some(ScalarValue::to_proto(value.value())),
43        }
44    }
45}
46
47impl ScalarValue {
48    /// Ideally, we would not have this function and instead implement this `From` implementation:
49    ///
50    /// ```ignore
51    /// impl From<Option<&ScalarValue>> for pb::ScalarValue { ... }
52    /// ```
53    ///
54    /// However, we are not allowed to do this because of the Orphan rule (`Option` and
55    /// `pb::ScalarValue` are not types defined in this crate). So we must make this a method on
56    /// `vortex_array::scalar::ScalarValue` directly.
57    pub fn to_proto(this: Option<&Self>) -> pb::ScalarValue {
58        match this {
59            None => pb::ScalarValue {
60                kind: Some(Kind::NullValue(0)),
61            },
62            Some(this) => pb::ScalarValue::from(this),
63        }
64    }
65
66    /// Serialize an optional [`ScalarValue`] to protobuf bytes (handles null values).
67    pub fn to_proto_bytes<B: Default + bytes::BufMut>(value: Option<&ScalarValue>) -> B {
68        let proto = Self::to_proto(value);
69        let mut buf = B::default();
70        proto
71            .encode(&mut buf)
72            .vortex_expect("Failed to encode scalar value");
73        buf
74    }
75}
76
77impl From<&ScalarValue> for pb::ScalarValue {
78    fn from(value: &ScalarValue) -> Self {
79        match value {
80            ScalarValue::Bool(v) => pb::ScalarValue {
81                kind: Some(Kind::BoolValue(*v)),
82            },
83            ScalarValue::Primitive(v) => pb::ScalarValue::from(v),
84            ScalarValue::Decimal(v) => {
85                let inner_value = match v {
86                    DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
87                    DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
88                    DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
89                    DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
90                    DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
91                    DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
92                };
93
94                pb::ScalarValue {
95                    kind: Some(Kind::BytesValue(inner_value)),
96                }
97            }
98            ScalarValue::Utf8(v) => pb::ScalarValue {
99                kind: Some(Kind::StringValue(v.to_string())),
100            },
101            ScalarValue::Binary(v) => pb::ScalarValue {
102                kind: Some(Kind::BytesValue(v.to_vec())),
103            },
104            ScalarValue::List(v) => {
105                let mut values = Vec::with_capacity(v.len());
106                for elem in v.iter() {
107                    values.push(ScalarValue::to_proto(elem.as_ref()));
108                }
109                pb::ScalarValue {
110                    kind: Some(Kind::ListValue(ListValue { values })),
111                }
112            }
113        }
114    }
115}
116
117impl From<&PValue> for pb::ScalarValue {
118    fn from(value: &PValue) -> Self {
119        match value {
120            PValue::I8(v) => pb::ScalarValue {
121                kind: Some(Kind::Int64Value(*v as i64)),
122            },
123            PValue::I16(v) => pb::ScalarValue {
124                kind: Some(Kind::Int64Value(*v as i64)),
125            },
126            PValue::I32(v) => pb::ScalarValue {
127                kind: Some(Kind::Int64Value(*v as i64)),
128            },
129            PValue::I64(v) => pb::ScalarValue {
130                kind: Some(Kind::Int64Value(*v)),
131            },
132            PValue::U8(v) => pb::ScalarValue {
133                kind: Some(Kind::Uint64Value(*v as u64)),
134            },
135            PValue::U16(v) => pb::ScalarValue {
136                kind: Some(Kind::Uint64Value(*v as u64)),
137            },
138            PValue::U32(v) => pb::ScalarValue {
139                kind: Some(Kind::Uint64Value(*v as u64)),
140            },
141            PValue::U64(v) => pb::ScalarValue {
142                kind: Some(Kind::Uint64Value(*v)),
143            },
144            PValue::F16(v) => pb::ScalarValue {
145                kind: Some(Kind::F16Value(v.to_bits() as u64)),
146            },
147            PValue::F32(v) => pb::ScalarValue {
148                kind: Some(Kind::F32Value(*v)),
149            },
150            PValue::F64(v) => pb::ScalarValue {
151                kind: Some(Kind::F64Value(*v)),
152            },
153        }
154    }
155}
156
157////////////////////////////////////////////////////////////////////////////////////////////////////
158// Serialize FROM proto.
159////////////////////////////////////////////////////////////////////////////////////////////////////
160
161impl Scalar {
162    /// Creates a [`Scalar`] from a [protobuf `ScalarValue`](pb::ScalarValue) representation.
163    ///
164    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
165    /// integers, and serializing _into_ protobuf loses that type information.
166    ///
167    /// # Errors
168    ///
169    /// Returns an error if type validation fails.
170    pub fn from_proto_value(value: &pb::ScalarValue, dtype: &DType) -> VortexResult<Self> {
171        let scalar_value = ScalarValue::from_proto(value, dtype)?;
172
173        Scalar::try_new(dtype.clone(), scalar_value)
174    }
175
176    /// Creates a [`Scalar`] from its [protobuf](pb::Scalar) representation.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if the protobuf is missing required fields or if type validation fails.
181    pub fn from_proto(value: &pb::Scalar, session: &VortexSession) -> VortexResult<Self> {
182        let dtype = DType::from_proto(
183            value
184                .dtype
185                .as_ref()
186                .ok_or_else(|| vortex_err!(Serde: "Scalar missing dtype"))?,
187            session,
188        )?;
189
190        let pb_scalar_value: &pb::ScalarValue = value
191            .value
192            .as_ref()
193            .ok_or_else(|| vortex_err!(Serde: "Scalar missing value"))?;
194
195        let value: Option<ScalarValue> = ScalarValue::from_proto(pb_scalar_value, &dtype)?;
196
197        Scalar::try_new(dtype, value)
198    }
199}
200
201impl ScalarValue {
202    /// Deserialize a [`ScalarValue`] from protobuf bytes.
203    ///
204    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
205    /// integers, and serializing _into_ protobuf loses that type information.
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if decoding or type validation fails.
210    pub fn from_proto_bytes(bytes: &[u8], dtype: &DType) -> VortexResult<Option<Self>> {
211        let proto = pb::ScalarValue::decode(bytes)?;
212        Self::from_proto(&proto, dtype)
213    }
214
215    /// Creates a [`ScalarValue`] from its [protobuf](pb::ScalarValue) representation.
216    ///
217    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
218    /// integers, and serializing _into_ protobuf loses that type information.
219    ///
220    /// # Errors
221    ///
222    /// Returns an error if the protobuf value cannot be converted to the given [`DType`].
223    pub fn from_proto(value: &pb::ScalarValue, dtype: &DType) -> VortexResult<Option<Self>> {
224        let kind = value
225            .kind
226            .as_ref()
227            .ok_or_else(|| vortex_err!(Serde: "Scalar value missing kind"))?;
228
229        // `DType::Extension` store their serialized values using the storage `DType`.
230        let dtype = match dtype {
231            DType::Extension(ext) => ext.storage_dtype(),
232            _ => dtype,
233        };
234
235        Ok(Some(match kind {
236            Kind::NullValue(_) => return Ok(None),
237            Kind::BoolValue(v) => bool_from_proto(*v, dtype)?,
238            Kind::Int64Value(v) => int64_from_proto(*v, dtype)?,
239            Kind::Uint64Value(v) => uint64_from_proto(*v, dtype)?,
240            Kind::F16Value(v) => f16_from_proto(*v, dtype)?,
241            Kind::F32Value(v) => f32_from_proto(*v, dtype)?,
242            Kind::F64Value(v) => f64_from_proto(*v, dtype)?,
243            Kind::StringValue(s) => string_from_proto(s, dtype)?,
244            Kind::BytesValue(b) => bytes_from_proto(b, dtype)?,
245            Kind::ListValue(v) => list_from_proto(v, dtype)?,
246        }))
247    }
248}
249
250/// Deserialize a [`ScalarValue::Bool`] from a protobuf `BoolValue`.
251fn bool_from_proto(v: bool, dtype: &DType) -> VortexResult<ScalarValue> {
252    vortex_ensure!(
253        dtype.is_boolean(),
254        Serde: "expected Bool dtype for BoolValue, got {dtype}"
255    );
256
257    Ok(ScalarValue::Bool(v))
258}
259
260/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Int64Value`.
261///
262/// Protobuf consolidates all signed integers into `i64`, so we narrow back to the original
263/// type using the provided [`DType`].
264fn int64_from_proto(v: i64, dtype: &DType) -> VortexResult<ScalarValue> {
265    vortex_ensure!(
266        dtype.is_primitive(),
267        Serde: "expected Primitive dtype for Int64Value, got {dtype}"
268    );
269
270    let pvalue = match dtype.as_ptype() {
271        PType::I8 => v.to_i8().map(PValue::I8),
272        PType::I16 => v.to_i16().map(PValue::I16),
273        PType::I32 => v.to_i32().map(PValue::I32),
274        PType::I64 => Some(PValue::I64(v)),
275        // It was previously possible for unsigned types to get their stats serialised as signed,
276        // so we allow casting back to unsigned for backwards compatibility.
277        PType::U8 => v.to_u8().map(PValue::U8),
278        PType::U16 => v.to_u16().map(PValue::U16),
279        PType::U32 => v.to_u32().map(PValue::U32),
280        PType::U64 => v.to_u64().map(PValue::U64),
281        ftype @ (PType::F16 | PType::F32 | PType::F64) => vortex_bail!(
282            Serde: "expected signed integer ptype for serialized Int64Value, got float {ftype}"
283        ),
284    }
285    .ok_or_else(|| vortex_err!(Serde: "Int64 value {v} out of range for dtype {dtype}"))?;
286
287    Ok(ScalarValue::Primitive(pvalue))
288}
289
290/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Uint64Value`.
291///
292/// Protobuf consolidates all unsigned integers into `u64`, so we narrow back to the original
293/// type using the provided [`DType`]. Also handles the backwards-compatible case where `f16`
294/// values were serialized as `u64` (via `f16::to_bits() as u64`).
295fn uint64_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
296    vortex_ensure!(
297        dtype.is_primitive(),
298        Serde: "expected Primitive dtype for Uint64Value, got {dtype}"
299    );
300
301    let pvalue = match dtype.as_ptype() {
302        PType::U8 => v.to_u8().map(PValue::U8),
303        PType::U16 => v.to_u16().map(PValue::U16),
304        PType::U32 => v.to_u32().map(PValue::U32),
305        PType::U64 => Some(PValue::U64(v)),
306        // It was previously possible for signed types to get their stats serialised as unsigned,
307        // so we allow casting back to signed for backwards compatibility.
308        PType::I8 => v.to_i8().map(PValue::I8),
309        PType::I16 => v.to_i16().map(PValue::I16),
310        PType::I32 => v.to_i32().map(PValue::I32),
311        PType::I64 => v.to_i64().map(PValue::I64),
312        // f16 values used to be serialized as u64, so we need to be able to read an f16 from a u64.
313        PType::F16 => v.to_u16().map(f16::from_bits).map(PValue::F16),
314        ftype @ (PType::F32 | PType::F64) => vortex_bail!(
315            Serde: "expected unsigned integer ptype for serialized Uint64Value, got {ftype}"
316        ),
317    }
318    .ok_or_else(|| vortex_err!(Serde: "Uint64 value {v} out of range for dtype {dtype}"))?;
319
320    Ok(ScalarValue::Primitive(pvalue))
321}
322
323/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F16Value`.
324fn f16_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
325    vortex_ensure!(
326        matches!(dtype, DType::Primitive(PType::F16, _)),
327        Serde: "expected F16 dtype for F16Value, got {dtype}"
328    );
329
330    let bits = u16::try_from(v)
331        .map_err(|_| vortex_err!(Serde: "f16 bitwise representation has more than 16 bits: {v}"))?;
332
333    Ok(ScalarValue::Primitive(PValue::F16(f16::from_bits(bits))))
334}
335
336/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F32Value`.
337fn f32_from_proto(v: f32, dtype: &DType) -> VortexResult<ScalarValue> {
338    vortex_ensure!(
339        matches!(dtype, DType::Primitive(PType::F32, _)),
340        Serde: "expected F32 dtype for F32Value, got {dtype}"
341    );
342
343    Ok(ScalarValue::Primitive(PValue::F32(v)))
344}
345
346/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F64Value`.
347fn f64_from_proto(v: f64, dtype: &DType) -> VortexResult<ScalarValue> {
348    vortex_ensure!(
349        matches!(dtype, DType::Primitive(PType::F64, _)),
350        Serde: "expected F64 dtype for F64Value, got {dtype}"
351    );
352
353    Ok(ScalarValue::Primitive(PValue::F64(v)))
354}
355
356/// Deserialize a [`ScalarValue::Utf8`] or [`ScalarValue::Binary`] from a protobuf
357/// `StringValue`.
358fn string_from_proto(s: &str, dtype: &DType) -> VortexResult<ScalarValue> {
359    match dtype {
360        DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::from(s))),
361        DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(s.as_bytes()))),
362        _ => vortex_bail!(
363            Serde: "expected Utf8 or Binary dtype for StringValue, got {dtype}"
364        ),
365    }
366}
367
368/// Deserialize a [`ScalarValue`] from a protobuf bytes and a `DType`.
369///
370/// Handles [`Utf8`](ScalarValue::Utf8), [`Binary`](ScalarValue::Binary), and
371/// [`Decimal`](ScalarValue::Decimal) dtypes.
372fn bytes_from_proto(bytes: &[u8], dtype: &DType) -> VortexResult<ScalarValue> {
373    match dtype {
374        DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::try_from(bytes)?)),
375        DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(bytes))),
376        // TODO(connor): This is incorrect, we need to verify this matches the `dtype`.
377        DType::Decimal(..) => Ok(ScalarValue::Decimal(match bytes.len() {
378            1 => DecimalValue::I8(bytes[0] as i8),
379            2 => DecimalValue::I16(i16::from_le_bytes(
380                bytes
381                    .try_into()
382                    .ok()
383                    .vortex_expect("Buffer has invalid number of bytes"),
384            )),
385            4 => DecimalValue::I32(i32::from_le_bytes(
386                bytes
387                    .try_into()
388                    .ok()
389                    .vortex_expect("Buffer has invalid number of bytes"),
390            )),
391            8 => DecimalValue::I64(i64::from_le_bytes(
392                bytes
393                    .try_into()
394                    .ok()
395                    .vortex_expect("Buffer has invalid number of bytes"),
396            )),
397            16 => DecimalValue::I128(i128::from_le_bytes(
398                bytes
399                    .try_into()
400                    .ok()
401                    .vortex_expect("Buffer has invalid number of bytes"),
402            )),
403            32 => DecimalValue::I256(i256::from_le_bytes(
404                bytes
405                    .try_into()
406                    .ok()
407                    .vortex_expect("Buffer has invalid number of bytes"),
408            )),
409            l => vortex_bail!(Serde: "invalid decimal byte length: {l}"),
410        })),
411        _ => vortex_bail!(
412            Serde: "expected Utf8, Binary, or Decimal dtype for BytesValue, got {dtype}"
413        ),
414    }
415}
416
417/// Deserialize a [`ScalarValue::List`] from a protobuf `ListValue`.
418fn list_from_proto(v: &ListValue, dtype: &DType) -> VortexResult<ScalarValue> {
419    let element_dtype = dtype
420        .as_list_element_opt()
421        .ok_or_else(|| vortex_err!(Serde: "expected List dtype for ListValue, got {dtype}"))?;
422
423    let mut values = Vec::with_capacity(v.values.len());
424    for elem in v.values.iter() {
425        values.push(ScalarValue::from_proto(elem, element_dtype.as_ref())?);
426    }
427
428    Ok(ScalarValue::List(values))
429}
430
431#[cfg(test)]
432mod tests {
433    use std::sync::Arc;
434
435    use vortex_buffer::BufferString;
436    use vortex_error::vortex_panic;
437    use vortex_proto::scalar as pb;
438    use vortex_session::VortexSession;
439
440    use super::*;
441    use crate::dtype::DType;
442    use crate::dtype::DecimalDType;
443    use crate::dtype::Nullability;
444    use crate::dtype::PType;
445    use crate::dtype::half::f16;
446    use crate::scalar::DecimalValue;
447    use crate::scalar::Scalar;
448    use crate::scalar::ScalarValue;
449
450    fn session() -> VortexSession {
451        VortexSession::empty()
452    }
453
454    fn round_trip(scalar: Scalar) {
455        assert_eq!(
456            scalar,
457            Scalar::from_proto(&pb::Scalar::from(&scalar), &session()).unwrap(),
458        );
459    }
460
461    #[test]
462    fn test_null() {
463        round_trip(Scalar::null(DType::Null));
464    }
465
466    #[test]
467    fn test_bool() {
468        round_trip(Scalar::new(
469            DType::Bool(Nullability::Nullable),
470            Some(ScalarValue::Bool(true)),
471        ));
472    }
473
474    #[test]
475    fn test_primitive() {
476        round_trip(Scalar::new(
477            DType::Primitive(PType::I32, Nullability::Nullable),
478            Some(ScalarValue::Primitive(42i32.into())),
479        ));
480    }
481
482    #[test]
483    fn test_buffer() {
484        round_trip(Scalar::new(
485            DType::Binary(Nullability::Nullable),
486            Some(ScalarValue::Binary(vec![1, 2, 3].into())),
487        ));
488    }
489
490    #[test]
491    fn test_buffer_string() {
492        round_trip(Scalar::new(
493            DType::Utf8(Nullability::Nullable),
494            Some(ScalarValue::Utf8(BufferString::from("hello".to_string()))),
495        ));
496    }
497
498    #[test]
499    fn test_list() {
500        round_trip(Scalar::new(
501            DType::List(
502                Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
503                Nullability::Nullable,
504            ),
505            Some(ScalarValue::List(vec![
506                Some(ScalarValue::Primitive(42i32.into())),
507                Some(ScalarValue::Primitive(43i32.into())),
508            ])),
509        ));
510    }
511
512    #[test]
513    fn test_f16() {
514        round_trip(Scalar::primitive(
515            f16::from_f32(0.42),
516            Nullability::Nullable,
517        ));
518    }
519
520    #[test]
521    fn test_i8() {
522        round_trip(Scalar::new(
523            DType::Primitive(PType::I8, Nullability::Nullable),
524            Some(ScalarValue::Primitive(i8::MIN.into())),
525        ));
526
527        round_trip(Scalar::new(
528            DType::Primitive(PType::I8, Nullability::Nullable),
529            Some(ScalarValue::Primitive(0i8.into())),
530        ));
531
532        round_trip(Scalar::new(
533            DType::Primitive(PType::I8, Nullability::Nullable),
534            Some(ScalarValue::Primitive(i8::MAX.into())),
535        ));
536    }
537
538    #[test]
539    fn test_decimal_i32_roundtrip() {
540        // A typical decimal with moderate precision and scale.
541        round_trip(Scalar::decimal(
542            DecimalValue::I32(123_456),
543            DecimalDType::new(10, 2),
544            Nullability::NonNullable,
545        ));
546    }
547
548    #[test]
549    fn test_decimal_i128_roundtrip() {
550        // A large decimal value that requires i128 storage.
551        round_trip(Scalar::decimal(
552            DecimalValue::I128(99_999_999_999_999_999_999),
553            DecimalDType::new(38, 6),
554            Nullability::Nullable,
555        ));
556    }
557
558    #[test]
559    fn test_decimal_null_roundtrip() {
560        round_trip(Scalar::null(DType::Decimal(
561            DecimalDType::new(10, 2),
562            Nullability::Nullable,
563        )));
564    }
565
566    #[test]
567    fn test_scalar_value_serde_roundtrip_binary() {
568        round_trip(Scalar::binary(
569            ByteBuffer::copy_from(b"hello"),
570            Nullability::NonNullable,
571        ));
572    }
573
574    #[test]
575    fn test_scalar_value_serde_roundtrip_utf8() {
576        round_trip(Scalar::utf8("hello", Nullability::NonNullable));
577    }
578
579    #[test]
580    fn test_backcompat_f16_serialized_as_u64() {
581        // Backwards compatibility test for the legacy f16 serialization format.
582        //
583        // Previously, f16 ScalarValues were serialized as `Uint64Value(v.to_bits() as u64)` because
584        // the proto schema only had 64-bit integer types, and f16's underlying representation is
585        // u16 which got widened to u64.
586        //
587        // The current implementation uses a dedicated `F16Value` proto field, but we must still be
588        // able to deserialize the old format. This test verifies that:
589        //
590        // 1. A `Uint64Value` containing f16 bits can be read as a U64 primitive (the raw bits).
591        // 2. When wrapped in a Scalar with F16 dtype, the value is correctly interpreted as f16.
592        //
593        // This ensures data written with the old serialization format remains readable.
594
595        // Simulate the old serialization: f16(0.42) stored as Uint64Value with its bit pattern.
596        let f16_value = f16::from_f32(0.42);
597        let f16_bits_as_u64 = f16_value.to_bits() as u64; // 14008
598
599        let pb_scalar_value = pb::ScalarValue {
600            kind: Some(Kind::Uint64Value(f16_bits_as_u64)),
601        };
602
603        // Step 1: Verify the normal U64 scalar.
604        let scalar_value = ScalarValue::from_proto(
605            &pb_scalar_value,
606            &DType::Primitive(PType::U64, Nullability::NonNullable),
607        )
608        .unwrap();
609        assert_eq!(
610            scalar_value.as_ref().map(|v| v.as_primitive()),
611            Some(&PValue::U64(14008u64)),
612        );
613
614        // Step 2: Verify that when we use F16 dtype, the Uint64Value is correctly interpreted.
615        let scalar_value_f16 = ScalarValue::from_proto(
616            &pb_scalar_value,
617            &DType::Primitive(PType::F16, Nullability::Nullable),
618        )
619        .unwrap();
620
621        let scalar = Scalar::new(
622            DType::Primitive(PType::F16, Nullability::Nullable),
623            scalar_value_f16,
624        );
625
626        assert_eq!(
627            scalar.as_primitive().pvalue().unwrap(),
628            PValue::F16(f16::from_f32(0.42)),
629            "Uint64Value should be correctly interpreted as f16 when dtype is F16"
630        );
631    }
632
633    #[test]
634    fn test_scalar_value_direct_roundtrip_f16() {
635        // Test that ScalarValue with f16 roundtrips correctly without going through Scalar.
636        let f16_values = vec![
637            f16::from_f32(0.0),
638            f16::from_f32(1.0),
639            f16::from_f32(-1.0),
640            f16::from_f32(0.42),
641            f16::from_f32(5.722046e-6),
642            f16::from_f32(std::f32::consts::PI),
643            f16::INFINITY,
644            f16::NEG_INFINITY,
645            f16::NAN,
646        ];
647
648        for f16_val in f16_values {
649            let scalar_value = ScalarValue::Primitive(PValue::F16(f16_val));
650            let pb_value = ScalarValue::to_proto(Some(&scalar_value));
651            let read_back = ScalarValue::from_proto(
652                &pb_value,
653                &DType::Primitive(PType::F16, Nullability::NonNullable),
654            )
655            .unwrap();
656
657            match (&scalar_value, read_back.as_ref()) {
658                (
659                    ScalarValue::Primitive(PValue::F16(original)),
660                    Some(ScalarValue::Primitive(PValue::F16(roundtripped))),
661                ) => {
662                    if original.is_nan() && roundtripped.is_nan() {
663                        // NaN values are equal for our purposes.
664                        continue;
665                    }
666                    assert_eq!(
667                        original, roundtripped,
668                        "F16 value {original:?} did not roundtrip correctly"
669                    );
670                }
671                _ => {
672                    vortex_panic!(
673                        "Expected f16 primitive values, got {scalar_value:?} and {read_back:?}"
674                    )
675                }
676            }
677        }
678    }
679
680    #[test]
681    fn test_scalar_value_direct_roundtrip_preserves_values() {
682        // Test that ScalarValue roundtripping preserves values (but not necessarily exact types).
683        // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64).
684
685        // Test cases that should roundtrip exactly.
686        let exact_roundtrip_cases: Vec<(&str, Option<ScalarValue>, DType)> = vec![
687            ("null", None, DType::Null),
688            (
689                "bool_true",
690                Some(ScalarValue::Bool(true)),
691                DType::Bool(Nullability::Nullable),
692            ),
693            (
694                "bool_false",
695                Some(ScalarValue::Bool(false)),
696                DType::Bool(Nullability::Nullable),
697            ),
698            (
699                "u64",
700                Some(ScalarValue::Primitive(PValue::U64(18446744073709551615))),
701                DType::Primitive(PType::U64, Nullability::Nullable),
702            ),
703            (
704                "i64",
705                Some(ScalarValue::Primitive(PValue::I64(-9223372036854775808))),
706                DType::Primitive(PType::I64, Nullability::Nullable),
707            ),
708            (
709                "f32",
710                Some(ScalarValue::Primitive(PValue::F32(std::f32::consts::E))),
711                DType::Primitive(PType::F32, Nullability::Nullable),
712            ),
713            (
714                "f64",
715                Some(ScalarValue::Primitive(PValue::F64(std::f64::consts::PI))),
716                DType::Primitive(PType::F64, Nullability::Nullable),
717            ),
718            (
719                "string",
720                Some(ScalarValue::Utf8(BufferString::from("test"))),
721                DType::Utf8(Nullability::Nullable),
722            ),
723            (
724                "bytes",
725                Some(ScalarValue::Binary(vec![1, 2, 3, 4, 5].into())),
726                DType::Binary(Nullability::Nullable),
727            ),
728        ];
729
730        for (name, value, dtype) in exact_roundtrip_cases {
731            let pb_value = ScalarValue::to_proto(value.as_ref());
732            let read_back = ScalarValue::from_proto(&pb_value, &dtype).unwrap();
733
734            let original_debug = format!("{value:?}");
735            let roundtrip_debug = format!("{read_back:?}");
736            assert_eq!(
737                original_debug, roundtrip_debug,
738                "ScalarValue {name} did not roundtrip exactly"
739            );
740        }
741
742        // Test cases where type changes but value is preserved.
743        // Unsigned integers consolidate to U64.
744        let unsigned_cases = vec![
745            (
746                "u8",
747                ScalarValue::Primitive(PValue::U8(255)),
748                DType::Primitive(PType::U8, Nullability::Nullable),
749                255u64,
750            ),
751            (
752                "u16",
753                ScalarValue::Primitive(PValue::U16(65535)),
754                DType::Primitive(PType::U16, Nullability::Nullable),
755                65535u64,
756            ),
757            (
758                "u32",
759                ScalarValue::Primitive(PValue::U32(4294967295)),
760                DType::Primitive(PType::U32, Nullability::Nullable),
761                4294967295u64,
762            ),
763        ];
764
765        for (name, value, dtype, expected) in unsigned_cases {
766            let pb_value = ScalarValue::to_proto(Some(&value));
767            let read_back = ScalarValue::from_proto(&pb_value, &dtype).unwrap();
768
769            match read_back.as_ref() {
770                Some(ScalarValue::Primitive(pv)) => {
771                    let v = match pv {
772                        PValue::U8(v) => *v as u64,
773                        PValue::U16(v) => *v as u64,
774                        PValue::U32(v) => *v as u64,
775                        PValue::U64(v) => *v,
776                        _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
777                    };
778                    assert_eq!(
779                        v, expected,
780                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
781                    );
782                }
783                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
784            }
785        }
786
787        // Signed integers consolidate to I64.
788        let signed_cases = vec![
789            (
790                "i8",
791                ScalarValue::Primitive(PValue::I8(-128)),
792                DType::Primitive(PType::I8, Nullability::Nullable),
793                -128i64,
794            ),
795            (
796                "i16",
797                ScalarValue::Primitive(PValue::I16(-32768)),
798                DType::Primitive(PType::I16, Nullability::Nullable),
799                -32768i64,
800            ),
801            (
802                "i32",
803                ScalarValue::Primitive(PValue::I32(-2147483648)),
804                DType::Primitive(PType::I32, Nullability::Nullable),
805                -2147483648i64,
806            ),
807        ];
808
809        for (name, value, dtype, expected) in signed_cases {
810            let pb_value = ScalarValue::to_proto(Some(&value));
811            let read_back = ScalarValue::from_proto(&pb_value, &dtype).unwrap();
812
813            match read_back.as_ref() {
814                Some(ScalarValue::Primitive(pv)) => {
815                    let v = match pv {
816                        PValue::I8(v) => *v as i64,
817                        PValue::I16(v) => *v as i64,
818                        PValue::I32(v) => *v as i64,
819                        PValue::I64(v) => *v,
820                        _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
821                    };
822                    assert_eq!(
823                        v, expected,
824                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
825                    );
826                }
827                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
828            }
829        }
830    }
831
832    // Backwards compatibility: signed integer stats could previously be serialized as unsigned.
833    // Therefore, we allow casting between signed and unsigned integers of the same bit width.
834    #[test]
835    fn test_backcompat_signed_integer_deserialized_as_unsigned() {
836        let v = ScalarValue::Primitive(PValue::I64(0));
837        assert_eq!(
838            Scalar::from_proto_value(
839                &pb::ScalarValue::from(&v),
840                &DType::Primitive(PType::U64, Nullability::Nullable)
841            )
842            .unwrap(),
843            Scalar::primitive(0u64, Nullability::Nullable)
844        );
845    }
846
847    // Backwards compatibility: unsigned integer stats could previously be serialized as signed.
848    // Therefore, we allow casting between signed and unsigned integers of the same bit width.
849    #[test]
850    fn test_backcompat_unsigned_integer_deserialized_as_signed() {
851        let v = ScalarValue::Primitive(PValue::U64(0));
852        assert_eq!(
853            Scalar::from_proto_value(
854                &pb::ScalarValue::from(&v),
855                &DType::Primitive(PType::I64, Nullability::Nullable)
856            )
857            .unwrap(),
858            Scalar::primitive(0i64, Nullability::Nullable)
859        );
860    }
861}