vortex_scalar/
proto.rs

1use std::sync::Arc;
2
3use num_traits::ToBytes;
4use vortex_buffer::{BufferString, ByteBuffer};
5use vortex_dtype::DType;
6use vortex_error::{VortexError, vortex_err};
7use vortex_proto::scalar as pb;
8use vortex_proto::scalar::ListValue;
9use vortex_proto::scalar::scalar_value::Kind;
10
11use crate::pvalue::PValue;
12use crate::{DecimalValue, InnerScalarValue, Scalar, ScalarValue};
13
14impl From<&Scalar> for pb::Scalar {
15    fn from(value: &Scalar) -> Self {
16        pb::Scalar {
17            dtype: Some((&value.dtype).into()),
18            value: Some((&value.value).into()),
19        }
20    }
21}
22
23impl From<&ScalarValue> for pb::ScalarValue {
24    fn from(value: &ScalarValue) -> Self {
25        match value {
26            ScalarValue(InnerScalarValue::Null) => pb::ScalarValue {
27                kind: Some(Kind::NullValue(0)),
28            },
29            ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue {
30                kind: Some(Kind::BoolValue(*v)),
31            },
32            ScalarValue(InnerScalarValue::Primitive(v)) => v.into(),
33            ScalarValue(InnerScalarValue::Decimal(v)) => {
34                let inner_value = match v {
35                    DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
36                    DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
37                    DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
38                    DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
39                    DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
40                    DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
41                };
42
43                pb::ScalarValue {
44                    kind: Some(Kind::BytesValue(inner_value)),
45                }
46            }
47            ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue {
48                kind: Some(Kind::BytesValue(v.as_slice().to_vec())),
49            },
50            ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue {
51                kind: Some(Kind::StringValue(v.as_str().to_string())),
52            },
53            ScalarValue(InnerScalarValue::List(v)) => {
54                let mut values = Vec::with_capacity(v.len());
55                for elem in v.iter() {
56                    values.push(pb::ScalarValue::from(elem));
57                }
58                pb::ScalarValue {
59                    kind: Some(Kind::ListValue(ListValue { values })),
60                }
61            }
62        }
63    }
64}
65
66impl From<&PValue> for pb::ScalarValue {
67    fn from(value: &PValue) -> Self {
68        match value {
69            PValue::I8(v) => pb::ScalarValue {
70                kind: Some(Kind::Int64Value(*v as i64)),
71            },
72            PValue::I16(v) => pb::ScalarValue {
73                kind: Some(Kind::Int64Value(*v as i64)),
74            },
75            PValue::I32(v) => pb::ScalarValue {
76                kind: Some(Kind::Int64Value(*v as i64)),
77            },
78            PValue::I64(v) => pb::ScalarValue {
79                kind: Some(Kind::Int64Value(*v)),
80            },
81            PValue::U8(v) => pb::ScalarValue {
82                kind: Some(Kind::Uint64Value(*v as u64)),
83            },
84            PValue::U16(v) => pb::ScalarValue {
85                kind: Some(Kind::Uint64Value(*v as u64)),
86            },
87            PValue::U32(v) => pb::ScalarValue {
88                kind: Some(Kind::Uint64Value(*v as u64)),
89            },
90            PValue::U64(v) => pb::ScalarValue {
91                kind: Some(Kind::Uint64Value(*v)),
92            },
93            PValue::F16(v) => pb::ScalarValue {
94                kind: Some(Kind::Uint64Value(v.to_bits() as u64)),
95            },
96            PValue::F32(v) => pb::ScalarValue {
97                kind: Some(Kind::F32Value(*v)),
98            },
99            PValue::F64(v) => pb::ScalarValue {
100                kind: Some(Kind::F64Value(*v)),
101            },
102        }
103    }
104}
105
106impl TryFrom<&pb::Scalar> for Scalar {
107    type Error = VortexError;
108
109    fn try_from(value: &pb::Scalar) -> Result<Self, Self::Error> {
110        let dtype = DType::try_from(
111            value
112                .dtype
113                .as_ref()
114                .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?,
115        )?;
116
117        let value = ScalarValue::try_from(
118            value
119                .value
120                .as_ref()
121                .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?,
122        )?;
123
124        Ok(Self { dtype, value })
125    }
126}
127
128impl TryFrom<&pb::ScalarValue> for ScalarValue {
129    type Error = VortexError;
130
131    fn try_from(value: &pb::ScalarValue) -> Result<Self, Self::Error> {
132        let kind = value
133            .kind
134            .as_ref()
135            .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;
136
137        match kind {
138            Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)),
139            Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))),
140            Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))),
141            Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))),
142            Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))),
143            Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))),
144            Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new(
145                BufferString::from(v.clone()),
146            )))),
147            Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new(
148                ByteBuffer::from(v.clone()),
149            )))),
150            Kind::ListValue(v) => {
151                let mut values = Vec::with_capacity(v.values.len());
152                for elem in v.values.iter() {
153                    values.push(elem.try_into()?);
154                }
155                Ok(ScalarValue(InnerScalarValue::List(values.into())))
156            }
157        }
158    }
159}
160
161#[cfg(test)]
162mod test {
163    use std::sync::Arc;
164
165    use vortex_buffer::BufferString;
166    use vortex_dtype::PType::{self, I32};
167    use vortex_dtype::half::f16;
168    use vortex_dtype::{DType, Nullability};
169    use vortex_proto::scalar as pb;
170
171    use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
172
173    fn round_trip(scalar: Scalar) {
174        assert_eq!(
175            scalar,
176            Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(),
177        );
178    }
179
180    #[test]
181    fn test_null() {
182        round_trip(Scalar::null(DType::Null));
183    }
184
185    #[test]
186    fn test_bool() {
187        round_trip(Scalar::new(
188            DType::Bool(Nullability::Nullable),
189            ScalarValue(InnerScalarValue::Bool(true)),
190        ));
191    }
192
193    #[test]
194    fn test_primitive() {
195        round_trip(Scalar::new(
196            DType::Primitive(I32, Nullability::Nullable),
197            ScalarValue(InnerScalarValue::Primitive(42i32.into())),
198        ));
199    }
200
201    #[test]
202    fn test_buffer() {
203        round_trip(Scalar::new(
204            DType::Binary(Nullability::Nullable),
205            ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))),
206        ));
207    }
208
209    #[test]
210    fn test_buffer_string() {
211        round_trip(Scalar::new(
212            DType::Utf8(Nullability::Nullable),
213            ScalarValue(InnerScalarValue::BufferString(Arc::new(
214                BufferString::from("hello".to_string()),
215            ))),
216        ));
217    }
218
219    #[test]
220    fn test_list() {
221        round_trip(Scalar::new(
222            DType::List(
223                Arc::new(DType::Primitive(I32, Nullability::Nullable)),
224                Nullability::Nullable,
225            ),
226            ScalarValue(InnerScalarValue::List(
227                vec![
228                    ScalarValue(InnerScalarValue::Primitive(42i32.into())),
229                    ScalarValue(InnerScalarValue::Primitive(43i32.into())),
230                ]
231                .into(),
232            )),
233        ));
234    }
235
236    #[test]
237    fn test_f16() {
238        round_trip(Scalar::new(
239            DType::Primitive(PType::F16, Nullability::Nullable),
240            ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16::from_f32(
241                0.42,
242            )))),
243        ));
244    }
245
246    #[test]
247    fn test_i8() {
248        round_trip(Scalar::new(
249            DType::Primitive(PType::I8, Nullability::Nullable),
250            ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())),
251        ));
252
253        round_trip(Scalar::new(
254            DType::Primitive(PType::I8, Nullability::Nullable),
255            ScalarValue(InnerScalarValue::Primitive(0i8.into())),
256        ));
257
258        round_trip(Scalar::new(
259            DType::Primitive(PType::I8, Nullability::Nullable),
260            ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())),
261        ));
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use half::f16;
268    use rstest::rstest;
269    use vortex_dtype::{DType, DecimalDType, FieldDType, Nullability, PType, StructFields, half};
270
271    use super::*;
272    use crate::{Scalar, i256};
273
274    #[rstest]
275    #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))]
276    #[case(Scalar::utf8("hello", Nullability::NonNullable))]
277    #[case(Scalar::primitive(1u8, Nullability::NonNullable))]
278    #[case(Scalar::primitive(
279        f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])),
280        Nullability::NonNullable
281    ))]
282    #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable
283    ))]
284    #[case(Scalar::struct_(DType::Struct(
285        Arc::new(StructFields::from_iter([
286            ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))),
287            ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
288        ])),
289        Nullability::NonNullable),
290        vec![
291            Scalar::primitive(23592960, Nullability::NonNullable),
292            Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
293        ],
294    ))]
295    #[case(Scalar::struct_(DType::Struct(
296        Arc::new(StructFields::from_iter([
297            ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))),
298            ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))),
299            ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
300        ])),
301        Nullability::NonNullable),
302        vec![
303            Scalar::primitive(415118687234i64, Nullability::NonNullable),
304            Scalar::primitive(2.6584664e36f32, Nullability::NonNullable),
305            Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
306        ],
307    ))]
308    #[case(Scalar::decimal(
309        DecimalValue::I256(i256::from_i128(12345643673471)),
310        DecimalDType::new(10, 2),
311        Nullability::NonNullable
312    ))]
313    #[case(Scalar::decimal(
314        DecimalValue::I16(23412),
315        DecimalDType::new(3, 2),
316        Nullability::NonNullable
317    ))]
318    fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) {
319        let written = scalar.value.to_protobytes::<Vec<u8>>();
320        let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap();
321        assert_eq!(
322            scalar,
323            Scalar::new(scalar.dtype().clone(), scalar_read_back)
324        );
325    }
326}