vortex_runend/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::arrays::PrimitiveVTable;
7use vortex_array::search_sorted::{SearchSorted, SearchSortedSide};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11    Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::DType;
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::Mask;
16use vortex_scalar::PValue;
17
18use crate::compress::{runend_decode_bools, runend_decode_primitive, runend_encode};
19
20vtable!(RunEnd);
21
22impl VTable for RunEndVTable {
23    type Array = RunEndArray;
24    type Encoding = RunEndEncoding;
25
26    type ArrayVTable = Self;
27    type CanonicalVTable = Self;
28    type OperationsVTable = Self;
29    type ValidityVTable = Self;
30    type VisitorVTable = Self;
31    type ComputeVTable = NotSupported;
32    type EncodeVTable = Self;
33    type SerdeVTable = Self;
34
35    fn id(_encoding: &Self::Encoding) -> EncodingId {
36        EncodingId::new_ref("vortex.runend")
37    }
38
39    fn encoding(_array: &Self::Array) -> EncodingRef {
40        EncodingRef::new_ref(RunEndEncoding.as_ref())
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct RunEndArray {
46    ends: ArrayRef,
47    values: ArrayRef,
48    offset: usize,
49    length: usize,
50    stats_set: ArrayStats,
51}
52
53#[derive(Clone, Debug)]
54pub struct RunEndEncoding;
55
56impl RunEndArray {
57    pub fn try_new(ends: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
58        let length = if ends.is_empty() {
59            0
60        } else {
61            ends.scalar_at(ends.len() - 1)?.as_ref().try_into()?
62        };
63        Self::with_offset_and_length(ends, values, 0, length)
64    }
65
66    pub(crate) fn with_offset_and_length(
67        ends: ArrayRef,
68        values: ArrayRef,
69        offset: usize,
70        length: usize,
71    ) -> VortexResult<Self> {
72        if !matches!(values.dtype(), &DType::Bool(_) | &DType::Primitive(_, _)) {
73            vortex_bail!(
74                "RunEnd array can only have Bool or Primitive values, {} given",
75                values.dtype()
76            );
77        }
78
79        if offset != 0 {
80            let first_run_end: usize = ends.scalar_at(0)?.as_ref().try_into()?;
81            if first_run_end <= offset {
82                vortex_bail!("First run end {first_run_end} must be bigger than offset {offset}");
83            }
84        }
85
86        if !ends.dtype().is_unsigned_int() || ends.dtype().is_nullable() {
87            vortex_bail!(MismatchedTypes: "non-nullable unsigned int", ends.dtype());
88        }
89        if !ends.statistics().compute_is_strict_sorted().unwrap_or(true) {
90            vortex_bail!("Ends array must be strictly sorted");
91        }
92
93        Ok(Self {
94            ends,
95            values,
96            offset,
97            length,
98            stats_set: Default::default(),
99        })
100    }
101
102    /// Convert the given logical index to an index into the `values` array
103    pub fn find_physical_index(&self, index: usize) -> VortexResult<usize> {
104        Ok(self
105            .ends()
106            .as_primitive_typed()
107            .search_sorted(
108                &PValue::from(index + self.offset()),
109                SearchSortedSide::Right,
110            )
111            .to_ends_index(self.ends().len()))
112    }
113
114    /// Run the array through run-end encoding.
115    pub fn encode(array: ArrayRef) -> VortexResult<Self> {
116        if let Some(parray) = array.as_opt::<PrimitiveVTable>() {
117            let (ends, values) = runend_encode(parray)?;
118            Self::try_new(ends.into_array(), values)
119        } else {
120            vortex_bail!("REE can only encode primitive arrays")
121        }
122    }
123
124    /// The offset that the `ends` is relative to.
125    ///
126    /// This is generally zero for a "new" array, and non-zero after a slicing operation.
127    #[inline]
128    pub fn offset(&self) -> usize {
129        self.offset
130    }
131
132    /// The encoded "ends" of value runs.
133    ///
134    /// The `i`-th element indicates that there is a run of the same value, beginning
135    /// at `ends[i]` (inclusive) and terminating at `ends[i+1]` (exclusive).
136    #[inline]
137    pub fn ends(&self) -> &ArrayRef {
138        &self.ends
139    }
140
141    /// The scalar values.
142    ///
143    /// The `i`-th element is the scalar value for the `i`-th repeated run. The run begins
144    /// at `ends[i]` (inclusive) and terminates at `ends[i+1]` (exclusive).
145    #[inline]
146    pub fn values(&self) -> &ArrayRef {
147        &self.values
148    }
149}
150
151impl ArrayVTable<RunEndVTable> for RunEndVTable {
152    fn len(array: &RunEndArray) -> usize {
153        array.length
154    }
155
156    fn dtype(array: &RunEndArray) -> &DType {
157        array.values.dtype()
158    }
159
160    fn stats(array: &RunEndArray) -> StatsSetRef<'_> {
161        array.stats_set.to_ref(array.as_ref())
162    }
163}
164
165impl ValidityVTable<RunEndVTable> for RunEndVTable {
166    fn is_valid(array: &RunEndArray, index: usize) -> VortexResult<bool> {
167        let physical_idx = array
168            .find_physical_index(index)
169            .vortex_expect("Invalid index");
170        array.values().is_valid(physical_idx)
171    }
172
173    fn all_valid(array: &RunEndArray) -> VortexResult<bool> {
174        array.values().all_valid()
175    }
176
177    fn all_invalid(array: &RunEndArray) -> VortexResult<bool> {
178        array.values().all_invalid()
179    }
180
181    fn validity_mask(array: &RunEndArray) -> VortexResult<Mask> {
182        Ok(match array.values().validity_mask()? {
183            Mask::AllTrue(_) => Mask::AllTrue(array.len()),
184            Mask::AllFalse(_) => Mask::AllFalse(array.len()),
185            Mask::Values(values) => {
186                let ree_validity = RunEndArray::with_offset_and_length(
187                    array.ends().clone(),
188                    values.into_array(),
189                    array.offset(),
190                    array.len(),
191                )
192                .vortex_expect("invalid array")
193                .into_array();
194                Mask::from_buffer(ree_validity.to_bool()?.boolean_buffer().clone())
195            }
196        })
197    }
198}
199
200impl CanonicalVTable<RunEndVTable> for RunEndVTable {
201    fn canonicalize(array: &RunEndArray) -> VortexResult<Canonical> {
202        let pends = array.ends().to_primitive()?;
203        match array.dtype() {
204            DType::Bool(_) => {
205                let bools = array.values().to_bool()?;
206                runend_decode_bools(pends, bools, array.offset(), array.len()).map(Canonical::Bool)
207            }
208            DType::Primitive(..) => {
209                let pvalues = array.values().to_primitive()?;
210                runend_decode_primitive(pends, pvalues, array.offset(), array.len())
211                    .map(Canonical::Primitive)
212            }
213            _ => vortex_bail!("Only Primitive and Bool values are supported"),
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use vortex_array::IntoArray;
221    use vortex_buffer::buffer;
222    use vortex_dtype::{DType, Nullability, PType};
223
224    use crate::RunEndArray;
225
226    #[test]
227    fn test_runend_constructor() {
228        let arr = RunEndArray::try_new(
229            buffer![2u32, 5, 10].into_array(),
230            buffer![1i32, 2, 3].into_array(),
231        )
232        .unwrap();
233        assert_eq!(arr.len(), 10);
234        assert_eq!(
235            arr.dtype(),
236            &DType::Primitive(PType::I32, Nullability::NonNullable)
237        );
238
239        // 0, 1 => 1
240        // 2, 3, 4 => 2
241        // 5, 6, 7, 8, 9 => 3
242        assert_eq!(arr.scalar_at(0).unwrap(), 1.into());
243        assert_eq!(arr.scalar_at(2).unwrap(), 2.into());
244        assert_eq!(arr.scalar_at(5).unwrap(), 3.into());
245        assert_eq!(arr.scalar_at(9).unwrap(), 3.into());
246    }
247}