Skip to main content

vortex_array/scalar/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Protobuf serialization and deserialization for scalars.
5
6use num_traits::ToBytes;
7use num_traits::ToPrimitive;
8use prost::Message;
9use vortex_buffer::BufferString;
10use vortex_buffer::ByteBuffer;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_ensure;
15use vortex_error::vortex_err;
16use vortex_proto::scalar as pb;
17use vortex_proto::scalar::ListValue;
18use vortex_proto::scalar::scalar_value::Kind;
19use vortex_session::VortexSession;
20
21use crate::dtype::DType;
22use crate::dtype::PType;
23use crate::dtype::half::f16;
24use crate::dtype::i256;
25use crate::scalar::DecimalValue;
26use crate::scalar::PValue;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30////////////////////////////////////////////////////////////////////////////////////////////////////
31// Serialize INTO proto.
32////////////////////////////////////////////////////////////////////////////////////////////////////
33
34impl From<&Scalar> for pb::Scalar {
35    fn from(value: &Scalar) -> Self {
36        pb::Scalar {
37            dtype: Some(
38                (value.dtype())
39                    .try_into()
40                    .vortex_expect("Failed to convert DType to proto"),
41            ),
42            value: Some(Box::new(ScalarValue::to_proto(value.value()))),
43        }
44    }
45}
46
47impl ScalarValue {
48    /// Ideally, we would not have this function and instead implement this `From` implementation:
49    ///
50    /// ```ignore
51    /// impl From<Option<&ScalarValue>> for pb::ScalarValue { ... }
52    /// ```
53    ///
54    /// However, we are not allowed to do this because of the Orphan rule (`Option` and
55    /// `pb::ScalarValue` are not types defined in this crate). So we must make this a method on
56    /// `vortex_array::scalar::ScalarValue` directly.
57    pub fn to_proto(this: Option<&Self>) -> pb::ScalarValue {
58        match this {
59            None => pb::ScalarValue {
60                kind: Some(Kind::NullValue(0)),
61            },
62            Some(this) => pb::ScalarValue::from(this),
63        }
64    }
65
66    /// Serialize an optional [`ScalarValue`] to protobuf bytes (handles null values).
67    pub fn to_proto_bytes<B: Default + bytes::BufMut>(value: Option<&ScalarValue>) -> B {
68        let proto = Self::to_proto(value);
69        let mut buf = B::default();
70        proto
71            .encode(&mut buf)
72            .vortex_expect("Failed to encode scalar value");
73        buf
74    }
75}
76
77impl From<&ScalarValue> for pb::ScalarValue {
78    fn from(value: &ScalarValue) -> Self {
79        match value {
80            ScalarValue::Bool(v) => pb::ScalarValue {
81                kind: Some(Kind::BoolValue(*v)),
82            },
83            ScalarValue::Primitive(v) => pb::ScalarValue::from(v),
84            ScalarValue::Decimal(v) => {
85                let inner_value = match v {
86                    DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
87                    DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
88                    DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
89                    DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
90                    DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
91                    DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
92                };
93
94                pb::ScalarValue {
95                    kind: Some(Kind::BytesValue(inner_value)),
96                }
97            }
98            ScalarValue::Utf8(v) => pb::ScalarValue {
99                kind: Some(Kind::StringValue(v.to_string())),
100            },
101            ScalarValue::Binary(v) => pb::ScalarValue {
102                kind: Some(Kind::BytesValue(v.to_vec())),
103            },
104            ScalarValue::List(v) => {
105                let mut values = Vec::with_capacity(v.len());
106                for elem in v.iter() {
107                    values.push(ScalarValue::to_proto(elem.as_ref()));
108                }
109                pb::ScalarValue {
110                    kind: Some(Kind::ListValue(ListValue { values })),
111                }
112            }
113            ScalarValue::Variant(v) => pb::ScalarValue {
114                kind: Some(Kind::VariantValue(Box::new(pb::Scalar::from(v.as_ref())))),
115            },
116        }
117    }
118}
119
120impl From<&PValue> for pb::ScalarValue {
121    fn from(value: &PValue) -> Self {
122        match value {
123            PValue::I8(v) => pb::ScalarValue {
124                kind: Some(Kind::Int64Value(*v as i64)),
125            },
126            PValue::I16(v) => pb::ScalarValue {
127                kind: Some(Kind::Int64Value(*v as i64)),
128            },
129            PValue::I32(v) => pb::ScalarValue {
130                kind: Some(Kind::Int64Value(*v as i64)),
131            },
132            PValue::I64(v) => pb::ScalarValue {
133                kind: Some(Kind::Int64Value(*v)),
134            },
135            PValue::U8(v) => pb::ScalarValue {
136                kind: Some(Kind::Uint64Value(*v as u64)),
137            },
138            PValue::U16(v) => pb::ScalarValue {
139                kind: Some(Kind::Uint64Value(*v as u64)),
140            },
141            PValue::U32(v) => pb::ScalarValue {
142                kind: Some(Kind::Uint64Value(*v as u64)),
143            },
144            PValue::U64(v) => pb::ScalarValue {
145                kind: Some(Kind::Uint64Value(*v)),
146            },
147            PValue::F16(v) => pb::ScalarValue {
148                kind: Some(Kind::F16Value(v.to_bits() as u64)),
149            },
150            PValue::F32(v) => pb::ScalarValue {
151                kind: Some(Kind::F32Value(*v)),
152            },
153            PValue::F64(v) => pb::ScalarValue {
154                kind: Some(Kind::F64Value(*v)),
155            },
156        }
157    }
158}
159
160////////////////////////////////////////////////////////////////////////////////////////////////////
161// Serialize FROM proto.
162////////////////////////////////////////////////////////////////////////////////////////////////////
163
164impl Scalar {
165    /// Creates a [`Scalar`] from a [protobuf `ScalarValue`](pb::ScalarValue) representation.
166    ///
167    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
168    /// integers, and serializing _into_ protobuf loses that type information.
169    ///
170    /// # Errors
171    ///
172    /// Returns an error if type validation fails.
173    pub fn from_proto_value(
174        value: &pb::ScalarValue,
175        dtype: &DType,
176        session: &VortexSession,
177    ) -> VortexResult<Self> {
178        let scalar_value = ScalarValue::from_proto(value, dtype, session)?;
179
180        Scalar::try_new(dtype.clone(), scalar_value)
181    }
182
183    /// Creates a [`Scalar`] from its [protobuf](pb::Scalar) representation.
184    ///
185    /// # Errors
186    ///
187    /// Returns an error if the protobuf is missing required fields or if type validation fails.
188    pub fn from_proto(value: &pb::Scalar, session: &VortexSession) -> VortexResult<Self> {
189        let dtype = DType::from_proto(
190            value
191                .dtype
192                .as_ref()
193                .ok_or_else(|| vortex_err!(Serde: "Scalar missing dtype"))?,
194            session,
195        )?;
196
197        let pb_scalar_value: &pb::ScalarValue = value
198            .value
199            .as_ref()
200            .ok_or_else(|| vortex_err!(Serde: "Scalar missing value"))?;
201
202        let value: Option<ScalarValue> = ScalarValue::from_proto(pb_scalar_value, &dtype, session)?;
203
204        Scalar::try_new(dtype, value)
205    }
206}
207
208impl ScalarValue {
209    /// Deserialize a [`ScalarValue`] from protobuf bytes.
210    ///
211    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
212    /// integers, and serializing _into_ protobuf loses that type information.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if decoding or type validation fails.
217    pub fn from_proto_bytes(
218        bytes: &[u8],
219        dtype: &DType,
220        session: &VortexSession,
221    ) -> VortexResult<Option<Self>> {
222        let proto = pb::ScalarValue::decode(bytes)?;
223        Self::from_proto(&proto, dtype, session)
224    }
225
226    /// Creates a [`ScalarValue`] from its [protobuf](pb::ScalarValue) representation.
227    ///
228    /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit
229    /// integers, and serializing _into_ protobuf loses that type information.
230    ///
231    /// # Errors
232    ///
233    /// Returns an error if the protobuf value cannot be converted to the given [`DType`].
234    pub fn from_proto(
235        value: &pb::ScalarValue,
236        dtype: &DType,
237        session: &VortexSession,
238    ) -> VortexResult<Option<Self>> {
239        let kind = value
240            .kind
241            .as_ref()
242            .ok_or_else(|| vortex_err!(Serde: "Scalar value missing kind"))?;
243
244        // `DType::Extension` store their serialized values using the storage `DType`.
245        let dtype = match dtype {
246            DType::Extension(ext) => ext.storage_dtype(),
247            _ => dtype,
248        };
249
250        Ok(match kind {
251            Kind::NullValue(_) => None,
252            Kind::BoolValue(v) => Some(bool_from_proto(*v, dtype)?),
253            Kind::Int64Value(v) => Some(int64_from_proto(*v, dtype)?),
254            Kind::Uint64Value(v) => Some(uint64_from_proto(*v, dtype)?),
255            Kind::F16Value(v) => Some(f16_from_proto(*v, dtype)?),
256            Kind::F32Value(v) => Some(f32_from_proto(*v, dtype)?),
257            Kind::F64Value(v) => Some(f64_from_proto(*v, dtype)?),
258            Kind::StringValue(s) => Some(string_from_proto(s, dtype)?),
259            Kind::BytesValue(b) => Some(bytes_from_proto(b, dtype)?),
260            Kind::ListValue(v) => Some(list_from_proto(v, dtype, session)?),
261            Kind::VariantValue(v) => match dtype {
262                DType::Variant(_) => Some(ScalarValue::Variant(Box::new(Scalar::from_proto(
263                    v, session,
264                )?))),
265                _ => vortex_bail!(Serde: "expected non-Variant scalar proto for dtype {dtype}"),
266            },
267        })
268    }
269}
270
271/// Deserialize a [`ScalarValue::Bool`] from a protobuf `BoolValue`.
272fn bool_from_proto(v: bool, dtype: &DType) -> VortexResult<ScalarValue> {
273    vortex_ensure!(
274        dtype.is_boolean(),
275        Serde: "expected Bool dtype for BoolValue, got {dtype}"
276    );
277
278    Ok(ScalarValue::Bool(v))
279}
280
281/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Int64Value`.
282///
283/// Protobuf consolidates all signed integers into `i64`, so we narrow back to the original
284/// type using the provided [`DType`].
285fn int64_from_proto(v: i64, dtype: &DType) -> VortexResult<ScalarValue> {
286    vortex_ensure!(
287        dtype.is_primitive(),
288        Serde: "expected Primitive dtype for Int64Value, got {dtype}"
289    );
290
291    let pvalue = match dtype.as_ptype() {
292        PType::I8 => v.to_i8().map(PValue::I8),
293        PType::I16 => v.to_i16().map(PValue::I16),
294        PType::I32 => v.to_i32().map(PValue::I32),
295        PType::I64 => Some(PValue::I64(v)),
296        // It was previously possible for unsigned types to get their stats serialised as signed,
297        // so we allow casting back to unsigned for backwards compatibility.
298        PType::U8 => v.to_u8().map(PValue::U8),
299        PType::U16 => v.to_u16().map(PValue::U16),
300        PType::U32 => v.to_u32().map(PValue::U32),
301        PType::U64 => v.to_u64().map(PValue::U64),
302        ftype @ (PType::F16 | PType::F32 | PType::F64) => vortex_bail!(
303            Serde: "expected signed integer ptype for serialized Int64Value, got float {ftype}"
304        ),
305    }
306    .ok_or_else(|| vortex_err!(Serde: "Int64 value {v} out of range for dtype {dtype}"))?;
307
308    Ok(ScalarValue::Primitive(pvalue))
309}
310
311/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Uint64Value`.
312///
313/// Protobuf consolidates all unsigned integers into `u64`, so we narrow back to the original
314/// type using the provided [`DType`]. Also handles the backwards-compatible case where `f16`
315/// values were serialized as `u64` (via `f16::to_bits() as u64`).
316fn uint64_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
317    vortex_ensure!(
318        dtype.is_primitive(),
319        Serde: "expected Primitive dtype for Uint64Value, got {dtype}"
320    );
321
322    let pvalue = match dtype.as_ptype() {
323        PType::U8 => v.to_u8().map(PValue::U8),
324        PType::U16 => v.to_u16().map(PValue::U16),
325        PType::U32 => v.to_u32().map(PValue::U32),
326        PType::U64 => Some(PValue::U64(v)),
327        // It was previously possible for signed types to get their stats serialised as unsigned,
328        // so we allow casting back to signed for backwards compatibility.
329        PType::I8 => v.to_i8().map(PValue::I8),
330        PType::I16 => v.to_i16().map(PValue::I16),
331        PType::I32 => v.to_i32().map(PValue::I32),
332        PType::I64 => v.to_i64().map(PValue::I64),
333        // f16 values used to be serialized as u64, so we need to be able to read an f16 from a u64.
334        PType::F16 => v.to_u16().map(f16::from_bits).map(PValue::F16),
335        ftype @ (PType::F32 | PType::F64) => vortex_bail!(
336            Serde: "expected unsigned integer ptype for serialized Uint64Value, got {ftype}"
337        ),
338    }
339    .ok_or_else(|| vortex_err!(Serde: "Uint64 value {v} out of range for dtype {dtype}"))?;
340
341    Ok(ScalarValue::Primitive(pvalue))
342}
343
344/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F16Value`.
345fn f16_from_proto(v: u64, dtype: &DType) -> VortexResult<ScalarValue> {
346    vortex_ensure!(
347        matches!(dtype, DType::Primitive(PType::F16, _)),
348        Serde: "expected F16 dtype for F16Value, got {dtype}"
349    );
350
351    let bits = u16::try_from(v)
352        .map_err(|_| vortex_err!(Serde: "f16 bitwise representation has more than 16 bits: {v}"))?;
353
354    Ok(ScalarValue::Primitive(PValue::F16(f16::from_bits(bits))))
355}
356
357/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F32Value`.
358fn f32_from_proto(v: f32, dtype: &DType) -> VortexResult<ScalarValue> {
359    vortex_ensure!(
360        matches!(dtype, DType::Primitive(PType::F32, _)),
361        Serde: "expected F32 dtype for F32Value, got {dtype}"
362    );
363
364    Ok(ScalarValue::Primitive(PValue::F32(v)))
365}
366
367/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F64Value`.
368fn f64_from_proto(v: f64, dtype: &DType) -> VortexResult<ScalarValue> {
369    vortex_ensure!(
370        matches!(dtype, DType::Primitive(PType::F64, _)),
371        Serde: "expected F64 dtype for F64Value, got {dtype}"
372    );
373
374    Ok(ScalarValue::Primitive(PValue::F64(v)))
375}
376
377/// Deserialize a [`ScalarValue::Utf8`] or [`ScalarValue::Binary`] from a protobuf
378/// `StringValue`.
379fn string_from_proto(s: &str, dtype: &DType) -> VortexResult<ScalarValue> {
380    match dtype {
381        DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::from(s))),
382        DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(s.as_bytes()))),
383        _ => vortex_bail!(
384            Serde: "expected Utf8 or Binary dtype for StringValue, got {dtype}"
385        ),
386    }
387}
388
389/// Deserialize a [`ScalarValue`] from a protobuf bytes and a `DType`.
390///
391/// Handles [`Utf8`](ScalarValue::Utf8), [`Binary`](ScalarValue::Binary), and
392/// [`Decimal`](ScalarValue::Decimal) dtypes.
393fn bytes_from_proto(bytes: &[u8], dtype: &DType) -> VortexResult<ScalarValue> {
394    match dtype {
395        DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::try_from(bytes)?)),
396        DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(bytes))),
397        // TODO(connor): This is incorrect, we need to verify this matches the inner decimal_dtype.
398        DType::Decimal(..) => Ok(ScalarValue::Decimal(match bytes.len() {
399            1 => DecimalValue::I8(bytes[0] as i8),
400            2 => DecimalValue::I16(i16::from_le_bytes(
401                bytes
402                    .try_into()
403                    .ok()
404                    .vortex_expect("Buffer has invalid number of bytes"),
405            )),
406            4 => DecimalValue::I32(i32::from_le_bytes(
407                bytes
408                    .try_into()
409                    .ok()
410                    .vortex_expect("Buffer has invalid number of bytes"),
411            )),
412            8 => DecimalValue::I64(i64::from_le_bytes(
413                bytes
414                    .try_into()
415                    .ok()
416                    .vortex_expect("Buffer has invalid number of bytes"),
417            )),
418            16 => DecimalValue::I128(i128::from_le_bytes(
419                bytes
420                    .try_into()
421                    .ok()
422                    .vortex_expect("Buffer has invalid number of bytes"),
423            )),
424            32 => DecimalValue::I256(i256::from_le_bytes(
425                bytes
426                    .try_into()
427                    .ok()
428                    .vortex_expect("Buffer has invalid number of bytes"),
429            )),
430            l => vortex_bail!(Serde: "invalid decimal byte length: {l}"),
431        })),
432        _ => vortex_bail!(
433            Serde: "expected Utf8, Binary, or Decimal dtype for BytesValue, got {dtype}"
434        ),
435    }
436}
437
438/// Deserialize a [`ScalarValue::List`] from a protobuf `ListValue`.
439fn list_from_proto(
440    v: &ListValue,
441    dtype: &DType,
442    session: &VortexSession,
443) -> VortexResult<ScalarValue> {
444    let element_dtype = dtype
445        .as_list_element_opt()
446        .ok_or_else(|| vortex_err!(Serde: "expected List dtype for ListValue, got {dtype}"))?;
447
448    let mut values = Vec::with_capacity(v.values.len());
449    for elem in v.values.iter() {
450        values.push(ScalarValue::from_proto(
451            elem,
452            element_dtype.as_ref(),
453            session,
454        )?);
455    }
456
457    Ok(ScalarValue::List(values))
458}
459
460#[cfg(test)]
461mod tests {
462    use std::sync::Arc;
463
464    use vortex_buffer::BufferString;
465    use vortex_error::vortex_panic;
466    use vortex_proto::scalar as pb;
467    use vortex_session::VortexSession;
468
469    use super::*;
470    use crate::dtype::DType;
471    use crate::dtype::DecimalDType;
472    use crate::dtype::Nullability;
473    use crate::dtype::PType;
474    use crate::dtype::half::f16;
475    use crate::scalar::DecimalValue;
476    use crate::scalar::Scalar;
477    use crate::scalar::ScalarValue;
478
479    fn session() -> VortexSession {
480        VortexSession::empty()
481    }
482
483    fn round_trip(scalar: Scalar) {
484        assert_eq!(
485            scalar,
486            Scalar::from_proto(&pb::Scalar::from(&scalar), &session()).unwrap(),
487        );
488    }
489
490    #[test]
491    fn test_null() {
492        round_trip(Scalar::null(DType::Null));
493    }
494
495    #[test]
496    fn test_bool() {
497        round_trip(Scalar::new(
498            DType::Bool(Nullability::Nullable),
499            Some(ScalarValue::Bool(true)),
500        ));
501    }
502
503    #[test]
504    fn test_primitive() {
505        round_trip(Scalar::new(
506            DType::Primitive(PType::I32, Nullability::Nullable),
507            Some(ScalarValue::Primitive(42i32.into())),
508        ));
509    }
510
511    #[test]
512    fn test_buffer() {
513        round_trip(Scalar::new(
514            DType::Binary(Nullability::Nullable),
515            Some(ScalarValue::Binary(vec![1, 2, 3].into())),
516        ));
517    }
518
519    #[test]
520    fn test_buffer_string() {
521        round_trip(Scalar::new(
522            DType::Utf8(Nullability::Nullable),
523            Some(ScalarValue::Utf8(BufferString::from("hello".to_string()))),
524        ));
525    }
526
527    #[test]
528    fn test_list() {
529        round_trip(Scalar::new(
530            DType::List(
531                Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
532                Nullability::Nullable,
533            ),
534            Some(ScalarValue::List(vec![
535                Some(ScalarValue::Primitive(42i32.into())),
536                Some(ScalarValue::Primitive(43i32.into())),
537            ])),
538        ));
539    }
540
541    #[test]
542    fn test_f16() {
543        round_trip(Scalar::primitive(
544            f16::from_f32(0.42),
545            Nullability::Nullable,
546        ));
547    }
548
549    #[test]
550    fn test_i8() {
551        round_trip(Scalar::new(
552            DType::Primitive(PType::I8, Nullability::Nullable),
553            Some(ScalarValue::Primitive(i8::MIN.into())),
554        ));
555
556        round_trip(Scalar::new(
557            DType::Primitive(PType::I8, Nullability::Nullable),
558            Some(ScalarValue::Primitive(0i8.into())),
559        ));
560
561        round_trip(Scalar::new(
562            DType::Primitive(PType::I8, Nullability::Nullable),
563            Some(ScalarValue::Primitive(i8::MAX.into())),
564        ));
565    }
566
567    #[test]
568    fn test_decimal_i32_roundtrip() {
569        // A typical decimal with moderate precision and scale.
570        round_trip(Scalar::decimal(
571            DecimalValue::I32(123_456),
572            DecimalDType::new(10, 2),
573            Nullability::NonNullable,
574        ));
575    }
576
577    #[test]
578    fn test_decimal_i128_roundtrip() {
579        // A large decimal value that requires i128 storage.
580        round_trip(Scalar::decimal(
581            DecimalValue::I128(99_999_999_999_999_999_999),
582            DecimalDType::new(38, 6),
583            Nullability::Nullable,
584        ));
585    }
586
587    #[test]
588    fn test_decimal_null_roundtrip() {
589        round_trip(Scalar::null(DType::Decimal(
590            DecimalDType::new(10, 2),
591            Nullability::Nullable,
592        )));
593    }
594
595    #[test]
596    fn test_scalar_value_serde_roundtrip_binary() {
597        round_trip(Scalar::binary(
598            ByteBuffer::copy_from(b"hello"),
599            Nullability::NonNullable,
600        ));
601    }
602
603    #[test]
604    fn test_scalar_value_serde_roundtrip_utf8() {
605        round_trip(Scalar::utf8("hello", Nullability::NonNullable));
606    }
607
608    #[test]
609    fn test_variant_scalar_roundtrip() {
610        let nums = Scalar::list(
611            Arc::new(DType::Variant(Nullability::NonNullable)),
612            vec![
613                Scalar::variant(Scalar::primitive(-7_i16, Nullability::NonNullable)),
614                Scalar::variant(Scalar::primitive(42_u32, Nullability::NonNullable)),
615                Scalar::variant(Scalar::decimal(
616                    DecimalValue::I128(123_456_789),
617                    DecimalDType::new(18, 0),
618                    Nullability::NonNullable,
619                )),
620            ],
621            Nullability::NonNullable,
622        );
623
624        let nested = Scalar::list(
625            Arc::new(DType::Variant(Nullability::NonNullable)),
626            vec![
627                Scalar::variant(Scalar::from(true)),
628                Scalar::variant(nums),
629                Scalar::variant(Scalar::binary(
630                    ByteBuffer::copy_from(b"abc"),
631                    Nullability::NonNullable,
632                )),
633                Scalar::variant(Scalar::null(DType::Null)),
634            ],
635            Nullability::NonNullable,
636        );
637
638        round_trip(Scalar::variant(nested));
639    }
640
641    #[test]
642    fn test_variant_scalar_proto_preserves_scalar_null_vs_variant_null() {
643        let scalar_null = Scalar::null(DType::Variant(Nullability::Nullable));
644        let variant_null = Scalar::variant(Scalar::null(DType::Null));
645
646        let scalar_null_pb = pb::Scalar::from(&scalar_null);
647        let variant_null_pb = pb::Scalar::from(&variant_null);
648
649        assert_ne!(scalar_null_pb, variant_null_pb);
650        assert_eq!(
651            Scalar::from_proto(&scalar_null_pb, &session()).unwrap(),
652            scalar_null,
653        );
654        assert_eq!(
655            Scalar::from_proto(&variant_null_pb, &session()).unwrap(),
656            variant_null,
657        );
658    }
659
660    #[test]
661    fn test_backcompat_f16_serialized_as_u64() {
662        // Backwards compatibility test for the legacy f16 serialization format.
663        //
664        // Previously, f16 ScalarValues were serialized as `Uint64Value(v.to_bits() as u64)` because
665        // the proto schema only had 64-bit integer types, and f16's underlying representation is
666        // u16 which got widened to u64.
667        //
668        // The current implementation uses a dedicated `F16Value` proto field, but we must still be
669        // able to deserialize the old format. This test verifies that:
670        //
671        // 1. A `Uint64Value` containing f16 bits can be read as a U64 primitive (the raw bits).
672        // 2. When wrapped in a Scalar with F16 dtype, the value is correctly interpreted as f16.
673        //
674        // This ensures data written with the old serialization format remains readable.
675
676        // Simulate the old serialization: f16(0.42) stored as Uint64Value with its bit pattern.
677        let f16_value = f16::from_f32(0.42);
678        let f16_bits_as_u64 = f16_value.to_bits() as u64; // 14008
679
680        let pb_scalar_value = pb::ScalarValue {
681            kind: Some(Kind::Uint64Value(f16_bits_as_u64)),
682        };
683
684        // Step 1: Verify the normal U64 scalar.
685        let scalar_value = ScalarValue::from_proto(
686            &pb_scalar_value,
687            &DType::Primitive(PType::U64, Nullability::NonNullable),
688            &session(),
689        )
690        .unwrap();
691        assert_eq!(
692            scalar_value.as_ref().map(|v| v.as_primitive()),
693            Some(&PValue::U64(14008u64)),
694        );
695
696        // Step 2: Verify that when we use F16 dtype, the Uint64Value is correctly interpreted.
697        let scalar_value_f16 = ScalarValue::from_proto(
698            &pb_scalar_value,
699            &DType::Primitive(PType::F16, Nullability::Nullable),
700            &session(),
701        )
702        .unwrap();
703
704        let scalar = Scalar::new(
705            DType::Primitive(PType::F16, Nullability::Nullable),
706            scalar_value_f16,
707        );
708
709        assert_eq!(
710            scalar.as_primitive().pvalue().unwrap(),
711            PValue::F16(f16::from_f32(0.42)),
712            "Uint64Value should be correctly interpreted as f16 when dtype is F16"
713        );
714    }
715
716    #[test]
717    fn test_scalar_value_direct_roundtrip_f16() {
718        // Test that ScalarValue with f16 roundtrips correctly without going through Scalar.
719        let f16_values = vec![
720            f16::from_f32(0.0),
721            f16::from_f32(1.0),
722            f16::from_f32(-1.0),
723            f16::from_f32(0.42),
724            f16::from_f32(5.722046e-6),
725            f16::from_f32(std::f32::consts::PI),
726            f16::INFINITY,
727            f16::NEG_INFINITY,
728            f16::NAN,
729        ];
730
731        for f16_val in f16_values {
732            let scalar_value = ScalarValue::Primitive(PValue::F16(f16_val));
733            let pb_value = ScalarValue::to_proto(Some(&scalar_value));
734            let read_back = ScalarValue::from_proto(
735                &pb_value,
736                &DType::Primitive(PType::F16, Nullability::NonNullable),
737                &session(),
738            )
739            .unwrap();
740
741            match (&scalar_value, read_back.as_ref()) {
742                (
743                    ScalarValue::Primitive(PValue::F16(original)),
744                    Some(ScalarValue::Primitive(PValue::F16(roundtripped))),
745                ) => {
746                    if original.is_nan() && roundtripped.is_nan() {
747                        // NaN values are equal for our purposes.
748                        continue;
749                    }
750                    assert_eq!(
751                        original, roundtripped,
752                        "F16 value {original:?} did not roundtrip correctly"
753                    );
754                }
755                _ => {
756                    vortex_panic!(
757                        "Expected f16 primitive values, got {scalar_value:?} and {read_back:?}"
758                    )
759                }
760            }
761        }
762    }
763
764    #[test]
765    fn test_scalar_value_direct_roundtrip_preserves_values() {
766        // Test that ScalarValue roundtripping preserves values (but not necessarily exact types).
767        // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64).
768
769        // Test cases that should roundtrip exactly.
770        let exact_roundtrip_cases: Vec<(&str, Option<ScalarValue>, DType)> = vec![
771            ("null", None, DType::Null),
772            (
773                "bool_true",
774                Some(ScalarValue::Bool(true)),
775                DType::Bool(Nullability::Nullable),
776            ),
777            (
778                "bool_false",
779                Some(ScalarValue::Bool(false)),
780                DType::Bool(Nullability::Nullable),
781            ),
782            (
783                "u64",
784                Some(ScalarValue::Primitive(PValue::U64(18446744073709551615))),
785                DType::Primitive(PType::U64, Nullability::Nullable),
786            ),
787            (
788                "i64",
789                Some(ScalarValue::Primitive(PValue::I64(-9223372036854775808))),
790                DType::Primitive(PType::I64, Nullability::Nullable),
791            ),
792            (
793                "f32",
794                Some(ScalarValue::Primitive(PValue::F32(std::f32::consts::E))),
795                DType::Primitive(PType::F32, Nullability::Nullable),
796            ),
797            (
798                "f64",
799                Some(ScalarValue::Primitive(PValue::F64(std::f64::consts::PI))),
800                DType::Primitive(PType::F64, Nullability::Nullable),
801            ),
802            (
803                "string",
804                Some(ScalarValue::Utf8(BufferString::from("test"))),
805                DType::Utf8(Nullability::Nullable),
806            ),
807            (
808                "bytes",
809                Some(ScalarValue::Binary(vec![1, 2, 3, 4, 5].into())),
810                DType::Binary(Nullability::Nullable),
811            ),
812        ];
813
814        for (name, value, dtype) in exact_roundtrip_cases {
815            let pb_value = ScalarValue::to_proto(value.as_ref());
816            let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
817
818            let original_debug = format!("{value:?}");
819            let roundtrip_debug = format!("{read_back:?}");
820            assert_eq!(
821                original_debug, roundtrip_debug,
822                "ScalarValue {name} did not roundtrip exactly"
823            );
824        }
825
826        // Test cases where type changes but value is preserved.
827        // Unsigned integers consolidate to U64.
828        let unsigned_cases = vec![
829            (
830                "u8",
831                ScalarValue::Primitive(PValue::U8(255)),
832                DType::Primitive(PType::U8, Nullability::Nullable),
833                255u64,
834            ),
835            (
836                "u16",
837                ScalarValue::Primitive(PValue::U16(65535)),
838                DType::Primitive(PType::U16, Nullability::Nullable),
839                65535u64,
840            ),
841            (
842                "u32",
843                ScalarValue::Primitive(PValue::U32(4294967295)),
844                DType::Primitive(PType::U32, Nullability::Nullable),
845                4294967295u64,
846            ),
847        ];
848
849        for (name, value, dtype, expected) in unsigned_cases {
850            let pb_value = ScalarValue::to_proto(Some(&value));
851            let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
852
853            match read_back.as_ref() {
854                Some(ScalarValue::Primitive(pv)) => {
855                    let v = match pv {
856                        PValue::U8(v) => *v as u64,
857                        PValue::U16(v) => *v as u64,
858                        PValue::U32(v) => *v as u64,
859                        PValue::U64(v) => *v,
860                        _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
861                    };
862                    assert_eq!(
863                        v, expected,
864                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
865                    );
866                }
867                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
868            }
869        }
870
871        // Signed integers consolidate to I64.
872        let signed_cases = vec![
873            (
874                "i8",
875                ScalarValue::Primitive(PValue::I8(-128)),
876                DType::Primitive(PType::I8, Nullability::Nullable),
877                -128i64,
878            ),
879            (
880                "i16",
881                ScalarValue::Primitive(PValue::I16(-32768)),
882                DType::Primitive(PType::I16, Nullability::Nullable),
883                -32768i64,
884            ),
885            (
886                "i32",
887                ScalarValue::Primitive(PValue::I32(-2147483648)),
888                DType::Primitive(PType::I32, Nullability::Nullable),
889                -2147483648i64,
890            ),
891        ];
892
893        for (name, value, dtype, expected) in signed_cases {
894            let pb_value = ScalarValue::to_proto(Some(&value));
895            let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap();
896
897            match read_back.as_ref() {
898                Some(ScalarValue::Primitive(pv)) => {
899                    let v = match pv {
900                        PValue::I8(v) => *v as i64,
901                        PValue::I16(v) => *v as i64,
902                        PValue::I32(v) => *v as i64,
903                        PValue::I64(v) => *v,
904                        _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"),
905                    };
906                    assert_eq!(
907                        v, expected,
908                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
909                    );
910                }
911                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
912            }
913        }
914    }
915
916    // Backwards compatibility: signed integer stats could previously be serialized as unsigned.
917    // Therefore, we allow casting between signed and unsigned integers of the same bit width.
918    #[test]
919    fn test_backcompat_signed_integer_deserialized_as_unsigned() {
920        let v = ScalarValue::Primitive(PValue::I64(0));
921        assert_eq!(
922            Scalar::from_proto_value(
923                &pb::ScalarValue::from(&v),
924                &DType::Primitive(PType::U64, Nullability::Nullable),
925                &session()
926            )
927            .unwrap(),
928            Scalar::primitive(0u64, Nullability::Nullable)
929        );
930    }
931
932    // Backwards compatibility: unsigned integer stats could previously be serialized as signed.
933    // Therefore, we allow casting between signed and unsigned integers of the same bit width.
934    #[test]
935    fn test_backcompat_unsigned_integer_deserialized_as_signed() {
936        let v = ScalarValue::Primitive(PValue::U64(0));
937        assert_eq!(
938            Scalar::from_proto_value(
939                &pb::ScalarValue::from(&v),
940                &DType::Primitive(PType::I64, Nullability::Nullable),
941                &session()
942            )
943            .unwrap(),
944            Scalar::primitive(0i64, Nullability::Nullable)
945        );
946    }
947}