1use std::sync::Arc;
5
6use num_traits::ToBytes;
7use vortex_buffer::{BufferString, ByteBuffer};
8use vortex_dtype::DType;
9use vortex_dtype::half::f16;
10use vortex_error::{VortexError, vortex_err};
11use vortex_proto::scalar as pb;
12use vortex_proto::scalar::ListValue;
13use vortex_proto::scalar::scalar_value::Kind;
14
15use crate::pvalue::PValue;
16use crate::{DecimalValue, InnerScalarValue, Scalar, ScalarValue};
17
18impl From<&Scalar> for pb::Scalar {
19 fn from(value: &Scalar) -> Self {
20 pb::Scalar {
21 dtype: Some((value.dtype()).into()),
22 value: Some((value.value()).into()),
23 }
24 }
25}
26
27impl From<&ScalarValue> for pb::ScalarValue {
28 fn from(value: &ScalarValue) -> Self {
29 match value {
30 ScalarValue(InnerScalarValue::Null) => pb::ScalarValue {
31 kind: Some(Kind::NullValue(0)),
32 },
33 ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue {
34 kind: Some(Kind::BoolValue(*v)),
35 },
36 ScalarValue(InnerScalarValue::Primitive(v)) => v.into(),
37 ScalarValue(InnerScalarValue::Decimal(v)) => {
38 let inner_value = match v {
39 DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
40 DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
41 DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
42 DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
43 DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
44 DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
45 };
46
47 pb::ScalarValue {
48 kind: Some(Kind::BytesValue(inner_value)),
49 }
50 }
51 ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue {
52 kind: Some(Kind::BytesValue(v.as_slice().to_vec())),
53 },
54 ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue {
55 kind: Some(Kind::StringValue(v.as_str().to_string())),
56 },
57 ScalarValue(InnerScalarValue::List(v)) => {
58 let mut values = Vec::with_capacity(v.len());
59 for elem in v.iter() {
60 values.push(pb::ScalarValue::from(elem));
61 }
62 pb::ScalarValue {
63 kind: Some(Kind::ListValue(ListValue { values })),
64 }
65 }
66 }
67 }
68}
69
70impl From<&PValue> for pb::ScalarValue {
71 fn from(value: &PValue) -> Self {
72 match value {
73 PValue::I8(v) => pb::ScalarValue {
74 kind: Some(Kind::Int64Value(*v as i64)),
75 },
76 PValue::I16(v) => pb::ScalarValue {
77 kind: Some(Kind::Int64Value(*v as i64)),
78 },
79 PValue::I32(v) => pb::ScalarValue {
80 kind: Some(Kind::Int64Value(*v as i64)),
81 },
82 PValue::I64(v) => pb::ScalarValue {
83 kind: Some(Kind::Int64Value(*v)),
84 },
85 PValue::U8(v) => pb::ScalarValue {
86 kind: Some(Kind::Uint64Value(*v as u64)),
87 },
88 PValue::U16(v) => pb::ScalarValue {
89 kind: Some(Kind::Uint64Value(*v as u64)),
90 },
91 PValue::U32(v) => pb::ScalarValue {
92 kind: Some(Kind::Uint64Value(*v as u64)),
93 },
94 PValue::U64(v) => pb::ScalarValue {
95 kind: Some(Kind::Uint64Value(*v)),
96 },
97 PValue::F16(v) => pb::ScalarValue {
98 kind: Some(Kind::F16Value(v.to_bits() as u64)),
99 },
100 PValue::F32(v) => pb::ScalarValue {
101 kind: Some(Kind::F32Value(*v)),
102 },
103 PValue::F64(v) => pb::ScalarValue {
104 kind: Some(Kind::F64Value(*v)),
105 },
106 }
107 }
108}
109
110impl TryFrom<&pb::Scalar> for Scalar {
111 type Error = VortexError;
112
113 fn try_from(value: &pb::Scalar) -> Result<Self, Self::Error> {
114 let dtype = DType::try_from(
115 value
116 .dtype
117 .as_ref()
118 .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?,
119 )?;
120
121 let value = ScalarValue::try_from(
122 value
123 .value
124 .as_ref()
125 .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?,
126 )?;
127
128 Ok(Scalar::new(dtype, value))
129 }
130}
131
132impl TryFrom<&pb::ScalarValue> for ScalarValue {
133 type Error = VortexError;
134
135 fn try_from(value: &pb::ScalarValue) -> Result<Self, Self::Error> {
136 let kind = value
137 .kind
138 .as_ref()
139 .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;
140
141 match kind {
142 Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)),
143 Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))),
144 Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))),
145 Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))),
146 Kind::F16Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F16(
147 f16::from_bits(u16::try_from(*v).map_err(|_| {
148 vortex_err!("f16 bitwise representation has more than 16 bits: {}", v)
149 })?),
150 )))),
151 Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))),
152 Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))),
153 Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new(
154 BufferString::from(v.clone()),
155 )))),
156 Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new(
157 ByteBuffer::from(v.clone()),
158 )))),
159 Kind::ListValue(v) => {
160 let mut values = Vec::with_capacity(v.values.len());
161 for elem in v.values.iter() {
162 values.push(elem.try_into()?);
163 }
164 Ok(ScalarValue(InnerScalarValue::List(values.into())))
165 }
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use std::sync::Arc;
173
174 use rstest::rstest;
175 use vortex_buffer::BufferString;
176 use vortex_dtype::half::f16;
177 use vortex_dtype::{DType, DecimalDType, FieldDType, Nullability, PType, StructFields, i256};
178 use vortex_error::vortex_panic;
179 use vortex_proto::scalar as pb;
180
181 use super::*;
182 use crate::{InnerScalarValue, Scalar, ScalarValue};
183
184 fn round_trip(scalar: Scalar) {
185 assert_eq!(
186 scalar,
187 Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(),
188 );
189 }
190
191 #[test]
192 fn test_null() {
193 round_trip(Scalar::null(DType::Null));
194 }
195
196 #[test]
197 fn test_bool() {
198 round_trip(Scalar::new(
199 DType::Bool(Nullability::Nullable),
200 ScalarValue(InnerScalarValue::Bool(true)),
201 ));
202 }
203
204 #[test]
205 fn test_primitive() {
206 round_trip(Scalar::new(
207 DType::Primitive(PType::I32, Nullability::Nullable),
208 ScalarValue(InnerScalarValue::Primitive(42i32.into())),
209 ));
210 }
211
212 #[test]
213 fn test_buffer() {
214 round_trip(Scalar::new(
215 DType::Binary(Nullability::Nullable),
216 ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))),
217 ));
218 }
219
220 #[test]
221 fn test_buffer_string() {
222 round_trip(Scalar::new(
223 DType::Utf8(Nullability::Nullable),
224 ScalarValue(InnerScalarValue::BufferString(Arc::new(
225 BufferString::from("hello".to_string()),
226 ))),
227 ));
228 }
229
230 #[test]
231 fn test_list() {
232 round_trip(Scalar::new(
233 DType::List(
234 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
235 Nullability::Nullable,
236 ),
237 ScalarValue(InnerScalarValue::List(
238 vec![
239 ScalarValue(InnerScalarValue::Primitive(42i32.into())),
240 ScalarValue(InnerScalarValue::Primitive(43i32.into())),
241 ]
242 .into(),
243 )),
244 ));
245 }
246
247 #[test]
248 fn test_f16() {
249 round_trip(Scalar::primitive(
250 f16::from_f32(0.42),
251 Nullability::Nullable,
252 ));
253 }
254
255 #[test]
256 fn test_i8() {
257 round_trip(Scalar::new(
258 DType::Primitive(PType::I8, Nullability::Nullable),
259 ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())),
260 ));
261
262 round_trip(Scalar::new(
263 DType::Primitive(PType::I8, Nullability::Nullable),
264 ScalarValue(InnerScalarValue::Primitive(0i8.into())),
265 ));
266
267 round_trip(Scalar::new(
268 DType::Primitive(PType::I8, Nullability::Nullable),
269 ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())),
270 ));
271 }
272
273 #[rstest]
274 #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))]
275 #[case(Scalar::utf8("hello", Nullability::NonNullable))]
276 #[case(Scalar::primitive(1u8, Nullability::NonNullable))]
277 #[case(Scalar::primitive(
278 f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])),
279 Nullability::NonNullable
280 ))]
281 #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable
282 ))]
283 #[case(Scalar::struct_(DType::Struct(
284 StructFields::from_iter([
285 ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))),
286 ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
287 ]),
288 Nullability::NonNullable),
289 vec![
290 Scalar::primitive(23592960u32, Nullability::NonNullable),
291 Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
292 ],
293 ))]
294 #[case(Scalar::struct_(DType::Struct(
295 StructFields::from_iter([
296 ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))),
297 ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))),
298 ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
299 ]),
300 Nullability::NonNullable),
301 vec![
302 Scalar::primitive(415118687234u64, Nullability::NonNullable),
303 Scalar::primitive(2.6584664e36f32, Nullability::NonNullable),
304 Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
305 ],
306 ))]
307 #[case(Scalar::decimal(
308 DecimalValue::I256(i256::from_i128(12345643673471)),
309 DecimalDType::new(10, 2),
310 Nullability::NonNullable
311 ))]
312 #[case(Scalar::decimal(
313 DecimalValue::I16(23412),
314 DecimalDType::new(3, 2),
315 Nullability::NonNullable
316 ))]
317 fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) {
318 let written = scalar.value().to_protobytes::<Vec<u8>>();
319 let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap();
320 assert_eq!(
321 Scalar::new(scalar.dtype().clone(), scalar_read_back),
322 scalar
323 );
324 }
325
326 #[test]
327 fn test_backcompat_f16_serialized_as_u64() {
328 let pb_scalar_value = pb::ScalarValue {
331 kind: Some(Kind::Uint64Value(f16::from_f32(0.42).to_bits() as u64)),
332 };
333 let scalar_value = ScalarValue::try_from(&pb_scalar_value).unwrap();
334 assert_eq!(
335 scalar_value.as_pvalue().unwrap(),
336 Some(PValue::U64(14008u64))
337 );
338
339 let scalar = Scalar::new(
340 DType::Primitive(PType::F16, Nullability::Nullable),
341 scalar_value,
342 );
343
344 assert_eq!(
345 scalar.as_primitive().pvalue().unwrap(),
346 PValue::F16(f16::from_f32(0.42))
347 );
348 }
349
350 #[test]
351 fn test_scalar_value_direct_roundtrip_f16() {
352 let f16_values = vec![
354 f16::from_f32(0.0),
355 f16::from_f32(1.0),
356 f16::from_f32(-1.0),
357 f16::from_f32(0.42),
358 f16::from_f32(5.722046e-6),
359 f16::from_f32(std::f32::consts::PI),
360 f16::INFINITY,
361 f16::NEG_INFINITY,
362 f16::NAN,
363 ];
364
365 for f16_val in f16_values {
366 let scalar_value = ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16_val)));
367 let written = scalar_value.to_protobytes::<Vec<u8>>();
368 let read_back = ScalarValue::from_protobytes(&written).unwrap();
369
370 match (&scalar_value.0, &read_back.0) {
371 (
372 InnerScalarValue::Primitive(PValue::F16(original)),
373 InnerScalarValue::Primitive(PValue::F16(roundtripped)),
374 ) => {
375 if original.is_nan() && roundtripped.is_nan() {
376 continue;
378 }
379 assert_eq!(
380 original, roundtripped,
381 "F16 value {original:?} did not roundtrip correctly"
382 );
383 }
384 _ => {
385 vortex_panic!(
386 "Expected f16 primitive values, got {scalar_value:?} and {read_back:?}"
387 )
388 }
389 }
390 }
391 }
392
393 #[test]
394 fn test_scalar_value_direct_roundtrip_preserves_values() {
395 let exact_roundtrip_cases = vec![
400 ("null", ScalarValue(InnerScalarValue::Null)),
401 ("bool_true", ScalarValue(InnerScalarValue::Bool(true))),
402 ("bool_false", ScalarValue(InnerScalarValue::Bool(false))),
403 (
404 "u64",
405 ScalarValue(InnerScalarValue::Primitive(PValue::U64(
406 18446744073709551615,
407 ))),
408 ),
409 (
410 "i64",
411 ScalarValue(InnerScalarValue::Primitive(PValue::I64(
412 -9223372036854775808,
413 ))),
414 ),
415 (
416 "f32",
417 ScalarValue(InnerScalarValue::Primitive(PValue::F32(
418 std::f32::consts::E,
419 ))),
420 ),
421 (
422 "f64",
423 ScalarValue(InnerScalarValue::Primitive(PValue::F64(
424 std::f64::consts::PI,
425 ))),
426 ),
427 (
428 "string",
429 ScalarValue(InnerScalarValue::BufferString(Arc::new(
430 BufferString::from("test"),
431 ))),
432 ),
433 (
434 "bytes",
435 ScalarValue(InnerScalarValue::Buffer(Arc::new(
436 vec![1, 2, 3, 4, 5].into(),
437 ))),
438 ),
439 ];
440
441 for (name, value) in exact_roundtrip_cases {
442 let written = value.to_protobytes::<Vec<u8>>();
443 let read_back = ScalarValue::from_protobytes(&written).unwrap();
444
445 let original_debug = format!("{value:?}");
446 let roundtrip_debug = format!("{read_back:?}");
447 assert_eq!(
448 original_debug, roundtrip_debug,
449 "ScalarValue {name} did not roundtrip exactly"
450 );
451 }
452
453 let unsigned_cases = vec![
456 (
457 "u8",
458 ScalarValue(InnerScalarValue::Primitive(PValue::U8(255))),
459 255u64,
460 ),
461 (
462 "u16",
463 ScalarValue(InnerScalarValue::Primitive(PValue::U16(65535))),
464 65535u64,
465 ),
466 (
467 "u32",
468 ScalarValue(InnerScalarValue::Primitive(PValue::U32(4294967295))),
469 4294967295u64,
470 ),
471 ];
472
473 for (name, value, expected) in unsigned_cases {
474 let written = value.to_protobytes::<Vec<u8>>();
475 let read_back = ScalarValue::from_protobytes(&written).unwrap();
476
477 match &read_back.0 {
478 InnerScalarValue::Primitive(PValue::U64(v)) => {
479 assert_eq!(
480 *v, expected,
481 "ScalarValue {name} value not preserved: expected {expected}, got {v}"
482 );
483 }
484 _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
485 }
486 }
487
488 let signed_cases = vec![
490 (
491 "i8",
492 ScalarValue(InnerScalarValue::Primitive(PValue::I8(-128))),
493 -128i64,
494 ),
495 (
496 "i16",
497 ScalarValue(InnerScalarValue::Primitive(PValue::I16(-32768))),
498 -32768i64,
499 ),
500 (
501 "i32",
502 ScalarValue(InnerScalarValue::Primitive(PValue::I32(-2147483648))),
503 -2147483648i64,
504 ),
505 ];
506
507 for (name, value, expected) in signed_cases {
508 let written = value.to_protobytes::<Vec<u8>>();
509 let read_back = ScalarValue::from_protobytes(&written).unwrap();
510
511 match &read_back.0 {
512 InnerScalarValue::Primitive(PValue::I64(v)) => {
513 assert_eq!(
514 *v, expected,
515 "ScalarValue {name} value not preserved: expected {expected}, got {v}"
516 );
517 }
518 _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
519 }
520 }
521 }
522}