1use std::sync::Arc;
2
3use num_traits::ToBytes;
4use vortex_buffer::{BufferString, ByteBuffer};
5use vortex_dtype::DType;
6use vortex_error::{VortexError, vortex_err};
7use vortex_proto::scalar as pb;
8use vortex_proto::scalar::ListValue;
9use vortex_proto::scalar::scalar_value::Kind;
10
11use crate::pvalue::PValue;
12use crate::{DecimalValue, InnerScalarValue, Scalar, ScalarValue};
13
14impl From<&Scalar> for pb::Scalar {
15 fn from(value: &Scalar) -> Self {
16 pb::Scalar {
17 dtype: Some((&value.dtype).into()),
18 value: Some((&value.value).into()),
19 }
20 }
21}
22
23impl From<&ScalarValue> for pb::ScalarValue {
24 fn from(value: &ScalarValue) -> Self {
25 match value {
26 ScalarValue(InnerScalarValue::Null) => pb::ScalarValue {
27 kind: Some(Kind::NullValue(0)),
28 },
29 ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue {
30 kind: Some(Kind::BoolValue(*v)),
31 },
32 ScalarValue(InnerScalarValue::Primitive(v)) => v.into(),
33 ScalarValue(InnerScalarValue::Decimal(v)) => {
34 let inner_value = match v {
35 DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
36 DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
37 DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
38 DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
39 DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
40 DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
41 };
42
43 pb::ScalarValue {
44 kind: Some(Kind::BytesValue(inner_value)),
45 }
46 }
47 ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue {
48 kind: Some(Kind::BytesValue(v.as_slice().to_vec())),
49 },
50 ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue {
51 kind: Some(Kind::StringValue(v.as_str().to_string())),
52 },
53 ScalarValue(InnerScalarValue::List(v)) => {
54 let mut values = Vec::with_capacity(v.len());
55 for elem in v.iter() {
56 values.push(pb::ScalarValue::from(elem));
57 }
58 pb::ScalarValue {
59 kind: Some(Kind::ListValue(ListValue { values })),
60 }
61 }
62 }
63 }
64}
65
66impl From<&PValue> for pb::ScalarValue {
67 fn from(value: &PValue) -> Self {
68 match value {
69 PValue::I8(v) => pb::ScalarValue {
70 kind: Some(Kind::Int64Value(*v as i64)),
71 },
72 PValue::I16(v) => pb::ScalarValue {
73 kind: Some(Kind::Int64Value(*v as i64)),
74 },
75 PValue::I32(v) => pb::ScalarValue {
76 kind: Some(Kind::Int64Value(*v as i64)),
77 },
78 PValue::I64(v) => pb::ScalarValue {
79 kind: Some(Kind::Int64Value(*v)),
80 },
81 PValue::U8(v) => pb::ScalarValue {
82 kind: Some(Kind::Uint64Value(*v as u64)),
83 },
84 PValue::U16(v) => pb::ScalarValue {
85 kind: Some(Kind::Uint64Value(*v as u64)),
86 },
87 PValue::U32(v) => pb::ScalarValue {
88 kind: Some(Kind::Uint64Value(*v as u64)),
89 },
90 PValue::U64(v) => pb::ScalarValue {
91 kind: Some(Kind::Uint64Value(*v)),
92 },
93 PValue::F16(v) => pb::ScalarValue {
94 kind: Some(Kind::Uint64Value(v.to_bits() as u64)),
95 },
96 PValue::F32(v) => pb::ScalarValue {
97 kind: Some(Kind::F32Value(*v)),
98 },
99 PValue::F64(v) => pb::ScalarValue {
100 kind: Some(Kind::F64Value(*v)),
101 },
102 }
103 }
104}
105
106impl TryFrom<&pb::Scalar> for Scalar {
107 type Error = VortexError;
108
109 fn try_from(value: &pb::Scalar) -> Result<Self, Self::Error> {
110 let dtype = DType::try_from(
111 value
112 .dtype
113 .as_ref()
114 .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?,
115 )?;
116
117 let value = ScalarValue::try_from(
118 value
119 .value
120 .as_ref()
121 .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?,
122 )?;
123
124 Ok(Self { dtype, value })
125 }
126}
127
128impl TryFrom<&pb::ScalarValue> for ScalarValue {
129 type Error = VortexError;
130
131 fn try_from(value: &pb::ScalarValue) -> Result<Self, Self::Error> {
132 let kind = value
133 .kind
134 .as_ref()
135 .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;
136
137 match kind {
138 Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)),
139 Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))),
140 Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))),
141 Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))),
142 Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))),
143 Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))),
144 Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new(
145 BufferString::from(v.clone()),
146 )))),
147 Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new(
148 ByteBuffer::from(v.clone()),
149 )))),
150 Kind::ListValue(v) => {
151 let mut values = Vec::with_capacity(v.values.len());
152 for elem in v.values.iter() {
153 values.push(elem.try_into()?);
154 }
155 Ok(ScalarValue(InnerScalarValue::List(values.into())))
156 }
157 }
158 }
159}
160
161#[cfg(test)]
162mod test {
163 use std::sync::Arc;
164
165 use vortex_buffer::BufferString;
166 use vortex_dtype::PType::{self, I32};
167 use vortex_dtype::half::f16;
168 use vortex_dtype::{DType, Nullability};
169 use vortex_proto::scalar as pb;
170
171 use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
172
173 fn round_trip(scalar: Scalar) {
174 assert_eq!(
175 scalar,
176 Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(),
177 );
178 }
179
180 #[test]
181 fn test_null() {
182 round_trip(Scalar::null(DType::Null));
183 }
184
185 #[test]
186 fn test_bool() {
187 round_trip(Scalar::new(
188 DType::Bool(Nullability::Nullable),
189 ScalarValue(InnerScalarValue::Bool(true)),
190 ));
191 }
192
193 #[test]
194 fn test_primitive() {
195 round_trip(Scalar::new(
196 DType::Primitive(I32, Nullability::Nullable),
197 ScalarValue(InnerScalarValue::Primitive(42i32.into())),
198 ));
199 }
200
201 #[test]
202 fn test_buffer() {
203 round_trip(Scalar::new(
204 DType::Binary(Nullability::Nullable),
205 ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))),
206 ));
207 }
208
209 #[test]
210 fn test_buffer_string() {
211 round_trip(Scalar::new(
212 DType::Utf8(Nullability::Nullable),
213 ScalarValue(InnerScalarValue::BufferString(Arc::new(
214 BufferString::from("hello".to_string()),
215 ))),
216 ));
217 }
218
219 #[test]
220 fn test_list() {
221 round_trip(Scalar::new(
222 DType::List(
223 Arc::new(DType::Primitive(I32, Nullability::Nullable)),
224 Nullability::Nullable,
225 ),
226 ScalarValue(InnerScalarValue::List(
227 vec![
228 ScalarValue(InnerScalarValue::Primitive(42i32.into())),
229 ScalarValue(InnerScalarValue::Primitive(43i32.into())),
230 ]
231 .into(),
232 )),
233 ));
234 }
235
236 #[test]
237 fn test_f16() {
238 round_trip(Scalar::new(
239 DType::Primitive(PType::F16, Nullability::Nullable),
240 ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16::from_f32(
241 0.42,
242 )))),
243 ));
244 }
245
246 #[test]
247 fn test_i8() {
248 round_trip(Scalar::new(
249 DType::Primitive(PType::I8, Nullability::Nullable),
250 ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())),
251 ));
252
253 round_trip(Scalar::new(
254 DType::Primitive(PType::I8, Nullability::Nullable),
255 ScalarValue(InnerScalarValue::Primitive(0i8.into())),
256 ));
257
258 round_trip(Scalar::new(
259 DType::Primitive(PType::I8, Nullability::Nullable),
260 ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())),
261 ));
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use half::f16;
268 use rstest::rstest;
269 use vortex_dtype::{DType, DecimalDType, FieldDType, Nullability, PType, StructFields, half};
270
271 use super::*;
272 use crate::{Scalar, i256};
273
274 #[rstest]
275 #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))]
276 #[case(Scalar::utf8("hello", Nullability::NonNullable))]
277 #[case(Scalar::primitive(1u8, Nullability::NonNullable))]
278 #[case(Scalar::primitive(
279 f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])),
280 Nullability::NonNullable
281 ))]
282 #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable
283 ))]
284 #[case(Scalar::struct_(DType::Struct(
285 Arc::new(StructFields::from_iter([
286 ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))),
287 ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
288 ])),
289 Nullability::NonNullable),
290 vec![
291 Scalar::primitive(23592960, Nullability::NonNullable),
292 Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
293 ],
294 ))]
295 #[case(Scalar::struct_(DType::Struct(
296 Arc::new(StructFields::from_iter([
297 ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))),
298 ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))),
299 ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
300 ])),
301 Nullability::NonNullable),
302 vec![
303 Scalar::primitive(415118687234i64, Nullability::NonNullable),
304 Scalar::primitive(2.6584664e36f32, Nullability::NonNullable),
305 Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
306 ],
307 ))]
308 #[case(Scalar::decimal(
309 DecimalValue::I256(i256::from_i128(12345643673471)),
310 DecimalDType::new(10, 2),
311 Nullability::NonNullable
312 ))]
313 #[case(Scalar::decimal(
314 DecimalValue::I16(23412),
315 DecimalDType::new(3, 2),
316 Nullability::NonNullable
317 ))]
318 fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) {
319 let written = scalar.value.to_protobytes::<Vec<u8>>();
320 let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap();
321 assert_eq!(
322 scalar,
323 Scalar::new(scalar.dtype().clone(), scalar_read_back)
324 );
325 }
326}