tenflowers_dataset/simd_transforms/
image_processing.rs1#![allow(unsafe_code)]
7
8use crate::Transform;
9use std::marker::PhantomData;
10use tenflowers_core::{Result, Tensor, TensorError};
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15pub struct SimdColorConvert<T> {
17 use_simd: bool,
18 _phantom: PhantomData<T>,
19}
20
21impl<T> SimdColorConvert<T>
22where
23 T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
24{
25 pub fn new() -> Self {
26 #[cfg(target_arch = "x86_64")]
27 let use_simd = is_x86_feature_detected!("avx2") && std::mem::size_of::<T>() == 4;
28
29 #[cfg(not(target_arch = "x86_64"))]
30 let use_simd = false;
31
32 Self {
33 use_simd,
34 _phantom: PhantomData,
35 }
36 }
37
38 pub fn rgb_to_hsv(&self, rgb_data: &mut [T]) {
40 if self.use_simd && std::mem::size_of::<T>() == 4 && rgb_data.len() % 3 == 0 {
41 #[cfg(target_arch = "x86_64")]
42 unsafe {
43 self.rgb_to_hsv_f32_simd(std::mem::transmute::<&mut [T], &mut [f32]>(rgb_data));
44 return;
45 }
46 }
47
48 self.rgb_to_hsv_scalar(rgb_data);
50 }
51
52 #[cfg(target_arch = "x86_64")]
54 unsafe fn rgb_to_hsv_f32_simd(&self, rgb_data: &mut [f32]) {
55 let pixels = rgb_data.len() / 3;
56
57 for i in 0..pixels {
58 let base = i * 3;
59 let r = rgb_data[base];
60 let g = rgb_data[base + 1];
61 let b = rgb_data[base + 2];
62
63 let max_val = r.max(g.max(b));
64 let min_val = r.min(g.min(b));
65 let delta = max_val - min_val;
66
67 let v = max_val;
69
70 let s = if max_val == 0.0 { 0.0 } else { delta / max_val };
72
73 let h = if delta == 0.0 {
75 0.0
76 } else if max_val == r {
77 60.0 * (((g - b) / delta) % 6.0)
78 } else if max_val == g {
79 60.0 * ((b - r) / delta + 2.0)
80 } else {
81 60.0 * ((r - g) / delta + 4.0)
82 };
83
84 let h_normalized = if h < 0.0 { h + 360.0 } else { h };
85
86 rgb_data[base] = h_normalized / 360.0; rgb_data[base + 1] = s;
88 rgb_data[base + 2] = v;
89 }
90 }
91
92 fn rgb_to_hsv_scalar(&self, rgb_data: &mut [T]) {
94 let pixels = rgb_data.len() / 3;
95
96 for i in 0..pixels {
97 let base = i * 3;
98 let r = rgb_data[base];
99 let g = rgb_data[base + 1];
100 let b = rgb_data[base + 2];
101
102 let max_val = r.max(g.max(b));
103 let min_val = r.min(g.min(b));
104 let delta = max_val - min_val;
105
106 let v = max_val;
108
109 let s = if max_val == T::zero() {
111 T::zero()
112 } else {
113 delta / max_val
114 };
115
116 let h = if delta == T::zero() {
118 T::zero()
119 } else if max_val == r {
120 let six = T::from(6.0).unwrap_or_else(|| T::from(6).unwrap_or(T::zero()));
121 let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
122 sixty * (((g - b) / delta) % six)
123 } else if max_val == g {
124 let two = T::from(2.0).unwrap_or_else(|| T::from(2).unwrap_or(T::zero()));
125 let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
126 sixty * ((b - r) / delta + two)
127 } else {
128 let four = T::from(4.0).unwrap_or_else(|| T::from(4).unwrap_or(T::zero()));
129 let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
130 sixty * ((r - g) / delta + four)
131 };
132
133 let three_sixty = T::from(360.0).unwrap_or_else(|| T::from(360).unwrap_or(T::one()));
134 let h_normalized = if h < T::zero() { h + three_sixty } else { h };
135
136 rgb_data[base] = h_normalized / three_sixty; rgb_data[base + 1] = s;
138 rgb_data[base + 2] = v;
139 }
140 }
141}
142
143impl<T> Default for SimdColorConvert<T>
144where
145 T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
146{
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152impl<T> Transform<T> for SimdColorConvert<T>
153where
154 T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
155{
156 fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
157 let (features, labels) = sample;
158 let mut data = features
159 .as_slice()
160 .ok_or_else(|| {
161 TensorError::invalid_argument(
162 "Unable to access tensor data for color conversion".to_string(),
163 )
164 })?
165 .to_vec();
166 self.rgb_to_hsv(&mut data);
167 let converted_features = Tensor::from_vec(data, features.shape().dims())?;
168 Ok((converted_features, labels))
169 }
170}
171
172pub struct SimdHistogram {
177 bins: usize,
178 min_val: f32,
179 max_val: f32,
180 use_simd: bool,
181}
182
183impl SimdHistogram {
184 pub fn new(bins: usize, min_val: f32, max_val: f32) -> Self {
186 #[cfg(target_arch = "x86_64")]
187 let use_simd = is_x86_feature_detected!("avx2");
188
189 #[cfg(not(target_arch = "x86_64"))]
190 let use_simd = false;
191
192 Self {
193 bins,
194 min_val,
195 max_val,
196 use_simd,
197 }
198 }
199
200 pub fn is_simd_enabled(&self) -> bool {
202 self.use_simd
203 }
204
205 pub fn compute(&self, tensor: &Tensor<f32>) -> Result<Vec<u32>> {
207 let data = tensor
208 .as_slice()
209 .ok_or_else(|| TensorError::InvalidOperation {
210 operation: "histogram_compute".to_string(),
211 reason: "Cannot get tensor slice".to_string(),
212 context: None,
213 })?;
214
215 let mut histogram = vec![0u32; self.bins];
216 let bin_width = (self.max_val - self.min_val) / self.bins as f32;
217
218 #[cfg(target_arch = "x86_64")]
219 if self.use_simd && data.len() >= 8 {
220 self.compute_simd_f32(data, &mut histogram, bin_width);
221 } else {
222 self.compute_scalar(data, &mut histogram, bin_width);
223 }
224
225 #[cfg(not(target_arch = "x86_64"))]
226 self.compute_scalar(data, &mut histogram, bin_width);
227
228 Ok(histogram)
229 }
230
231 #[cfg(target_arch = "x86_64")]
233 fn compute_simd_f32(&self, data: &[f32], histogram: &mut [u32], bin_width: f32) {
234 unsafe {
235 let min_vec = _mm256_set1_ps(self.min_val);
236 let max_vec = _mm256_set1_ps(self.max_val);
237 let bin_width_vec = _mm256_set1_ps(bin_width);
238 let bins_minus_one = _mm256_set1_epi32((self.bins - 1) as i32);
239 let zero_vec = _mm256_setzero_si256();
240
241 let chunks = data.chunks_exact(8);
242 let remainder = chunks.remainder();
243
244 for chunk in chunks {
246 let values = _mm256_loadu_ps(chunk.as_ptr());
247
248 let clamped = _mm256_max_ps(_mm256_min_ps(values, max_vec), min_vec);
250
251 let normalized = _mm256_sub_ps(clamped, min_vec);
253 let bin_indices_f = _mm256_div_ps(normalized, bin_width_vec);
254
255 let bin_indices = _mm256_cvttps_epi32(bin_indices_f);
257
258 let clamped_indices =
260 _mm256_max_epi32(_mm256_min_epi32(bin_indices, bins_minus_one), zero_vec);
261
262 let indices: [i32; 8] = std::mem::transmute(clamped_indices);
264 for &idx in &indices {
265 histogram[idx as usize] += 1;
266 }
267 }
268
269 self.compute_scalar(remainder, histogram, bin_width);
271 }
272 }
273
274 fn compute_scalar(&self, data: &[f32], histogram: &mut [u32], bin_width: f32) {
276 for &value in data {
277 let clamped = value.clamp(self.min_val, self.max_val);
278 let bin_idx = if clamped == self.max_val {
280 self.bins - 1
281 } else {
282 ((clamped - self.min_val) / bin_width) as usize
283 };
284 let bin_idx = bin_idx.min(self.bins - 1);
285 histogram[bin_idx] += 1;
286 }
287 }
288}
289
290impl Transform<f32> for SimdHistogram {
291 fn apply(&self, sample: (Tensor<f32>, Tensor<f32>)) -> Result<(Tensor<f32>, Tensor<f32>)> {
292 Ok(sample)
294 }
295}
296
297pub struct SimdHistogramTransform {
299 histogram_computer: SimdHistogram,
300}
301
302impl SimdHistogramTransform {
303 pub fn new(bins: usize, min_val: f32, max_val: f32) -> Self {
304 Self {
305 histogram_computer: SimdHistogram::new(bins, min_val, max_val),
306 }
307 }
308
309 pub fn apply_with_histogram(&self, input: &Tensor<f32>) -> Result<Vec<u32>> {
310 self.histogram_computer.compute(input)
311 }
312}