vortex_array/compute/
slice.rs1use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2
3use crate::encoding::Encoding;
4use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
5use crate::{Array, ArrayRef, Canonical, IntoArray};
6
7pub trait SliceFn<A> {
9 fn slice(&self, array: A, start: usize, stop: usize) -> VortexResult<ArrayRef>;
13}
14
15impl<E: Encoding> SliceFn<&dyn Array> for E
16where
17 E: for<'a> SliceFn<&'a E::Array>,
18{
19 fn slice(&self, array: &dyn Array, start: usize, stop: usize) -> VortexResult<ArrayRef> {
20 let array_ref = array
21 .as_any()
22 .downcast_ref::<E::Array>()
23 .vortex_expect("Failed to downcast array");
24 SliceFn::slice(self, array_ref, start, stop)
25 }
26}
27
28pub fn slice(array: &dyn Array, start: usize, stop: usize) -> VortexResult<ArrayRef> {
38 if start == 0 && stop == array.len() {
39 return Ok(array.to_array());
40 }
41
42 if start == stop {
43 return Ok(Canonical::empty(array.dtype()).into_array());
44 }
45
46 check_slice_bounds(array, start, stop)?;
47
48 let derived_stats = (!array.is_constant()).then(|| derive_sliced_stats(array));
51
52 let sliced = array
53 .vtable()
54 .slice_fn()
55 .map(|f| f.slice(array, start, stop))
56 .unwrap_or_else(|| {
57 Err(vortex_err!(
58 NotImplemented: "slice",
59 array.encoding()
60 ))
61 })?;
62
63 if let Some(derived_stats) = derived_stats {
64 let mut stats = sliced.statistics().to_owned();
65 stats.combine_sets(&derived_stats, array.dtype())?;
66 for (stat, val) in stats.into_iter() {
67 sliced.statistics().set(stat, val)
68 }
69 }
70
71 debug_assert_eq!(
72 sliced.len(),
73 stop - start,
74 "Slice length mismatch {}",
75 array.encoding()
76 );
77 debug_assert_eq!(
78 sliced.dtype(),
79 array.dtype(),
80 "Slice dtype mismatch {}",
81 array.encoding()
82 );
83
84 Ok(sliced)
85}
86
87fn derive_sliced_stats(arr: &dyn Array) -> StatsSet {
88 let stats = arr.statistics().to_owned();
89
90 let is_constant = stats.get_as::<bool>(Stat::IsConstant);
92 let is_sorted = stats.get_as::<bool>(Stat::IsConstant);
93 let is_strict_sorted = stats.get_as::<bool>(Stat::IsConstant);
94
95 let mut stats = stats.keep_inexact_stats(&[
96 Stat::Max,
97 Stat::Min,
98 Stat::NullCount,
99 Stat::UncompressedSizeInBytes,
100 ]);
101
102 if is_constant == Some(Precision::Exact(true)) {
103 stats.set(Stat::IsConstant, Precision::exact(true));
104 }
105 if is_sorted == Some(Precision::Exact(true)) {
106 stats.set(Stat::IsSorted, Precision::exact(true));
107 }
108 if is_strict_sorted == Some(Precision::Exact(true)) {
109 stats.set(Stat::IsStrictSorted, Precision::exact(true));
110 }
111
112 stats
113}
114
115fn check_slice_bounds(array: &dyn Array, start: usize, stop: usize) -> VortexResult<()> {
116 if start > array.len() {
117 vortex_bail!(OutOfBounds: start, 0, array.len());
118 }
119 if stop > array.len() {
120 vortex_bail!(OutOfBounds: stop, 0, array.len());
121 }
122 if start > stop {
123 vortex_bail!("start ({start}) must be <= stop ({stop})");
124 }
125 Ok(())
126}
127
128#[cfg(test)]
129mod tests {
130 use vortex_scalar::Scalar;
131
132 use crate::Array;
133 use crate::arrays::{ConstantArray, PrimitiveArray};
134 use crate::compute::slice;
135 use crate::stats::{Precision, STATS_TO_WRITE, Stat, StatsProviderExt};
136
137 #[test]
138 fn test_slice_primitive() {
139 let c = PrimitiveArray::from_iter(0i32..100);
140 c.statistics().compute_all(STATS_TO_WRITE).unwrap();
141
142 let c2 = slice(&c, 10, 20).unwrap();
143
144 let result_stats = c2.statistics().to_owned();
145 assert_eq!(
146 result_stats.get_as::<i32>(Stat::Max),
147 Some(Precision::inexact(99))
148 );
149 assert_eq!(
150 result_stats.get_as::<i32>(Stat::Min),
151 Some(Precision::inexact(0))
152 );
153 }
154
155 #[test]
156 fn test_slice_const() {
157 let c = ConstantArray::new(Scalar::from(10), 100);
158 c.statistics().compute_all(STATS_TO_WRITE).unwrap();
159
160 let c2 = slice(&c, 10, 20).unwrap();
161 let result_stats = c2.statistics().to_owned();
162
163 assert_eq!(
165 result_stats.get_as::<i32>(Stat::Max),
166 Some(Precision::exact(10))
167 );
168 assert_eq!(
169 result_stats.get_as::<i32>(Stat::Min),
170 Some(Precision::exact(10))
171 );
172 }
173}