vortex_array/arrays/primitive/compute/
pipeline.rs1use std::any::Any;
5use std::sync::Arc;
6
7use log;
8use vortex_buffer::{Buffer, ByteBuffer};
9use vortex_dtype::{NativePType, PType, match_each_native_ptype};
10use vortex_error::VortexResult;
11
12use crate::arrays::{PrimitiveArray, PrimitiveVTable};
13use crate::pipeline::bits::BitView;
14use crate::pipeline::operators::{BindContext, Operator, OperatorRef};
15use crate::pipeline::view::ViewMut;
16use crate::pipeline::{Element, Kernel, KernelContext, N, PipelineVTable, VType};
17use crate::vtable::ValidityHelper;
18
19impl PipelineVTable<PrimitiveVTable> for PrimitiveVTable {
20 fn to_operator(array: &PrimitiveArray) -> VortexResult<Option<OperatorRef>> {
21 if !array.validity().all_valid() {
22 log::debug!(
23 "PipelineVTable::to_operator is not supported for arrays with invalid values"
24 );
25 return Ok(None);
26 }
27 Ok(Some(Arc::new(PrimitiveOperator::new(
28 array.ptype(),
29 array.byte_buffer().clone(),
30 ))))
31 }
32}
33
34#[derive(Debug, Clone, Hash)]
36pub struct PrimitiveOperator {
37 ptype: PType,
38 byte_buffer: ByteBuffer,
39}
40
41impl PrimitiveOperator {
42 pub fn new(ptype: PType, byte_buffer: ByteBuffer) -> Self {
43 Self { ptype, byte_buffer }
44 }
45}
46
47impl Operator for PrimitiveOperator {
48 fn as_any(&self) -> &dyn Any {
49 self
50 }
51
52 fn vtype(&self) -> VType {
53 VType::Primitive(self.ptype)
54 }
55
56 fn children(&self) -> &[OperatorRef] {
57 &[]
58 }
59
60 fn with_children(&self, _children: Vec<OperatorRef>) -> OperatorRef {
61 Arc::new(self.clone())
62 }
63
64 fn bind(&self, _ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
65 match_each_native_ptype!(self.ptype, |T| {
66 Ok(Box::new(PrimitiveKernel::<T> {
67 buffer: Buffer::from_byte_buffer(self.byte_buffer.clone()),
68 offset: 0,
69 }) as Box<dyn Kernel>)
70 })
71 }
72}
73
74pub struct PrimitiveKernel<T: NativePType> {
77 buffer: Buffer<T>,
78 offset: usize,
79}
80
81impl<T: Element + NativePType> Kernel for PrimitiveKernel<T> {
82 fn seek(&mut self, chunk_idx: usize) -> VortexResult<()> {
83 self.offset = chunk_idx * N;
84 Ok(())
85 }
86
87 fn step(&mut self, _ctx: &KernelContext, mask: BitView, out: &mut ViewMut) -> VortexResult<()> {
88 let buffer = &self.buffer;
89 let remaining = buffer.len() - self.offset;
90
91 let out_slice = out.as_slice_mut::<T>();
92
93 if remaining > N {
94 out_slice.copy_from_slice(&buffer[self.offset..][..N]);
95 self.offset += N;
96 } else {
97 out_slice[..remaining].copy_from_slice(&buffer[self.offset..]);
98 self.offset += remaining;
99 }
100
101 out.select_mask::<T>(&mask);
103
104 Ok(())
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use itertools::Itertools;
111 use vortex_buffer::BufferMut;
112 use vortex_mask::Mask;
113
114 use super::*;
115 use crate::pipeline::export_canonical_pipeline;
116 use crate::{IntoArray, ToCanonical};
117
118 #[test]
119 fn test_primitive_kernel_basic_operation() {
120 let size = 16;
122 let values = (0..i32::try_from(size).unwrap()).collect::<BufferMut<_>>();
123 let primitive_array = values.into_array().to_primitive();
124
125 let mut kernel = PrimitiveKernel::<i32> {
127 buffer: primitive_array.buffer(),
128 offset: 0,
129 };
130
131 let out = export_canonical_pipeline(
132 primitive_array.dtype(),
133 size,
134 &mut kernel,
135 &Mask::AllTrue(size),
136 )
137 .unwrap()
138 .into_primitive();
139
140 let output = out.as_slice::<i32>();
141
142 for i in 0..size {
144 assert_eq!(
145 output[i],
146 i32::try_from(i).unwrap(),
147 "Mismatch at position {}: expected {}, got {}",
148 i,
149 i,
150 output[i]
151 );
152 }
153 }
154
155 #[test]
156 fn test_primitive_kernel_with_mask() {
157 let size = 16;
159 let primitive_array = (0i32..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
160
161 let mut kernel = PrimitiveKernel::<i32> {
163 buffer: primitive_array.buffer(),
164 offset: 0,
165 };
166
167 let mask = Mask::from_indices(size, (0..size).step_by(2).collect_vec());
169 let out = export_canonical_pipeline(primitive_array.dtype(), size, &mut kernel, &mask)
170 .unwrap()
171 .into_primitive();
172
173 let output = out.as_slice::<i32>();
174
175 assert_eq!(output[0], 0, "First element should be 0 since bit 0 is set");
177
178 assert_eq!(
180 out.len(),
181 size / 2,
182 "Selected element count should match true_count"
183 )
184 }
185}