vortex_array/arrays/primitive/compute/
pipeline.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::Arc;
6
7use log;
8use vortex_buffer::{Buffer, ByteBuffer};
9use vortex_dtype::{NativePType, PType, match_each_native_ptype};
10use vortex_error::VortexResult;
11
12use crate::arrays::{PrimitiveArray, PrimitiveVTable};
13use crate::pipeline::bits::BitView;
14use crate::pipeline::operators::{BindContext, Operator, OperatorRef};
15use crate::pipeline::view::ViewMut;
16use crate::pipeline::{Element, Kernel, KernelContext, N, PipelineVTable, VType};
17use crate::vtable::ValidityHelper;
18
19impl PipelineVTable<PrimitiveVTable> for PrimitiveVTable {
20    fn to_operator(array: &PrimitiveArray) -> VortexResult<Option<OperatorRef>> {
21        if !array.validity().all_valid() {
22            log::debug!(
23                "PipelineVTable::to_operator is not supported for arrays with invalid values"
24            );
25            return Ok(None);
26        }
27        Ok(Some(Arc::new(PrimitiveOperator::new(
28            array.ptype(),
29            array.byte_buffer().clone(),
30        ))))
31    }
32}
33
34/// Pipeline operator for primitive arrays that produces values from a byte buffer.
35#[derive(Debug, Clone, Hash)]
36pub struct PrimitiveOperator {
37    ptype: PType,
38    byte_buffer: ByteBuffer,
39}
40
41impl PrimitiveOperator {
42    pub fn new(ptype: PType, byte_buffer: ByteBuffer) -> Self {
43        Self { ptype, byte_buffer }
44    }
45}
46
47impl Operator for PrimitiveOperator {
48    fn as_any(&self) -> &dyn Any {
49        self
50    }
51
52    fn vtype(&self) -> VType {
53        VType::Primitive(self.ptype)
54    }
55
56    fn children(&self) -> &[OperatorRef] {
57        &[]
58    }
59
60    fn with_children(&self, _children: Vec<OperatorRef>) -> OperatorRef {
61        Arc::new(self.clone())
62    }
63
64    fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
65        match_each_native_ptype!(self.ptype, |T| {
66            Ok(Box::new(PrimitiveKernel::<T> {
67                buffer: Buffer::from_byte_buffer(self.byte_buffer.clone()),
68                offset: 0,
69            }) as Box<dyn Kernel>)
70        })
71    }
72}
73
74/// A kernel that produces primitive values from a byte buffer.
75/// A kernel that produces primitive values from a byte buffer.
76pub struct PrimitiveKernel<T: NativePType> {
77    buffer: Buffer<T>,
78    offset: usize,
79}
80
81impl<T: Element + NativePType> Kernel for PrimitiveKernel<T> {
82    fn seek(&mut self, chunk_idx: usize) -> VortexResult<()> {
83        self.offset = chunk_idx * N;
84        Ok(())
85    }
86
87    fn step(&mut self, _ctx: &KernelContext, mask: BitView, out: &mut ViewMut) -> VortexResult<()> {
88        let buffer = &self.buffer;
89        let remaining = buffer.len() - self.offset;
90
91        let out_slice = out.as_slice_mut::<T>();
92
93        if remaining > N {
94            out_slice.copy_from_slice(&buffer[self.offset..][..N]);
95            self.offset += N;
96        } else {
97            out_slice[..remaining].copy_from_slice(&buffer[self.offset..]);
98            self.offset += remaining;
99        }
100
101        // TODO(joe): use mask in copy_from_slice, if faster.
102        out.select_mask::<T>(&mask);
103
104        Ok(())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use itertools::Itertools;
111    use vortex_buffer::BufferMut;
112    use vortex_mask::Mask;
113
114    use super::*;
115    use crate::pipeline::export_canonical_pipeline;
116    use crate::{IntoArray, ToCanonical};
117
118    #[test]
119    fn test_primitive_kernel_basic_operation() {
120        // Create a primitive array with values 0..16
121        let size = 16;
122        let values = (0..i32::try_from(size).unwrap()).collect::<BufferMut<_>>();
123        let primitive_array = values.into_array().to_primitive();
124
125        // Create the kernel
126        let mut kernel = PrimitiveKernel::<i32> {
127            buffer: primitive_array.buffer(),
128            offset: 0,
129        };
130
131        let out = export_canonical_pipeline(
132            primitive_array.dtype(),
133            size,
134            &mut kernel,
135            &Mask::AllTrue(size),
136        )
137        .unwrap()
138        .into_primitive();
139
140        let output = out.as_slice::<i32>();
141
142        // Verify the first elements contain our values
143        for i in 0..size {
144            assert_eq!(
145                output[i],
146                i32::try_from(i).unwrap(),
147                "Mismatch at position {}: expected {}, got {}",
148                i,
149                i,
150                output[i]
151            );
152        }
153    }
154
155    #[test]
156    fn test_primitive_kernel_with_mask() {
157        // Create a primitive array with values 0..16
158        let size = 16;
159        let primitive_array = (0i32..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
160
161        // Create the kernel
162        let mut kernel = PrimitiveKernel::<i32> {
163            buffer: primitive_array.buffer(),
164            offset: 0,
165        };
166
167        // Create a mask with alternating bits (every other element selected)
168        let mask = Mask::from_indices(size, (0..size).step_by(2).collect_vec());
169        let out = export_canonical_pipeline(primitive_array.dtype(), size, &mut kernel, &mask)
170            .unwrap()
171            .into_primitive();
172
173        let output = out.as_slice::<i32>();
174
175        // Verify that element 0 was selected (first bit in mask is 1)
176        assert_eq!(output[0], 0, "First element should be 0 since bit 0 is set");
177
178        // The exact number of selected elements should match our true_count
179        assert_eq!(
180            out.len(),
181            size / 2,
182            "Selected element count should match true_count"
183        )
184    }
185}