Skip to main content

vortex_array/arrays/patched/vtable/
operations.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ExecutionCtx;
7use crate::array::ArrayView;
8use crate::array::OperationsVTable;
9use crate::arrays::PrimitiveArray;
10use crate::arrays::patched::Patched;
11use crate::arrays::patched::PatchedArrayExt;
12use crate::arrays::patched::PatchedArraySlotsExt;
13use crate::optimizer::ArrayOptimizer;
14use crate::scalar::Scalar;
15
16impl OperationsVTable<Patched> for Patched {
17    fn scalar_at(
18        array: ArrayView<'_, Patched>,
19        index: usize,
20        ctx: &mut ExecutionCtx,
21    ) -> VortexResult<Scalar> {
22        let chunk = (index + array.offset()) / 1024;
23
24        #[expect(
25            clippy::cast_possible_truncation,
26            reason = "N % 1024 always fits in u16"
27        )]
28        let chunk_index = ((index + array.offset()) % 1024) as u16;
29
30        let lane = (index + array.offset()) % array.n_lanes();
31
32        let range = array.lane_range(chunk, lane)?;
33
34        // Get the range of indices corresponding to the lane, potentially decoding them to avoid
35        // the overhead of repeated scalar_at calls.
36        let patch_indices = array
37            .patch_indices()
38            .slice(range.clone())?
39            .optimize()?
40            .execute::<PrimitiveArray>(ctx)?;
41
42        // NOTE: we do linear scan as lane has <= 32 patches, binary search would likely
43        //  be slower.
44        for (&patch_index, idx) in std::iter::zip(patch_indices.as_slice::<u16>(), range) {
45            if patch_index == chunk_index {
46                return array.patch_values().scalar_at(idx)?.cast(array.dtype());
47            }
48        }
49
50        // Otherwise, access the underlying value.
51        array.inner().scalar_at(index)
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use vortex_buffer::buffer;
58    use vortex_session::VortexSession;
59
60    use crate::ExecutionCtx;
61    use crate::IntoArray;
62    use crate::arrays::Patched;
63    use crate::dtype::Nullability;
64    use crate::optimizer::ArrayOptimizer;
65    use crate::patches::Patches;
66    use crate::scalar::Scalar;
67
68    #[test]
69    fn test_simple() {
70        let values = buffer![0u16; 1024].into_array();
71        let patches = Patches::new(
72            1024,
73            0,
74            buffer![1u32, 2, 3].into_array(),
75            buffer![1u16; 3].into_array(),
76            None,
77        )
78        .unwrap();
79
80        let session = VortexSession::empty();
81        let mut ctx = ExecutionCtx::new(session);
82
83        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
84            .unwrap()
85            .into_array();
86
87        assert_eq!(
88            array.scalar_at(0).unwrap(),
89            Scalar::primitive(0u16, Nullability::NonNullable)
90        );
91        assert_eq!(
92            array.scalar_at(1).unwrap(),
93            Scalar::primitive(1u16, Nullability::NonNullable)
94        );
95        assert_eq!(
96            array.scalar_at(2).unwrap(),
97            Scalar::primitive(1u16, Nullability::NonNullable)
98        );
99        assert_eq!(
100            array.scalar_at(3).unwrap(),
101            Scalar::primitive(1u16, Nullability::NonNullable)
102        );
103    }
104
105    #[test]
106    fn test_multi_chunk() {
107        let values = buffer![0u16; 4096].into_array();
108        let patches = Patches::new(
109            4096,
110            0,
111            buffer![1u32, 2, 3].into_array(),
112            buffer![1u16; 3].into_array(),
113            None,
114        )
115        .unwrap();
116
117        let session = VortexSession::empty();
118        let mut ctx = ExecutionCtx::new(session);
119
120        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
121            .unwrap()
122            .into_array();
123
124        for index in 0..array.len() {
125            let value = array.scalar_at(index).unwrap();
126
127            if [1, 2, 3].contains(&index) {
128                assert_eq!(value, 1u16.into());
129            } else {
130                assert_eq!(value, 0u16.into());
131            }
132        }
133    }
134
135    #[test]
136    fn test_multi_chunk_sliced() {
137        let values = buffer![0u16; 4096].into_array();
138        let patches = Patches::new(
139            4096,
140            0,
141            buffer![1u32, 2, 3].into_array(),
142            buffer![1u16; 3].into_array(),
143            None,
144        )
145        .unwrap();
146
147        let session = VortexSession::empty();
148        let mut ctx = ExecutionCtx::new(session);
149
150        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
151            .unwrap()
152            .into_array()
153            .slice(3..4096)
154            .unwrap()
155            .optimize()
156            .unwrap();
157
158        assert!(array.is::<Patched>());
159
160        assert_eq!(array.scalar_at(0).unwrap(), 1u16.into());
161        for index in 1..array.len() {
162            assert_eq!(array.scalar_at(index).unwrap(), 0u16.into());
163        }
164    }
165}