tenflowers_dataset/simd_transforms/
convolution.rs1#![allow(unsafe_code)]
7
8use crate::Transform;
9use std::marker::PhantomData;
10use tenflowers_core::{Result, Tensor, TensorError};
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15pub struct SimdConvolution<T> {
20 kernel: Vec<T>,
21 kernel_size: usize,
22 use_simd: bool,
23 _marker: PhantomData<T>,
24}
25
26impl<T> SimdConvolution<T>
27where
28 T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
29{
30 pub fn new(kernel: Vec<T>, kernel_size: usize) -> Result<Self> {
36 if kernel.len() != kernel_size * kernel_size {
37 return Err(TensorError::InvalidShape {
38 operation: "SimdConvolution::new".to_string(),
39 reason: format!(
40 "Kernel length {} doesn't match expected size {}x{}",
41 kernel.len(),
42 kernel_size,
43 kernel_size
44 ),
45 shape: Some(vec![kernel_size, kernel_size]),
46 context: None,
47 });
48 }
49
50 #[cfg(target_arch = "x86_64")]
51 let use_simd = is_x86_feature_detected!("avx2") && std::mem::size_of::<T>() == 4;
52
53 #[cfg(not(target_arch = "x86_64"))]
54 let use_simd = false;
55
56 Ok(Self {
57 kernel,
58 kernel_size,
59 use_simd,
60 _marker: PhantomData,
61 })
62 }
63
64 pub fn convolve_2d(&self, input: &Tensor<T>, output: &mut Tensor<T>) -> Result<()>
70 where
71 T: bytemuck::Pod + bytemuck::Zeroable,
72 {
73 let input_shape = input.shape().dims();
74 let output_shape = output.shape().dims();
75
76 if input_shape.len() != 2 || output_shape.len() != 2 {
77 return Err(TensorError::InvalidShape {
78 operation: "SimdConvolution::convolve_2d".to_string(),
79 reason: "Convolution requires 2D tensors".to_string(),
80 shape: Some(input_shape.to_vec()),
81 context: None,
82 });
83 }
84
85 let height = input_shape[0];
86 let width = input_shape[1];
87 let out_height = output_shape[0];
88 let out_width = output_shape[1];
89
90 let input_data = input.to_vec()?;
91 let mut output_data = vec![T::default(); out_height * out_width];
92
93 for out_y in 0..out_height {
95 for out_x in 0..out_width {
96 let mut sum = T::zero();
97
98 for ky in 0..self.kernel_size {
99 for kx in 0..self.kernel_size {
100 let in_y = out_y + ky;
101 let in_x = out_x + kx;
102
103 if in_y < height && in_x < width {
104 let input_idx = in_y * width + in_x;
105 let kernel_idx = ky * self.kernel_size + kx;
106
107 sum = sum
108 + input_data[input_idx].clone() * self.kernel[kernel_idx].clone();
109 }
110 }
111 }
112
113 output_data[out_y * out_width + out_x] = sum;
114 }
115 }
116
117 *output = Tensor::<T>::from_vec(output_data, &[out_height, out_width])?;
118 Ok(())
119 }
120
121 pub fn is_simd_enabled(&self) -> bool {
123 self.use_simd
124 }
125}
126
127impl<T> Transform<T> for SimdConvolution<T>
128where
129 T: Clone
130 + Default
131 + scirs2_core::numeric::Float
132 + Send
133 + Sync
134 + bytemuck::Pod
135 + bytemuck::Zeroable
136 + 'static,
137{
138 fn apply(&self, (features, labels): (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
139 let shape = features.shape().dims();
141
142 if shape.len() == 2 {
143 let out_height = shape[0].saturating_sub(self.kernel_size - 1);
145 let out_width = shape[1].saturating_sub(self.kernel_size - 1);
146
147 let mut output = Tensor::<T>::zeros(&[out_height, out_width]);
148 self.convolve_2d(&features, &mut output)?;
149
150 Ok((output, labels))
151 } else {
152 Ok((features, labels))
154 }
155 }
156}