vortex_scalar/
proto.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use num_traits::ToBytes;
7use vortex_buffer::{BufferString, ByteBuffer};
8use vortex_dtype::DType;
9use vortex_error::{VortexError, vortex_err};
10use vortex_proto::scalar as pb;
11use vortex_proto::scalar::ListValue;
12use vortex_proto::scalar::scalar_value::Kind;
13
14use crate::pvalue::PValue;
15use crate::{DecimalValue, InnerScalarValue, Scalar, ScalarValue};
16
17impl From<&Scalar> for pb::Scalar {
18    fn from(value: &Scalar) -> Self {
19        pb::Scalar {
20            dtype: Some((&value.dtype).into()),
21            value: Some((&value.value).into()),
22        }
23    }
24}
25
26impl From<&ScalarValue> for pb::ScalarValue {
27    fn from(value: &ScalarValue) -> Self {
28        match value {
29            ScalarValue(InnerScalarValue::Null) => pb::ScalarValue {
30                kind: Some(Kind::NullValue(0)),
31            },
32            ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue {
33                kind: Some(Kind::BoolValue(*v)),
34            },
35            ScalarValue(InnerScalarValue::Primitive(v)) => v.into(),
36            ScalarValue(InnerScalarValue::Decimal(v)) => {
37                let inner_value = match v {
38                    DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
39                    DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
40                    DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
41                    DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
42                    DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
43                    DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
44                };
45
46                pb::ScalarValue {
47                    kind: Some(Kind::BytesValue(inner_value)),
48                }
49            }
50            ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue {
51                kind: Some(Kind::BytesValue(v.as_slice().to_vec())),
52            },
53            ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue {
54                kind: Some(Kind::StringValue(v.as_str().to_string())),
55            },
56            ScalarValue(InnerScalarValue::List(v)) => {
57                let mut values = Vec::with_capacity(v.len());
58                for elem in v.iter() {
59                    values.push(pb::ScalarValue::from(elem));
60                }
61                pb::ScalarValue {
62                    kind: Some(Kind::ListValue(ListValue { values })),
63                }
64            }
65        }
66    }
67}
68
69impl From<&PValue> for pb::ScalarValue {
70    fn from(value: &PValue) -> Self {
71        match value {
72            PValue::I8(v) => pb::ScalarValue {
73                kind: Some(Kind::Int64Value(*v as i64)),
74            },
75            PValue::I16(v) => pb::ScalarValue {
76                kind: Some(Kind::Int64Value(*v as i64)),
77            },
78            PValue::I32(v) => pb::ScalarValue {
79                kind: Some(Kind::Int64Value(*v as i64)),
80            },
81            PValue::I64(v) => pb::ScalarValue {
82                kind: Some(Kind::Int64Value(*v)),
83            },
84            PValue::U8(v) => pb::ScalarValue {
85                kind: Some(Kind::Uint64Value(*v as u64)),
86            },
87            PValue::U16(v) => pb::ScalarValue {
88                kind: Some(Kind::Uint64Value(*v as u64)),
89            },
90            PValue::U32(v) => pb::ScalarValue {
91                kind: Some(Kind::Uint64Value(*v as u64)),
92            },
93            PValue::U64(v) => pb::ScalarValue {
94                kind: Some(Kind::Uint64Value(*v)),
95            },
96            PValue::F16(v) => pb::ScalarValue {
97                kind: Some(Kind::Uint64Value(v.to_bits() as u64)),
98            },
99            PValue::F32(v) => pb::ScalarValue {
100                kind: Some(Kind::F32Value(*v)),
101            },
102            PValue::F64(v) => pb::ScalarValue {
103                kind: Some(Kind::F64Value(*v)),
104            },
105        }
106    }
107}
108
109impl TryFrom<&pb::Scalar> for Scalar {
110    type Error = VortexError;
111
112    fn try_from(value: &pb::Scalar) -> Result<Self, Self::Error> {
113        let dtype = DType::try_from(
114            value
115                .dtype
116                .as_ref()
117                .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?,
118        )?;
119
120        let value = ScalarValue::try_from(
121            value
122                .value
123                .as_ref()
124                .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?,
125        )?;
126
127        Ok(Self { dtype, value })
128    }
129}
130
131impl TryFrom<&pb::ScalarValue> for ScalarValue {
132    type Error = VortexError;
133
134    fn try_from(value: &pb::ScalarValue) -> Result<Self, Self::Error> {
135        let kind = value
136            .kind
137            .as_ref()
138            .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;
139
140        match kind {
141            Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)),
142            Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))),
143            Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))),
144            Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))),
145            Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))),
146            Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))),
147            Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new(
148                BufferString::from(v.clone()),
149            )))),
150            Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new(
151                ByteBuffer::from(v.clone()),
152            )))),
153            Kind::ListValue(v) => {
154                let mut values = Vec::with_capacity(v.values.len());
155                for elem in v.values.iter() {
156                    values.push(elem.try_into()?);
157                }
158                Ok(ScalarValue(InnerScalarValue::List(values.into())))
159            }
160        }
161    }
162}
163
164#[cfg(test)]
165mod test {
166    use std::sync::Arc;
167
168    use vortex_buffer::BufferString;
169    use vortex_dtype::PType::{self, I32};
170    use vortex_dtype::half::f16;
171    use vortex_dtype::{DType, Nullability};
172    use vortex_proto::scalar as pb;
173
174    use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
175
176    fn round_trip(scalar: Scalar) {
177        assert_eq!(
178            scalar,
179            Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(),
180        );
181    }
182
183    #[test]
184    fn test_null() {
185        round_trip(Scalar::null(DType::Null));
186    }
187
188    #[test]
189    fn test_bool() {
190        round_trip(Scalar::new(
191            DType::Bool(Nullability::Nullable),
192            ScalarValue(InnerScalarValue::Bool(true)),
193        ));
194    }
195
196    #[test]
197    fn test_primitive() {
198        round_trip(Scalar::new(
199            DType::Primitive(I32, Nullability::Nullable),
200            ScalarValue(InnerScalarValue::Primitive(42i32.into())),
201        ));
202    }
203
204    #[test]
205    fn test_buffer() {
206        round_trip(Scalar::new(
207            DType::Binary(Nullability::Nullable),
208            ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))),
209        ));
210    }
211
212    #[test]
213    fn test_buffer_string() {
214        round_trip(Scalar::new(
215            DType::Utf8(Nullability::Nullable),
216            ScalarValue(InnerScalarValue::BufferString(Arc::new(
217                BufferString::from("hello".to_string()),
218            ))),
219        ));
220    }
221
222    #[test]
223    fn test_list() {
224        round_trip(Scalar::new(
225            DType::List(
226                Arc::new(DType::Primitive(I32, Nullability::Nullable)),
227                Nullability::Nullable,
228            ),
229            ScalarValue(InnerScalarValue::List(
230                vec![
231                    ScalarValue(InnerScalarValue::Primitive(42i32.into())),
232                    ScalarValue(InnerScalarValue::Primitive(43i32.into())),
233                ]
234                .into(),
235            )),
236        ));
237    }
238
239    #[test]
240    fn test_f16() {
241        round_trip(Scalar::new(
242            DType::Primitive(PType::F16, Nullability::Nullable),
243            ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16::from_f32(
244                0.42,
245            )))),
246        ));
247    }
248
249    #[test]
250    fn test_i8() {
251        round_trip(Scalar::new(
252            DType::Primitive(PType::I8, Nullability::Nullable),
253            ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())),
254        ));
255
256        round_trip(Scalar::new(
257            DType::Primitive(PType::I8, Nullability::Nullable),
258            ScalarValue(InnerScalarValue::Primitive(0i8.into())),
259        ));
260
261        round_trip(Scalar::new(
262            DType::Primitive(PType::I8, Nullability::Nullable),
263            ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())),
264        ));
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use half::f16;
271    use rstest::rstest;
272    use vortex_dtype::{DType, DecimalDType, FieldDType, Nullability, PType, StructFields, half};
273
274    use super::*;
275    use crate::{Scalar, i256};
276
277    #[rstest]
278    #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))]
279    #[case(Scalar::utf8("hello", Nullability::NonNullable))]
280    #[case(Scalar::primitive(1u8, Nullability::NonNullable))]
281    #[case(Scalar::primitive(
282        f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])),
283        Nullability::NonNullable
284    ))]
285    #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable
286    ))]
287    #[case(Scalar::struct_(DType::Struct(
288        StructFields::from_iter([
289            ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))),
290            ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
291        ]),
292        Nullability::NonNullable),
293        vec![
294            Scalar::primitive(23592960u32, Nullability::NonNullable),
295            Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
296        ],
297    ))]
298    #[case(Scalar::struct_(DType::Struct(
299        StructFields::from_iter([
300            ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))),
301            ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))),
302            ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
303        ]),
304        Nullability::NonNullable),
305        vec![
306            Scalar::primitive(415118687234u64, Nullability::NonNullable),
307            Scalar::primitive(2.6584664e36f32, Nullability::NonNullable),
308            Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
309        ],
310    ))]
311    #[case(Scalar::decimal(
312        DecimalValue::I256(i256::from_i128(12345643673471)),
313        DecimalDType::new(10, 2),
314        Nullability::NonNullable
315    ))]
316    #[case(Scalar::decimal(
317        DecimalValue::I16(23412),
318        DecimalDType::new(3, 2),
319        Nullability::NonNullable
320    ))]
321    fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) {
322        let written = scalar.value.to_protobytes::<Vec<u8>>();
323        let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap();
324        assert_eq!(
325            scalar,
326            Scalar::new(scalar.dtype().clone(), scalar_read_back)
327        );
328    }
329}