Skip to main content

vortex_array/arrays/patched/vtable/
operations.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ExecutionCtx;
7use crate::array::ArrayView;
8use crate::array::OperationsVTable;
9use crate::arrays::PrimitiveArray;
10use crate::arrays::patched::Patched;
11use crate::arrays::patched::PatchedArrayExt;
12use crate::arrays::patched::PatchedArraySlotsExt;
13use crate::optimizer::ArrayOptimizer;
14use crate::scalar::Scalar;
15
16impl OperationsVTable<Patched> for Patched {
17    fn scalar_at(
18        array: ArrayView<'_, Patched>,
19        index: usize,
20        ctx: &mut ExecutionCtx,
21    ) -> VortexResult<Scalar> {
22        let chunk = (index + array.offset()) / 1024;
23
24        #[expect(
25            clippy::cast_possible_truncation,
26            reason = "N % 1024 always fits in u16"
27        )]
28        let chunk_index = ((index + array.offset()) % 1024) as u16;
29
30        let lane = (index + array.offset()) % array.n_lanes();
31
32        let range = array.lane_range(chunk, lane)?;
33
34        // Get the range of indices corresponding to the lane, potentially decoding them to avoid
35        // the overhead of repeated scalar_at calls.
36        let patch_indices = array
37            .patch_indices()
38            .slice(range.clone())?
39            .optimize()?
40            .execute::<PrimitiveArray>(ctx)?;
41
42        // NOTE: we do linear scan as lane has <= 32 patches, binary search would likely
43        //  be slower.
44        for (&patch_index, idx) in std::iter::zip(patch_indices.as_slice::<u16>(), range) {
45            if patch_index == chunk_index {
46                return array
47                    .patch_values()
48                    .execute_scalar(idx, ctx)?
49                    .cast(array.dtype());
50            }
51        }
52
53        // Otherwise, access the underlying value.
54        array.inner().execute_scalar(index, ctx)
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use vortex_buffer::buffer;
61    use vortex_session::VortexSession;
62
63    use crate::IntoArray;
64    use crate::VortexSessionExecute;
65    use crate::array_session;
66    use crate::arrays::Patched;
67    use crate::dtype::Nullability;
68    use crate::optimizer::ArrayOptimizer;
69    use crate::patches::Patches;
70    use crate::scalar::Scalar;
71
72    #[test]
73    fn test_simple() {
74        let values = buffer![0u16; 1024].into_array();
75        let patches = Patches::new(
76            1024,
77            0,
78            buffer![1u32, 2, 3].into_array(),
79            buffer![1u16; 3].into_array(),
80            None,
81        )
82        .unwrap();
83
84        let session = VortexSession::empty();
85        let mut ctx = session.create_execution_ctx();
86
87        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
88            .unwrap()
89            .into_array();
90
91        assert_eq!(
92            array
93                .execute_scalar(0, &mut array_session().create_execution_ctx())
94                .unwrap(),
95            Scalar::primitive(0u16, Nullability::NonNullable)
96        );
97        assert_eq!(
98            array
99                .execute_scalar(1, &mut array_session().create_execution_ctx())
100                .unwrap(),
101            Scalar::primitive(1u16, Nullability::NonNullable)
102        );
103        assert_eq!(
104            array
105                .execute_scalar(2, &mut array_session().create_execution_ctx())
106                .unwrap(),
107            Scalar::primitive(1u16, Nullability::NonNullable)
108        );
109        assert_eq!(
110            array
111                .execute_scalar(3, &mut array_session().create_execution_ctx())
112                .unwrap(),
113            Scalar::primitive(1u16, Nullability::NonNullable)
114        );
115    }
116
117    #[test]
118    fn test_multi_chunk() {
119        let values = buffer![0u16; 4096].into_array();
120        let patches = Patches::new(
121            4096,
122            0,
123            buffer![1u32, 2, 3].into_array(),
124            buffer![1u16; 3].into_array(),
125            None,
126        )
127        .unwrap();
128
129        let session = VortexSession::empty();
130        let mut ctx = session.create_execution_ctx();
131
132        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
133            .unwrap()
134            .into_array();
135
136        for index in 0..array.len() {
137            let value = array
138                .execute_scalar(index, &mut array_session().create_execution_ctx())
139                .unwrap();
140
141            if [1, 2, 3].contains(&index) {
142                assert_eq!(value, 1u16.into());
143            } else {
144                assert_eq!(value, 0u16.into());
145            }
146        }
147    }
148
149    #[test]
150    fn test_multi_chunk_sliced() {
151        let values = buffer![0u16; 4096].into_array();
152        let patches = Patches::new(
153            4096,
154            0,
155            buffer![1u32, 2, 3].into_array(),
156            buffer![1u16; 3].into_array(),
157            None,
158        )
159        .unwrap();
160
161        let session = VortexSession::empty();
162        let mut ctx = session.create_execution_ctx();
163
164        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
165            .unwrap()
166            .into_array()
167            .slice(3..4096)
168            .unwrap()
169            .optimize()
170            .unwrap();
171
172        assert!(array.is::<Patched>());
173
174        assert_eq!(
175            array
176                .execute_scalar(0, &mut array_session().create_execution_ctx())
177                .unwrap(),
178            1u16.into()
179        );
180        for index in 1..array.len() {
181            assert_eq!(
182                array
183                    .execute_scalar(index, &mut array_session().create_execution_ctx())
184                    .unwrap(),
185                0u16.into()
186            );
187        }
188    }
189}