vortex_scalar/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use num_traits::ToBytes;
7use vortex_buffer::BufferString;
8use vortex_buffer::ByteBuffer;
9use vortex_dtype::DType;
10use vortex_dtype::half::f16;
11use vortex_error::VortexError;
12use vortex_error::vortex_err;
13use vortex_proto::scalar as pb;
14use vortex_proto::scalar::ListValue;
15use vortex_proto::scalar::scalar_value::Kind;
16
17use crate::DecimalValue;
18use crate::InnerScalarValue;
19use crate::Scalar;
20use crate::ScalarValue;
21use crate::pvalue::PValue;
22
23impl From<&Scalar> for pb::Scalar {
24    fn from(value: &Scalar) -> Self {
25        pb::Scalar {
26            dtype: Some((value.dtype()).into()),
27            value: Some((value.value()).into()),
28        }
29    }
30}
31
32impl From<&ScalarValue> for pb::ScalarValue {
33    fn from(value: &ScalarValue) -> Self {
34        match value {
35            ScalarValue(InnerScalarValue::Null) => pb::ScalarValue {
36                kind: Some(Kind::NullValue(0)),
37            },
38            ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue {
39                kind: Some(Kind::BoolValue(*v)),
40            },
41            ScalarValue(InnerScalarValue::Primitive(v)) => v.into(),
42            ScalarValue(InnerScalarValue::Decimal(v)) => {
43                let inner_value = match v {
44                    DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
45                    DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
46                    DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
47                    DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
48                    DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
49                    DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
50                };
51
52                pb::ScalarValue {
53                    kind: Some(Kind::BytesValue(inner_value)),
54                }
55            }
56            ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue {
57                kind: Some(Kind::BytesValue(v.as_slice().to_vec())),
58            },
59            ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue {
60                kind: Some(Kind::StringValue(v.as_str().to_string())),
61            },
62            ScalarValue(InnerScalarValue::List(v)) => {
63                let mut values = Vec::with_capacity(v.len());
64                for elem in v.iter() {
65                    values.push(pb::ScalarValue::from(elem));
66                }
67                pb::ScalarValue {
68                    kind: Some(Kind::ListValue(ListValue { values })),
69                }
70            }
71        }
72    }
73}
74
75impl From<&PValue> for pb::ScalarValue {
76    fn from(value: &PValue) -> Self {
77        match value {
78            PValue::I8(v) => pb::ScalarValue {
79                kind: Some(Kind::Int64Value(*v as i64)),
80            },
81            PValue::I16(v) => pb::ScalarValue {
82                kind: Some(Kind::Int64Value(*v as i64)),
83            },
84            PValue::I32(v) => pb::ScalarValue {
85                kind: Some(Kind::Int64Value(*v as i64)),
86            },
87            PValue::I64(v) => pb::ScalarValue {
88                kind: Some(Kind::Int64Value(*v)),
89            },
90            PValue::U8(v) => pb::ScalarValue {
91                kind: Some(Kind::Uint64Value(*v as u64)),
92            },
93            PValue::U16(v) => pb::ScalarValue {
94                kind: Some(Kind::Uint64Value(*v as u64)),
95            },
96            PValue::U32(v) => pb::ScalarValue {
97                kind: Some(Kind::Uint64Value(*v as u64)),
98            },
99            PValue::U64(v) => pb::ScalarValue {
100                kind: Some(Kind::Uint64Value(*v)),
101            },
102            PValue::F16(v) => pb::ScalarValue {
103                kind: Some(Kind::F16Value(v.to_bits() as u64)),
104            },
105            PValue::F32(v) => pb::ScalarValue {
106                kind: Some(Kind::F32Value(*v)),
107            },
108            PValue::F64(v) => pb::ScalarValue {
109                kind: Some(Kind::F64Value(*v)),
110            },
111        }
112    }
113}
114
115impl TryFrom<&pb::Scalar> for Scalar {
116    type Error = VortexError;
117
118    fn try_from(value: &pb::Scalar) -> Result<Self, Self::Error> {
119        let dtype = DType::try_from(
120            value
121                .dtype
122                .as_ref()
123                .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?,
124        )?;
125
126        let value = ScalarValue::try_from(
127            value
128                .value
129                .as_ref()
130                .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?,
131        )?;
132
133        Ok(Scalar::new(dtype, value))
134    }
135}
136
137impl TryFrom<&pb::ScalarValue> for ScalarValue {
138    type Error = VortexError;
139
140    fn try_from(value: &pb::ScalarValue) -> Result<Self, Self::Error> {
141        let kind = value
142            .kind
143            .as_ref()
144            .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;
145
146        match kind {
147            Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)),
148            Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))),
149            Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))),
150            Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))),
151            Kind::F16Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F16(
152                f16::from_bits(u16::try_from(*v).map_err(|_| {
153                    vortex_err!("f16 bitwise representation has more than 16 bits: {}", v)
154                })?),
155            )))),
156            Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))),
157            Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))),
158            Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new(
159                BufferString::from(v.clone()),
160            )))),
161            Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new(
162                ByteBuffer::from(v.clone()),
163            )))),
164            Kind::ListValue(v) => {
165                let mut values = Vec::with_capacity(v.values.len());
166                for elem in v.values.iter() {
167                    values.push(elem.try_into()?);
168                }
169                Ok(ScalarValue(InnerScalarValue::List(values.into())))
170            }
171        }
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use std::sync::Arc;
178
179    use rstest::rstest;
180    use vortex_buffer::BufferString;
181    use vortex_dtype::DType;
182    use vortex_dtype::DecimalDType;
183    use vortex_dtype::FieldDType;
184    use vortex_dtype::Nullability;
185    use vortex_dtype::PType;
186    use vortex_dtype::StructFields;
187    use vortex_dtype::half::f16;
188    use vortex_dtype::i256;
189    use vortex_error::vortex_panic;
190    use vortex_proto::scalar as pb;
191
192    use super::*;
193    use crate::InnerScalarValue;
194    use crate::Scalar;
195    use crate::ScalarValue;
196
197    fn round_trip(scalar: Scalar) {
198        assert_eq!(
199            scalar,
200            Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(),
201        );
202    }
203
204    #[test]
205    fn test_null() {
206        round_trip(Scalar::null(DType::Null));
207    }
208
209    #[test]
210    fn test_bool() {
211        round_trip(Scalar::new(
212            DType::Bool(Nullability::Nullable),
213            ScalarValue(InnerScalarValue::Bool(true)),
214        ));
215    }
216
217    #[test]
218    fn test_primitive() {
219        round_trip(Scalar::new(
220            DType::Primitive(PType::I32, Nullability::Nullable),
221            ScalarValue(InnerScalarValue::Primitive(42i32.into())),
222        ));
223    }
224
225    #[test]
226    fn test_buffer() {
227        round_trip(Scalar::new(
228            DType::Binary(Nullability::Nullable),
229            ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))),
230        ));
231    }
232
233    #[test]
234    fn test_buffer_string() {
235        round_trip(Scalar::new(
236            DType::Utf8(Nullability::Nullable),
237            ScalarValue(InnerScalarValue::BufferString(Arc::new(
238                BufferString::from("hello".to_string()),
239            ))),
240        ));
241    }
242
243    #[test]
244    fn test_list() {
245        round_trip(Scalar::new(
246            DType::List(
247                Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
248                Nullability::Nullable,
249            ),
250            ScalarValue(InnerScalarValue::List(
251                vec![
252                    ScalarValue(InnerScalarValue::Primitive(42i32.into())),
253                    ScalarValue(InnerScalarValue::Primitive(43i32.into())),
254                ]
255                .into(),
256            )),
257        ));
258    }
259
260    #[test]
261    fn test_f16() {
262        round_trip(Scalar::primitive(
263            f16::from_f32(0.42),
264            Nullability::Nullable,
265        ));
266    }
267
268    #[test]
269    fn test_i8() {
270        round_trip(Scalar::new(
271            DType::Primitive(PType::I8, Nullability::Nullable),
272            ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())),
273        ));
274
275        round_trip(Scalar::new(
276            DType::Primitive(PType::I8, Nullability::Nullable),
277            ScalarValue(InnerScalarValue::Primitive(0i8.into())),
278        ));
279
280        round_trip(Scalar::new(
281            DType::Primitive(PType::I8, Nullability::Nullable),
282            ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())),
283        ));
284    }
285
286    #[rstest]
287    #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))]
288    #[case(Scalar::utf8("hello", Nullability::NonNullable))]
289    #[case(Scalar::primitive(1u8, Nullability::NonNullable))]
290    #[case(Scalar::primitive(
291        f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])),
292        Nullability::NonNullable
293    ))]
294    #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable
295    ))]
296    #[case(Scalar::struct_(DType::Struct(
297        StructFields::from_iter([
298            ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))),
299            ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
300        ]),
301        Nullability::NonNullable),
302        vec![
303            Scalar::primitive(23592960u32, Nullability::NonNullable),
304            Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
305        ],
306    ))]
307    #[case(Scalar::struct_(DType::Struct(
308        StructFields::from_iter([
309            ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))),
310            ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))),
311            ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
312        ]),
313        Nullability::NonNullable),
314        vec![
315            Scalar::primitive(415118687234u64, Nullability::NonNullable),
316            Scalar::primitive(2.6584664e36f32, Nullability::NonNullable),
317            Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
318        ],
319    ))]
320    #[case(Scalar::decimal(
321        DecimalValue::I256(i256::from_i128(12345643673471)),
322        DecimalDType::new(10, 2),
323        Nullability::NonNullable
324    ))]
325    #[case(Scalar::decimal(
326        DecimalValue::I16(23412),
327        DecimalDType::new(3, 2),
328        Nullability::NonNullable
329    ))]
330    fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) {
331        let written = scalar.value().to_protobytes::<Vec<u8>>();
332        let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap();
333        assert_eq!(
334            Scalar::new(scalar.dtype().clone(), scalar_read_back),
335            scalar
336        );
337    }
338
339    #[test]
340    fn test_backcompat_f16_serialized_as_u64() {
341        // Note that this is a backwards compatibility test for poor design in the previous implementation.
342        // Previously, f16 ScalarValues were serialized as `pb::ScalarValue::Uint64Value(v.to_bits() as u64)`.
343        let pb_scalar_value = pb::ScalarValue {
344            kind: Some(Kind::Uint64Value(f16::from_f32(0.42).to_bits() as u64)),
345        };
346        let scalar_value = ScalarValue::try_from(&pb_scalar_value).unwrap();
347        assert_eq!(
348            scalar_value.as_pvalue().unwrap(),
349            Some(PValue::U64(14008u64))
350        );
351
352        let scalar = Scalar::new(
353            DType::Primitive(PType::F16, Nullability::Nullable),
354            scalar_value,
355        );
356
357        assert_eq!(
358            scalar.as_primitive().pvalue().unwrap(),
359            PValue::F16(f16::from_f32(0.42))
360        );
361    }
362
363    #[test]
364    fn test_scalar_value_direct_roundtrip_f16() {
365        // Test that ScalarValue with f16 roundtrips correctly without going through Scalar
366        let f16_values = vec![
367            f16::from_f32(0.0),
368            f16::from_f32(1.0),
369            f16::from_f32(-1.0),
370            f16::from_f32(0.42),
371            f16::from_f32(5.722046e-6),
372            f16::from_f32(std::f32::consts::PI),
373            f16::INFINITY,
374            f16::NEG_INFINITY,
375            f16::NAN,
376        ];
377
378        for f16_val in f16_values {
379            let scalar_value = ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16_val)));
380            let written = scalar_value.to_protobytes::<Vec<u8>>();
381            let read_back = ScalarValue::from_protobytes(&written).unwrap();
382
383            match (&scalar_value.0, &read_back.0) {
384                (
385                    InnerScalarValue::Primitive(PValue::F16(original)),
386                    InnerScalarValue::Primitive(PValue::F16(roundtripped)),
387                ) => {
388                    if original.is_nan() && roundtripped.is_nan() {
389                        // NaN values are equal for our purposes
390                        continue;
391                    }
392                    assert_eq!(
393                        original, roundtripped,
394                        "F16 value {original:?} did not roundtrip correctly"
395                    );
396                }
397                _ => {
398                    vortex_panic!(
399                        "Expected f16 primitive values, got {scalar_value:?} and {read_back:?}"
400                    )
401                }
402            }
403        }
404    }
405
406    #[test]
407    fn test_scalar_value_direct_roundtrip_preserves_values() {
408        // Test that ScalarValue roundtripping preserves values (but not necessarily exact types)
409        // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64)
410
411        // Test cases that should roundtrip exactly
412        let exact_roundtrip_cases = vec![
413            ("null", ScalarValue(InnerScalarValue::Null)),
414            ("bool_true", ScalarValue(InnerScalarValue::Bool(true))),
415            ("bool_false", ScalarValue(InnerScalarValue::Bool(false))),
416            (
417                "u64",
418                ScalarValue(InnerScalarValue::Primitive(PValue::U64(
419                    18446744073709551615,
420                ))),
421            ),
422            (
423                "i64",
424                ScalarValue(InnerScalarValue::Primitive(PValue::I64(
425                    -9223372036854775808,
426                ))),
427            ),
428            (
429                "f32",
430                ScalarValue(InnerScalarValue::Primitive(PValue::F32(
431                    std::f32::consts::E,
432                ))),
433            ),
434            (
435                "f64",
436                ScalarValue(InnerScalarValue::Primitive(PValue::F64(
437                    std::f64::consts::PI,
438                ))),
439            ),
440            (
441                "string",
442                ScalarValue(InnerScalarValue::BufferString(Arc::new(
443                    BufferString::from("test"),
444                ))),
445            ),
446            (
447                "bytes",
448                ScalarValue(InnerScalarValue::Buffer(Arc::new(
449                    vec![1, 2, 3, 4, 5].into(),
450                ))),
451            ),
452        ];
453
454        for (name, value) in exact_roundtrip_cases {
455            let written = value.to_protobytes::<Vec<u8>>();
456            let read_back = ScalarValue::from_protobytes(&written).unwrap();
457
458            let original_debug = format!("{value:?}");
459            let roundtrip_debug = format!("{read_back:?}");
460            assert_eq!(
461                original_debug, roundtrip_debug,
462                "ScalarValue {name} did not roundtrip exactly"
463            );
464        }
465
466        // Test cases where type changes but value is preserved
467        // Unsigned integers consolidate to U64
468        let unsigned_cases = vec![
469            (
470                "u8",
471                ScalarValue(InnerScalarValue::Primitive(PValue::U8(255))),
472                255u64,
473            ),
474            (
475                "u16",
476                ScalarValue(InnerScalarValue::Primitive(PValue::U16(65535))),
477                65535u64,
478            ),
479            (
480                "u32",
481                ScalarValue(InnerScalarValue::Primitive(PValue::U32(4294967295))),
482                4294967295u64,
483            ),
484        ];
485
486        for (name, value, expected) in unsigned_cases {
487            let written = value.to_protobytes::<Vec<u8>>();
488            let read_back = ScalarValue::from_protobytes(&written).unwrap();
489
490            match &read_back.0 {
491                InnerScalarValue::Primitive(PValue::U64(v)) => {
492                    assert_eq!(
493                        *v, expected,
494                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
495                    );
496                }
497                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
498            }
499        }
500
501        // Signed integers consolidate to I64
502        let signed_cases = vec![
503            (
504                "i8",
505                ScalarValue(InnerScalarValue::Primitive(PValue::I8(-128))),
506                -128i64,
507            ),
508            (
509                "i16",
510                ScalarValue(InnerScalarValue::Primitive(PValue::I16(-32768))),
511                -32768i64,
512            ),
513            (
514                "i32",
515                ScalarValue(InnerScalarValue::Primitive(PValue::I32(-2147483648))),
516                -2147483648i64,
517            ),
518        ];
519
520        for (name, value, expected) in signed_cases {
521            let written = value.to_protobytes::<Vec<u8>>();
522            let read_back = ScalarValue::from_protobytes(&written).unwrap();
523
524            match &read_back.0 {
525                InnerScalarValue::Primitive(PValue::I64(v)) => {
526                    assert_eq!(
527                        *v, expected,
528                        "ScalarValue {name} value not preserved: expected {expected}, got {v}"
529                    );
530                }
531                _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
532            }
533        }
534    }
535}