1use std::sync::Arc;
5
6use num_traits::ToBytes;
7use vortex_buffer::{BufferString, ByteBuffer};
8use vortex_dtype::DType;
9use vortex_error::{VortexError, vortex_err};
10use vortex_proto::scalar as pb;
11use vortex_proto::scalar::ListValue;
12use vortex_proto::scalar::scalar_value::Kind;
13
14use crate::pvalue::PValue;
15use crate::{DecimalValue, InnerScalarValue, Scalar, ScalarValue};
16
17impl From<&Scalar> for pb::Scalar {
18 fn from(value: &Scalar) -> Self {
19 pb::Scalar {
20 dtype: Some((&value.dtype).into()),
21 value: Some((&value.value).into()),
22 }
23 }
24}
25
26impl From<&ScalarValue> for pb::ScalarValue {
27 fn from(value: &ScalarValue) -> Self {
28 match value {
29 ScalarValue(InnerScalarValue::Null) => pb::ScalarValue {
30 kind: Some(Kind::NullValue(0)),
31 },
32 ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue {
33 kind: Some(Kind::BoolValue(*v)),
34 },
35 ScalarValue(InnerScalarValue::Primitive(v)) => v.into(),
36 ScalarValue(InnerScalarValue::Decimal(v)) => {
37 let inner_value = match v {
38 DecimalValue::I8(v) => v.to_le_bytes().to_vec(),
39 DecimalValue::I16(v) => v.to_le_bytes().to_vec(),
40 DecimalValue::I32(v) => v.to_le_bytes().to_vec(),
41 DecimalValue::I64(v) => v.to_le_bytes().to_vec(),
42 DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(),
43 DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(),
44 };
45
46 pb::ScalarValue {
47 kind: Some(Kind::BytesValue(inner_value)),
48 }
49 }
50 ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue {
51 kind: Some(Kind::BytesValue(v.as_slice().to_vec())),
52 },
53 ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue {
54 kind: Some(Kind::StringValue(v.as_str().to_string())),
55 },
56 ScalarValue(InnerScalarValue::List(v)) => {
57 let mut values = Vec::with_capacity(v.len());
58 for elem in v.iter() {
59 values.push(pb::ScalarValue::from(elem));
60 }
61 pb::ScalarValue {
62 kind: Some(Kind::ListValue(ListValue { values })),
63 }
64 }
65 }
66 }
67}
68
69impl From<&PValue> for pb::ScalarValue {
70 fn from(value: &PValue) -> Self {
71 match value {
72 PValue::I8(v) => pb::ScalarValue {
73 kind: Some(Kind::Int64Value(*v as i64)),
74 },
75 PValue::I16(v) => pb::ScalarValue {
76 kind: Some(Kind::Int64Value(*v as i64)),
77 },
78 PValue::I32(v) => pb::ScalarValue {
79 kind: Some(Kind::Int64Value(*v as i64)),
80 },
81 PValue::I64(v) => pb::ScalarValue {
82 kind: Some(Kind::Int64Value(*v)),
83 },
84 PValue::U8(v) => pb::ScalarValue {
85 kind: Some(Kind::Uint64Value(*v as u64)),
86 },
87 PValue::U16(v) => pb::ScalarValue {
88 kind: Some(Kind::Uint64Value(*v as u64)),
89 },
90 PValue::U32(v) => pb::ScalarValue {
91 kind: Some(Kind::Uint64Value(*v as u64)),
92 },
93 PValue::U64(v) => pb::ScalarValue {
94 kind: Some(Kind::Uint64Value(*v)),
95 },
96 PValue::F16(v) => pb::ScalarValue {
97 kind: Some(Kind::Uint64Value(v.to_bits() as u64)),
98 },
99 PValue::F32(v) => pb::ScalarValue {
100 kind: Some(Kind::F32Value(*v)),
101 },
102 PValue::F64(v) => pb::ScalarValue {
103 kind: Some(Kind::F64Value(*v)),
104 },
105 }
106 }
107}
108
109impl TryFrom<&pb::Scalar> for Scalar {
110 type Error = VortexError;
111
112 fn try_from(value: &pb::Scalar) -> Result<Self, Self::Error> {
113 let dtype = DType::try_from(
114 value
115 .dtype
116 .as_ref()
117 .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?,
118 )?;
119
120 let value = ScalarValue::try_from(
121 value
122 .value
123 .as_ref()
124 .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?,
125 )?;
126
127 Ok(Self { dtype, value })
128 }
129}
130
131impl TryFrom<&pb::ScalarValue> for ScalarValue {
132 type Error = VortexError;
133
134 fn try_from(value: &pb::ScalarValue) -> Result<Self, Self::Error> {
135 let kind = value
136 .kind
137 .as_ref()
138 .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;
139
140 match kind {
141 Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)),
142 Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))),
143 Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))),
144 Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))),
145 Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))),
146 Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))),
147 Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new(
148 BufferString::from(v.clone()),
149 )))),
150 Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new(
151 ByteBuffer::from(v.clone()),
152 )))),
153 Kind::ListValue(v) => {
154 let mut values = Vec::with_capacity(v.values.len());
155 for elem in v.values.iter() {
156 values.push(elem.try_into()?);
157 }
158 Ok(ScalarValue(InnerScalarValue::List(values.into())))
159 }
160 }
161 }
162}
163
164#[cfg(test)]
165mod test {
166 use std::sync::Arc;
167
168 use vortex_buffer::BufferString;
169 use vortex_dtype::PType::{self, I32};
170 use vortex_dtype::half::f16;
171 use vortex_dtype::{DType, Nullability};
172 use vortex_proto::scalar as pb;
173
174 use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
175
176 fn round_trip(scalar: Scalar) {
177 assert_eq!(
178 scalar,
179 Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(),
180 );
181 }
182
183 #[test]
184 fn test_null() {
185 round_trip(Scalar::null(DType::Null));
186 }
187
188 #[test]
189 fn test_bool() {
190 round_trip(Scalar::new(
191 DType::Bool(Nullability::Nullable),
192 ScalarValue(InnerScalarValue::Bool(true)),
193 ));
194 }
195
196 #[test]
197 fn test_primitive() {
198 round_trip(Scalar::new(
199 DType::Primitive(I32, Nullability::Nullable),
200 ScalarValue(InnerScalarValue::Primitive(42i32.into())),
201 ));
202 }
203
204 #[test]
205 fn test_buffer() {
206 round_trip(Scalar::new(
207 DType::Binary(Nullability::Nullable),
208 ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))),
209 ));
210 }
211
212 #[test]
213 fn test_buffer_string() {
214 round_trip(Scalar::new(
215 DType::Utf8(Nullability::Nullable),
216 ScalarValue(InnerScalarValue::BufferString(Arc::new(
217 BufferString::from("hello".to_string()),
218 ))),
219 ));
220 }
221
222 #[test]
223 fn test_list() {
224 round_trip(Scalar::new(
225 DType::List(
226 Arc::new(DType::Primitive(I32, Nullability::Nullable)),
227 Nullability::Nullable,
228 ),
229 ScalarValue(InnerScalarValue::List(
230 vec![
231 ScalarValue(InnerScalarValue::Primitive(42i32.into())),
232 ScalarValue(InnerScalarValue::Primitive(43i32.into())),
233 ]
234 .into(),
235 )),
236 ));
237 }
238
239 #[test]
240 fn test_f16() {
241 round_trip(Scalar::new(
242 DType::Primitive(PType::F16, Nullability::Nullable),
243 ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16::from_f32(
244 0.42,
245 )))),
246 ));
247 }
248
249 #[test]
250 fn test_i8() {
251 round_trip(Scalar::new(
252 DType::Primitive(PType::I8, Nullability::Nullable),
253 ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())),
254 ));
255
256 round_trip(Scalar::new(
257 DType::Primitive(PType::I8, Nullability::Nullable),
258 ScalarValue(InnerScalarValue::Primitive(0i8.into())),
259 ));
260
261 round_trip(Scalar::new(
262 DType::Primitive(PType::I8, Nullability::Nullable),
263 ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())),
264 ));
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use half::f16;
271 use rstest::rstest;
272 use vortex_dtype::{DType, DecimalDType, FieldDType, Nullability, PType, StructFields, half};
273
274 use super::*;
275 use crate::{Scalar, i256};
276
277 #[rstest]
278 #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))]
279 #[case(Scalar::utf8("hello", Nullability::NonNullable))]
280 #[case(Scalar::primitive(1u8, Nullability::NonNullable))]
281 #[case(Scalar::primitive(
282 f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])),
283 Nullability::NonNullable
284 ))]
285 #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable
286 ))]
287 #[case(Scalar::struct_(DType::Struct(
288 StructFields::from_iter([
289 ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))),
290 ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
291 ]),
292 Nullability::NonNullable),
293 vec![
294 Scalar::primitive(23592960u32, Nullability::NonNullable),
295 Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
296 ],
297 ))]
298 #[case(Scalar::struct_(DType::Struct(
299 StructFields::from_iter([
300 ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))),
301 ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))),
302 ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))),
303 ]),
304 Nullability::NonNullable),
305 vec![
306 Scalar::primitive(415118687234u64, Nullability::NonNullable),
307 Scalar::primitive(2.6584664e36f32, Nullability::NonNullable),
308 Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable),
309 ],
310 ))]
311 #[case(Scalar::decimal(
312 DecimalValue::I256(i256::from_i128(12345643673471)),
313 DecimalDType::new(10, 2),
314 Nullability::NonNullable
315 ))]
316 #[case(Scalar::decimal(
317 DecimalValue::I16(23412),
318 DecimalDType::new(3, 2),
319 Nullability::NonNullable
320 ))]
321 fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) {
322 let written = scalar.value.to_protobytes::<Vec<u8>>();
323 let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap();
324 assert_eq!(
325 scalar,
326 Scalar::new(scalar.dtype().clone(), scalar_read_back)
327 );
328 }
329}