Skip to main content

vortex_array/scalar/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Protobuf serialization and deserialization for scalars.
5
6use num_traits::ToBytes;
7use num_traits::ToPrimitive;
8use prost::Message;
9use vortex_buffer::BufferString;
10use vortex_buffer::ByteBuffer;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_ensure;
15use vortex_error::vortex_err;
16use vortex_proto::scalar as pb;
17use vortex_proto::scalar::ListValue;
18use vortex_proto::scalar::scalar_value::Kind;
19use vortex_session::VortexSession;
20
21use crate::dtype::DType;
22use crate::dtype::PType;
23use crate::dtype::half::f16;
24use crate::dtype::i256;
25use crate::scalar::DecimalValue;
26use crate::scalar::PValue;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30////////////////////////////////////////////////////////////////////////////////////////////////////
31// Serialize INTO proto.
32////////////////////////////////////////////////////////////////////////////////////////////////////
33
34impl From<&Scalar> for pb::Scalar {
35    fn from(value: &Scalar) -> Self {
36        pb::Scalar {
37            dtype: Some(
38                (value.dtype())
39                    .try_into()
40                    .vortex_expect("Failed to convert DType to proto"),
41            ),
42            value: Some(ScalarValue::to_proto(value.value())),
43        }
44    }
45}
46
47impl ScalarValue {
48    /// Ideally, we would not have this function and instead implement this `From` implementation:
49    ///
50    /// ```ignore
51    /// impl From<Option<&ScalarValue>> for pb::ScalarValue { ... }
52    /// ```
53    ///
54    /// However, we are not allowed to do this because of the Orphan rule (`Option` and
55    /// `pb::ScalarValue` are not types defined in this crate). So we must make this a method on
56    /// `vortex_array::scalar::ScalarValue` directly.
57    pub fn to_proto(this: Option<&Self>) -> pb::ScalarValue {
58        match this {
59            None => pb::ScalarValue {
60                kind: Some(Kind::NullValue(0)),
61            },
62            Some(this) => pb::ScalarValue::from(this),
63        }
64    }
65
66    /// Serialize an optional [`ScalarValue`] to protobuf bytes (handles null values).
67    pub fn to_proto_bytes<B: Default + bytes::BufMut>(value: Option<&ScalarValue>) -> B {
68        let proto = Self::to_proto(value);
69        let mut buf = B::default();
70        proto
71            .encode(&mut buf)
72            .vortex_expect("Failed to encode scalar value");
73        buf
74    }
75}
76
77impl From<&ScalarValue> for pb::ScalarValue {
78    fn from(value: &ScalarValue) -> Self {
79        match value {
80            ScalarValue::Bool(v) => pb::ScalarValue {
81                kind: Some(Kind::BoolValue(*v)),
82            },
83            ScalarValue::Primitive(v) => pb::ScalarValue::from(v),
84            ScalarValue::Decimal(v) => {
85                let inner_value = match v {
86                    DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
87                    DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
88                    DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
89                    DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
90                    DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
91                    DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
92                };
93
94                pb::ScalarValue {
95                    kind: Some(Kind::BytesValue(inner_value)),
96                }
97            }
98            ScalarValue::Utf8(v) => pb::ScalarValue {
99                kind: Some(Kind::StringValue(v.to_string())),
100            },
101            ScalarValue::Binary(v) => pb::ScalarValue {
102                kind: Some(Kind::BytesValue(v.to_vec())),
103            },
104            ScalarValue::List(v) => {
105                let mut values = Vec::with_capacity(v.len());
106                for elem in v.iter() {
107                    values.push(ScalarValue::to_proto(elem.as_ref()));
108                }
109                pb::ScalarValue {
110                    kind: Some(Kind::ListValue(ListValue { values })),
111                }
112            }
113        }
114    }
115}
116
117impl From<&PValue> for pb::ScalarValue {
118    fn from(value: &PValue) -> Self {
119        match value {
120            PValue::I8(v) => pb::ScalarValue {
121                kind: Some(Kind::Int64Value(*v as i64)),
122            },
123            PValue::I16(v) => pb::ScalarValue {
124                kind: Some(Kind::Int64Value(*v as i64)),
125            },
126            PValue::I32(v) => pb::ScalarValue {
127                kind: Some(Kind::Int64Value(*v as i64)),
128            },
129            PValue::I64(v) => pb::ScalarValue {
130                kind: Some(Kind::Int64Value(*v)),
131            },
132            PValue::U8(v) => pb::ScalarValue {
133                kind: Some(Kind::Uint64Value(*v as u64)),
134            },
135            PValue::U16(v) => pb::ScalarValue {
136                kind: Some(Kind::Uint64Value(*v as u64)),
137            },
138            PValue::U32(v) => pb::ScalarValue {
139                kind: Some(Kind::Uint64Value(*v as u64)),
140            },
141            PValue::U64(v) => pb::ScalarValue {
142                kind: Some(Kind::Uint64Value(*v)),
143            },
144            PValue::F16(v) => pb::ScalarValue {
145                kind: Some(Kind::F16Value(v.to_bits() as u64)),
146            },
147            PValue::F32(v) => pb::ScalarValue {
148                kind: Some(Kind::F32Value(*v)),
149            },
150            PValue::F64(v) => pb::ScalarValue {
151                kind: Some(Kind::F64Value(*v)),
152            },
153        }
154    }
155}
156
157////////////////////////////////////////////////////////////////////////////////////////////////////
158// Serialize FROM proto.
159////////////////////////////////////////////////////////////////////////////////////////////////////
160
161impl Scalar {
162    /// Creates a [`Scalar`] from a [protobuf `ScalarValue`](pb::ScalarValue) representation.
163    ///
164    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
165    /// integers, and serializing _into_ protobuf loses that type information.
166    ///
167    /// # Errors
168    ///
169    /// Returns an error if type validation fails.
170    pub fn from_proto_value(
171        value: &pb::ScalarValue,
172        dtype: &DType,
173        session: &VortexSession,
174    ) -> VortexResult<Self> {
175        let scalar_value = ScalarValue::from_proto(value, dtype, session)?;
176
177        Scalar::try_new(dtype.clone(), scalar_value)
178    }
179
180    /// Creates a [`Scalar`] from its [protobuf](pb::Scalar) representation.
181    ///
182    /// # Errors
183    ///
184    /// Returns an error if the protobuf is missing required fields or if type validation fails.
185    pub fn from_proto(value: &pb::Scalar, session: &VortexSession) -> VortexResult<Self> {
186        let dtype = DType::from_proto(
187            value
188                .dtype
189                .as_ref()
190                .ok_or_else(|| vortex_err!(Serde: "Scalar missing dtype"))?,
191            session,
192        )?;
193
194        let pb_scalar_value: &pb::ScalarValue = value
195            .value
196            .as_ref()
197            .ok_or_else(|| vortex_err!(Serde: "Scalar missing value"))?;
198
199        let value: Option<ScalarValue> = ScalarValue::from_proto(pb_scalar_value, &dtype, session)?;
200
201        Scalar::try_new(dtype, value)
202    }
203}
204
205impl ScalarValue {
206    /// Deserialize a [`ScalarValue`] from protobuf bytes.
207    ///
208    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
209    /// integers, and serializing _into_ protobuf loses that type information.
210    ///
211    /// # Errors
212    ///
213    /// Returns an error if decoding or type validation fails.
214    pub fn from_proto_bytes(
215        bytes: &[u8],
216        dtype: &DType,
217        session: &VortexSession,
218    ) -> VortexResult<Option<Self>> {
219        let proto = pb::ScalarValue::decode(bytes)?;
220        Self::from_proto(&proto, dtype, session)
221    }
222
223    /// Creates a [`ScalarValue`] from its [protobuf](pb::ScalarValue) representation.
224    ///
225    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
226    /// integers, and serializing _into_ protobuf loses that type information.
227    ///
228    /// # Errors
229    ///
230    /// Returns an error if the protobuf value cannot be converted to the given [`DType`].
231    pub fn from_proto(
232        value: &pb::ScalarValue,
233        dtype: &DType,
234        session: &VortexSession,
235    ) -> VortexResult<Option<Self>> {
236        let kind = value
237            .kind
238            .as_ref()
239            .ok_or_else(|| vortex_err!(Serde: "Scalar value missing kind"))?;
240
241        // `DType::Extension` store their serialized values using the storage `DType`.
242        let dtype = match dtype {
243            DType::Extension(ext) => ext.storage_dtype(),
244            _ => dtype,
245        };
246
247        Ok(Some(match kind {
248            Kind::NullValue(_) => return Ok(None),
249            Kind::BoolValue(v) => bool_from_proto(*v, dtype)?,
250            Kind::Int64Value(v) => int64_from_proto(*v, dtype)?,
251            Kind::Uint64Value(v) => uint64_from_proto(*v, dtype)?,
252            Kind::F16Value(v) => f16_from_proto(*v, dtype)?,
253            Kind::F32Value(v) => f32_from_proto(*v, dtype)?,
254            Kind::F64Value(v) => f64_from_proto(*v, dtype)?,
255            Kind::StringValue(s) => string_from_proto(s, dtype)?,
256            Kind::BytesValue(b) => bytes_from_proto(b, dtype)?,
257            Kind::ListValue(v) => list_from_proto(v, dtype, session)?,
258        }))
259    }
260}
261
262/// Deserialize a [`ScalarValue::Bool`] from a protobuf `BoolValue`.
263fn bool_from_proto(v: bool, dtype: &DType) -> VortexResult<ScalarValue> {
264    vortex_ensure!(
265        dtype.is_boolean(),
266        Serde: "expected Bool dtype for BoolValue, got {dtype}"
267    );
268
269    Ok(ScalarValue::Bool(v))
270}
271
272/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Int64Value`.
273///
274/// Protobuf consolidates all signed integers into `i64`, so we narrow back to the original
275/// type using the provided [`DType`].
276fn int64_from_proto(v: i64, dtype: &DType) -> VortexResult<ScalarValue> {
277    vortex_ensure!(
278        dtype.is_primitive(),
279        Serde: "expected Primitive dtype for Int64Value, got {dtype}"
280    );
281
282    let pvalue = match dtype.as_ptype() {
283        PType::I8 => v.to_i8().map(PValue::I8),
284        PType::I16 => v.to_i16().map(PValue::I16),
285        PType::I32 => v.to_i32().map(PValue::I32),
286        PType::I64 => Some(PValue::I64(v)),
287        // It was previously possible for unsigned types to get their stats serialised as signed,
288        // so we allow casting back to unsigned for backwards compatibility.
289        PType::U8 => v.to_u8().map(PValue::U8),
290        PType::U16 => v.to_u16().map(PValue::U16),
291        PType::U32 => v.to_u32().map(PValue::U32),
292        PType::U64 => v.to_u64().map(PValue::U64),
293        ftype @ (PType::F16 | PType::F32 | PType::F64) => vortex_bail!(
294            Serde: "expected signed integer ptype for serialized Int64Value, got float {ftype}"
295        ),
296    }
297    .ok_or_else(|| vortex_err!(Serde: "Int64 value {v} out of range for dtype {dtype}"))?;
298
299    Ok(ScalarValue::Primitive(pvalue))
300}
301
302/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Uint64Value`.
303///
304/// Protobuf consolidates all unsigned integers into `u64`, so we narrow back to the original
305/// type using the provided [`DType`]. Also handles the backwards-compatible case where `f16`
306/// values were serialized as `u64` (via `f16::to_bits() as u64`).
307fn uint64_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
308    vortex_ensure!(
309        dtype.is_primitive(),
310        Serde: "expected Primitive dtype for Uint64Value, got {dtype}"
311    );
312
313    let pvalue = match dtype.as_ptype() {
314        PType::U8 => v.to_u8().map(PValue::U8),
315        PType::U16 => v.to_u16().map(PValue::U16),
316        PType::U32 => v.to_u32().map(PValue::U32),
317        PType::U64 => Some(PValue::U64(v)),
318        // It was previously possible for signed types to get their stats serialised as unsigned,
319        // so we allow casting back to signed for backwards compatibility.
320        PType::I8 => v.to_i8().map(PValue::I8),
321        PType::I16 => v.to_i16().map(PValue::I16),
322        PType::I32 => v.to_i32().map(PValue::I32),
323        PType::I64 => v.to_i64().map(PValue::I64),
324        // f16 values used to be serialized as u64, so we need to be able to read an f16 from a u64.
325        PType::F16 => v.to_u16().map(f16::from_bits).map(PValue::F16),
326        ftype @ (PType::F32 | PType::F64) => vortex_bail!(
327            Serde: "expected unsigned integer ptype for serialized Uint64Value, got {ftype}"
328        ),
329    }
330    .ok_or_else(|| vortex_err!(Serde: "Uint64 value {v} out of range for dtype {dtype}"))?;
331
332    Ok(ScalarValue::Primitive(pvalue))
333}
334
335/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F16Value`.
336fn f16_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
337    vortex_ensure!(
338        matches!(dtype, DType::Primitive(PType::F16, _)),
339        Serde: "expected F16 dtype for F16Value, got {dtype}"
340    );
341
342    let bits = u16::try_from(v)
343        .map_err(|_| vortex_err!(Serde: "f16 bitwise representation has more than 16 bits: {v}"))?;
344
345    Ok(ScalarValue::Primitive(PValue::F16(f16::from_bits(bits))))
346}
347
348/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F32Value`.
349fn f32_from_proto(v: f32, dtype: &DType) -> VortexResult<ScalarValue> {
350    vortex_ensure!(
351        matches!(dtype, DType::Primitive(PType::F32, _)),
352        Serde: "expected F32 dtype for F32Value, got {dtype}"
353    );
354
355    Ok(ScalarValue::Primitive(PValue::F32(v)))
356}
357
358/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F64Value`.
359fn f64_from_proto(v: f64, dtype: &DType) -> VortexResult<ScalarValue> {
360    vortex_ensure!(
361        matches!(dtype, DType::Primitive(PType::F64, _)),
362        Serde: "expected F64 dtype for F64Value, got {dtype}"
363    );
364
365    Ok(ScalarValue::Primitive(PValue::F64(v)))
366}
367
368/// Deserialize a [`ScalarValue::Utf8`] or [`ScalarValue::Binary`] from a protobuf
369/// `StringValue`.
370fn string_from_proto(s: &str, dtype: &DType) -> VortexResult<ScalarValue> {
371    match dtype {
372        DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::from(s))),
373        DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(s.as_bytes()))),
374        _ => vortex_bail!(
375            Serde: "expected Utf8 or Binary dtype for StringValue, got {dtype}"
376        ),
377    }
378}
379
380/// Deserialize a [`ScalarValue`] from a protobuf bytes and a `DType`.
381///
382/// Handles [`Utf8`](ScalarValue::Utf8), [`Binary`](ScalarValue::Binary), and
383/// [`Decimal`](ScalarValue::Decimal) dtypes.
384fn bytes_from_proto(bytes: &[u8], dtype: &DType) -> VortexResult<ScalarValue> {
385    match dtype {
386        DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::try_from(bytes)?)),
387        DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(bytes))),
388        // TODO(connor): This is incorrect, we need to verify this matches the inner decimal_dtype.
389        DType::Decimal(..) => Ok(ScalarValue::Decimal(match bytes.len() {
390            1 => DecimalValue::I8(bytes[0] as i8),
391            2 => DecimalValue::I16(i16::from_le_bytes(
392                bytes
393                    .try_into()
394                    .ok()
395                    .vortex_expect("Buffer has invalid number of bytes"),
396            )),
397            4 => DecimalValue::I32(i32::from_le_bytes(
398                bytes
399                    .try_into()
400                    .ok()
401                    .vortex_expect("Buffer has invalid number of bytes"),
402            )),
403            8 => DecimalValue::I64(i64::from_le_bytes(
404                bytes
405                    .try_into()
406                    .ok()
407                    .vortex_expect("Buffer has invalid number of bytes"),
408            )),
409            16 => DecimalValue::I128(i128::from_le_bytes(
410                bytes
411                    .try_into()
412                    .ok()
413                    .vortex_expect("Buffer has invalid number of bytes"),
414            )),
415            32 => DecimalValue::I256(i256::from_le_bytes(
416                bytes
417                    .try_into()
418                    .ok()
419                    .vortex_expect("Buffer has invalid number of bytes"),
420            )),
421            l => vortex_bail!(Serde: "invalid decimal byte length: {l}"),
422        })),
423        _ => vortex_bail!(
424            Serde: "expected Utf8, Binary, or Decimal dtype for BytesValue, got {dtype}"
425        ),
426    }
427}
428
429/// Deserialize a [`ScalarValue::List`] from a protobuf `ListValue`.
430fn list_from_proto(
431    v: &ListValue,
432    dtype: &DType,
433    session: &VortexSession,
434) -> VortexResult<ScalarValue> {
435    let element_dtype = dtype
436        .as_list_element_opt()
437        .ok_or_else(|| vortex_err!(Serde: "expected List dtype for ListValue, got {dtype}"))?;
438
439    let mut values = Vec::with_capacity(v.values.len());
440    for elem in v.values.iter() {
441        values.push(ScalarValue::from_proto(
442            elem,
443            element_dtype.as_ref(),
444            session,
445        )?);
446    }
447
448    Ok(ScalarValue::List(values))
449}
450
451#[cfg(test)]
452mod tests {
453    use std::sync::Arc;
454
455    use vortex_buffer::BufferString;
456    use vortex_error::vortex_panic;
457    use vortex_proto::scalar as pb;
458    use vortex_session::VortexSession;
459
460    use super::*;
461    use crate::dtype::DType;
462    use crate::dtype::DecimalDType;
463    use crate::dtype::Nullability;
464    use crate::dtype::PType;
465    use crate::dtype::half::f16;
466    use crate::scalar::DecimalValue;
467    use crate::scalar::Scalar;
468    use crate::scalar::ScalarValue;
469
470    fn session() -> VortexSession {
471        VortexSession::empty()
472    }
473
474    fn round_trip(scalar: Scalar) {
475        assert_eq!(
476            scalar,
477            Scalar::from_proto(&pb::Scalar::from(&scalar), &session()).unwrap(),
478        );
479    }
480
481    #[test]
482    fn test_null() {
483        round_trip(Scalar::null(DType::Null));
484    }
485
486    #[test]
487    fn test_bool() {
488        round_trip(Scalar::new(
489            DType::Bool(Nullability::Nullable),
490            Some(ScalarValue::Bool(true)),
491        ));
492    }
493
494    #[test]
495    fn test_primitive() {
496        round_trip(Scalar::new(
497            DType::Primitive(PType::I32, Nullability::Nullable),
498            Some(ScalarValue::Primitive(42i32.into())),
499        ));
500    }
501
502    #[test]
503    fn test_buffer() {
504        round_trip(Scalar::new(
505            DType::Binary(Nullability::Nullable),
506            Some(ScalarValue::Binary(vec![1, 2, 3].into())),
507        ));
508    }
509
510    #[test]
511    fn test_buffer_string() {
512        round_trip(Scalar::new(
513            DType::Utf8(Nullability::Nullable),
514            Some(ScalarValue::Utf8(BufferString::from("hello".to_string()))),
515        ));
516    }
517
518    #[test]
519    fn test_list() {
520        round_trip(Scalar::new(
521            DType::List(
522                Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
523                Nullability::Nullable,
524            ),
525            Some(ScalarValue::List(vec![
526                Some(ScalarValue::Primitive(42i32.into())),
527                Some(ScalarValue::Primitive(43i32.into())),
528            ])),
529        ));
530    }
531
532    #[test]
533    fn test_f16() {
534        round_trip(Scalar::primitive(
535            f16::from_f32(0.42),
536            Nullability::Nullable,
537        ));
538    }
539
540    #[test]
541    fn test_i8() {
542        round_trip(Scalar::new(
543            DType::Primitive(PType::I8, Nullability::Nullable),
544            Some(ScalarValue::Primitive(i8::MIN.into())),
545        ));
546
547        round_trip(Scalar::new(
548            DType::Primitive(PType::I8, Nullability::Nullable),
549            Some(ScalarValue::Primitive(0i8.into())),
550        ));
551
552        round_trip(Scalar::new(
553            DType::Primitive(PType::I8, Nullability::Nullable),
554            Some(ScalarValue::Primitive(i8::MAX.into())),
555        ));
556    }
557
558    #[test]
559    fn test_decimal_i32_roundtrip() {
560        // A typical decimal with moderate precision and scale.
561        round_trip(Scalar::decimal(
562            DecimalValue::I32(123_456),
563            DecimalDType::new(10, 2),
564            Nullability::NonNullable,
565        ));
566    }
567
568    #[test]
569    fn test_decimal_i128_roundtrip() {
570        // A large decimal value that requires i128 storage.
571        round_trip(Scalar::decimal(
572            DecimalValue::I128(99_999_999_999_999_999_999),
573            DecimalDType::new(38, 6),
574            Nullability::Nullable,
575        ));
576    }
577
578    #[test]
579    fn test_decimal_null_roundtrip() {
580        round_trip(Scalar::null(DType::Decimal(
581            DecimalDType::new(10, 2),
582            Nullability::Nullable,
583        )));
584    }
585
586    #[test]
587    fn test_scalar_value_serde_roundtrip_binary() {
588        round_trip(Scalar::binary(
589            ByteBuffer::copy_from(b"hello"),
590            Nullability::NonNullable,
591        ));
592    }
593
594    #[test]
595    fn test_scalar_value_serde_roundtrip_utf8() {
596        round_trip(Scalar::utf8("hello", Nullability::NonNullable));
597    }
598
599    #[test]
600    fn test_backcompat_f16_serialized_as_u64() {
601        // Backwards compatibility test for the legacy f16 serialization format.
602        //
603        // Previously, f16 ScalarValues were serialized as `Uint64Value(v.to_bits() as u64)` because
604        // the proto schema only had 64-bit integer types, and f16's underlying representation is
605        // u16 which got widened to u64.
606        //
607        // The current implementation uses a dedicated `F16Value` proto field, but we must still be
608        // able to deserialize the old format. This test verifies that:
609        //
610        // 1. A `Uint64Value` containing f16 bits can be read as a U64 primitive (the raw bits).
611        // 2. When wrapped in a Scalar with F16 dtype, the value is correctly interpreted as f16.
612        //
613        // This ensures data written with the old serialization format remains readable.
614
615        // Simulate the old serialization: f16(0.42) stored as Uint64Value with its bit pattern.
616        let f16_value = f16::from_f32(0.42);
617        let f16_bits_as_u64 = f16_value.to_bits() as u64; // 14008
618
619        let pb_scalar_value = pb::ScalarValue {
620            kind: Some(Kind::Uint64Value(f16_bits_as_u64)),
621        };
622
623        // Step 1: Verify the normal U64 scalar.
624        let scalar_value = ScalarValue::from_proto(
625            &pb_scalar_value,
626            &DType::Primitive(PType::U64, Nullability::NonNullable),
627            &session(),
628        )
629        .unwrap();
630        assert_eq!(
631            scalar_value.as_ref().map(|v| v.as_primitive()),
632            Some(&PValue::U64(14008u64)),
633        );
634
635        // Step 2: Verify that when we use F16 dtype, the Uint64Value is correctly interpreted.
636        let scalar_value_f16 = ScalarValue::from_proto(
637            &pb_scalar_value,
638            &DType::Primitive(PType::F16, Nullability::Nullable),
639            &session(),
640        )
641        .unwrap();
642
643        let scalar = Scalar::new(
644            DType::Primitive(PType::F16, Nullability::Nullable),
645            scalar_value_f16,
646        );
647
648        assert_eq!(
649            scalar.as_primitive().pvalue().unwrap(),
650            PValue::F16(f16::from_f32(0.42)),
651            "Uint64Value should be correctly interpreted as f16 when dtype is F16"
652        );
653    }
654
655    #[test]
656    fn test_scalar_value_direct_roundtrip_f16() {
657        // Test that ScalarValue with f16 roundtrips correctly without going through Scalar.
658        let f16_values = vec![
659            f16::from_f32(0.0),
660            f16::from_f32(1.0),
661            f16::from_f32(-1.0),
662            f16::from_f32(0.42),
663            f16::from_f32(5.722046e-6),
664            f16::from_f32(std::f32::consts::PI),
665            f16::INFINITY,
666            f16::NEG_INFINITY,
667            f16::NAN,
668        ];
669
670        for f16_val in f16_values {
671            let scalar_value = ScalarValue::Primitive(PValue::F16(f16_val));
672            let pb_value = ScalarValue::to_proto(Some(&scalar_value));
673            let read_back = ScalarValue::from_proto(
674                &pb_value,
675                &DType::Primitive(PType::F16, Nullability::NonNullable),
676                &session(),
677            )
678            .unwrap();
679
680            match (&scalar_value, read_back.as_ref()) {
681                (
682                    ScalarValue::Primitive(PValue::F16(original)),
683                    Some(ScalarValue::Primitive(PValue::F16(roundtripped))),
684                ) => {
685                    if original.is_nan() && roundtripped.is_nan() {
686                        // NaN values are equal for our purposes.
687                        continue;
688                    }
689                    assert_eq!(
690                        original, roundtripped,
691                        "F16 value {original:?} did not roundtrip correctly"
692                    );
693                }
694                _ => {
695                    vortex_panic!(
696                        "Expected f16 primitive values, got {scalar_value:?} and {read_back:?}"
697                    )
698                }
699            }
700        }
701    }
702
703    #[test]
704    fn test_scalar_value_direct_roundtrip_preserves_values() {
705        // Test that ScalarValue roundtripping preserves values (but not necessarily exact types).
706        // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64).
707
708        // Test cases that should roundtrip exactly.
709        let exact_roundtrip_cases: Vec<(&str, Option<ScalarValue>, DType)> = vec![
710            ("null", None, DType::Null),
711            (
712                "bool_true",
713                Some(ScalarValue::Bool(true)),
714                DType::Bool(Nullability::Nullable),
715            ),
716            (
717                "bool_false",
718                Some(ScalarValue::Bool(false)),
719                DType::Bool(Nullability::Nullable),
720            ),
721            (
722                "u64",
723                Some(ScalarValue::Primitive(PValue::U64(18446744073709551615))),
724                DType::Primitive(PType::U64, Nullability::Nullable),
725            ),
726            (
727                "i64",
728                Some(ScalarValue::Primitive(PValue::I64(-9223372036854775808))),
729                DType::Primitive(PType::I64, Nullability::Nullable),
730            ),
731            (
732                "f32",
733                Some(ScalarValue::Primitive(PValue::F32(std::f32::consts::E))),
734                DType::Primitive(PType::F32, Nullability::Nullable),
735            ),
736            (
737                "f64",
738                Some(ScalarValue::Primitive(PValue::F64(std::f64::consts::PI))),
739                DType::Primitive(PType::F64, Nullability::Nullable),
740            ),
741            (
742                "string",
743                Some(ScalarValue::Utf8(BufferString::from("test"))),
744                DType::Utf8(Nullability::Nullable),
745            ),
746            (
747                "bytes",
748                Some(ScalarValue::Binary(vec![1, 2, 3, 4, 5].into())),
749                DType::Binary(Nullability::Nullable),
750            ),
751        ];
752
753        for (name, value, dtype) in exact_roundtrip_cases {
754            let pb_value = ScalarValue::to_proto(value.as_ref());
755            let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
756
757            let original_debug = format!("{value:?}");
758            let roundtrip_debug = format!("{read_back:?}");
759            assert_eq!(
760                original_debug, roundtrip_debug,
761                "ScalarValue {name} did not roundtrip exactly"
762            );
763        }
764
765        // Test cases where type changes but value is preserved.
766        // Unsigned integers consolidate to U64.
767        let unsigned_cases = vec![
768            (
769                "u8",
770                ScalarValue::Primitive(PValue::U8(255)),
771                DType::Primitive(PType::U8, Nullability::Nullable),
772                255u64,
773            ),
774            (
775                "u16",
776                ScalarValue::Primitive(PValue::U16(65535)),
777                DType::Primitive(PType::U16, Nullability::Nullable),
778                65535u64,
779            ),
780            (
781                "u32",
782                ScalarValue::Primitive(PValue::U32(4294967295)),
783                DType::Primitive(PType::U32, Nullability::Nullable),
784                4294967295u64,
785            ),
786        ];
787
788        for (name, value, dtype, expected) in unsigned_cases {
789            let pb_value = ScalarValue::to_proto(Some(&value));
790            let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
791
792            match read_back.as_ref() {
793                Some(ScalarValue::Primitive(pv)) => {
794                    let v = match pv {
795                        PValue::U8(v) => *v as u64,
796                        PValue::U16(v) => *v as u64,
797                        PValue::U32(v) => *v as u64,
798                        PValue::U64(v) => *v,
799                        _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
800                    };
801                    assert_eq!(
802                        v, expected,
803                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
804                    );
805                }
806                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
807            }
808        }
809
810        // Signed integers consolidate to I64.
811        let signed_cases = vec![
812            (
813                "i8",
814                ScalarValue::Primitive(PValue::I8(-128)),
815                DType::Primitive(PType::I8, Nullability::Nullable),
816                -128i64,
817            ),
818            (
819                "i16",
820                ScalarValue::Primitive(PValue::I16(-32768)),
821                DType::Primitive(PType::I16, Nullability::Nullable),
822                -32768i64,
823            ),
824            (
825                "i32",
826                ScalarValue::Primitive(PValue::I32(-2147483648)),
827                DType::Primitive(PType::I32, Nullability::Nullable),
828                -2147483648i64,
829            ),
830        ];
831
832        for (name, value, dtype, expected) in signed_cases {
833            let pb_value = ScalarValue::to_proto(Some(&value));
834            let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
835
836            match read_back.as_ref() {
837                Some(ScalarValue::Primitive(pv)) => {
838                    let v = match pv {
839                        PValue::I8(v) => *v as i64,
840                        PValue::I16(v) => *v as i64,
841                        PValue::I32(v) => *v as i64,
842                        PValue::I64(v) => *v,
843                        _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
844                    };
845                    assert_eq!(
846                        v, expected,
847                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
848                    );
849                }
850                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
851            }
852        }
853    }
854
855    // Backwards compatibility: signed integer stats could previously be serialized as unsigned.
856    // Therefore, we allow casting between signed and unsigned integers of the same bit width.
857    #[test]
858    fn test_backcompat_signed_integer_deserialized_as_unsigned() {
859        let v = ScalarValue::Primitive(PValue::I64(0));
860        assert_eq!(
861            Scalar::from_proto_value(
862                &pb::ScalarValue::from(&v),
863                &DType::Primitive(PType::U64, Nullability::Nullable),
864                &session()
865            )
866            .unwrap(),
867            Scalar::primitive(0u64, Nullability::Nullable)
868        );
869    }
870
871    // Backwards compatibility: unsigned integer stats could previously be serialized as signed.
872    // Therefore, we allow casting between signed and unsigned integers of the same bit width.
873    #[test]
874    fn test_backcompat_unsigned_integer_deserialized_as_signed() {
875        let v = ScalarValue::Primitive(PValue::U64(0));
876        assert_eq!(
877            Scalar::from_proto_value(
878                &pb::ScalarValue::from(&v),
879                &DType::Primitive(PType::I64, Nullability::Nullable),
880                &session()
881            )
882            .unwrap(),
883            Scalar::primitive(0i64, Nullability::Nullable)
884        );
885    }
886}