1use std::cmp::Ordering;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure_eq;
12use vortex_error::vortex_panic;
13
14use crate::dtype::DType;
15use crate::dtype::NativeDType;
16use crate::dtype::PType;
17use crate::dtype::StructFields;
18use crate::scalar::Scalar;
19use crate::scalar::ScalarValue;
20
21impl Scalar {
22 pub fn null(dtype: DType) -> Self {
30 assert!(
31 dtype.is_nullable(),
32 "Cannot create null scalar with non-nullable dtype {dtype}"
33 );
34
35 Self { dtype, value: None }
36 }
37
38 pub fn null_native<T: NativeDType>() -> Self {
43 Self {
44 dtype: T::dtype().as_nullable(),
45 value: None,
46 }
47 }
48
49 #[cfg(test)]
59 pub fn new(dtype: DType, value: Option<ScalarValue>) -> Self {
60 use vortex_error::VortexExpect;
61
62 Self::try_new(dtype, value).vortex_expect("Failed to create Scalar")
63 }
64
65 pub fn try_new(dtype: DType, value: Option<ScalarValue>) -> VortexResult<Self> {
72 Self::validate(&dtype, value.as_ref())?;
73
74 Ok(Self { dtype, value })
75 }
76
77 pub unsafe fn new_unchecked(dtype: DType, value: Option<ScalarValue>) -> Self {
85 #[cfg(debug_assertions)]
86 {
87 use vortex_error::VortexExpect;
88
89 Self::validate(&dtype, value.as_ref())
90 .vortex_expect("Scalar::new_unchecked called with incompatible dtype and value");
91 }
92
93 Self { dtype, value }
94 }
95
96 pub fn default_value(dtype: &DType) -> Self {
107 let value = ScalarValue::default_value(dtype);
108
109 unsafe { Self::new_unchecked(dtype.clone(), value) }
111 }
112
113 pub fn zero_value(dtype: &DType) -> Self {
132 let value = ScalarValue::zero_value(dtype);
133
134 unsafe { Self::new_unchecked(dtype.clone(), Some(value)) }
136 }
137
138 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
142 self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
143 }
144
145 pub fn into_parts(self) -> (DType, Option<ScalarValue>) {
147 (self.dtype, self.value)
148 }
149
150 pub fn dtype(&self) -> &DType {
152 &self.dtype
153 }
154
155 pub fn value(&self) -> Option<&ScalarValue> {
157 self.value.as_ref()
158 }
159
160 pub fn into_value(self) -> Option<ScalarValue> {
163 self.value
164 }
165
166 pub fn is_valid(&self) -> bool {
168 self.value.is_some()
169 }
170
171 pub fn is_null(&self) -> bool {
173 self.value.is_none()
174 }
175
176 pub fn is_zero(&self) -> Option<bool> {
181 let value = self.value()?;
182
183 let is_zero = match self.dtype() {
184 DType::Null => vortex_panic!("non-null value somehow had `DType::Null`"),
185 DType::Bool(_) => !value.as_bool(),
186 DType::Primitive(..) => value.as_primitive().is_zero(),
187 DType::Decimal(..) => value.as_decimal().is_zero(),
188 DType::Utf8(_) => value.as_utf8().is_empty(),
189 DType::Binary(_) => value.as_binary().is_empty(),
190 DType::List(..) => value.as_list().is_empty(),
191 DType::FixedSizeList(_, list_size, _) => {
194 let list = self.as_list();
195 list.len() == *list_size as usize
196 && (0..list.len())
197 .all(|i| list.element(i).is_some_and(|e| e.is_zero() == Some(true)))
198 }
199 DType::Struct(..) => self
201 .as_struct()
202 .fields_iter()
203 .is_some_and(|mut fields| fields.all(|f| f.is_zero() == Some(true))),
204 DType::Union(..) => todo!("TODO(connor)[Union]: unimplemented"),
205 DType::Variant(_) => self.as_variant().is_zero()?,
206 DType::Extension(_) => self.as_extension().to_storage_scalar().is_zero()?,
207 };
208
209 Some(is_zero)
210 }
211
212 pub fn primitive_reinterpret_cast(&self, ptype: PType) -> VortexResult<Self> {
218 let primitive = self.as_primitive();
219 if primitive.ptype() == ptype {
220 return Ok(self.clone());
221 }
222
223 vortex_ensure_eq!(
224 primitive.ptype().byte_width(),
225 ptype.byte_width(),
226 "can't reinterpret cast between integers of two different widths"
227 );
228
229 Scalar::try_new(
230 DType::Primitive(ptype, self.dtype().nullability()),
231 primitive
232 .pvalue()
233 .map(|p| p.reinterpret_cast(ptype))
234 .map(ScalarValue::Primitive),
235 )
236 }
237
238 pub fn approx_nbytes(&self) -> usize {
243 use crate::dtype::NativeDecimalType;
244 use crate::dtype::i256;
245
246 match self.dtype() {
247 DType::Null => 0,
248 DType::Bool(_) => 1,
249 DType::Primitive(ptype, _) => ptype.byte_width(),
250 DType::Decimal(dt, _) => {
251 if dt.precision() <= i128::MAX_PRECISION {
252 size_of::<i128>()
253 } else {
254 size_of::<i256>()
255 }
256 }
257 DType::Utf8(_) => self
258 .value()
259 .map_or_else(|| 0, |value| value.as_utf8().len()),
260 DType::Binary(_) => self
261 .value()
262 .map_or_else(|| 0, |value| value.as_binary().len()),
263 DType::List(..) | DType::FixedSizeList(..) => self
264 .as_list()
265 .elements()
266 .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
267 .unwrap_or_default(),
268 DType::Struct(..) => self
269 .as_struct()
270 .fields_iter()
271 .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
272 .unwrap_or_default(),
273 DType::Union(..) => todo!("TODO(connor)[Union]: unimplemented"),
274 DType::Variant(_) => self.as_variant().value().map_or(0, Scalar::approx_nbytes),
275 DType::Extension(_) => self.as_extension().to_storage_scalar().approx_nbytes(),
276 }
277 }
278}
279
280impl Hash for Scalar {
284 fn hash<H: Hasher>(&self, state: &mut H) {
285 self.dtype.as_nonnullable().hash(state);
286 self.value.hash(state);
287 }
288}
289
290impl PartialEq for Scalar {
297 fn eq(&self, other: &Self) -> bool {
298 self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
299 }
300}
301
302impl PartialOrd for Scalar {
303 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
339 if !self.dtype().eq_ignore_nullability(other.dtype()) {
340 return None;
341 }
342
343 partial_cmp_scalar_values(self.dtype(), self.value(), other.value())
344 }
345}
346
347fn partial_cmp_scalar_values(
349 dtype: &DType,
350 lhs: Option<&ScalarValue>,
351 rhs: Option<&ScalarValue>,
352) -> Option<Ordering> {
353 match (lhs, rhs) {
354 (None, None) => Some(Ordering::Equal),
355 (None, Some(_)) => Some(Ordering::Less),
356 (Some(_), None) => Some(Ordering::Greater),
357 (Some(lhs), Some(rhs)) => partial_cmp_non_null_scalar_values(dtype, lhs, rhs),
358 }
359}
360
361fn partial_cmp_non_null_scalar_values(
363 dtype: &DType,
364 lhs: &ScalarValue,
365 rhs: &ScalarValue,
366) -> Option<Ordering> {
367 match (lhs, rhs) {
370 (ScalarValue::Bool(lhs), ScalarValue::Bool(rhs)) => lhs.partial_cmp(rhs),
371 (ScalarValue::Primitive(lhs), ScalarValue::Primitive(rhs)) => lhs.partial_cmp(rhs),
372 (ScalarValue::Decimal(lhs), ScalarValue::Decimal(rhs)) => lhs.partial_cmp(rhs),
373 (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => lhs.partial_cmp(rhs),
374 (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => lhs.partial_cmp(rhs),
375 (ScalarValue::Tuple(lhs), ScalarValue::Tuple(rhs)) => {
378 partial_cmp_tuple_values(dtype, lhs, rhs)
379 }
380 (ScalarValue::Variant(_), ScalarValue::Variant(_)) => None,
383 _ => None,
384 }
385}
386
387fn partial_cmp_tuple_values(
389 dtype: &DType,
390 lhs: &[Option<ScalarValue>],
391 rhs: &[Option<ScalarValue>],
392) -> Option<Ordering> {
393 match dtype {
394 DType::List(element_dtype, _) | DType::FixedSizeList(element_dtype, ..) => {
395 partial_cmp_list_values(element_dtype, lhs, rhs)
396 }
397 DType::Struct(fields, _) => partial_cmp_struct_values(fields, lhs, rhs),
398 DType::Extension(ext_dtype) => {
399 partial_cmp_tuple_values(ext_dtype.storage_dtype(), lhs, rhs)
400 }
401 _ => None,
402 }
403}
404
405fn partial_cmp_list_values(
407 element_dtype: &DType,
408 lhs: &[Option<ScalarValue>],
409 rhs: &[Option<ScalarValue>],
410) -> Option<Ordering> {
411 for (lhs, rhs) in lhs.iter().zip(rhs.iter()) {
412 match partial_cmp_scalar_values(element_dtype, lhs.as_ref(), rhs.as_ref())? {
413 Ordering::Equal => continue,
414 ordering => return Some(ordering),
415 }
416 }
417
418 Some(lhs.len().cmp(&rhs.len()))
419}
420
421fn partial_cmp_struct_values(
423 fields: &StructFields,
424 lhs: &[Option<ScalarValue>],
425 rhs: &[Option<ScalarValue>],
426) -> Option<Ordering> {
427 if lhs.len() != fields.nfields() || rhs.len() != fields.nfields() {
428 return None;
429 }
430
431 for ((field_dtype, lhs), rhs) in fields.fields().zip(lhs.iter()).zip(rhs.iter()) {
432 match partial_cmp_scalar_values(&field_dtype, lhs.as_ref(), rhs.as_ref())? {
433 Ordering::Equal => continue,
434 ordering => return Some(ordering),
435 }
436 }
437
438 Some(Ordering::Equal)
439}
440
441#[cfg(test)]
442mod tests {
443 use std::sync::Arc;
444
445 use rstest::rstest;
446
447 use crate::dtype::DType;
448 use crate::dtype::Nullability;
449 use crate::dtype::PType;
450 use crate::dtype::StructFields;
451 use crate::scalar::Scalar;
452
453 fn i32_scalar(value: i32) -> Scalar {
454 Scalar::primitive::<i32>(value, Nullability::NonNullable)
455 }
456
457 fn nullable_i32(value: Option<i32>) -> Scalar {
458 match value {
459 Some(value) => Scalar::primitive::<i32>(value, Nullability::Nullable),
460 None => Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
461 }
462 }
463
464 fn ab_struct_dtype(nullability: Nullability) -> DType {
465 DType::Struct(
466 StructFields::new(
467 ["a", "b"].into(),
468 vec![
469 DType::Primitive(PType::I32, Nullability::NonNullable),
470 DType::Utf8(Nullability::NonNullable),
471 ],
472 ),
473 nullability,
474 )
475 }
476
477 #[rstest]
478 #[case(vec![0, 0], Some(true))]
480 #[case(vec![0], Some(true))]
481 #[case(vec![0, 5], Some(false))]
484 #[case(vec![5, 0], Some(false))]
485 #[case(vec![1, 2], Some(false))]
486 fn fixed_size_list_is_zero(#[case] values: Vec<i32>, #[case] expected: Option<bool>) {
487 let element_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
488 let children: Vec<Scalar> = values.into_iter().map(i32_scalar).collect();
489 let scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
490 assert_eq!(scalar.is_zero(), expected);
491 }
492
493 #[test]
494 fn null_fixed_size_list_is_zero_is_none() {
495 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
496 let scalar = Scalar::null(DType::FixedSizeList(
497 element_dtype,
498 2,
499 Nullability::Nullable,
500 ));
501 assert_eq!(scalar.is_zero(), None);
502 }
503
504 #[test]
505 fn fixed_size_list_with_null_element_is_not_zero() {
506 let element_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
509 let children = vec![nullable_i32(Some(0)), nullable_i32(None)];
510 let scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
511 assert_eq!(scalar.is_zero(), Some(false));
512 }
513
514 #[test]
515 fn struct_with_all_zero_fields_is_zero() {
516 let scalar = Scalar::struct_(
517 ab_struct_dtype(Nullability::NonNullable),
518 vec![i32_scalar(0), Scalar::utf8("", Nullability::NonNullable)],
519 );
520 assert_eq!(scalar.is_zero(), Some(true));
521 }
522
523 #[rstest]
524 #[case(5, "")]
527 #[case(0, "x")]
528 #[case(7, "y")]
529 fn struct_with_non_zero_field_is_not_zero(#[case] a: i32, #[case] b: &str) {
530 let scalar = Scalar::struct_(
531 ab_struct_dtype(Nullability::NonNullable),
532 vec![i32_scalar(a), Scalar::utf8(b, Nullability::NonNullable)],
533 );
534 assert_eq!(scalar.is_zero(), Some(false));
535 }
536
537 #[test]
538 fn null_struct_is_zero_is_none() {
539 let scalar = Scalar::null(ab_struct_dtype(Nullability::Nullable));
540 assert_eq!(scalar.is_zero(), None);
541 }
542
543 #[test]
544 fn struct_with_null_field_is_not_zero() {
545 let dtype = DType::Struct(
548 StructFields::new(
549 ["a", "b"].into(),
550 vec![
551 DType::Primitive(PType::I32, Nullability::Nullable),
552 DType::Primitive(PType::I32, Nullability::Nullable),
553 ],
554 ),
555 Nullability::NonNullable,
556 );
557 let scalar = Scalar::struct_(dtype, vec![nullable_i32(Some(0)), nullable_i32(None)]);
558 assert_eq!(scalar.is_zero(), Some(false));
559 }
560
561 #[test]
562 fn nested_struct_of_fixed_size_list_recurses() {
563 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
566 let fsl_dtype =
567 DType::FixedSizeList(Arc::clone(&element_dtype), 2, Nullability::NonNullable);
568 let struct_dtype = DType::Struct(
569 StructFields::new(["fsl"].into(), vec![fsl_dtype]),
570 Nullability::NonNullable,
571 );
572
573 let all_zero = Scalar::struct_(
574 struct_dtype.clone(),
575 vec![Scalar::fixed_size_list(
576 Arc::clone(&element_dtype),
577 vec![i32_scalar(0), i32_scalar(0)],
578 Nullability::NonNullable,
579 )],
580 );
581 assert_eq!(all_zero.is_zero(), Some(true));
582
583 let with_non_zero = Scalar::struct_(
584 struct_dtype,
585 vec![Scalar::fixed_size_list(
586 element_dtype,
587 vec![i32_scalar(0), i32_scalar(9)],
588 Nullability::NonNullable,
589 )],
590 );
591 assert_eq!(with_non_zero.is_zero(), Some(false));
592 }
593}