vortex_array/arrays/primitive/compute/
pipeline.rs1use std::any::Any;
5use std::rc::Rc;
6
7use vortex_buffer::{Buffer, ByteBuffer};
8use vortex_dtype::{NativePType, PType, match_each_native_ptype};
9use vortex_error::{VortexResult, vortex_bail};
10
11use crate::arrays::{PrimitiveArray, PrimitiveVTable};
12use crate::pipeline::bits::BitView;
13use crate::pipeline::operators::{BindContext, Operator};
14use crate::pipeline::view::ViewMut;
15use crate::pipeline::{Element, Kernel, KernelContext, N, PipelineVTable, VType};
16use crate::vtable::ValidityHelper;
17
18impl PipelineVTable<PrimitiveVTable> for PrimitiveVTable {
19 fn to_operator(array: &PrimitiveArray) -> VortexResult<Option<Rc<dyn Operator>>> {
20 if !array.validity().all_valid()? {
21 vortex_bail!(
22 "PipelineVTable::to_operator is not supported for arrays with invalid values"
23 );
24 }
25 Ok(Some(Rc::new(PrimitiveOperator::new(
26 array.ptype(),
27 array.byte_buffer().clone(),
28 ))))
29 }
30}
31
32#[derive(Debug, Clone, Hash)]
34pub struct PrimitiveOperator {
35 ptype: PType,
36 byte_buffer: ByteBuffer,
37}
38
39impl PrimitiveOperator {
40 pub fn new(ptype: PType, byte_buffer: ByteBuffer) -> Self {
41 Self { ptype, byte_buffer }
42 }
43}
44
45impl Operator for PrimitiveOperator {
46 fn as_any(&self) -> &dyn Any {
47 self
48 }
49
50 fn vtype(&self) -> VType {
51 VType::Primitive(self.ptype)
52 }
53
54 fn children(&self) -> &[Rc<dyn Operator>] {
55 &[]
56 }
57
58 fn with_children(&self, _children: Vec<Rc<dyn Operator>>) -> Rc<dyn Operator> {
59 Rc::new(self.clone())
60 }
61
62 fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
63 match_each_native_ptype!(self.ptype, |T| {
64 Ok(Box::new(PrimitiveKernel::<T> {
65 buffer: Buffer::from_byte_buffer(self.byte_buffer.clone()),
66 offset: 0,
67 }) as Box<dyn Kernel>)
68 })
69 }
70}
71
72pub struct PrimitiveKernel<T: NativePType> {
75 buffer: Buffer<T>,
76 offset: usize,
77}
78
79impl<T: Element + NativePType> Kernel for PrimitiveKernel<T> {
80 fn seek(&mut self, chunk_idx: usize) -> VortexResult<()> {
81 self.offset = chunk_idx * N;
82 Ok(())
83 }
84
85 fn step(&mut self, _ctx: &KernelContext, mask: BitView, out: &mut ViewMut) -> VortexResult<()> {
86 let buffer = &self.buffer;
87 let remaining = buffer.len() - self.offset;
88
89 let out_slice = out.as_slice_mut::<T>();
90
91 if remaining > N {
92 out_slice.copy_from_slice(&buffer[self.offset..][..N]);
93 self.offset += N;
94 } else {
95 out_slice[..remaining].copy_from_slice(&buffer[self.offset..]);
96 self.offset += remaining;
97 }
98
99 out.select_mask::<T>(&mask);
101
102 Ok(())
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use itertools::Itertools;
109 use vortex_buffer::BufferMut;
110 use vortex_mask::Mask;
111
112 use super::*;
113 use crate::pipeline::export_canonical_pipeline;
114 use crate::{IntoArray, ToCanonical};
115
116 #[test]
117 fn test_primitive_kernel_basic_operation() {
118 let size = 16;
120 let values = (0..i32::try_from(size).unwrap()).collect::<BufferMut<_>>();
121 let primitive_array = values.into_array().to_primitive().unwrap();
122
123 let mut kernel = PrimitiveKernel::<i32> {
125 buffer: primitive_array.buffer(),
126 offset: 0,
127 };
128
129 let out = export_canonical_pipeline(
130 primitive_array.dtype(),
131 size,
132 &mut kernel,
133 &Mask::AllTrue(size),
134 )
135 .unwrap()
136 .into_primitive()
137 .unwrap();
138
139 let output = out.as_slice::<i32>();
140
141 for i in 0..size {
143 assert_eq!(
144 output[i],
145 i32::try_from(i).unwrap(),
146 "Mismatch at position {}: expected {}, got {}",
147 i,
148 i,
149 output[i]
150 );
151 }
152 }
153
154 #[test]
155 fn test_primitive_kernel_with_mask() {
156 let size = 16;
158 let primitive_array = (0i32..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
159
160 let mut kernel = PrimitiveKernel::<i32> {
162 buffer: primitive_array.buffer(),
163 offset: 0,
164 };
165
166 let mask = Mask::from_indices(size, (0..size).step_by(2).collect_vec());
168 let out = export_canonical_pipeline(primitive_array.dtype(), size, &mut kernel, &mask)
169 .unwrap()
170 .into_primitive()
171 .unwrap();
172
173 let output = out.as_slice::<i32>();
174
175 assert_eq!(output[0], 0, "First element should be 0 since bit 0 is set");
177
178 assert_eq!(
180 out.len(),
181 size / 2,
182 "Selected element count should match true_count"
183 )
184 }
185}