vortex_scalar/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use num_traits::ToBytes;
7use vortex_buffer::{BufferString, ByteBuffer};
8use vortex_dtype::DType;
9use vortex_dtype::half::f16;
10use vortex_error::{VortexError, vortex_err};
11use vortex_proto::scalar as pb;
12use vortex_proto::scalar::ListValue;
13use vortex_proto::scalar::scalar_value::Kind;
14
15use crate::pvalue::PValue;
16use crate::{DecimalValue, InnerScalarValue, Scalar, ScalarValue};
17
18impl From<&Scalar> for pb::Scalar {
19    fn from(value: &Scalar) -> Self {
20        pb::Scalar {
21            dtype: Some((value.dtype()).into()),
22            value: Some((value.value()).into()),
23        }
24    }
25}
26
27impl From<&ScalarValue> for pb::ScalarValue {
28    fn from(value: &ScalarValue) -> Self {
29        match value {
30            ScalarValue(InnerScalarValue::Null) => pb::ScalarValue {
31                kind: Some(Kind::NullValue(0)),
32            },
33            ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue {
34                kind: Some(Kind::BoolValue(*v)),
35            },
36            ScalarValue(InnerScalarValue::Primitive(v)) => v.into(),
37            ScalarValue(InnerScalarValue::Decimal(v)) => {
38                let inner_value = match v {
39                    DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
40                    DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
41                    DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
42                    DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
43                    DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
44                    DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
45                };
46
47                pb::ScalarValue {
48                    kind: Some(Kind::BytesValue(inner_value)),
49                }
50            }
51            ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue {
52                kind: Some(Kind::BytesValue(v.as_slice().to_vec())),
53            },
54            ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue {
55                kind: Some(Kind::StringValue(v.as_str().to_string())),
56            },
57            ScalarValue(InnerScalarValue::List(v)) => {
58                let mut values = Vec::with_capacity(v.len());
59                for elem in v.iter() {
60                    values.push(pb::ScalarValue::from(elem));
61                }
62                pb::ScalarValue {
63                    kind: Some(Kind::ListValue(ListValue { values })),
64                }
65            }
66        }
67    }
68}
69
70impl From<&PValue> for pb::ScalarValue {
71    fn from(value: &PValue) -> Self {
72        match value {
73            PValue::I8(v) => pb::ScalarValue {
74                kind: Some(Kind::Int64Value(*v as i64)),
75            },
76            PValue::I16(v) => pb::ScalarValue {
77                kind: Some(Kind::Int64Value(*v as i64)),
78            },
79            PValue::I32(v) => pb::ScalarValue {
80                kind: Some(Kind::Int64Value(*v as i64)),
81            },
82            PValue::I64(v) => pb::ScalarValue {
83                kind: Some(Kind::Int64Value(*v)),
84            },
85            PValue::U8(v) => pb::ScalarValue {
86                kind: Some(Kind::Uint64Value(*v as u64)),
87            },
88            PValue::U16(v) => pb::ScalarValue {
89                kind: Some(Kind::Uint64Value(*v as u64)),
90            },
91            PValue::U32(v) => pb::ScalarValue {
92                kind: Some(Kind::Uint64Value(*v as u64)),
93            },
94            PValue::U64(v) => pb::ScalarValue {
95                kind: Some(Kind::Uint64Value(*v)),
96            },
97            PValue::F16(v) => pb::ScalarValue {
98                kind: Some(Kind::F16Value(v.to_bits() as u64)),
99            },
100            PValue::F32(v) => pb::ScalarValue {
101                kind: Some(Kind::F32Value(*v)),
102            },
103            PValue::F64(v) => pb::ScalarValue {
104                kind: Some(Kind::F64Value(*v)),
105            },
106        }
107    }
108}
109
110impl TryFrom<&pb::Scalar> for Scalar {
111    type Error = VortexError;
112
113    fn try_from(value: &pb::Scalar) -> Result<Self, Self::Error> {
114        let dtype = DType::try_from(
115            value
116                .dtype
117                .as_ref()
118                .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?,
119        )?;
120
121        let value = ScalarValue::try_from(
122            value
123                .value
124                .as_ref()
125                .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?,
126        )?;
127
128        Ok(Scalar::new(dtype, value))
129    }
130}
131
132impl TryFrom<&pb::ScalarValue> for ScalarValue {
133    type Error = VortexError;
134
135    fn try_from(value: &pb::ScalarValue) -> Result<Self, Self::Error> {
136        let kind = value
137            .kind
138            .as_ref()
139            .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;
140
141        match kind {
142            Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)),
143            Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))),
144            Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))),
145            Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))),
146            Kind::F16Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F16(
147                f16::from_bits(u16::try_from(*v).map_err(|_| {
148                    vortex_err!("f16 bitwise representation has more than 16 bits: {}", v)
149                })?),
150            )))),
151            Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))),
152            Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))),
153            Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new(
154                BufferString::from(v.clone()),
155            )))),
156            Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new(
157                ByteBuffer::from(v.clone()),
158            )))),
159            Kind::ListValue(v) => {
160                let mut values = Vec::with_capacity(v.values.len());
161                for elem in v.values.iter() {
162                    values.push(elem.try_into()?);
163                }
164                Ok(ScalarValue(InnerScalarValue::List(values.into())))
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use std::sync::Arc;
173
174    use rstest::rstest;
175    use vortex_buffer::BufferString;
176    use vortex_dtype::half::f16;
177    use vortex_dtype::{DType, DecimalDType, FieldDType, Nullability, PType, StructFields, i256};
178    use vortex_error::vortex_panic;
179    use vortex_proto::scalar as pb;
180
181    use super::*;
182    use crate::{InnerScalarValue, Scalar, ScalarValue};
183
184    fn round_trip(scalar: Scalar) {
185        assert_eq!(
186            scalar,
187            Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(),
188        );
189    }
190
191    #[test]
192    fn test_null() {
193        round_trip(Scalar::null(DType::Null));
194    }
195
196    #[test]
197    fn test_bool() {
198        round_trip(Scalar::new(
199            DType::Bool(Nullability::Nullable),
200            ScalarValue(InnerScalarValue::Bool(true)),
201        ));
202    }
203
204    #[test]
205    fn test_primitive() {
206        round_trip(Scalar::new(
207            DType::Primitive(PType::I32, Nullability::Nullable),
208            ScalarValue(InnerScalarValue::Primitive(42i32.into())),
209        ));
210    }
211
212    #[test]
213    fn test_buffer() {
214        round_trip(Scalar::new(
215            DType::Binary(Nullability::Nullable),
216            ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))),
217        ));
218    }
219
220    #[test]
221    fn test_buffer_string() {
222        round_trip(Scalar::new(
223            DType::Utf8(Nullability::Nullable),
224            ScalarValue(InnerScalarValue::BufferString(Arc::new(
225                BufferString::from("hello".to_string()),
226            ))),
227        ));
228    }
229
230    #[test]
231    fn test_list() {
232        round_trip(Scalar::new(
233            DType::List(
234                Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
235                Nullability::Nullable,
236            ),
237            ScalarValue(InnerScalarValue::List(
238                vec![
239                    ScalarValue(InnerScalarValue::Primitive(42i32.into())),
240                    ScalarValue(InnerScalarValue::Primitive(43i32.into())),
241                ]
242                .into(),
243            )),
244        ));
245    }
246
247    #[test]
248    fn test_f16() {
249        round_trip(Scalar::primitive(
250            f16::from_f32(0.42),
251            Nullability::Nullable,
252        ));
253    }
254
255    #[test]
256    fn test_i8() {
257        round_trip(Scalar::new(
258            DType::Primitive(PType::I8, Nullability::Nullable),
259            ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())),
260        ));
261
262        round_trip(Scalar::new(
263            DType::Primitive(PType::I8, Nullability::Nullable),
264            ScalarValue(InnerScalarValue::Primitive(0i8.into())),
265        ));
266
267        round_trip(Scalar::new(
268            DType::Primitive(PType::I8, Nullability::Nullable),
269            ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())),
270        ));
271    }
272
273    #[rstest]
274    #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))]
275    #[case(Scalar::utf8("hello", Nullability::NonNullable))]
276    #[case(Scalar::primitive(1u8, Nullability::NonNullable))]
277    #[case(Scalar::primitive(
278        f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])),
279        Nullability::NonNullable
280    ))]
281    #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable
282    ))]
283    #[case(Scalar::struct_(DType::Struct(
284        StructFields::from_iter([
285            ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))),
286            ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
287        ]),
288        Nullability::NonNullable),
289        vec![
290            Scalar::primitive(23592960u32, Nullability::NonNullable),
291            Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
292        ],
293    ))]
294    #[case(Scalar::struct_(DType::Struct(
295        StructFields::from_iter([
296            ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))),
297            ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))),
298            ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
299        ]),
300        Nullability::NonNullable),
301        vec![
302            Scalar::primitive(415118687234u64, Nullability::NonNullable),
303            Scalar::primitive(2.6584664e36f32, Nullability::NonNullable),
304            Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
305        ],
306    ))]
307    #[case(Scalar::decimal(
308        DecimalValue::I256(i256::from_i128(12345643673471)),
309        DecimalDType::new(10, 2),
310        Nullability::NonNullable
311    ))]
312    #[case(Scalar::decimal(
313        DecimalValue::I16(23412),
314        DecimalDType::new(3, 2),
315        Nullability::NonNullable
316    ))]
317    fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) {
318        let written = scalar.value().to_protobytes::<Vec<u8>>();
319        let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap();
320        assert_eq!(
321            Scalar::new(scalar.dtype().clone(), scalar_read_back),
322            scalar
323        );
324    }
325
326    #[test]
327    fn test_backcompat_f16_serialized_as_u64() {
328        // Note that this is a backwards compatibility test for poor design in the previous implementation.
329        // Previously, f16 ScalarValues were serialized as `pb::ScalarValue::Uint64Value(v.to_bits() as u64)`.
330        let pb_scalar_value = pb::ScalarValue {
331            kind: Some(Kind::Uint64Value(f16::from_f32(0.42).to_bits() as u64)),
332        };
333        let scalar_value = ScalarValue::try_from(&pb_scalar_value).unwrap();
334        assert_eq!(
335            scalar_value.as_pvalue().unwrap(),
336            Some(PValue::U64(14008u64))
337        );
338
339        let scalar = Scalar::new(
340            DType::Primitive(PType::F16, Nullability::Nullable),
341            scalar_value,
342        );
343
344        assert_eq!(
345            scalar.as_primitive().pvalue().unwrap(),
346            PValue::F16(f16::from_f32(0.42))
347        );
348    }
349
350    #[test]
351    fn test_scalar_value_direct_roundtrip_f16() {
352        // Test that ScalarValue with f16 roundtrips correctly without going through Scalar
353        let f16_values = vec![
354            f16::from_f32(0.0),
355            f16::from_f32(1.0),
356            f16::from_f32(-1.0),
357            f16::from_f32(0.42),
358            f16::from_f32(5.722046e-6),
359            f16::from_f32(std::f32::consts::PI),
360            f16::INFINITY,
361            f16::NEG_INFINITY,
362            f16::NAN,
363        ];
364
365        for f16_val in f16_values {
366            let scalar_value = ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16_val)));
367            let written = scalar_value.to_protobytes::<Vec<u8>>();
368            let read_back = ScalarValue::from_protobytes(&written).unwrap();
369
370            match (&scalar_value.0, &read_back.0) {
371                (
372                    InnerScalarValue::Primitive(PValue::F16(original)),
373                    InnerScalarValue::Primitive(PValue::F16(roundtripped)),
374                ) => {
375                    if original.is_nan() && roundtripped.is_nan() {
376                        // NaN values are equal for our purposes
377                        continue;
378                    }
379                    assert_eq!(
380                        original, roundtripped,
381                        "F16 value {original:?} did not roundtrip correctly"
382                    );
383                }
384                _ => {
385                    vortex_panic!(
386                        "Expected f16 primitive values, got {scalar_value:?} and {read_back:?}"
387                    )
388                }
389            }
390        }
391    }
392
393    #[test]
394    fn test_scalar_value_direct_roundtrip_preserves_values() {
395        // Test that ScalarValue roundtripping preserves values (but not necessarily exact types)
396        // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64)
397
398        // Test cases that should roundtrip exactly
399        let exact_roundtrip_cases = vec![
400            ("null", ScalarValue(InnerScalarValue::Null)),
401            ("bool_true", ScalarValue(InnerScalarValue::Bool(true))),
402            ("bool_false", ScalarValue(InnerScalarValue::Bool(false))),
403            (
404                "u64",
405                ScalarValue(InnerScalarValue::Primitive(PValue::U64(
406                    18446744073709551615,
407                ))),
408            ),
409            (
410                "i64",
411                ScalarValue(InnerScalarValue::Primitive(PValue::I64(
412                    -9223372036854775808,
413                ))),
414            ),
415            (
416                "f32",
417                ScalarValue(InnerScalarValue::Primitive(PValue::F32(
418                    std::f32::consts::E,
419                ))),
420            ),
421            (
422                "f64",
423                ScalarValue(InnerScalarValue::Primitive(PValue::F64(
424                    std::f64::consts::PI,
425                ))),
426            ),
427            (
428                "string",
429                ScalarValue(InnerScalarValue::BufferString(Arc::new(
430                    BufferString::from("test"),
431                ))),
432            ),
433            (
434                "bytes",
435                ScalarValue(InnerScalarValue::Buffer(Arc::new(
436                    vec![1, 2, 3, 4, 5].into(),
437                ))),
438            ),
439        ];
440
441        for (name, value) in exact_roundtrip_cases {
442            let written = value.to_protobytes::<Vec<u8>>();
443            let read_back = ScalarValue::from_protobytes(&written).unwrap();
444
445            let original_debug = format!("{value:?}");
446            let roundtrip_debug = format!("{read_back:?}");
447            assert_eq!(
448                original_debug, roundtrip_debug,
449                "ScalarValue {name} did not roundtrip exactly"
450            );
451        }
452
453        // Test cases where type changes but value is preserved
454        // Unsigned integers consolidate to U64
455        let unsigned_cases = vec![
456            (
457                "u8",
458                ScalarValue(InnerScalarValue::Primitive(PValue::U8(255))),
459                255u64,
460            ),
461            (
462                "u16",
463                ScalarValue(InnerScalarValue::Primitive(PValue::U16(65535))),
464                65535u64,
465            ),
466            (
467                "u32",
468                ScalarValue(InnerScalarValue::Primitive(PValue::U32(4294967295))),
469                4294967295u64,
470            ),
471        ];
472
473        for (name, value, expected) in unsigned_cases {
474            let written = value.to_protobytes::<Vec<u8>>();
475            let read_back = ScalarValue::from_protobytes(&written).unwrap();
476
477            match &read_back.0 {
478                InnerScalarValue::Primitive(PValue::U64(v)) => {
479                    assert_eq!(
480                        *v, expected,
481                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
482                    );
483                }
484                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
485            }
486        }
487
488        // Signed integers consolidate to I64
489        let signed_cases = vec![
490            (
491                "i8",
492                ScalarValue(InnerScalarValue::Primitive(PValue::I8(-128))),
493                -128i64,
494            ),
495            (
496                "i16",
497                ScalarValue(InnerScalarValue::Primitive(PValue::I16(-32768))),
498                -32768i64,
499            ),
500            (
501                "i32",
502                ScalarValue(InnerScalarValue::Primitive(PValue::I32(-2147483648))),
503                -2147483648i64,
504            ),
505        ];
506
507        for (name, value, expected) in signed_cases {
508            let written = value.to_protobytes::<Vec<u8>>();
509            let read_back = ScalarValue::from_protobytes(&written).unwrap();
510
511            match &read_back.0 {
512                InnerScalarValue::Primitive(PValue::I64(v)) => {
513                    assert_eq!(
514                        *v, expected,
515                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
516                    );
517                }
518                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
519            }
520        }
521    }
522}