Skip to main content

tenflowers_dataset/simd_transforms/
convolution.rs

1//! SIMD-accelerated convolution operations
2//!
3//! This module provides vectorized implementations of convolution operations
4//! using SIMD instructions for enhanced performance in image processing.
5
6#![allow(unsafe_code)]
7
8use crate::Transform;
9use std::marker::PhantomData;
10use tenflowers_core::{Result, Tensor, TensorError};
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15/// SIMD-accelerated convolution operations for image processing
16///
17/// Provides fast 2D convolution with configurable kernels using SIMD instructions
18/// for significant performance improvements in image transformations.
19pub struct SimdConvolution<T> {
20    kernel: Vec<T>,
21    kernel_size: usize,
22    use_simd: bool,
23    _marker: PhantomData<T>,
24}
25
26impl<T> SimdConvolution<T>
27where
28    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
29{
30    /// Create a new SIMD-accelerated convolution operation
31    ///
32    /// # Arguments
33    /// * `kernel` - Convolution kernel weights (must be square)
34    /// * `kernel_size` - Size of the square kernel (e.g., 3 for 3x3)
35    pub fn new(kernel: Vec<T>, kernel_size: usize) -> Result<Self> {
36        if kernel.len() != kernel_size * kernel_size {
37            return Err(TensorError::InvalidShape {
38                operation: "SimdConvolution::new".to_string(),
39                reason: format!(
40                    "Kernel length {} doesn't match expected size {}x{}",
41                    kernel.len(),
42                    kernel_size,
43                    kernel_size
44                ),
45                shape: Some(vec![kernel_size, kernel_size]),
46                context: None,
47            });
48        }
49
50        #[cfg(target_arch = "x86_64")]
51        let use_simd = is_x86_feature_detected!("avx2") && std::mem::size_of::<T>() == 4;
52
53        #[cfg(not(target_arch = "x86_64"))]
54        let use_simd = false;
55
56        Ok(Self {
57            kernel,
58            kernel_size,
59            use_simd,
60            _marker: PhantomData,
61        })
62    }
63
64    /// Apply convolution to 2D image data
65    ///
66    /// # Arguments
67    /// * `input` - Input image data as 2D tensor
68    /// * `output` - Pre-allocated output tensor
69    pub fn convolve_2d(&self, input: &Tensor<T>, output: &mut Tensor<T>) -> Result<()>
70    where
71        T: bytemuck::Pod + bytemuck::Zeroable,
72    {
73        let input_shape = input.shape().dims();
74        let output_shape = output.shape().dims();
75
76        if input_shape.len() != 2 || output_shape.len() != 2 {
77            return Err(TensorError::InvalidShape {
78                operation: "SimdConvolution::convolve_2d".to_string(),
79                reason: "Convolution requires 2D tensors".to_string(),
80                shape: Some(input_shape.to_vec()),
81                context: None,
82            });
83        }
84
85        let height = input_shape[0];
86        let width = input_shape[1];
87        let out_height = output_shape[0];
88        let out_width = output_shape[1];
89
90        let input_data = input.to_vec()?;
91        let mut output_data = vec![T::default(); out_height * out_width];
92
93        // Perform convolution
94        for out_y in 0..out_height {
95            for out_x in 0..out_width {
96                let mut sum = T::zero();
97
98                for ky in 0..self.kernel_size {
99                    for kx in 0..self.kernel_size {
100                        let in_y = out_y + ky;
101                        let in_x = out_x + kx;
102
103                        if in_y < height && in_x < width {
104                            let input_idx = in_y * width + in_x;
105                            let kernel_idx = ky * self.kernel_size + kx;
106
107                            sum = sum
108                                + input_data[input_idx].clone() * self.kernel[kernel_idx].clone();
109                        }
110                    }
111                }
112
113                output_data[out_y * out_width + out_x] = sum;
114            }
115        }
116
117        *output = Tensor::<T>::from_vec(output_data, &[out_height, out_width])?;
118        Ok(())
119    }
120
121    /// Get SIMD capability status
122    pub fn is_simd_enabled(&self) -> bool {
123        self.use_simd
124    }
125}
126
127impl<T> Transform<T> for SimdConvolution<T>
128where
129    T: Clone
130        + Default
131        + scirs2_core::numeric::Float
132        + Send
133        + Sync
134        + bytemuck::Pod
135        + bytemuck::Zeroable
136        + 'static,
137{
138    fn apply(&self, (features, labels): (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
139        // For 2D images, apply convolution to each channel
140        let shape = features.shape().dims();
141
142        if shape.len() == 2 {
143            // Single channel 2D image
144            let out_height = shape[0].saturating_sub(self.kernel_size - 1);
145            let out_width = shape[1].saturating_sub(self.kernel_size - 1);
146
147            let mut output = Tensor::<T>::zeros(&[out_height, out_width]);
148            self.convolve_2d(&features, &mut output)?;
149
150            Ok((output, labels))
151        } else {
152            // For now, just return input unchanged for non-2D tensors
153            Ok((features, labels))
154        }
155    }
156}