Skip to main content

vortex_array/scalar/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Protobuf serialization and deserialization for scalars.
5
6use num_traits::ToBytes;
7use num_traits::ToPrimitive;
8use prost::Message;
9use vortex_buffer::BufferString;
10use vortex_buffer::ByteBuffer;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_ensure;
15use vortex_error::vortex_err;
16use vortex_proto::scalar as pb;
17use vortex_proto::scalar::ListValue;
18use vortex_proto::scalar::scalar_value::Kind;
19use vortex_session::VortexSession;
20
21use crate::dtype::DType;
22use crate::dtype::PType;
23use crate::dtype::half::f16;
24use crate::dtype::i256;
25use crate::scalar::DecimalValue;
26use crate::scalar::PValue;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30////////////////////////////////////////////////////////////////////////////////////////////////////
31// Serialize INTO proto.
32////////////////////////////////////////////////////////////////////////////////////////////////////
33
34impl From<&Scalar> for pb::Scalar {
35    fn from(value: &Scalar) -> Self {
36        pb::Scalar {
37            dtype: Some(
38                (value.dtype())
39                    .try_into()
40                    .vortex_expect("Failed to convert DType to proto"),
41            ),
42            value: Some(Box::new(ScalarValue::to_proto(value.value()))),
43        }
44    }
45}
46
47impl ScalarValue {
48    /// Ideally, we would not have this function and instead implement this `From` implementation:
49    ///
50    /// ```ignore
51    /// impl From<Option<&ScalarValue>> for pb::ScalarValue { ... }
52    /// ```
53    ///
54    /// However, we are not allowed to do this because of the Orphan rule (`Option` and
55    /// `pb::ScalarValue` are not types defined in this crate). So we must make this a method on
56    /// `vortex_array::scalar::ScalarValue` directly.
57    pub fn to_proto(this: Option<&Self>) -> pb::ScalarValue {
58        match this {
59            None => pb::ScalarValue {
60                kind: Some(Kind::NullValue(0)),
61            },
62            Some(this) => pb::ScalarValue::from(this),
63        }
64    }
65
66    /// Serialize an optional [`ScalarValue`] to protobuf bytes (handles null values).
67    pub fn to_proto_bytes<B: Default + bytes::BufMut>(value: Option<&ScalarValue>) -> B {
68        let proto = Self::to_proto(value);
69        let mut buf = B::default();
70        proto
71            .encode(&mut buf)
72            .vortex_expect("Failed to encode scalar value");
73        buf
74    }
75}
76
77impl From<&ScalarValue> for pb::ScalarValue {
78    fn from(value: &ScalarValue) -> Self {
79        match value {
80            ScalarValue::Bool(v) => pb::ScalarValue {
81                kind: Some(Kind::BoolValue(*v)),
82            },
83            ScalarValue::Primitive(v) => pb::ScalarValue::from(v),
84            ScalarValue::Decimal(v) => {
85                let inner_value = match v {
86                    DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
87                    DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
88                    DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
89                    DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
90                    DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
91                    DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
92                };
93
94                pb::ScalarValue {
95                    kind: Some(Kind::BytesValue(inner_value)),
96                }
97            }
98            ScalarValue::Utf8(v) => pb::ScalarValue {
99                kind: Some(Kind::StringValue(v.to_string())),
100            },
101            ScalarValue::Binary(v) => pb::ScalarValue {
102                kind: Some(Kind::BytesValue(v.to_vec())),
103            },
104            ScalarValue::Tuple(v) => {
105                let mut values = Vec::with_capacity(v.len());
106                for elem in v.iter() {
107                    values.push(ScalarValue::to_proto(elem.as_ref()));
108                }
109                pb::ScalarValue {
110                    kind: Some(Kind::ListValue(ListValue { values })),
111                }
112            }
113            ScalarValue::Variant(v) => pb::ScalarValue {
114                kind: Some(Kind::VariantValue(Box::new(pb::Scalar::from(v.as_ref())))),
115            },
116        }
117    }
118}
119
120impl From<&PValue> for pb::ScalarValue {
121    fn from(value: &PValue) -> Self {
122        match value {
123            PValue::I8(v) => pb::ScalarValue {
124                kind: Some(Kind::Int64Value(*v as i64)),
125            },
126            PValue::I16(v) => pb::ScalarValue {
127                kind: Some(Kind::Int64Value(*v as i64)),
128            },
129            PValue::I32(v) => pb::ScalarValue {
130                kind: Some(Kind::Int64Value(*v as i64)),
131            },
132            PValue::I64(v) => pb::ScalarValue {
133                kind: Some(Kind::Int64Value(*v)),
134            },
135            PValue::U8(v) => pb::ScalarValue {
136                kind: Some(Kind::Uint64Value(*v as u64)),
137            },
138            PValue::U16(v) => pb::ScalarValue {
139                kind: Some(Kind::Uint64Value(*v as u64)),
140            },
141            PValue::U32(v) => pb::ScalarValue {
142                kind: Some(Kind::Uint64Value(*v as u64)),
143            },
144            PValue::U64(v) => pb::ScalarValue {
145                kind: Some(Kind::Uint64Value(*v)),
146            },
147            PValue::F16(v) => pb::ScalarValue {
148                kind: Some(Kind::F16Value(v.to_bits() as u64)),
149            },
150            PValue::F32(v) => pb::ScalarValue {
151                kind: Some(Kind::F32Value(*v)),
152            },
153            PValue::F64(v) => pb::ScalarValue {
154                kind: Some(Kind::F64Value(*v)),
155            },
156        }
157    }
158}
159
160////////////////////////////////////////////////////////////////////////////////////////////////////
161// Serialize FROM proto.
162////////////////////////////////////////////////////////////////////////////////////////////////////
163
164impl Scalar {
165    /// Creates a [`Scalar`] from a [protobuf `ScalarValue`](pb::ScalarValue) representation.
166    ///
167    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
168    /// integers, and serializing _into_ protobuf loses that type information.
169    ///
170    /// # Errors
171    ///
172    /// Returns an error if type validation fails.
173    pub fn from_proto_value(
174        value: &pb::ScalarValue,
175        dtype: &DType,
176        session: &VortexSession,
177    ) -> VortexResult<Self> {
178        let scalar_value = ScalarValue::from_proto(value, dtype, session)?;
179
180        Scalar::try_new(dtype.clone(), scalar_value)
181    }
182
183    /// Creates a [`Scalar`] from its [protobuf](pb::Scalar) representation.
184    ///
185    /// # Errors
186    ///
187    /// Returns an error if the protobuf is missing required fields or if type validation fails.
188    pub fn from_proto(value: &pb::Scalar, session: &VortexSession) -> VortexResult<Self> {
189        let dtype = DType::from_proto(
190            value
191                .dtype
192                .as_ref()
193                .ok_or_else(|| vortex_err!(Serde: "Scalar missing dtype"))?,
194            session,
195        )?;
196
197        let pb_scalar_value: &pb::ScalarValue = value
198            .value
199            .as_ref()
200            .ok_or_else(|| vortex_err!(Serde: "Scalar missing value"))?;
201
202        let value: Option<ScalarValue> = ScalarValue::from_proto(pb_scalar_value, &dtype, session)?;
203
204        Scalar::try_new(dtype, value)
205    }
206}
207
208impl ScalarValue {
209    /// Deserialize a [`ScalarValue`] from protobuf bytes.
210    ///
211    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
212    /// integers, and serializing _into_ protobuf loses that type information.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if decoding or type validation fails.
217    pub fn from_proto_bytes(
218        bytes: &[u8],
219        dtype: &DType,
220        session: &VortexSession,
221    ) -> VortexResult<Option<Self>> {
222        let proto = pb::ScalarValue::decode(bytes)?;
223        Self::from_proto(&proto, dtype, session)
224    }
225
226    /// Creates a [`ScalarValue`] from its [protobuf](pb::ScalarValue) representation.
227    ///
228    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
229    /// integers, and serializing _into_ protobuf loses that type information.
230    ///
231    /// # Errors
232    ///
233    /// Returns an error if the protobuf value cannot be converted to the given [`DType`].
234    pub fn from_proto(
235        value: &pb::ScalarValue,
236        dtype: &DType,
237        session: &VortexSession,
238    ) -> VortexResult<Option<Self>> {
239        let kind = value
240            .kind
241            .as_ref()
242            .ok_or_else(|| vortex_err!(Serde: "Scalar value missing kind"))?;
243
244        // `DType::Extension` store their serialized values using the storage `DType`.
245        let dtype = match dtype {
246            DType::Extension(ext) => ext.storage_dtype(),
247            _ => dtype,
248        };
249
250        Ok(match kind {
251            Kind::NullValue(_) => None,
252            Kind::BoolValue(v) => Some(bool_from_proto(*v, dtype)?),
253            Kind::Int64Value(v) => Some(int64_from_proto(*v, dtype)?),
254            Kind::Uint64Value(v) => Some(uint64_from_proto(*v, dtype)?),
255            Kind::F16Value(v) => Some(f16_from_proto(*v, dtype)?),
256            Kind::F32Value(v) => Some(f32_from_proto(*v, dtype)?),
257            Kind::F64Value(v) => Some(f64_from_proto(*v, dtype)?),
258            Kind::StringValue(s) => Some(string_from_proto(s, dtype)?),
259            Kind::BytesValue(b) => Some(bytes_from_proto(b, dtype)?),
260            Kind::ListValue(v) => Some(list_from_proto(v, dtype, session)?),
261            Kind::VariantValue(v) => match dtype {
262                DType::Variant(_) => Some(ScalarValue::Variant(Box::new(Scalar::from_proto(
263                    v, session,
264                )?))),
265                _ => vortex_bail!(Serde: "expected non-Variant scalar proto for dtype {dtype}"),
266            },
267        })
268    }
269}
270
271/// Deserialize a [`ScalarValue::Bool`] from a protobuf `BoolValue`.
272fn bool_from_proto(v: bool, dtype: &DType) -> VortexResult<ScalarValue> {
273    vortex_ensure!(
274        dtype.is_boolean(),
275        Serde: "expected Bool dtype for BoolValue, got {dtype}"
276    );
277
278    Ok(ScalarValue::Bool(v))
279}
280
281/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Int64Value`.
282///
283/// Protobuf consolidates all signed integers into `i64`, so we narrow back to the original
284/// type using the provided [`DType`].
285fn int64_from_proto(v: i64, dtype: &DType) -> VortexResult<ScalarValue> {
286    vortex_ensure!(
287        dtype.is_primitive(),
288        Serde: "expected Primitive dtype for Int64Value, got {dtype}"
289    );
290
291    let pvalue = match dtype.as_ptype() {
292        PType::I8 => v.to_i8().map(PValue::I8),
293        PType::I16 => v.to_i16().map(PValue::I16),
294        PType::I32 => v.to_i32().map(PValue::I32),
295        PType::I64 => Some(PValue::I64(v)),
296        // It was previously possible for unsigned types to get their stats serialised as signed,
297        // so we allow casting back to unsigned for backwards compatibility.
298        PType::U8 => v.to_u8().map(PValue::U8),
299        PType::U16 => v.to_u16().map(PValue::U16),
300        PType::U32 => v.to_u32().map(PValue::U32),
301        PType::U64 => v.to_u64().map(PValue::U64),
302        ftype @ (PType::F16 | PType::F32 | PType::F64) => vortex_bail!(
303            Serde: "expected signed integer ptype for serialized Int64Value, got float {ftype}"
304        ),
305    }
306    .ok_or_else(|| vortex_err!(Serde: "Int64 value {v} out of range for dtype {dtype}"))?;
307
308    Ok(ScalarValue::Primitive(pvalue))
309}
310
311/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Uint64Value`.
312///
313/// Protobuf consolidates all unsigned integers into `u64`, so we narrow back to the original
314/// type using the provided [`DType`]. Also handles the backwards-compatible case where `f16`
315/// values were serialized as `u64` (via `f16::to_bits() as u64`).
316fn uint64_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
317    vortex_ensure!(
318        dtype.is_primitive(),
319        Serde: "expected Primitive dtype for Uint64Value, got {dtype}"
320    );
321
322    let pvalue = match dtype.as_ptype() {
323        PType::U8 => v.to_u8().map(PValue::U8),
324        PType::U16 => v.to_u16().map(PValue::U16),
325        PType::U32 => v.to_u32().map(PValue::U32),
326        PType::U64 => Some(PValue::U64(v)),
327        // It was previously possible for signed types to get their stats serialised as unsigned,
328        // so we allow casting back to signed for backwards compatibility.
329        PType::I8 => v.to_i8().map(PValue::I8),
330        PType::I16 => v.to_i16().map(PValue::I16),
331        PType::I32 => v.to_i32().map(PValue::I32),
332        PType::I64 => v.to_i64().map(PValue::I64),
333        // f16 values used to be serialized as u64, so we need to be able to read an f16 from a u64.
334        PType::F16 => v.to_u16().map(f16::from_bits).map(PValue::F16),
335        ftype @ (PType::F32 | PType::F64) => vortex_bail!(
336            Serde: "expected unsigned integer ptype for serialized Uint64Value, got {ftype}"
337        ),
338    }
339    .ok_or_else(|| vortex_err!(Serde: "Uint64 value {v} out of range for dtype {dtype}"))?;
340
341    Ok(ScalarValue::Primitive(pvalue))
342}
343
344/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F16Value`.
345fn f16_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
346    vortex_ensure!(
347        matches!(dtype, DType::Primitive(PType::F16, _)),
348        Serde: "expected F16 dtype for F16Value, got {dtype}"
349    );
350
351    let bits = u16::try_from(v)
352        .map_err(|_| vortex_err!(Serde: "f16 bitwise representation has more than 16 bits: {v}"))?;
353
354    Ok(ScalarValue::Primitive(PValue::F16(f16::from_bits(bits))))
355}
356
357/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F32Value`.
358fn f32_from_proto(v: f32, dtype: &DType) -> VortexResult<ScalarValue> {
359    vortex_ensure!(
360        matches!(dtype, DType::Primitive(PType::F32, _)),
361        Serde: "expected F32 dtype for F32Value, got {dtype}"
362    );
363
364    Ok(ScalarValue::Primitive(PValue::F32(v)))
365}
366
367/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F64Value`.
368fn f64_from_proto(v: f64, dtype: &DType) -> VortexResult<ScalarValue> {
369    vortex_ensure!(
370        matches!(dtype, DType::Primitive(PType::F64, _)),
371        Serde: "expected F64 dtype for F64Value, got {dtype}"
372    );
373
374    Ok(ScalarValue::Primitive(PValue::F64(v)))
375}
376
377/// Deserialize a [`ScalarValue::Utf8`] or [`ScalarValue::Binary`] from a protobuf
378/// `StringValue`.
379fn string_from_proto(s: &str, dtype: &DType) -> VortexResult<ScalarValue> {
380    match dtype {
381        DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::from(s))),
382        DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(s.as_bytes()))),
383        _ => vortex_bail!(
384            Serde: "expected Utf8 or Binary dtype for StringValue, got {dtype}"
385        ),
386    }
387}
388
389/// Deserialize a [`ScalarValue`] from a protobuf bytes and a `DType`.
390///
391/// Handles [`Utf8`](ScalarValue::Utf8), [`Binary`](ScalarValue::Binary), and
392/// [`Decimal`](ScalarValue::Decimal) dtypes.
393fn bytes_from_proto(bytes: &[u8], dtype: &DType) -> VortexResult<ScalarValue> {
394    match dtype {
395        DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::try_from(bytes)?)),
396        DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(bytes))),
397        // TODO(connor): This is incorrect, we need to verify this matches the inner decimal_dtype.
398        DType::Decimal(..) => Ok(ScalarValue::Decimal(match bytes.len() {
399            1 => DecimalValue::I8(bytes[0] as i8),
400            2 => DecimalValue::I16(i16::from_le_bytes(
401                bytes
402                    .try_into()
403                    .ok()
404                    .vortex_expect("Buffer has invalid number of bytes"),
405            )),
406            4 => DecimalValue::I32(i32::from_le_bytes(
407                bytes
408                    .try_into()
409                    .ok()
410                    .vortex_expect("Buffer has invalid number of bytes"),
411            )),
412            8 => DecimalValue::I64(i64::from_le_bytes(
413                bytes
414                    .try_into()
415                    .ok()
416                    .vortex_expect("Buffer has invalid number of bytes"),
417            )),
418            16 => DecimalValue::I128(i128::from_le_bytes(
419                bytes
420                    .try_into()
421                    .ok()
422                    .vortex_expect("Buffer has invalid number of bytes"),
423            )),
424            32 => DecimalValue::I256(i256::from_le_bytes(
425                bytes
426                    .try_into()
427                    .ok()
428                    .vortex_expect("Buffer has invalid number of bytes"),
429            )),
430            l => vortex_bail!(Serde: "invalid decimal byte length: {l}"),
431        })),
432        _ => vortex_bail!(
433            Serde: "expected Utf8, Binary, or Decimal dtype for BytesValue, got {dtype}"
434        ),
435    }
436}
437
438/// Deserialize a [`ScalarValue::Tuple`] from a protobuf `ListValue`.
439fn list_from_proto(
440    v: &ListValue,
441    dtype: &DType,
442    session: &VortexSession,
443) -> VortexResult<ScalarValue> {
444    let element_dtype = dtype
445        .as_list_element_opt()
446        .ok_or_else(|| vortex_err!(Serde: "expected List dtype for ListValue, got {dtype}"))?;
447
448    let mut values = Vec::with_capacity(v.values.len());
449    for elem in v.values.iter() {
450        values.push(ScalarValue::from_proto(
451            elem,
452            element_dtype.as_ref(),
453            session,
454        )?);
455    }
456
457    Ok(ScalarValue::Tuple(values))
458}
459
460#[cfg(test)]
461mod tests {
462    use std::f32;
463    use std::f64;
464    use std::sync::Arc;
465
466    use vortex_buffer::BufferString;
467    use vortex_error::vortex_panic;
468    use vortex_proto::scalar as pb;
469    use vortex_session::VortexSession;
470
471    use super::*;
472    use crate::dtype::DType;
473    use crate::dtype::DecimalDType;
474    use crate::dtype::Nullability;
475    use crate::dtype::PType;
476    use crate::dtype::half::f16;
477    use crate::scalar::DecimalValue;
478    use crate::scalar::Scalar;
479    use crate::scalar::ScalarValue;
480
481    fn session() -> VortexSession {
482        VortexSession::empty()
483    }
484
485    fn round_trip(scalar: Scalar) {
486        assert_eq!(
487            scalar,
488            Scalar::from_proto(&pb::Scalar::from(&scalar), &session()).unwrap(),
489        );
490    }
491
492    #[test]
493    fn test_null() {
494        round_trip(Scalar::null(DType::Null));
495    }
496
497    #[test]
498    fn test_bool() {
499        round_trip(Scalar::new(
500            DType::Bool(Nullability::Nullable),
501            Some(ScalarValue::Bool(true)),
502        ));
503    }
504
505    #[test]
506    fn test_primitive() {
507        round_trip(Scalar::new(
508            DType::Primitive(PType::I32, Nullability::Nullable),
509            Some(ScalarValue::Primitive(42i32.into())),
510        ));
511    }
512
513    #[test]
514    fn test_buffer() {
515        round_trip(Scalar::new(
516            DType::Binary(Nullability::Nullable),
517            Some(ScalarValue::Binary(vec![1, 2, 3].into())),
518        ));
519    }
520
521    #[test]
522    fn test_buffer_string() {
523        round_trip(Scalar::new(
524            DType::Utf8(Nullability::Nullable),
525            Some(ScalarValue::Utf8(BufferString::from("hello".to_string()))),
526        ));
527    }
528
529    #[test]
530    fn test_list() {
531        round_trip(Scalar::new(
532            DType::List(
533                Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
534                Nullability::Nullable,
535            ),
536            Some(ScalarValue::Tuple(vec![
537                Some(ScalarValue::Primitive(42i32.into())),
538                Some(ScalarValue::Primitive(43i32.into())),
539            ])),
540        ));
541    }
542
543    #[test]
544    fn test_f16() {
545        round_trip(Scalar::primitive(
546            f16::from_f32(0.42),
547            Nullability::Nullable,
548        ));
549    }
550
551    #[test]
552    fn test_i8() {
553        round_trip(Scalar::new(
554            DType::Primitive(PType::I8, Nullability::Nullable),
555            Some(ScalarValue::Primitive(i8::MIN.into())),
556        ));
557
558        round_trip(Scalar::new(
559            DType::Primitive(PType::I8, Nullability::Nullable),
560            Some(ScalarValue::Primitive(0i8.into())),
561        ));
562
563        round_trip(Scalar::new(
564            DType::Primitive(PType::I8, Nullability::Nullable),
565            Some(ScalarValue::Primitive(i8::MAX.into())),
566        ));
567    }
568
569    #[test]
570    fn test_decimal_i32_roundtrip() {
571        // A typical decimal with moderate precision and scale.
572        round_trip(Scalar::decimal(
573            DecimalValue::I32(123_456),
574            DecimalDType::new(10, 2),
575            Nullability::NonNullable,
576        ));
577    }
578
579    #[test]
580    fn test_decimal_i128_roundtrip() {
581        // A large decimal value that requires i128 storage.
582        round_trip(Scalar::decimal(
583            DecimalValue::I128(99_999_999_999_999_999_999),
584            DecimalDType::new(38, 6),
585            Nullability::Nullable,
586        ));
587    }
588
589    #[test]
590    fn test_decimal_null_roundtrip() {
591        round_trip(Scalar::null(DType::Decimal(
592            DecimalDType::new(10, 2),
593            Nullability::Nullable,
594        )));
595    }
596
597    #[test]
598    fn test_scalar_value_serde_roundtrip_binary() {
599        round_trip(Scalar::binary(
600            ByteBuffer::copy_from(b"hello"),
601            Nullability::NonNullable,
602        ));
603    }
604
605    #[test]
606    fn test_scalar_value_serde_roundtrip_utf8() {
607        round_trip(Scalar::utf8("hello", Nullability::NonNullable));
608    }
609
610    #[test]
611    fn test_variant_scalar_roundtrip() {
612        let nums = Scalar::list(
613            Arc::new(DType::Variant(Nullability::NonNullable)),
614            vec![
615                Scalar::variant(Scalar::primitive(-7_i16, Nullability::NonNullable)),
616                Scalar::variant(Scalar::primitive(42_u32, Nullability::NonNullable)),
617                Scalar::variant(Scalar::decimal(
618                    DecimalValue::I128(123_456_789),
619                    DecimalDType::new(18, 0),
620                    Nullability::NonNullable,
621                )),
622            ],
623            Nullability::NonNullable,
624        );
625
626        let nested = Scalar::list(
627            Arc::new(DType::Variant(Nullability::NonNullable)),
628            vec![
629                Scalar::variant(Scalar::from(true)),
630                Scalar::variant(nums),
631                Scalar::variant(Scalar::binary(
632                    ByteBuffer::copy_from(b"abc"),
633                    Nullability::NonNullable,
634                )),
635                Scalar::variant(Scalar::null(DType::Null)),
636            ],
637            Nullability::NonNullable,
638        );
639
640        round_trip(Scalar::variant(nested));
641    }
642
643    #[test]
644    fn test_variant_scalar_proto_preserves_scalar_null_vs_variant_null() {
645        let scalar_null = Scalar::null(DType::Variant(Nullability::Nullable));
646        let variant_null = Scalar::variant(Scalar::null(DType::Null));
647
648        let scalar_null_pb = pb::Scalar::from(&scalar_null);
649        let variant_null_pb = pb::Scalar::from(&variant_null);
650
651        assert_ne!(scalar_null_pb, variant_null_pb);
652        assert_eq!(
653            Scalar::from_proto(&scalar_null_pb, &session()).unwrap(),
654            scalar_null,
655        );
656        assert_eq!(
657            Scalar::from_proto(&variant_null_pb, &session()).unwrap(),
658            variant_null,
659        );
660    }
661
662    #[test]
663    fn test_backcompat_f16_serialized_as_u64() {
664        // Backwards compatibility test for the legacy f16 serialization format.
665        //
666        // Previously, f16 ScalarValues were serialized as `Uint64Value(v.to_bits() as u64)` because
667        // the proto schema only had 64-bit integer types, and f16's underlying representation is
668        // u16 which got widened to u64.
669        //
670        // The current implementation uses a dedicated `F16Value` proto field, but we must still be
671        // able to deserialize the old format. This test verifies that:
672        //
673        // 1. A `Uint64Value` containing f16 bits can be read as a U64 primitive (the raw bits).
674        // 2. When wrapped in a Scalar with F16 dtype, the value is correctly interpreted as f16.
675        //
676        // This ensures data written with the old serialization format remains readable.
677
678        // Simulate the old serialization: f16(0.42) stored as Uint64Value with its bit pattern.
679        let f16_value = f16::from_f32(0.42);
680        let f16_bits_as_u64 = f16_value.to_bits() as u64; // 14008
681
682        let pb_scalar_value = pb::ScalarValue {
683            kind: Some(Kind::Uint64Value(f16_bits_as_u64)),
684        };
685
686        // Step 1: Verify the normal U64 scalar.
687        let scalar_value = ScalarValue::from_proto(
688            &pb_scalar_value,
689            &DType::Primitive(PType::U64, Nullability::NonNullable),
690            &session(),
691        )
692        .unwrap();
693        assert_eq!(
694            scalar_value.as_ref().map(|v| v.as_primitive()),
695            Some(&PValue::U64(14008u64)),
696        );
697
698        // Step 2: Verify that when we use F16 dtype, the Uint64Value is correctly interpreted.
699        let scalar_value_f16 = ScalarValue::from_proto(
700            &pb_scalar_value,
701            &DType::Primitive(PType::F16, Nullability::Nullable),
702            &session(),
703        )
704        .unwrap();
705
706        let scalar = Scalar::new(
707            DType::Primitive(PType::F16, Nullability::Nullable),
708            scalar_value_f16,
709        );
710
711        assert_eq!(
712            scalar.as_primitive().pvalue().unwrap(),
713            PValue::F16(f16::from_f32(0.42)),
714            "Uint64Value should be correctly interpreted as f16 when dtype is F16"
715        );
716    }
717
718    #[test]
719    fn test_scalar_value_direct_roundtrip_f16() {
720        // Test that ScalarValue with f16 roundtrips correctly without going through Scalar.
721        let f16_values = vec![
722            f16::from_f32(0.0),
723            f16::from_f32(1.0),
724            f16::from_f32(-1.0),
725            f16::from_f32(0.42),
726            f16::from_f32(5.722046e-6),
727            f16::from_f32(f32::consts::PI),
728            f16::INFINITY,
729            f16::NEG_INFINITY,
730            f16::NAN,
731        ];
732
733        for f16_val in f16_values {
734            let scalar_value = ScalarValue::Primitive(PValue::F16(f16_val));
735            let pb_value = ScalarValue::to_proto(Some(&scalar_value));
736            let read_back = ScalarValue::from_proto(
737                &pb_value,
738                &DType::Primitive(PType::F16, Nullability::NonNullable),
739                &session(),
740            )
741            .unwrap();
742
743            match (&scalar_value, read_back.as_ref()) {
744                (
745                    ScalarValue::Primitive(PValue::F16(original)),
746                    Some(ScalarValue::Primitive(PValue::F16(roundtripped))),
747                ) => {
748                    if original.is_nan() && roundtripped.is_nan() {
749                        // NaN values are equal for our purposes.
750                        continue;
751                    }
752                    assert_eq!(
753                        original, roundtripped,
754                        "F16 value {original:?} did not roundtrip correctly"
755                    );
756                }
757                _ => {
758                    vortex_panic!(
759                        "Expected f16 primitive values, got {scalar_value:?} and {read_back:?}"
760                    )
761                }
762            }
763        }
764    }
765
766    #[test]
767    fn test_scalar_value_direct_roundtrip_preserves_values() {
768        // Test that ScalarValue roundtripping preserves values (but not necessarily exact types).
769        // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64).
770
771        // Test cases that should roundtrip exactly.
772        let exact_roundtrip_cases: Vec<(&str, Option<ScalarValue>, DType)> = vec![
773            ("null", None, DType::Null),
774            (
775                "bool_true",
776                Some(ScalarValue::Bool(true)),
777                DType::Bool(Nullability::Nullable),
778            ),
779            (
780                "bool_false",
781                Some(ScalarValue::Bool(false)),
782                DType::Bool(Nullability::Nullable),
783            ),
784            (
785                "u64",
786                Some(ScalarValue::Primitive(PValue::U64(18446744073709551615))),
787                DType::Primitive(PType::U64, Nullability::Nullable),
788            ),
789            (
790                "i64",
791                Some(ScalarValue::Primitive(PValue::I64(-9223372036854775808))),
792                DType::Primitive(PType::I64, Nullability::Nullable),
793            ),
794            (
795                "f32",
796                Some(ScalarValue::Primitive(PValue::F32(f32::consts::E))),
797                DType::Primitive(PType::F32, Nullability::Nullable),
798            ),
799            (
800                "f64",
801                Some(ScalarValue::Primitive(PValue::F64(f64::consts::PI))),
802                DType::Primitive(PType::F64, Nullability::Nullable),
803            ),
804            (
805                "string",
806                Some(ScalarValue::Utf8(BufferString::from("test"))),
807                DType::Utf8(Nullability::Nullable),
808            ),
809            (
810                "bytes",
811                Some(ScalarValue::Binary(vec![1, 2, 3, 4, 5].into())),
812                DType::Binary(Nullability::Nullable),
813            ),
814        ];
815
816        for (name, value, dtype) in exact_roundtrip_cases {
817            let pb_value = ScalarValue::to_proto(value.as_ref());
818            let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
819
820            let original_debug = format!("{value:?}");
821            let roundtrip_debug = format!("{read_back:?}");
822            assert_eq!(
823                original_debug, roundtrip_debug,
824                "ScalarValue {name} did not roundtrip exactly"
825            );
826        }
827
828        // Test cases where type changes but value is preserved.
829        // Unsigned integers consolidate to U64.
830        let unsigned_cases = vec![
831            (
832                "u8",
833                ScalarValue::Primitive(PValue::U8(255)),
834                DType::Primitive(PType::U8, Nullability::Nullable),
835                255u64,
836            ),
837            (
838                "u16",
839                ScalarValue::Primitive(PValue::U16(65535)),
840                DType::Primitive(PType::U16, Nullability::Nullable),
841                65535u64,
842            ),
843            (
844                "u32",
845                ScalarValue::Primitive(PValue::U32(4294967295)),
846                DType::Primitive(PType::U32, Nullability::Nullable),
847                4294967295u64,
848            ),
849        ];
850
851        for (name, value, dtype, expected) in unsigned_cases {
852            let pb_value = ScalarValue::to_proto(Some(&value));
853            let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
854
855            match read_back.as_ref() {
856                Some(ScalarValue::Primitive(pv)) => {
857                    let v = match pv {
858                        PValue::U8(v) => *v as u64,
859                        PValue::U16(v) => *v as u64,
860                        PValue::U32(v) => *v as u64,
861                        PValue::U64(v) => *v,
862                        _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
863                    };
864                    assert_eq!(
865                        v, expected,
866                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
867                    );
868                }
869                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
870            }
871        }
872
873        // Signed integers consolidate to I64.
874        let signed_cases = vec![
875            (
876                "i8",
877                ScalarValue::Primitive(PValue::I8(-128)),
878                DType::Primitive(PType::I8, Nullability::Nullable),
879                -128i64,
880            ),
881            (
882                "i16",
883                ScalarValue::Primitive(PValue::I16(-32768)),
884                DType::Primitive(PType::I16, Nullability::Nullable),
885                -32768i64,
886            ),
887            (
888                "i32",
889                ScalarValue::Primitive(PValue::I32(-2147483648)),
890                DType::Primitive(PType::I32, Nullability::Nullable),
891                -2147483648i64,
892            ),
893        ];
894
895        for (name, value, dtype, expected) in signed_cases {
896            let pb_value = ScalarValue::to_proto(Some(&value));
897            let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
898
899            match read_back.as_ref() {
900                Some(ScalarValue::Primitive(pv)) => {
901                    let v = match pv {
902                        PValue::I8(v) => *v as i64,
903                        PValue::I16(v) => *v as i64,
904                        PValue::I32(v) => *v as i64,
905                        PValue::I64(v) => *v,
906                        _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
907                    };
908                    assert_eq!(
909                        v, expected,
910                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
911                    );
912                }
913                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
914            }
915        }
916    }
917
918    // Backwards compatibility: signed integer stats could previously be serialized as unsigned.
919    // Therefore, we allow casting between signed and unsigned integers of the same bit width.
920    #[test]
921    fn test_backcompat_signed_integer_deserialized_as_unsigned() {
922        let v = ScalarValue::Primitive(PValue::I64(0));
923        assert_eq!(
924            Scalar::from_proto_value(
925                &pb::ScalarValue::from(&v),
926                &DType::Primitive(PType::U64, Nullability::Nullable),
927                &session()
928            )
929            .unwrap(),
930            Scalar::primitive(0u64, Nullability::Nullable)
931        );
932    }
933
934    // Backwards compatibility: unsigned integer stats could previously be serialized as signed.
935    // Therefore, we allow casting between signed and unsigned integers of the same bit width.
936    #[test]
937    fn test_backcompat_unsigned_integer_deserialized_as_signed() {
938        let v = ScalarValue::Primitive(PValue::U64(0));
939        assert_eq!(
940            Scalar::from_proto_value(
941                &pb::ScalarValue::from(&v),
942                &DType::Primitive(PType::I64, Nullability::Nullable),
943                &session()
944            )
945            .unwrap(),
946            Scalar::primitive(0i64, Nullability::Nullable)
947        );
948    }
949}