1use std::fmt::Debug;
5
6use vortex_array::arrays::PrimitiveVTable;
7use vortex_array::search_sorted::{SearchSorted, SearchSortedSide};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11 Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::DType;
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::Mask;
16use vortex_scalar::PValue;
17
18use crate::compress::{runend_decode_bools, runend_decode_primitive, runend_encode};
19
20vtable!(RunEnd);
21
22impl VTable for RunEndVTable {
23 type Array = RunEndArray;
24 type Encoding = RunEndEncoding;
25
26 type ArrayVTable = Self;
27 type CanonicalVTable = Self;
28 type OperationsVTable = Self;
29 type ValidityVTable = Self;
30 type VisitorVTable = Self;
31 type ComputeVTable = NotSupported;
32 type EncodeVTable = Self;
33 type SerdeVTable = Self;
34
35 fn id(_encoding: &Self::Encoding) -> EncodingId {
36 EncodingId::new_ref("vortex.runend")
37 }
38
39 fn encoding(_array: &Self::Array) -> EncodingRef {
40 EncodingRef::new_ref(RunEndEncoding.as_ref())
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct RunEndArray {
46 ends: ArrayRef,
47 values: ArrayRef,
48 offset: usize,
49 length: usize,
50 stats_set: ArrayStats,
51}
52
53#[derive(Clone, Debug)]
54pub struct RunEndEncoding;
55
56impl RunEndArray {
57 pub fn try_new(ends: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
58 let length = if ends.is_empty() {
59 0
60 } else {
61 ends.scalar_at(ends.len() - 1)?.as_ref().try_into()?
62 };
63 Self::with_offset_and_length(ends, values, 0, length)
64 }
65
66 pub(crate) fn with_offset_and_length(
67 ends: ArrayRef,
68 values: ArrayRef,
69 offset: usize,
70 length: usize,
71 ) -> VortexResult<Self> {
72 if !matches!(values.dtype(), &DType::Bool(_) | &DType::Primitive(_, _)) {
73 vortex_bail!(
74 "RunEnd array can only have Bool or Primitive values, {} given",
75 values.dtype()
76 );
77 }
78
79 if offset != 0 {
80 let first_run_end: usize = ends.scalar_at(0)?.as_ref().try_into()?;
81 if first_run_end <= offset {
82 vortex_bail!("First run end {first_run_end} must be bigger than offset {offset}");
83 }
84 }
85
86 if !ends.dtype().is_unsigned_int() || ends.dtype().is_nullable() {
87 vortex_bail!(MismatchedTypes: "non-nullable unsigned int", ends.dtype());
88 }
89 if !ends.statistics().compute_is_strict_sorted().unwrap_or(true) {
90 vortex_bail!("Ends array must be strictly sorted");
91 }
92
93 Ok(Self {
94 ends,
95 values,
96 offset,
97 length,
98 stats_set: Default::default(),
99 })
100 }
101
102 pub fn find_physical_index(&self, index: usize) -> VortexResult<usize> {
104 Ok(self
105 .ends()
106 .as_primitive_typed()
107 .search_sorted(
108 &PValue::from(index + self.offset()),
109 SearchSortedSide::Right,
110 )
111 .to_ends_index(self.ends().len()))
112 }
113
114 pub fn encode(array: ArrayRef) -> VortexResult<Self> {
116 if let Some(parray) = array.as_opt::<PrimitiveVTable>() {
117 let (ends, values) = runend_encode(parray)?;
118 Self::try_new(ends.into_array(), values)
119 } else {
120 vortex_bail!("REE can only encode primitive arrays")
121 }
122 }
123
124 #[inline]
128 pub fn offset(&self) -> usize {
129 self.offset
130 }
131
132 #[inline]
137 pub fn ends(&self) -> &ArrayRef {
138 &self.ends
139 }
140
141 #[inline]
146 pub fn values(&self) -> &ArrayRef {
147 &self.values
148 }
149}
150
151impl ArrayVTable<RunEndVTable> for RunEndVTable {
152 fn len(array: &RunEndArray) -> usize {
153 array.length
154 }
155
156 fn dtype(array: &RunEndArray) -> &DType {
157 array.values.dtype()
158 }
159
160 fn stats(array: &RunEndArray) -> StatsSetRef<'_> {
161 array.stats_set.to_ref(array.as_ref())
162 }
163}
164
165impl ValidityVTable<RunEndVTable> for RunEndVTable {
166 fn is_valid(array: &RunEndArray, index: usize) -> VortexResult<bool> {
167 let physical_idx = array
168 .find_physical_index(index)
169 .vortex_expect("Invalid index");
170 array.values().is_valid(physical_idx)
171 }
172
173 fn all_valid(array: &RunEndArray) -> VortexResult<bool> {
174 array.values().all_valid()
175 }
176
177 fn all_invalid(array: &RunEndArray) -> VortexResult<bool> {
178 array.values().all_invalid()
179 }
180
181 fn validity_mask(array: &RunEndArray) -> VortexResult<Mask> {
182 Ok(match array.values().validity_mask()? {
183 Mask::AllTrue(_) => Mask::AllTrue(array.len()),
184 Mask::AllFalse(_) => Mask::AllFalse(array.len()),
185 Mask::Values(values) => {
186 let ree_validity = RunEndArray::with_offset_and_length(
187 array.ends().clone(),
188 values.into_array(),
189 array.offset(),
190 array.len(),
191 )
192 .vortex_expect("invalid array")
193 .into_array();
194 Mask::from_buffer(ree_validity.to_bool()?.boolean_buffer().clone())
195 }
196 })
197 }
198}
199
200impl CanonicalVTable<RunEndVTable> for RunEndVTable {
201 fn canonicalize(array: &RunEndArray) -> VortexResult<Canonical> {
202 let pends = array.ends().to_primitive()?;
203 match array.dtype() {
204 DType::Bool(_) => {
205 let bools = array.values().to_bool()?;
206 runend_decode_bools(pends, bools, array.offset(), array.len()).map(Canonical::Bool)
207 }
208 DType::Primitive(..) => {
209 let pvalues = array.values().to_primitive()?;
210 runend_decode_primitive(pends, pvalues, array.offset(), array.len())
211 .map(Canonical::Primitive)
212 }
213 _ => vortex_bail!("Only Primitive and Bool values are supported"),
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use vortex_array::IntoArray;
221 use vortex_buffer::buffer;
222 use vortex_dtype::{DType, Nullability, PType};
223
224 use crate::RunEndArray;
225
226 #[test]
227 fn test_runend_constructor() {
228 let arr = RunEndArray::try_new(
229 buffer![2u32, 5, 10].into_array(),
230 buffer![1i32, 2, 3].into_array(),
231 )
232 .unwrap();
233 assert_eq!(arr.len(), 10);
234 assert_eq!(
235 arr.dtype(),
236 &DType::Primitive(PType::I32, Nullability::NonNullable)
237 );
238
239 assert_eq!(arr.scalar_at(0).unwrap(), 1.into());
243 assert_eq!(arr.scalar_at(2).unwrap(), 2.into());
244 assert_eq!(arr.scalar_at(5).unwrap(), 3.into());
245 assert_eq!(arr.scalar_at(9).unwrap(), 3.into());
246 }
247}