1use std::sync::Arc;
5
6use num_traits::ToBytes;
7use vortex_buffer::BufferString;
8use vortex_buffer::ByteBuffer;
9use vortex_dtype::DType;
10use vortex_dtype::half::f16;
11use vortex_error::VortexError;
12use vortex_error::vortex_err;
13use vortex_proto::scalar as pb;
14use vortex_proto::scalar::ListValue;
15use vortex_proto::scalar::scalar_value::Kind;
16
17use crate::DecimalValue;
18use crate::InnerScalarValue;
19use crate::Scalar;
20use crate::ScalarValue;
21use crate::pvalue::PValue;
22
23impl From<&Scalar> for pb::Scalar {
24 fn from(value: &Scalar) -> Self {
25 pb::Scalar {
26 dtype: Some((value.dtype()).into()),
27 value: Some((value.value()).into()),
28 }
29 }
30}
31
32impl From<&ScalarValue> for pb::ScalarValue {
33 fn from(value: &ScalarValue) -> Self {
34 match value {
35 ScalarValue(InnerScalarValue::Null) => pb::ScalarValue {
36 kind: Some(Kind::NullValue(0)),
37 },
38 ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue {
39 kind: Some(Kind::BoolValue(*v)),
40 },
41 ScalarValue(InnerScalarValue::Primitive(v)) => v.into(),
42 ScalarValue(InnerScalarValue::Decimal(v)) => {
43 let inner_value = match v {
44 DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
45 DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
46 DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
47 DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
48 DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
49 DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
50 };
51
52 pb::ScalarValue {
53 kind: Some(Kind::BytesValue(inner_value)),
54 }
55 }
56 ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue {
57 kind: Some(Kind::BytesValue(v.as_slice().to_vec())),
58 },
59 ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue {
60 kind: Some(Kind::StringValue(v.as_str().to_string())),
61 },
62 ScalarValue(InnerScalarValue::List(v)) => {
63 let mut values = Vec::with_capacity(v.len());
64 for elem in v.iter() {
65 values.push(pb::ScalarValue::from(elem));
66 }
67 pb::ScalarValue {
68 kind: Some(Kind::ListValue(ListValue { values })),
69 }
70 }
71 }
72 }
73}
74
75impl From<&PValue> for pb::ScalarValue {
76 fn from(value: &PValue) -> Self {
77 match value {
78 PValue::I8(v) => pb::ScalarValue {
79 kind: Some(Kind::Int64Value(*v as i64)),
80 },
81 PValue::I16(v) => pb::ScalarValue {
82 kind: Some(Kind::Int64Value(*v as i64)),
83 },
84 PValue::I32(v) => pb::ScalarValue {
85 kind: Some(Kind::Int64Value(*v as i64)),
86 },
87 PValue::I64(v) => pb::ScalarValue {
88 kind: Some(Kind::Int64Value(*v)),
89 },
90 PValue::U8(v) => pb::ScalarValue {
91 kind: Some(Kind::Uint64Value(*v as u64)),
92 },
93 PValue::U16(v) => pb::ScalarValue {
94 kind: Some(Kind::Uint64Value(*v as u64)),
95 },
96 PValue::U32(v) => pb::ScalarValue {
97 kind: Some(Kind::Uint64Value(*v as u64)),
98 },
99 PValue::U64(v) => pb::ScalarValue {
100 kind: Some(Kind::Uint64Value(*v)),
101 },
102 PValue::F16(v) => pb::ScalarValue {
103 kind: Some(Kind::F16Value(v.to_bits() as u64)),
104 },
105 PValue::F32(v) => pb::ScalarValue {
106 kind: Some(Kind::F32Value(*v)),
107 },
108 PValue::F64(v) => pb::ScalarValue {
109 kind: Some(Kind::F64Value(*v)),
110 },
111 }
112 }
113}
114
115impl TryFrom<&pb::Scalar> for Scalar {
116 type Error = VortexError;
117
118 fn try_from(value: &pb::Scalar) -> Result<Self, Self::Error> {
119 let dtype = DType::try_from(
120 value
121 .dtype
122 .as_ref()
123 .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?,
124 )?;
125
126 let value = ScalarValue::try_from(
127 value
128 .value
129 .as_ref()
130 .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?,
131 )?;
132
133 Ok(Scalar::new(dtype, value))
134 }
135}
136
137impl TryFrom<&pb::ScalarValue> for ScalarValue {
138 type Error = VortexError;
139
140 fn try_from(value: &pb::ScalarValue) -> Result<Self, Self::Error> {
141 let kind = value
142 .kind
143 .as_ref()
144 .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;
145
146 match kind {
147 Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)),
148 Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))),
149 Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))),
150 Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))),
151 Kind::F16Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F16(
152 f16::from_bits(u16::try_from(*v).map_err(|_| {
153 vortex_err!("f16 bitwise representation has more than 16 bits: {}", v)
154 })?),
155 )))),
156 Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))),
157 Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))),
158 Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new(
159 BufferString::from(v.clone()),
160 )))),
161 Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new(
162 ByteBuffer::from(v.clone()),
163 )))),
164 Kind::ListValue(v) => {
165 let mut values = Vec::with_capacity(v.values.len());
166 for elem in v.values.iter() {
167 values.push(elem.try_into()?);
168 }
169 Ok(ScalarValue(InnerScalarValue::List(values.into())))
170 }
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use std::sync::Arc;
178
179 use rstest::rstest;
180 use vortex_buffer::BufferString;
181 use vortex_dtype::DType;
182 use vortex_dtype::DecimalDType;
183 use vortex_dtype::FieldDType;
184 use vortex_dtype::Nullability;
185 use vortex_dtype::PType;
186 use vortex_dtype::StructFields;
187 use vortex_dtype::half::f16;
188 use vortex_dtype::i256;
189 use vortex_error::vortex_panic;
190 use vortex_proto::scalar as pb;
191
192 use super::*;
193 use crate::InnerScalarValue;
194 use crate::Scalar;
195 use crate::ScalarValue;
196
197 fn round_trip(scalar: Scalar) {
198 assert_eq!(
199 scalar,
200 Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(),
201 );
202 }
203
204 #[test]
205 fn test_null() {
206 round_trip(Scalar::null(DType::Null));
207 }
208
209 #[test]
210 fn test_bool() {
211 round_trip(Scalar::new(
212 DType::Bool(Nullability::Nullable),
213 ScalarValue(InnerScalarValue::Bool(true)),
214 ));
215 }
216
217 #[test]
218 fn test_primitive() {
219 round_trip(Scalar::new(
220 DType::Primitive(PType::I32, Nullability::Nullable),
221 ScalarValue(InnerScalarValue::Primitive(42i32.into())),
222 ));
223 }
224
225 #[test]
226 fn test_buffer() {
227 round_trip(Scalar::new(
228 DType::Binary(Nullability::Nullable),
229 ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))),
230 ));
231 }
232
233 #[test]
234 fn test_buffer_string() {
235 round_trip(Scalar::new(
236 DType::Utf8(Nullability::Nullable),
237 ScalarValue(InnerScalarValue::BufferString(Arc::new(
238 BufferString::from("hello".to_string()),
239 ))),
240 ));
241 }
242
243 #[test]
244 fn test_list() {
245 round_trip(Scalar::new(
246 DType::List(
247 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
248 Nullability::Nullable,
249 ),
250 ScalarValue(InnerScalarValue::List(
251 vec![
252 ScalarValue(InnerScalarValue::Primitive(42i32.into())),
253 ScalarValue(InnerScalarValue::Primitive(43i32.into())),
254 ]
255 .into(),
256 )),
257 ));
258 }
259
260 #[test]
261 fn test_f16() {
262 round_trip(Scalar::primitive(
263 f16::from_f32(0.42),
264 Nullability::Nullable,
265 ));
266 }
267
268 #[test]
269 fn test_i8() {
270 round_trip(Scalar::new(
271 DType::Primitive(PType::I8, Nullability::Nullable),
272 ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())),
273 ));
274
275 round_trip(Scalar::new(
276 DType::Primitive(PType::I8, Nullability::Nullable),
277 ScalarValue(InnerScalarValue::Primitive(0i8.into())),
278 ));
279
280 round_trip(Scalar::new(
281 DType::Primitive(PType::I8, Nullability::Nullable),
282 ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())),
283 ));
284 }
285
286 #[rstest]
287 #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))]
288 #[case(Scalar::utf8("hello", Nullability::NonNullable))]
289 #[case(Scalar::primitive(1u8, Nullability::NonNullable))]
290 #[case(Scalar::primitive(
291 f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])),
292 Nullability::NonNullable
293 ))]
294 #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable
295 ))]
296 #[case(Scalar::struct_(DType::Struct(
297 StructFields::from_iter([
298 ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))),
299 ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
300 ]),
301 Nullability::NonNullable),
302 vec![
303 Scalar::primitive(23592960u32, Nullability::NonNullable),
304 Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
305 ],
306 ))]
307 #[case(Scalar::struct_(DType::Struct(
308 StructFields::from_iter([
309 ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))),
310 ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))),
311 ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
312 ]),
313 Nullability::NonNullable),
314 vec![
315 Scalar::primitive(415118687234u64, Nullability::NonNullable),
316 Scalar::primitive(2.6584664e36f32, Nullability::NonNullable),
317 Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
318 ],
319 ))]
320 #[case(Scalar::decimal(
321 DecimalValue::I256(i256::from_i128(12345643673471)),
322 DecimalDType::new(10, 2),
323 Nullability::NonNullable
324 ))]
325 #[case(Scalar::decimal(
326 DecimalValue::I16(23412),
327 DecimalDType::new(3, 2),
328 Nullability::NonNullable
329 ))]
330 fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) {
331 let written = scalar.value().to_protobytes::<Vec<u8>>();
332 let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap();
333 assert_eq!(
334 Scalar::new(scalar.dtype().clone(), scalar_read_back),
335 scalar
336 );
337 }
338
339 #[test]
340 fn test_backcompat_f16_serialized_as_u64() {
341 let pb_scalar_value = pb::ScalarValue {
344 kind: Some(Kind::Uint64Value(f16::from_f32(0.42).to_bits() as u64)),
345 };
346 let scalar_value = ScalarValue::try_from(&pb_scalar_value).unwrap();
347 assert_eq!(
348 scalar_value.as_pvalue().unwrap(),
349 Some(PValue::U64(14008u64))
350 );
351
352 let scalar = Scalar::new(
353 DType::Primitive(PType::F16, Nullability::Nullable),
354 scalar_value,
355 );
356
357 assert_eq!(
358 scalar.as_primitive().pvalue().unwrap(),
359 PValue::F16(f16::from_f32(0.42))
360 );
361 }
362
363 #[test]
364 fn test_scalar_value_direct_roundtrip_f16() {
365 let f16_values = vec![
367 f16::from_f32(0.0),
368 f16::from_f32(1.0),
369 f16::from_f32(-1.0),
370 f16::from_f32(0.42),
371 f16::from_f32(5.722046e-6),
372 f16::from_f32(std::f32::consts::PI),
373 f16::INFINITY,
374 f16::NEG_INFINITY,
375 f16::NAN,
376 ];
377
378 for f16_val in f16_values {
379 let scalar_value = ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16_val)));
380 let written = scalar_value.to_protobytes::<Vec<u8>>();
381 let read_back = ScalarValue::from_protobytes(&written).unwrap();
382
383 match (&scalar_value.0, &read_back.0) {
384 (
385 InnerScalarValue::Primitive(PValue::F16(original)),
386 InnerScalarValue::Primitive(PValue::F16(roundtripped)),
387 ) => {
388 if original.is_nan() && roundtripped.is_nan() {
389 continue;
391 }
392 assert_eq!(
393 original, roundtripped,
394 "F16 value {original:?} did not roundtrip correctly"
395 );
396 }
397 _ => {
398 vortex_panic!(
399 "Expected f16 primitive values, got {scalar_value:?} and {read_back:?}"
400 )
401 }
402 }
403 }
404 }
405
406 #[test]
407 fn test_scalar_value_direct_roundtrip_preserves_values() {
408 let exact_roundtrip_cases = vec![
413 ("null", ScalarValue(InnerScalarValue::Null)),
414 ("bool_true", ScalarValue(InnerScalarValue::Bool(true))),
415 ("bool_false", ScalarValue(InnerScalarValue::Bool(false))),
416 (
417 "u64",
418 ScalarValue(InnerScalarValue::Primitive(PValue::U64(
419 18446744073709551615,
420 ))),
421 ),
422 (
423 "i64",
424 ScalarValue(InnerScalarValue::Primitive(PValue::I64(
425 -9223372036854775808,
426 ))),
427 ),
428 (
429 "f32",
430 ScalarValue(InnerScalarValue::Primitive(PValue::F32(
431 std::f32::consts::E,
432 ))),
433 ),
434 (
435 "f64",
436 ScalarValue(InnerScalarValue::Primitive(PValue::F64(
437 std::f64::consts::PI,
438 ))),
439 ),
440 (
441 "string",
442 ScalarValue(InnerScalarValue::BufferString(Arc::new(
443 BufferString::from("test"),
444 ))),
445 ),
446 (
447 "bytes",
448 ScalarValue(InnerScalarValue::Buffer(Arc::new(
449 vec![1, 2, 3, 4, 5].into(),
450 ))),
451 ),
452 ];
453
454 for (name, value) in exact_roundtrip_cases {
455 let written = value.to_protobytes::<Vec<u8>>();
456 let read_back = ScalarValue::from_protobytes(&written).unwrap();
457
458 let original_debug = format!("{value:?}");
459 let roundtrip_debug = format!("{read_back:?}");
460 assert_eq!(
461 original_debug, roundtrip_debug,
462 "ScalarValue {name} did not roundtrip exactly"
463 );
464 }
465
466 let unsigned_cases = vec![
469 (
470 "u8",
471 ScalarValue(InnerScalarValue::Primitive(PValue::U8(255))),
472 255u64,
473 ),
474 (
475 "u16",
476 ScalarValue(InnerScalarValue::Primitive(PValue::U16(65535))),
477 65535u64,
478 ),
479 (
480 "u32",
481 ScalarValue(InnerScalarValue::Primitive(PValue::U32(4294967295))),
482 4294967295u64,
483 ),
484 ];
485
486 for (name, value, expected) in unsigned_cases {
487 let written = value.to_protobytes::<Vec<u8>>();
488 let read_back = ScalarValue::from_protobytes(&written).unwrap();
489
490 match &read_back.0 {
491 InnerScalarValue::Primitive(PValue::U64(v)) => {
492 assert_eq!(
493 *v, expected,
494 "ScalarValue {name} value not preserved: expected {expected}, got {v}"
495 );
496 }
497 _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
498 }
499 }
500
501 let signed_cases = vec![
503 (
504 "i8",
505 ScalarValue(InnerScalarValue::Primitive(PValue::I8(-128))),
506 -128i64,
507 ),
508 (
509 "i16",
510 ScalarValue(InnerScalarValue::Primitive(PValue::I16(-32768))),
511 -32768i64,
512 ),
513 (
514 "i32",
515 ScalarValue(InnerScalarValue::Primitive(PValue::I32(-2147483648))),
516 -2147483648i64,
517 ),
518 ];
519
520 for (name, value, expected) in signed_cases {
521 let written = value.to_protobytes::<Vec<u8>>();
522 let read_back = ScalarValue::from_protobytes(&written).unwrap();
523
524 match &read_back.0 {
525 InnerScalarValue::Primitive(PValue::I64(v)) => {
526 assert_eq!(
527 *v, expected,
528 "ScalarValue {name} value not preserved: expected {expected}, got {v}"
529 );
530 }
531 _ => vortex_panic!("Unexpected type after roundtrip for {name}: {read_back:?}"),
532 }
533 }
534 }
535}