vortex_array/arrays/primitive/compute/
pipeline.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::rc::Rc;
6
7use vortex_buffer::{Buffer, ByteBuffer};
8use vortex_dtype::{NativePType, PType, match_each_native_ptype};
9use vortex_error::{VortexResult, vortex_bail};
10
11use crate::arrays::{PrimitiveArray, PrimitiveVTable};
12use crate::pipeline::bits::BitView;
13use crate::pipeline::operators::{BindContext, Operator};
14use crate::pipeline::view::ViewMut;
15use crate::pipeline::{Element, Kernel, KernelContext, N, PipelineVTable, VType};
16use crate::vtable::ValidityHelper;
17
18impl PipelineVTable<PrimitiveVTable> for PrimitiveVTable {
19    fn to_operator(array: &PrimitiveArray) -> VortexResult<Option<Rc<dyn Operator>>> {
20        if !array.validity().all_valid()? {
21            vortex_bail!(
22                "PipelineVTable::to_operator is not supported for arrays with invalid values"
23            );
24        }
25        Ok(Some(Rc::new(PrimitiveOperator::new(
26            array.ptype(),
27            array.byte_buffer().clone(),
28        ))))
29    }
30}
31
32/// Pipeline operator for primitive arrays that produces values from a byte buffer.
33#[derive(Debug, Clone, Hash)]
34pub struct PrimitiveOperator {
35    ptype: PType,
36    byte_buffer: ByteBuffer,
37}
38
39impl PrimitiveOperator {
40    pub fn new(ptype: PType, byte_buffer: ByteBuffer) -> Self {
41        Self { ptype, byte_buffer }
42    }
43}
44
45impl Operator for PrimitiveOperator {
46    fn as_any(&self) -> &dyn Any {
47        self
48    }
49
50    fn vtype(&self) -> VType {
51        VType::Primitive(self.ptype)
52    }
53
54    fn children(&self) -> &[Rc<dyn Operator>] {
55        &[]
56    }
57
58    fn with_children(&self, _children: Vec<Rc<dyn Operator>>) -> Rc<dyn Operator> {
59        Rc::new(self.clone())
60    }
61
62    fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
63        match_each_native_ptype!(self.ptype, |T| {
64            Ok(Box::new(PrimitiveKernel::<T> {
65                buffer: Buffer::from_byte_buffer(self.byte_buffer.clone()),
66                offset: 0,
67            }) as Box<dyn Kernel>)
68        })
69    }
70}
71
72/// A kernel that produces primitive values from a byte buffer.
73/// A kernel that produces primitive values from a byte buffer.
74pub struct PrimitiveKernel<T: NativePType> {
75    buffer: Buffer<T>,
76    offset: usize,
77}
78
79impl<T: Element + NativePType> Kernel for PrimitiveKernel<T> {
80    fn seek(&mut self, chunk_idx: usize) -> VortexResult<()> {
81        self.offset = chunk_idx * N;
82        Ok(())
83    }
84
85    fn step(&mut self, _ctx: &KernelContext, mask: BitView, out: &mut ViewMut) -> VortexResult<()> {
86        let buffer = &self.buffer;
87        let remaining = buffer.len() - self.offset;
88
89        let out_slice = out.as_slice_mut::<T>();
90
91        if remaining > N {
92            out_slice.copy_from_slice(&buffer[self.offset..][..N]);
93            self.offset += N;
94        } else {
95            out_slice[..remaining].copy_from_slice(&buffer[self.offset..]);
96            self.offset += remaining;
97        }
98
99        // TODO(joe): use mask in copy_from_slice, if faster.
100        out.select_mask::<T>(&mask);
101
102        Ok(())
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use itertools::Itertools;
109    use vortex_buffer::BufferMut;
110    use vortex_mask::Mask;
111
112    use super::*;
113    use crate::pipeline::export_canonical_pipeline;
114    use crate::{IntoArray, ToCanonical};
115
116    #[test]
117    fn test_primitive_kernel_basic_operation() {
118        // Create a primitive array with values 0..16
119        let size = 16;
120        let values = (0..i32::try_from(size).unwrap()).collect::<BufferMut<_>>();
121        let primitive_array = values.into_array().to_primitive().unwrap();
122
123        // Create the kernel
124        let mut kernel = PrimitiveKernel::<i32> {
125            buffer: primitive_array.buffer(),
126            offset: 0,
127        };
128
129        let out = export_canonical_pipeline(
130            primitive_array.dtype(),
131            size,
132            &mut kernel,
133            &Mask::AllTrue(size),
134        )
135        .unwrap()
136        .into_primitive()
137        .unwrap();
138
139        let output = out.as_slice::<i32>();
140
141        // Verify the first elements contain our values
142        for i in 0..size {
143            assert_eq!(
144                output[i],
145                i32::try_from(i).unwrap(),
146                "Mismatch at position {}: expected {}, got {}",
147                i,
148                i,
149                output[i]
150            );
151        }
152    }
153
154    #[test]
155    fn test_primitive_kernel_with_mask() {
156        // Create a primitive array with values 0..16
157        let size = 16;
158        let primitive_array = (0i32..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
159
160        // Create the kernel
161        let mut kernel = PrimitiveKernel::<i32> {
162            buffer: primitive_array.buffer(),
163            offset: 0,
164        };
165
166        // Create a mask with alternating bits (every other element selected)
167        let mask = Mask::from_indices(size, (0..size).step_by(2).collect_vec());
168        let out = export_canonical_pipeline(primitive_array.dtype(), size, &mut kernel, &mask)
169            .unwrap()
170            .into_primitive()
171            .unwrap();
172
173        let output = out.as_slice::<i32>();
174
175        // Verify that element 0 was selected (first bit in mask is 1)
176        assert_eq!(output[0], 0, "First element should be 0 since bit 0 is set");
177
178        // The exact number of selected elements should match our true_count
179        assert_eq!(
180            out.len(),
181            size / 2,
182            "Selected element count should match true_count"
183        )
184    }
185}