Skip to main content

vortex_array/arrays/patched/vtable/
operations.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ExecutionCtx;
7use crate::array::ArrayView;
8use crate::array::OperationsVTable;
9use crate::arrays::PrimitiveArray;
10use crate::arrays::patched::Patched;
11use crate::arrays::patched::PatchedArrayExt;
12use crate::arrays::patched::PatchedArraySlotsExt;
13use crate::optimizer::ArrayOptimizer;
14use crate::scalar::Scalar;
15
16impl OperationsVTable<Patched> for Patched {
17    fn scalar_at(
18        array: ArrayView<'_, Patched>,
19        index: usize,
20        ctx: &mut ExecutionCtx,
21    ) -> VortexResult<Scalar> {
22        let chunk = (index + array.offset()) / 1024;
23
24        #[expect(
25            clippy::cast_possible_truncation,
26            reason = "N % 1024 always fits in u16"
27        )]
28        let chunk_index = ((index + array.offset()) % 1024) as u16;
29
30        let lane = (index + array.offset()) % array.n_lanes();
31
32        let range = array.lane_range(chunk, lane)?;
33
34        // Get the range of indices corresponding to the lane, potentially decoding them to avoid
35        // the overhead of repeated scalar_at calls.
36        let patch_indices = array
37            .patch_indices()
38            .slice(range.clone())?
39            .optimize()?
40            .execute::<PrimitiveArray>(ctx)?;
41
42        // NOTE: we do linear scan as lane has <= 32 patches, binary search would likely
43        //  be slower.
44        for (&patch_index, idx) in std::iter::zip(patch_indices.as_slice::<u16>(), range) {
45            if patch_index == chunk_index {
46                return array
47                    .patch_values()
48                    .execute_scalar(idx, ctx)?
49                    .cast(array.dtype());
50            }
51        }
52
53        // Otherwise, access the underlying value.
54        array.inner().execute_scalar(index, ctx)
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use vortex_buffer::buffer;
61    use vortex_session::VortexSession;
62
63    use crate::ExecutionCtx;
64    use crate::IntoArray;
65    use crate::LEGACY_SESSION;
66    use crate::VortexSessionExecute;
67    use crate::arrays::Patched;
68    use crate::dtype::Nullability;
69    use crate::optimizer::ArrayOptimizer;
70    use crate::patches::Patches;
71    use crate::scalar::Scalar;
72
73    #[test]
74    fn test_simple() {
75        let values = buffer![0u16; 1024].into_array();
76        let patches = Patches::new(
77            1024,
78            0,
79            buffer![1u32, 2, 3].into_array(),
80            buffer![1u16; 3].into_array(),
81            None,
82        )
83        .unwrap();
84
85        let session = VortexSession::empty();
86        let mut ctx = ExecutionCtx::new(session);
87
88        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
89            .unwrap()
90            .into_array();
91
92        assert_eq!(
93            array
94                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
95                .unwrap(),
96            Scalar::primitive(0u16, Nullability::NonNullable)
97        );
98        assert_eq!(
99            array
100                .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
101                .unwrap(),
102            Scalar::primitive(1u16, Nullability::NonNullable)
103        );
104        assert_eq!(
105            array
106                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
107                .unwrap(),
108            Scalar::primitive(1u16, Nullability::NonNullable)
109        );
110        assert_eq!(
111            array
112                .execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())
113                .unwrap(),
114            Scalar::primitive(1u16, Nullability::NonNullable)
115        );
116    }
117
118    #[test]
119    fn test_multi_chunk() {
120        let values = buffer![0u16; 4096].into_array();
121        let patches = Patches::new(
122            4096,
123            0,
124            buffer![1u32, 2, 3].into_array(),
125            buffer![1u16; 3].into_array(),
126            None,
127        )
128        .unwrap();
129
130        let session = VortexSession::empty();
131        let mut ctx = ExecutionCtx::new(session);
132
133        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
134            .unwrap()
135            .into_array();
136
137        for index in 0..array.len() {
138            let value = array
139                .execute_scalar(index, &mut LEGACY_SESSION.create_execution_ctx())
140                .unwrap();
141
142            if [1, 2, 3].contains(&index) {
143                assert_eq!(value, 1u16.into());
144            } else {
145                assert_eq!(value, 0u16.into());
146            }
147        }
148    }
149
150    #[test]
151    fn test_multi_chunk_sliced() {
152        let values = buffer![0u16; 4096].into_array();
153        let patches = Patches::new(
154            4096,
155            0,
156            buffer![1u32, 2, 3].into_array(),
157            buffer![1u16; 3].into_array(),
158            None,
159        )
160        .unwrap();
161
162        let session = VortexSession::empty();
163        let mut ctx = ExecutionCtx::new(session);
164
165        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
166            .unwrap()
167            .into_array()
168            .slice(3..4096)
169            .unwrap()
170            .optimize()
171            .unwrap();
172
173        assert!(array.is::<Patched>());
174
175        assert_eq!(
176            array
177                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
178                .unwrap(),
179            1u16.into()
180        );
181        for index in 1..array.len() {
182            assert_eq!(
183                array
184                    .execute_scalar(index, &mut LEGACY_SESSION.create_execution_ctx())
185                    .unwrap(),
186                0u16.into()
187            );
188        }
189    }
190}