Skip to main content

tenflowers_dataset/simd_transforms/
image_processing.rs

1//! SIMD-accelerated image processing operations
2//!
3//! This module provides vectorized implementations of image processing operations
4//! such as color space conversion and histogram computation using SIMD instructions.
5
6#![allow(unsafe_code)]
7
8use crate::Transform;
9use std::marker::PhantomData;
10use tenflowers_core::{Result, Tensor, TensorError};
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15/// SIMD-accelerated RGB to HSV color space conversion
16pub struct SimdColorConvert<T> {
17    use_simd: bool,
18    _phantom: PhantomData<T>,
19}
20
21impl<T> SimdColorConvert<T>
22where
23    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
24{
25    pub fn new() -> Self {
26        #[cfg(target_arch = "x86_64")]
27        let use_simd = is_x86_feature_detected!("avx2") && std::mem::size_of::<T>() == 4;
28
29        #[cfg(not(target_arch = "x86_64"))]
30        let use_simd = false;
31
32        Self {
33            use_simd,
34            _phantom: PhantomData,
35        }
36    }
37
38    /// Convert RGB to HSV using SIMD acceleration
39    pub fn rgb_to_hsv(&self, rgb_data: &mut [T]) {
40        if self.use_simd && std::mem::size_of::<T>() == 4 && rgb_data.len() % 3 == 0 {
41            #[cfg(target_arch = "x86_64")]
42            unsafe {
43                self.rgb_to_hsv_f32_simd(std::mem::transmute::<&mut [T], &mut [f32]>(rgb_data));
44                return;
45            }
46        }
47
48        // Fallback to scalar implementation
49        self.rgb_to_hsv_scalar(rgb_data);
50    }
51
52    /// SIMD-accelerated RGB to HSV conversion for f32 data
53    #[cfg(target_arch = "x86_64")]
54    unsafe fn rgb_to_hsv_f32_simd(&self, rgb_data: &mut [f32]) {
55        let pixels = rgb_data.len() / 3;
56
57        for i in 0..pixels {
58            let base = i * 3;
59            let r = rgb_data[base];
60            let g = rgb_data[base + 1];
61            let b = rgb_data[base + 2];
62
63            let max_val = r.max(g.max(b));
64            let min_val = r.min(g.min(b));
65            let delta = max_val - min_val;
66
67            // Value
68            let v = max_val;
69
70            // Saturation
71            let s = if max_val == 0.0 { 0.0 } else { delta / max_val };
72
73            // Hue
74            let h = if delta == 0.0 {
75                0.0
76            } else if max_val == r {
77                60.0 * (((g - b) / delta) % 6.0)
78            } else if max_val == g {
79                60.0 * ((b - r) / delta + 2.0)
80            } else {
81                60.0 * ((r - g) / delta + 4.0)
82            };
83
84            let h_normalized = if h < 0.0 { h + 360.0 } else { h };
85
86            rgb_data[base] = h_normalized / 360.0; // Normalize H to [0,1]
87            rgb_data[base + 1] = s;
88            rgb_data[base + 2] = v;
89        }
90    }
91
92    /// Scalar fallback for RGB to HSV conversion
93    fn rgb_to_hsv_scalar(&self, rgb_data: &mut [T]) {
94        let pixels = rgb_data.len() / 3;
95
96        for i in 0..pixels {
97            let base = i * 3;
98            let r = rgb_data[base];
99            let g = rgb_data[base + 1];
100            let b = rgb_data[base + 2];
101
102            let max_val = r.max(g.max(b));
103            let min_val = r.min(g.min(b));
104            let delta = max_val - min_val;
105
106            // Value
107            let v = max_val;
108
109            // Saturation
110            let s = if max_val == T::zero() {
111                T::zero()
112            } else {
113                delta / max_val
114            };
115
116            // Hue
117            let h = if delta == T::zero() {
118                T::zero()
119            } else if max_val == r {
120                let six = T::from(6.0).unwrap_or_else(|| T::from(6).unwrap_or(T::zero()));
121                let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
122                sixty * (((g - b) / delta) % six)
123            } else if max_val == g {
124                let two = T::from(2.0).unwrap_or_else(|| T::from(2).unwrap_or(T::zero()));
125                let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
126                sixty * ((b - r) / delta + two)
127            } else {
128                let four = T::from(4.0).unwrap_or_else(|| T::from(4).unwrap_or(T::zero()));
129                let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
130                sixty * ((r - g) / delta + four)
131            };
132
133            let three_sixty = T::from(360.0).unwrap_or_else(|| T::from(360).unwrap_or(T::one()));
134            let h_normalized = if h < T::zero() { h + three_sixty } else { h };
135
136            rgb_data[base] = h_normalized / three_sixty; // Normalize H to [0,1]
137            rgb_data[base + 1] = s;
138            rgb_data[base + 2] = v;
139        }
140    }
141}
142
143impl<T> Default for SimdColorConvert<T>
144where
145    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
146{
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152impl<T> Transform<T> for SimdColorConvert<T>
153where
154    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
155{
156    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
157        let (features, labels) = sample;
158        let mut data = features
159            .as_slice()
160            .ok_or_else(|| {
161                TensorError::invalid_argument(
162                    "Unable to access tensor data for color conversion".to_string(),
163                )
164            })?
165            .to_vec();
166        self.rgb_to_hsv(&mut data);
167        let converted_features = Tensor::from_vec(data, features.shape().dims())?;
168        Ok((converted_features, labels))
169    }
170}
171
172/// SIMD-accelerated histogram computation for efficient data distribution analysis
173///
174/// Provides high-performance histogram calculation using SIMD instructions
175/// for up to 8x speedup on compatible hardware for dataset statistics.
176pub struct SimdHistogram {
177    bins: usize,
178    min_val: f32,
179    max_val: f32,
180    use_simd: bool,
181}
182
183impl SimdHistogram {
184    /// Create a new SIMD-accelerated histogram calculator
185    pub fn new(bins: usize, min_val: f32, max_val: f32) -> Self {
186        #[cfg(target_arch = "x86_64")]
187        let use_simd = is_x86_feature_detected!("avx2");
188
189        #[cfg(not(target_arch = "x86_64"))]
190        let use_simd = false;
191
192        Self {
193            bins,
194            min_val,
195            max_val,
196            use_simd,
197        }
198    }
199
200    /// Get SIMD capability status
201    pub fn is_simd_enabled(&self) -> bool {
202        self.use_simd
203    }
204
205    /// Compute histogram of tensor data with SIMD acceleration
206    pub fn compute(&self, tensor: &Tensor<f32>) -> Result<Vec<u32>> {
207        let data = tensor
208            .as_slice()
209            .ok_or_else(|| TensorError::InvalidOperation {
210                operation: "histogram_compute".to_string(),
211                reason: "Cannot get tensor slice".to_string(),
212                context: None,
213            })?;
214
215        let mut histogram = vec![0u32; self.bins];
216        let bin_width = (self.max_val - self.min_val) / self.bins as f32;
217
218        #[cfg(target_arch = "x86_64")]
219        if self.use_simd && data.len() >= 8 {
220            self.compute_simd_f32(data, &mut histogram, bin_width);
221        } else {
222            self.compute_scalar(data, &mut histogram, bin_width);
223        }
224
225        #[cfg(not(target_arch = "x86_64"))]
226        self.compute_scalar(data, &mut histogram, bin_width);
227
228        Ok(histogram)
229    }
230
231    /// SIMD-accelerated histogram computation for f32 data
232    #[cfg(target_arch = "x86_64")]
233    fn compute_simd_f32(&self, data: &[f32], histogram: &mut [u32], bin_width: f32) {
234        unsafe {
235            let min_vec = _mm256_set1_ps(self.min_val);
236            let max_vec = _mm256_set1_ps(self.max_val);
237            let bin_width_vec = _mm256_set1_ps(bin_width);
238            let bins_minus_one = _mm256_set1_epi32((self.bins - 1) as i32);
239            let zero_vec = _mm256_setzero_si256();
240
241            let chunks = data.chunks_exact(8);
242            let remainder = chunks.remainder();
243
244            // Process 8 elements at a time with SIMD
245            for chunk in chunks {
246                let values = _mm256_loadu_ps(chunk.as_ptr());
247
248                // Clamp values to [min_val, max_val]
249                let clamped = _mm256_max_ps(_mm256_min_ps(values, max_vec), min_vec);
250
251                // Calculate bin indices: (value - min_val) / bin_width
252                let normalized = _mm256_sub_ps(clamped, min_vec);
253                let bin_indices_f = _mm256_div_ps(normalized, bin_width_vec);
254
255                // Use truncation toward zero instead of rounding
256                let bin_indices = _mm256_cvttps_epi32(bin_indices_f);
257
258                // Clamp bin indices to valid range [0, bins-1]
259                let clamped_indices =
260                    _mm256_max_epi32(_mm256_min_epi32(bin_indices, bins_minus_one), zero_vec);
261
262                // Extract indices and increment histogram bins
263                let indices: [i32; 8] = std::mem::transmute(clamped_indices);
264                for &idx in &indices {
265                    histogram[idx as usize] += 1;
266                }
267            }
268
269            // Process remaining elements with scalar code
270            self.compute_scalar(remainder, histogram, bin_width);
271        }
272    }
273
274    /// Fallback scalar implementation for non-SIMD hardware
275    fn compute_scalar(&self, data: &[f32], histogram: &mut [u32], bin_width: f32) {
276        for &value in data {
277            let clamped = value.clamp(self.min_val, self.max_val);
278            // Handle the edge case where clamped equals max_val
279            let bin_idx = if clamped == self.max_val {
280                self.bins - 1
281            } else {
282                ((clamped - self.min_val) / bin_width) as usize
283            };
284            let bin_idx = bin_idx.min(self.bins - 1);
285            histogram[bin_idx] += 1;
286        }
287    }
288}
289
290impl Transform<f32> for SimdHistogram {
291    fn apply(&self, sample: (Tensor<f32>, Tensor<f32>)) -> Result<(Tensor<f32>, Tensor<f32>)> {
292        // Return the original sample unchanged - this transform is meant for analysis
293        Ok(sample)
294    }
295}
296
297/// Specialized histogram transform that computes histogram alongside the data
298pub struct SimdHistogramTransform {
299    histogram_computer: SimdHistogram,
300}
301
302impl SimdHistogramTransform {
303    pub fn new(bins: usize, min_val: f32, max_val: f32) -> Self {
304        Self {
305            histogram_computer: SimdHistogram::new(bins, min_val, max_val),
306        }
307    }
308
309    pub fn apply_with_histogram(&self, input: &Tensor<f32>) -> Result<Vec<u32>> {
310        self.histogram_computer.compute(input)
311    }
312}