1use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use vortex_dtype::{DType, Nullability};
10use vortex_error::{
11 VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic,
12};
13
14use crate::{InnerScalarValue, Scalar, ScalarValue};
15
16#[derive(Debug, Clone)]
27pub struct ListScalar<'a> {
28 dtype: &'a DType,
29 element_dtype: &'a Arc<DType>,
30 elements: Option<Arc<[ScalarValue]>>,
31}
32
33impl Display for ListScalar<'_> {
34 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35 match &self.elements {
36 None => write!(f, "null"),
37 Some(elems) => {
38 let fixed_size_list_str: &dyn Display =
39 if let DType::FixedSizeList(_, size, _) = self.dtype {
40 &format!("fixed_size<{size}>")
41 } else {
42 &""
43 };
44
45 write!(
46 f,
47 "{fixed_size_list_str}[{}]",
48 elems
49 .iter()
50 .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
51 .format(", ")
52 )
53 }
54 }
55 }
56}
57
58impl PartialEq for ListScalar<'_> {
59 fn eq(&self, other: &Self) -> bool {
60 self.dtype.eq_ignore_nullability(other.dtype) && self.elements() == other.elements()
61 }
62}
63
64impl Eq for ListScalar<'_> {}
65
66impl PartialOrd for ListScalar<'_> {
68 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
69 if !self
70 .element_dtype
71 .eq_ignore_nullability(other.element_dtype)
72 {
73 return None;
74 }
75 self.elements().partial_cmp(&other.elements())
76 }
77}
78
79impl Hash for ListScalar<'_> {
80 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
81 self.dtype.hash(state);
82 self.elements().hash(state);
83 }
84}
85
86impl<'a> ListScalar<'a> {
87 #[inline]
89 pub fn dtype(&self) -> &'a DType {
90 self.dtype
91 }
92
93 #[inline]
97 pub fn len(&self) -> usize {
98 self.elements.as_ref().map(|e| e.len()).unwrap_or(0)
99 }
100
101 #[inline]
103 pub fn is_empty(&self) -> bool {
104 match self.elements.as_ref() {
105 None => true,
106 Some(l) => l.is_empty(),
107 }
108 }
109
110 #[inline]
112 pub fn is_null(&self) -> bool {
113 self.elements.is_none()
114 }
115
116 pub fn element_dtype(&self) -> &DType {
118 self.dtype
119 .as_any_size_list_element_opt()
120 .unwrap_or_else(|| vortex_panic!("`ListScalar` somehow had dtype {}", self.dtype))
121 .as_ref()
122 }
123
124 pub fn element(&self, idx: usize) -> Option<Scalar> {
128 self.elements
129 .as_ref()
130 .and_then(|l| l.get(idx))
131 .map(|value| Scalar::new(self.element_dtype().clone(), value.clone()))
132 }
133
134 pub fn elements(&self) -> Option<Vec<Scalar>> {
138 self.elements.as_ref().map(|elems| {
139 elems
140 .iter()
141 .map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
142 .collect_vec()
143 })
144 }
145
146 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
156 let target_element_dtype = dtype
157 .as_any_size_list_element_opt()
158 .ok_or_else(|| {
159 vortex_err!(
160 "Cannot cast {} to {}: list can only be cast to a list or fixed-size list",
161 self.dtype(),
162 dtype
163 )
164 })?
165 .as_ref();
166
167 if let DType::FixedSizeList(_, size, _) = dtype
168 && *size as usize != self.len()
169 {
170 vortex_bail!(
171 "tried to cast to a `FixedSizeList[{size}]` but had {} elements",
172 self.len()
173 )
174 }
175
176 Ok(Scalar::new(
177 dtype.clone(),
178 ScalarValue(InnerScalarValue::List(
179 self.elements
180 .as_ref()
181 .vortex_expect("nullness handled in Scalar::cast")
182 .iter()
183 .map(|element| {
184 Scalar::new(DType::clone(self.element_dtype), element.clone())
186 .cast(target_element_dtype)
187 .map(|x| x.into_value())
188 })
189 .collect::<VortexResult<Arc<[ScalarValue]>>>()?,
190 )),
191 ))
192 }
193}
194
195enum ListKind {
197 Variable,
198 FixedSize,
199}
200
201impl Scalar {
203 fn create_list(
204 element_dtype: impl Into<Arc<DType>>,
205 children: Vec<Scalar>,
206 nullability: Nullability,
207 list_kind: ListKind,
208 ) -> Self {
209 let element_dtype = element_dtype.into();
210
211 let children: Arc<[ScalarValue]> = children
212 .into_iter()
213 .map(|child| {
214 if child.dtype() != &*element_dtype {
215 vortex_panic!(
216 "tried to create list of {} with values of type {}",
217 element_dtype,
218 child.dtype()
219 );
220 }
221 child.into_value()
222 })
223 .collect();
224 let size: u32 = children
225 .len()
226 .try_into()
227 .vortex_expect("tried to create a list that was too large");
228
229 let dtype = match list_kind {
230 ListKind::Variable => DType::List(element_dtype, nullability),
231 ListKind::FixedSize => DType::FixedSizeList(element_dtype, size, nullability),
232 };
233
234 Self::new(dtype, ScalarValue(InnerScalarValue::List(children)))
235 }
236
237 pub fn list(
244 element_dtype: impl Into<Arc<DType>>,
245 children: Vec<Scalar>,
246 nullability: Nullability,
247 ) -> Self {
248 Self::create_list(element_dtype, children, nullability, ListKind::Variable)
249 }
250
251 pub fn list_empty(element_dtype: Arc<DType>, nullability: Nullability) -> Self {
253 Self::create_list(element_dtype, vec![], nullability, ListKind::Variable)
254 }
255
256 pub fn fixed_size_list(
263 element_dtype: impl Into<Arc<DType>>,
264 children: Vec<Scalar>,
265 nullability: Nullability,
266 ) -> Self {
267 Self::create_list(element_dtype, children, nullability, ListKind::FixedSize)
268 }
269}
270
271impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> {
272 type Error = VortexError;
273
274 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
275 let element_dtype = value
276 .dtype()
277 .as_any_size_list_element_opt()
278 .ok_or_else(|| vortex_err!("Expected list scalar, found {}", value.dtype()))?;
279
280 Ok(Self {
281 dtype: value.dtype(),
282 element_dtype,
283 elements: value.value().as_list()?.cloned(),
284 })
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use std::sync::Arc;
291
292 use vortex_dtype::{DType, Nullability, PType};
293
294 use super::*;
295
296 #[test]
297 fn test_list_scalar_creation() {
298 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
299 let children = vec![
300 Scalar::primitive(1i32, Nullability::NonNullable),
301 Scalar::primitive(2i32, Nullability::NonNullable),
302 Scalar::primitive(3i32, Nullability::NonNullable),
303 ];
304 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
305
306 let list = ListScalar::try_from(&list_scalar).unwrap();
307 assert_eq!(list.len(), 3);
308 assert!(!list.is_empty());
309 assert!(!list.is_null());
310 }
311
312 #[test]
313 fn test_empty_list() {
314 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
315 let list_scalar = Scalar::list_empty(element_dtype, Nullability::NonNullable);
316
317 let list = ListScalar::try_from(&list_scalar).unwrap();
318 assert_eq!(list.len(), 0);
319 assert!(list.is_empty());
320 assert!(!list.is_null());
321 }
322
323 #[test]
324 fn test_null_list() {
325 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
326 let null = Scalar::null(DType::List(element_dtype, Nullability::Nullable));
327
328 let list = ListScalar::try_from(&null).unwrap();
329 assert!(list.is_empty());
330 assert!(list.is_null());
331 }
332
333 #[test]
334 fn test_list_element_access() {
335 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
336 let children = vec![
337 Scalar::primitive(10i32, Nullability::NonNullable),
338 Scalar::primitive(20i32, Nullability::NonNullable),
339 Scalar::primitive(30i32, Nullability::NonNullable),
340 ];
341 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
342
343 let list = ListScalar::try_from(&list_scalar).unwrap();
344
345 let elem0 = list.element(0).unwrap();
347 assert_eq!(elem0.as_primitive().typed_value::<i32>().unwrap(), 10);
348
349 let elem1 = list.element(1).unwrap();
350 assert_eq!(elem1.as_primitive().typed_value::<i32>().unwrap(), 20);
351
352 let elem2 = list.element(2).unwrap();
353 assert_eq!(elem2.as_primitive().typed_value::<i32>().unwrap(), 30);
354
355 assert!(list.element(3).is_none());
357 }
358
359 #[test]
360 fn test_list_elements() {
361 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
362 let children = vec![
363 Scalar::primitive(100i32, Nullability::NonNullable),
364 Scalar::primitive(200i32, Nullability::NonNullable),
365 ];
366 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
367
368 let list = ListScalar::try_from(&list_scalar).unwrap();
369 let elements = list.elements().unwrap();
370
371 assert_eq!(elements.len(), 2);
372 assert_eq!(
373 elements[0].as_primitive().typed_value::<i32>().unwrap(),
374 100
375 );
376 assert_eq!(
377 elements[1].as_primitive().typed_value::<i32>().unwrap(),
378 200
379 );
380 }
381
382 #[test]
383 fn test_list_display() {
384 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
385 let children = vec![
386 Scalar::primitive(1i32, Nullability::NonNullable),
387 Scalar::primitive(2i32, Nullability::NonNullable),
388 ];
389 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
390
391 let list = ListScalar::try_from(&list_scalar).unwrap();
392 let display = format!("{list}");
393 assert!(display.contains("1"));
394 assert!(display.contains("2"));
395 }
396
397 #[test]
398 fn test_list_equality() {
399 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
400 let children1 = vec![
401 Scalar::primitive(1i32, Nullability::NonNullable),
402 Scalar::primitive(2i32, Nullability::NonNullable),
403 ];
404 let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
405
406 let children2 = vec![
407 Scalar::primitive(1i32, Nullability::NonNullable),
408 Scalar::primitive(2i32, Nullability::NonNullable),
409 ];
410 let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
411
412 let list1 = ListScalar::try_from(&list_scalar1).unwrap();
413 let list2 = ListScalar::try_from(&list_scalar2).unwrap();
414
415 assert_eq!(list1, list2);
416 }
417
418 #[test]
419 fn test_list_inequality() {
420 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
421 let children1 = vec![
422 Scalar::primitive(1i32, Nullability::NonNullable),
423 Scalar::primitive(2i32, Nullability::NonNullable),
424 ];
425 let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
426
427 let children2 = vec![
428 Scalar::primitive(1i32, Nullability::NonNullable),
429 Scalar::primitive(3i32, Nullability::NonNullable),
430 ];
431 let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
432
433 let list1 = ListScalar::try_from(&list_scalar1).unwrap();
434 let list2 = ListScalar::try_from(&list_scalar2).unwrap();
435
436 assert_ne!(list1, list2);
437 }
438
439 #[test]
440 fn test_list_partial_ord() {
441 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
442
443 let children1 = vec![Scalar::primitive(1i32, Nullability::NonNullable)];
444 let list_scalar1 = Scalar::list(element_dtype.clone(), children1, Nullability::NonNullable);
445
446 let children2 = vec![Scalar::primitive(2i32, Nullability::NonNullable)];
447 let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable);
448
449 let list1 = ListScalar::try_from(&list_scalar1).unwrap();
450 let list2 = ListScalar::try_from(&list_scalar2).unwrap();
451
452 assert!(list1 < list2);
453 }
454
455 #[test]
456 fn test_list_partial_ord_different_types() {
457 let element_dtype1 = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
458 let element_dtype2 = Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable));
459
460 let children1 = vec![Scalar::primitive(1i32, Nullability::NonNullable)];
461 let list_scalar1 = Scalar::list(element_dtype1, children1, Nullability::NonNullable);
462
463 let children2 = vec![Scalar::primitive(1i64, Nullability::NonNullable)];
464 let list_scalar2 = Scalar::list(element_dtype2, children2, Nullability::NonNullable);
465
466 let list1 = ListScalar::try_from(&list_scalar1).unwrap();
467 let list2 = ListScalar::try_from(&list_scalar2).unwrap();
468
469 assert!(list1.partial_cmp(&list2).is_none());
470 }
471
472 #[test]
473 fn test_list_hash() {
474 use std::collections::hash_map::DefaultHasher;
475 use std::hash::{Hash, Hasher};
476
477 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
478 let children = vec![
479 Scalar::primitive(1i32, Nullability::NonNullable),
480 Scalar::primitive(2i32, Nullability::NonNullable),
481 ];
482 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
483
484 let list = ListScalar::try_from(&list_scalar).unwrap();
485
486 let mut hasher1 = DefaultHasher::new();
487 list.hash(&mut hasher1);
488 let hash1 = hasher1.finish();
489
490 let mut hasher2 = DefaultHasher::new();
491 list.hash(&mut hasher2);
492 let hash2 = hasher2.finish();
493
494 assert_eq!(hash1, hash2);
495 }
496
497 #[test]
498 fn test_vec_conversion() {
499 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
500 let children = vec![
501 Scalar::primitive(10i32, Nullability::NonNullable),
502 Scalar::primitive(20i32, Nullability::NonNullable),
503 Scalar::primitive(30i32, Nullability::NonNullable),
504 ];
505 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
506
507 let vec: Vec<i32> = Vec::try_from(&list_scalar).unwrap();
508 assert_eq!(vec, vec![10, 20, 30]);
509 }
510
511 #[test]
512 fn test_vec_conversion_empty_list() {
513 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable));
514 let list_scalar = Scalar::list_empty(element_dtype, Nullability::Nullable);
515
516 let result: VortexResult<Vec<i32>> = Vec::try_from(&list_scalar);
517 assert!(result.unwrap().is_empty());
518 }
519
520 #[test]
521 fn test_list_cast() {
522 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
523 let children = vec![
524 Scalar::primitive(1i32, Nullability::NonNullable),
525 Scalar::primitive(2i32, Nullability::NonNullable),
526 ];
527 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
528
529 let list = ListScalar::try_from(&list_scalar).unwrap();
530
531 let target_dtype = DType::List(
533 Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
534 Nullability::NonNullable,
535 );
536
537 let casted = list.cast(&target_dtype).unwrap();
538 let casted_list = ListScalar::try_from(&casted).unwrap();
539
540 assert_eq!(casted_list.len(), 2);
541 let elem0 = casted_list.element(0).unwrap();
542 assert_eq!(elem0.as_primitive().typed_value::<i64>().unwrap(), 1);
543 }
544
545 #[test]
546 #[should_panic(expected = "tried to create list of i32 with values of type i64")]
547 fn test_list_wrong_element_type_panic() {
548 let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
549 let children = vec![
550 Scalar::primitive(1i64, Nullability::NonNullable), ];
552 Scalar::list(element_dtype, children, Nullability::NonNullable);
553 }
554
555 #[test]
556 fn test_try_from_wrong_dtype() {
557 let scalar = Scalar::primitive(42i32, Nullability::NonNullable);
558 let result = ListScalar::try_from(&scalar);
559 assert!(result.is_err());
560 }
561
562 #[test]
563 fn test_string_list() {
564 let element_dtype = Arc::new(DType::Utf8(Nullability::NonNullable));
565 let children = vec![
566 Scalar::utf8("hello".to_string(), Nullability::NonNullable),
567 Scalar::utf8("world".to_string(), Nullability::NonNullable),
568 ];
569 let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable);
570
571 let list = ListScalar::try_from(&list_scalar).unwrap();
572 assert_eq!(list.len(), 2);
573
574 let elem0 = list.element(0).unwrap();
575 assert_eq!(elem0.as_utf8().value().unwrap().as_str(), "hello");
576
577 let elem1 = list.element(1).unwrap();
578 assert_eq!(elem1.as_utf8().value().unwrap().as_str(), "world");
579 }
580
581 #[test]
582 fn test_nested_lists() {
583 let inner_element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
584 let inner_list_dtype = Arc::new(DType::List(
585 inner_element_dtype.clone(),
586 Nullability::NonNullable,
587 ));
588
589 let inner_list1 = Scalar::list(
590 inner_element_dtype.clone(),
591 vec![
592 Scalar::primitive(1i32, Nullability::NonNullable),
593 Scalar::primitive(2i32, Nullability::NonNullable),
594 ],
595 Nullability::NonNullable,
596 );
597
598 let inner_list2 = Scalar::list(
599 inner_element_dtype,
600 vec![
601 Scalar::primitive(3i32, Nullability::NonNullable),
602 Scalar::primitive(4i32, Nullability::NonNullable),
603 ],
604 Nullability::NonNullable,
605 );
606
607 let outer_list = Scalar::list(
608 inner_list_dtype,
609 vec![inner_list1, inner_list2],
610 Nullability::NonNullable,
611 );
612
613 let list = ListScalar::try_from(&outer_list).unwrap();
614 assert_eq!(list.len(), 2);
615
616 let nested_list = list.element(0).unwrap();
617 let nested = ListScalar::try_from(&nested_list).unwrap();
618 assert_eq!(nested.len(), 2);
619 }
620}