1use std::sync::Arc;
7
8use vortex_array::ArrayRef;
9use vortex_array::IntoArray;
10use vortex_array::VortexSessionExecute;
11use vortex_array::aggregate_fn::fns::all_nan::AllNan;
12use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan;
13use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull;
14use vortex_array::aggregate_fn::fns::all_null::AllNull;
15use vortex_array::aggregate_fn::fns::nan_count::NanCount;
16use vortex_array::arrays::ConstantArray;
17use vortex_array::arrays::PrimitiveArray;
18use vortex_array::arrays::StructArray;
19use vortex_array::arrays::struct_::StructArrayExt;
20use vortex_array::dtype::DType;
21use vortex_array::dtype::Nullability;
22use vortex_array::expr::Expression;
23use vortex_array::expr::eq;
24use vortex_array::expr::get_item;
25use vortex_array::expr::is_root;
26use vortex_array::expr::lit;
27use vortex_array::expr::root;
28use vortex_array::expr::stats::Stat;
29use vortex_array::expr::traversal::NodeExt;
30use vortex_array::expr::traversal::Transformed;
31use vortex_array::scalar::Scalar;
32use vortex_array::scalar_fn::EmptyOptions;
33use vortex_array::scalar_fn::ScalarFnVTableExt;
34use vortex_array::scalar_fn::fns::stat::StatFn;
35use vortex_array::scalar_fn::internal::row_count::RowCount;
36use vortex_array::scalar_fn::internal::row_count::contains_row_count;
37use vortex_array::scalar_fn::internal::row_count::substitute_row_count;
38use vortex_array::validity::Validity;
39use vortex_buffer::buffer;
40use vortex_error::VortexResult;
41use vortex_error::vortex_bail;
42use vortex_mask::Mask;
43use vortex_runend::RunEnd;
44use vortex_session::VortexSession;
45
46use crate::layouts::zoned::schema::stats_table_dtype;
47
48#[derive(Clone)]
53pub struct ZoneMap {
54 column_dtype: DType,
56 array: StructArray,
58 zone_len: u64,
60 row_count: u64,
62}
63
64impl ZoneMap {
65 pub fn try_new(
68 column_dtype: DType,
69 array: StructArray,
70 stats: Arc<[Stat]>,
71 zone_len: u64,
72 row_count: u64,
73 ) -> VortexResult<Self> {
74 let expected_dtype = stats_table_dtype(&column_dtype, &stats);
75 if &expected_dtype != array.dtype() {
76 vortex_bail!("Array dtype does not match expected zone map dtype: {expected_dtype}");
77 }
78
79 Ok(unsafe { Self::new_unchecked(column_dtype, array, zone_len, row_count) })
81 }
82
83 pub(super) unsafe fn new_unchecked(
84 column_dtype: DType,
85 array: StructArray,
86 zone_len: u64,
87 row_count: u64,
88 ) -> Self {
89 Self {
90 column_dtype,
91 array,
92 zone_len,
93 row_count,
94 }
95 }
96
97 #[deprecated(note = "zone-map stats table dtypes are an internal layout detail")]
101 pub fn dtype_for_stats_table(column_dtype: &DType, present_stats: &[Stat]) -> DType {
102 stats_table_dtype(column_dtype, present_stats)
103 }
104
105 pub fn prune(&self, predicate: &Expression, session: &VortexSession) -> VortexResult<Mask> {
119 let mut ctx = session.create_execution_ctx();
120 let num_zones = self.array.len();
121 let predicate = self.lower_stats(predicate.clone())?;
122
123 let applied = self.array.clone().into_array().apply(&predicate)?;
124
125 if !contains_row_count(&applied) {
126 return applied.execute::<Mask>(&mut ctx);
127 }
128
129 let row_count_array = row_count_array(self.zone_len, self.row_count, num_zones)?;
130 let substituted = substitute_row_count(applied, &row_count_array)?;
131 substituted.execute::<Mask>(&mut ctx)
132 }
133
134 fn lower_stats(&self, predicate: Expression) -> VortexResult<Expression> {
135 predicate
139 .transform_down(|expr| {
140 if expr.is::<StatFn>() {
141 return self.lower_stat_fn(expr).map(Transformed::yes);
142 }
143
144 Ok(Transformed::no(expr))
145 })
146 .map(Transformed::into_inner)
147 }
148
149 fn lower_stat_fn(&self, expr: Expression) -> VortexResult<Expression> {
150 let options = expr.as_::<StatFn>();
154 let input = expr.child(0);
155 let input_dtype = input.return_dtype(&self.column_dtype)?;
156 let input_is_root = is_root(input);
157
158 if options.aggregate_fn().is::<AllNan>() {
159 if !has_nans(&input_dtype) {
160 return Ok(lit(false));
161 }
162 if !input_is_root {
163 return Ok(null_expr(DType::Bool(Nullability::NonNullable)));
164 }
165 return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, row_count_expr()));
166 }
167
168 if options.aggregate_fn().is::<AllNonNan>() {
169 if !has_nans(&input_dtype) {
170 return Ok(lit(true));
171 }
172 if !input_is_root {
173 return Ok(null_expr(DType::Bool(Nullability::NonNullable)));
174 }
175 return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, lit(0u64)));
176 }
177
178 if options.aggregate_fn().is::<NanCount>() && !has_nans(&input_dtype) {
179 return Ok(lit(0u64));
180 }
181
182 let return_dtype = match options.aggregate_fn().return_dtype(&input_dtype) {
183 Some(return_dtype) => return_dtype,
184 None => vortex_bail!(
185 "Aggregate function {} does not support input dtype {}",
186 options.aggregate_fn(),
187 input_dtype
188 ),
189 };
190
191 if !input_is_root {
192 return Ok(null_expr(return_dtype));
193 }
194
195 if options.aggregate_fn().is::<AllNull>() {
196 return Ok(eq(self.stat_field_expr(Stat::NullCount)?, row_count_expr()));
197 }
198
199 if options.aggregate_fn().is::<AllNonNull>() {
200 return Ok(eq(self.stat_field_expr(Stat::NullCount)?, lit(0u64)));
201 }
202
203 let Some(stat) = Stat::from_aggregate_fn(options.aggregate_fn()) else {
204 return Ok(null_expr(return_dtype));
205 };
206
207 self.stat_field_expr(stat)
208 }
209
210 fn stat_field_expr(&self, stat: Stat) -> VortexResult<Expression> {
211 if self.array.unmasked_field_by_name_opt(stat.name()).is_some() {
212 return Ok(get_item(stat.name(), root()));
213 }
214
215 let Some(dtype) = stat.dtype(&self.column_dtype) else {
216 vortex_bail!(
217 "Stat {} does not support column dtype {}",
218 stat,
219 self.column_dtype
220 );
221 };
222 Ok(null_expr(dtype))
223 }
224}
225
226fn row_count_expr() -> Expression {
227 RowCount.new_expr(EmptyOptions, [])
228}
229
230fn null_expr(dtype: DType) -> Expression {
231 lit(Scalar::null(dtype.as_nullable()))
232}
233
234fn has_nans(dtype: &DType) -> bool {
235 matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
236}
237
238fn row_count_array(zone_len: u64, row_count: u64, num_zones: usize) -> VortexResult<ArrayRef> {
244 if num_zones == 0 {
245 return Ok(ConstantArray::new(0u64, 0).into_array());
246 }
247
248 let last_zone_len = row_count - zone_len.saturating_mul((num_zones as u64) - 1);
249 if num_zones == 1 || last_zone_len == zone_len {
250 return Ok(ConstantArray::new(last_zone_len, num_zones).into_array());
251 }
252
253 let ends = unsafe {
254 PrimitiveArray::new_unchecked(
255 buffer![num_zones as u64 - 1, num_zones as u64],
256 Validity::NonNullable,
257 )
258 }
259 .into_array();
260 let values = unsafe {
261 PrimitiveArray::new_unchecked(buffer![zone_len, last_zone_len], Validity::NonNullable)
262 }
263 .into_array();
264
265 Ok(unsafe { RunEnd::new_unchecked(ends, values, 0, num_zones) }.into_array())
268}
269
270#[cfg(test)]
271mod tests {
272 use std::sync::Arc;
273
274 use vortex_array::IntoArray;
275 use vortex_array::arrays::BoolArray;
276 use vortex_array::arrays::PrimitiveArray;
277 use vortex_array::arrays::StructArray;
278 use vortex_array::assert_arrays_eq;
279 use vortex_array::dtype::DType;
280 use vortex_array::dtype::DecimalDType;
281 use vortex_array::dtype::FieldNames;
282 use vortex_array::dtype::Nullability;
283 use vortex_array::dtype::PType;
284 use vortex_array::expr::Expression;
285 use vortex_array::expr::cast;
286 use vortex_array::expr::gt;
287 use vortex_array::expr::gt_eq;
288 use vortex_array::expr::is_not_null;
289 use vortex_array::expr::is_null;
290 use vortex_array::expr::lit;
291 use vortex_array::expr::lt;
292 use vortex_array::expr::not_eq;
293 use vortex_array::expr::root;
294 use vortex_array::expr::stats::Stat;
295 use vortex_array::stats::all_nan;
296 use vortex_array::stats::all_non_nan;
297 use vortex_array::stats::all_non_null;
298 use vortex_array::stats::all_null;
299 use vortex_array::validity::Validity;
300 use vortex_buffer::buffer;
301
302 use crate::layouts::zoned::zone_map::ZoneMap;
303 use crate::test::SESSION;
304
305 fn falsify(expr: &Expression, dtype: DType) -> Expression {
306 expr.falsify(&dtype, &SESSION).unwrap().unwrap()
307 }
308
309 #[test]
310 fn test_zone_map_prunes() {
311 let zone_map = ZoneMap::try_new(
323 PType::I32.into(),
324 StructArray::from_fields(&[
325 (
326 "max",
327 PrimitiveArray::new(buffer![5i32, 6i32, 7i32], Validity::AllValid).into_array(),
328 ),
329 (
330 "max_is_truncated",
331 BoolArray::from_iter([false, false, false]).into_array(),
332 ),
333 (
334 "min",
335 PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::AllValid).into_array(),
336 ),
337 (
338 "min_is_truncated",
339 BoolArray::from_iter([false, false, false]).into_array(),
340 ),
341 ])
342 .unwrap(),
343 Arc::new([Stat::Max, Stat::Min]),
344 3,
345 10,
346 )
347 .unwrap();
348
349 let expr = gt_eq(root(), lit(6i32));
352 let pruning_expr = falsify(&expr, PType::I32.into());
353 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
354 assert_arrays_eq!(
355 mask.into_array(),
356 BoolArray::from_iter([true, false, false])
357 );
358
359 let expr = gt(root(), lit(5i32));
362 let pruning_expr = falsify(&expr, PType::I32.into());
363 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
364 assert_arrays_eq!(
365 mask.into_array(),
366 BoolArray::from_iter([true, false, false])
367 );
368
369 let expr = lt(root(), lit(2i32));
372 let pruning_expr = falsify(&expr, PType::I32.into());
373 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
374 assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true]));
375 }
376
377 #[test]
378 fn row_count_prunes_short_trailing_zone() {
379 let zone_map = ZoneMap::try_new(
380 PType::U64.into(),
381 StructArray::from_fields(&[(
382 "null_count",
383 PrimitiveArray::new(buffer![0u64, 0, 2], Validity::AllValid).into_array(),
384 )])
385 .unwrap(),
386 Arc::new([Stat::NullCount]),
387 4,
388 10,
389 )
390 .unwrap();
391
392 let expr = is_not_null(root());
393 let pruning_expr = falsify(&expr, PType::U64.into());
394
395 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
396 assert_arrays_eq!(
397 mask.into_array(),
398 BoolArray::from_iter([false, false, true])
399 );
400 }
401
402 #[test]
403 fn row_count_substitution_handles_empty_zone_map() {
404 let zone_map = ZoneMap::try_new(
405 PType::U64.into(),
406 StructArray::from_fields(&[(
407 "null_count",
408 PrimitiveArray::new::<u64>(buffer![], Validity::AllValid).into_array(),
409 )])
410 .unwrap(),
411 Arc::new([Stat::NullCount]),
412 4,
413 0,
414 )
415 .unwrap();
416
417 let expr = is_not_null(root());
418 let pruning_expr = falsify(&expr, PType::U64.into());
419
420 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
421 assert_eq!(mask.len(), 0);
422 }
423
424 #[test]
425 fn all_null_stat_fn_lowers_to_null_count_and_row_count() {
426 let zone_map = ZoneMap::try_new(
427 PType::U64.into(),
428 StructArray::from_fields(&[(
429 "null_count",
430 PrimitiveArray::new(buffer![0u64, 4, 2], Validity::AllValid).into_array(),
431 )])
432 .unwrap(),
433 Arc::new([Stat::NullCount]),
434 4,
435 10,
436 )
437 .unwrap();
438
439 let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap();
440 assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true]));
441 }
442
443 #[test]
444 fn all_non_null_stat_fn_lowers_to_null_count() {
445 let zone_map = ZoneMap::try_new(
446 PType::U64.into(),
447 StructArray::from_fields(&[(
448 "null_count",
449 PrimitiveArray::new(buffer![0u64, 4, 2], Validity::AllValid).into_array(),
450 )])
451 .unwrap(),
452 Arc::new([Stat::NullCount]),
453 4,
454 10,
455 )
456 .unwrap();
457
458 let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap();
459 assert_arrays_eq!(
460 mask.into_array(),
461 BoolArray::from_iter([true, false, false])
462 );
463 }
464
465 #[test]
466 fn non_float_nan_stat_fns_lower_to_constants() {
467 let zone_map = ZoneMap::try_new(
468 PType::I32.into(),
469 StructArray::try_new(FieldNames::empty(), vec![], 2, Validity::NonNullable).unwrap(),
470 Arc::new([]),
471 4,
472 8,
473 )
474 .unwrap();
475
476 let mask = zone_map.prune(&all_nan(root()), &SESSION).unwrap();
477 assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false]));
478
479 let mask = zone_map.prune(&all_non_nan(root()), &SESSION).unwrap();
480 assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([true, true]));
481 }
482
483 #[test]
484 fn unavailable_stat_fn_lowers_to_unknown_mask() {
485 let zone_map = ZoneMap::try_new(
486 PType::U64.into(),
487 StructArray::try_new(FieldNames::empty(), vec![], 3, Validity::NonNullable).unwrap(),
488 Arc::new([]),
489 4,
490 10,
491 )
492 .unwrap();
493
494 let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap();
495 assert_arrays_eq!(
496 mask.into_array(),
497 BoolArray::from_iter([false, false, false])
498 );
499
500 let expr = gt(root(), lit(5u64));
501 let pruning_expr = falsify(&expr, PType::U64.into());
502 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
503 assert_arrays_eq!(
504 mask.into_array(),
505 BoolArray::from_iter([false, false, false])
506 );
507 }
508
509 #[test]
510 fn float_min_max_stat_fn_requires_nan_count() {
511 let zone_map = ZoneMap::try_new(
512 PType::F32.into(),
513 StructArray::from_fields(&[
514 (
515 "max",
516 PrimitiveArray::new(buffer![5.0f32, 6.0, 7.0], Validity::AllValid).into_array(),
517 ),
518 (
519 "max_is_truncated",
520 BoolArray::from_iter([false, false, false]).into_array(),
521 ),
522 ])
523 .unwrap(),
524 Arc::new([Stat::Max]),
525 4,
526 12,
527 )
528 .unwrap();
529
530 let expr = gt(root(), lit(5.0f32));
531 let pruning_expr = falsify(&expr, PType::F32.into());
532 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
533 assert_arrays_eq!(
534 mask.into_array(),
535 BoolArray::from_iter([false, false, false])
536 );
537
538 let zone_map = ZoneMap::try_new(
539 PType::F32.into(),
540 StructArray::from_fields(&[
541 (
542 "max",
543 PrimitiveArray::new(buffer![5.0f32, 6.0, 7.0], Validity::AllValid).into_array(),
544 ),
545 (
546 "max_is_truncated",
547 BoolArray::from_iter([false, false, false]).into_array(),
548 ),
549 (
550 "nan_count",
551 PrimitiveArray::new(buffer![0u64, 0, 0], Validity::AllValid).into_array(),
552 ),
553 ])
554 .unwrap(),
555 Arc::new([Stat::Max, Stat::NaNCount]),
556 4,
557 12,
558 )
559 .unwrap();
560
561 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
562 assert_arrays_eq!(
563 mask.into_array(),
564 BoolArray::from_iter([true, false, false])
565 );
566 }
567
568 #[test]
569 fn float_cast_min_max_stat_fn_uses_source_nan_count() {
570 let zone_map = ZoneMap::try_new(
571 PType::F32.into(),
572 StructArray::from_fields(&[
573 (
574 "max",
575 PrimitiveArray::new(buffer![5.0f32, 5.0], Validity::AllValid).into_array(),
576 ),
577 (
578 "max_is_truncated",
579 BoolArray::from_iter([false, false]).into_array(),
580 ),
581 (
582 "min",
583 PrimitiveArray::new(buffer![5.0f32, 5.0], Validity::AllValid).into_array(),
584 ),
585 (
586 "min_is_truncated",
587 BoolArray::from_iter([false, false]).into_array(),
588 ),
589 (
590 "nan_count",
591 PrimitiveArray::new(buffer![1u64, 0], Validity::AllValid).into_array(),
592 ),
593 ])
594 .unwrap(),
595 Arc::new([Stat::Max, Stat::Min, Stat::NaNCount]),
596 4,
597 8,
598 )
599 .unwrap();
600
601 let cast_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
602 let expr = not_eq(cast(root(), cast_dtype), lit(5i32));
603 let pruning_expr = falsify(&expr, PType::F32.into());
604
605 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
606 assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true]));
607 }
608
609 #[test]
610 fn fixed_size_list_min_max_stat_fn_lowers_to_unknown_mask() {
611 let elem_dtype = Arc::new(DType::Decimal(
615 DecimalDType::new(10, 2),
616 Nullability::Nullable,
617 ));
618 let column_dtype = DType::FixedSizeList(elem_dtype, 1, Nullability::Nullable);
619
620 let zone_map = ZoneMap::try_new(
621 column_dtype,
622 StructArray::try_new(FieldNames::empty(), vec![], 3, Validity::NonNullable).unwrap(),
623 Arc::new([]),
624 4,
625 10,
626 )
627 .unwrap();
628
629 let max_fn = Stat::Max
630 .aggregate_fn()
631 .expect("max should have an aggregate function");
632 let predicate = is_null(vortex_array::stats::stat(root(), max_fn));
633
634 let mask = zone_map.prune(&predicate, &SESSION).unwrap();
636 assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([true, true, true]));
637 }
638
639 #[test]
640 fn unsupported_aggregate_input_dtype_errors() {
641 let zone_map = ZoneMap::try_new(
642 DType::Null,
643 StructArray::try_new(FieldNames::empty(), vec![], 3, Validity::NonNullable).unwrap(),
644 Arc::new([]),
645 4,
646 10,
647 )
648 .unwrap();
649
650 let max_fn = Stat::Max
651 .aggregate_fn()
652 .expect("max should have an aggregate function");
653 let predicate = is_null(vortex_array::stats::stat(root(), max_fn));
654 let error = zone_map.prune(&predicate, &SESSION).unwrap_err();
655
656 assert!(
657 error
658 .to_string()
659 .contains("Aggregate function vortex.max() does not support input dtype null"),
660 "{error}"
661 );
662 }
663
664 #[test]
665 fn row_count_prunes_all_null_uniform_zones() {
666 let zone_map = ZoneMap::try_new(
667 PType::U64.into(),
668 StructArray::from_fields(&[(
669 "null_count",
670 PrimitiveArray::new(buffer![0u64, 4, 0], Validity::AllValid).into_array(),
671 )])
672 .unwrap(),
673 Arc::new([Stat::NullCount]),
674 4,
675 12,
676 )
677 .unwrap();
678
679 let expr = is_not_null(root());
680 let pruning_expr = falsify(&expr, PType::U64.into());
681
682 let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();
684 assert_arrays_eq!(
685 mask.into_array(),
686 BoolArray::from_iter([false, true, false])
687 );
688 }
689}