1#![allow(unsafe_op_in_unsafe_fn)]
2
3use yscv_tensor::Tensor;
4
5use super::super::ImgProcError;
6use super::super::shape::hwc_shape;
7
8#[allow(unsafe_code)]
14pub fn normalize(input: &Tensor, mean: &[f32], std: &[f32]) -> Result<Tensor, ImgProcError> {
15 let (h, w, channels) = hwc_shape(input)?;
16 if mean.len() != channels || std.len() != channels {
17 return Err(ImgProcError::InvalidNormalizationParams {
18 expected_channels: channels,
19 mean_len: mean.len(),
20 std_len: std.len(),
21 });
22 }
23 for (channel, value) in std.iter().enumerate() {
24 if *value == 0.0 {
25 return Err(ImgProcError::ZeroStdAtChannel { channel });
26 }
27 }
28
29 let inv_std: Vec<f32> = std.iter().map(|&s| 1.0 / s).collect();
31
32 let len = h * w * channels;
33 let mut out = vec![0.0f32; len];
34
35 let src = input.data();
36 let num_pixels = h * w;
37
38 unsafe {
40 let src_ptr = src.as_ptr();
41 let dst_ptr = out.as_mut_ptr();
42
43 match channels {
45 3 => {
46 normalize_3ch(src_ptr, dst_ptr, mean, &inv_std, num_pixels);
47 }
48 1 => {
49 normalize_1ch(src_ptr, dst_ptr, mean[0], inv_std[0], len);
50 }
51 _ => {
52 normalize_generic(src_ptr, dst_ptr, mean, &inv_std, channels, num_pixels);
53 }
54 }
55 }
56
57 Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
58}
59
60#[allow(unsafe_code)]
61unsafe fn normalize_3ch(
62 src_ptr: *const f32,
63 dst_ptr: *mut f32,
64 mean: &[f32],
65 inv_std: &[f32],
66 num_pixels: usize,
67) {
68 let (m0, m1, m2) = (mean[0], mean[1], mean[2]);
69 let (s0, s1, s2) = (inv_std[0], inv_std[1], inv_std[2]);
70
71 #[cfg(target_arch = "aarch64")]
72 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
73 use std::arch::aarch64::*;
74 let vm = vld1q_f32([m0, m1, m2, 0.0].as_ptr());
75 let vs = vld1q_f32([s0, s1, s2, 0.0].as_ptr());
76 let full_quads = num_pixels / 4;
77 for q in 0..full_quads {
78 let base = q * 12;
79 for p in 0..4 {
80 let off = base + p * 3;
81 let v = vld1q_f32(
82 [
83 *src_ptr.add(off),
84 *src_ptr.add(off + 1),
85 *src_ptr.add(off + 2),
86 0.0,
87 ]
88 .as_ptr(),
89 );
90 let r = vmulq_f32(vsubq_f32(v, vm), vs);
91 *dst_ptr.add(off) = vgetq_lane_f32::<0>(r);
92 *dst_ptr.add(off + 1) = vgetq_lane_f32::<1>(r);
93 *dst_ptr.add(off + 2) = vgetq_lane_f32::<2>(r);
94 }
95 }
96 for i in (full_quads * 4)..num_pixels {
97 let off = i * 3;
98 *dst_ptr.add(off) = (*src_ptr.add(off) - m0) * s0;
99 *dst_ptr.add(off + 1) = (*src_ptr.add(off + 1) - m1) * s1;
100 *dst_ptr.add(off + 2) = (*src_ptr.add(off + 2) - m2) * s2;
101 }
102 return;
103 }
104
105 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
106 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
107 normalize_3ch_avx(src_ptr, dst_ptr, m0, m1, m2, s0, s1, s2, num_pixels);
108 return;
109 }
110
111 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
112 if !cfg!(miri) && std::is_x86_feature_detected!("sse2") {
113 normalize_3ch_sse(src_ptr, dst_ptr, m0, m1, m2, s0, s1, s2, num_pixels);
114 return;
115 }
116
117 for i in 0..num_pixels {
119 let off = i * 3;
120 *dst_ptr.add(off) = (*src_ptr.add(off) - m0) * s0;
121 *dst_ptr.add(off + 1) = (*src_ptr.add(off + 1) - m1) * s1;
122 *dst_ptr.add(off + 2) = (*src_ptr.add(off + 2) - m2) * s2;
123 }
124}
125
126#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
127#[target_feature(enable = "avx")]
128#[allow(unsafe_code)]
129unsafe fn normalize_3ch_avx(
130 src_ptr: *const f32,
131 dst_ptr: *mut f32,
132 m0: f32,
133 m1: f32,
134 m2: f32,
135 s0: f32,
136 s1: f32,
137 s2: f32,
138 num_pixels: usize,
139) {
140 #[cfg(target_arch = "x86")]
141 use std::arch::x86::*;
142 #[cfg(target_arch = "x86_64")]
143 use std::arch::x86_64::*;
144
145 let vm_a = _mm256_set_ps(m1, m0, m2, m1, m0, m2, m1, m0);
149 let vm_b = _mm256_set_ps(m2, m1, m0, m2, m1, m0, m2, m1);
150 let vm_c = _mm256_set_ps(m0, m2, m1, m0, m2, m1, m0, m2);
151 let vs_a = _mm256_set_ps(s1, s0, s2, s1, s0, s2, s1, s0);
152 let vs_b = _mm256_set_ps(s2, s1, s0, s2, s1, s0, s2, s1);
153 let vs_c = _mm256_set_ps(s0, s2, s1, s0, s2, s1, s0, s2);
154
155 let full_groups = num_pixels / 8;
156 for g in 0..full_groups {
157 let base = g * 24;
158 let a = _mm256_loadu_ps(src_ptr.add(base));
160 let b = _mm256_loadu_ps(src_ptr.add(base + 8));
161 let c = _mm256_loadu_ps(src_ptr.add(base + 16));
162
163 let ra = _mm256_mul_ps(_mm256_sub_ps(a, vm_a), vs_a);
164 let rb = _mm256_mul_ps(_mm256_sub_ps(b, vm_b), vs_b);
165 let rc = _mm256_mul_ps(_mm256_sub_ps(c, vm_c), vs_c);
166
167 _mm256_storeu_ps(dst_ptr.add(base), ra);
168 _mm256_storeu_ps(dst_ptr.add(base + 8), rb);
169 _mm256_storeu_ps(dst_ptr.add(base + 16), rc);
170 }
171 for i in (full_groups * 8)..num_pixels {
173 let off = i * 3;
174 *dst_ptr.add(off) = (*src_ptr.add(off) - m0) * s0;
175 *dst_ptr.add(off + 1) = (*src_ptr.add(off + 1) - m1) * s1;
176 *dst_ptr.add(off + 2) = (*src_ptr.add(off + 2) - m2) * s2;
177 }
178}
179
180#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
181#[target_feature(enable = "sse2")]
182#[allow(unsafe_code)]
183unsafe fn normalize_3ch_sse(
184 src_ptr: *const f32,
185 dst_ptr: *mut f32,
186 m0: f32,
187 m1: f32,
188 m2: f32,
189 s0: f32,
190 s1: f32,
191 s2: f32,
192 num_pixels: usize,
193) {
194 #[cfg(target_arch = "x86")]
195 use std::arch::x86::*;
196 #[cfg(target_arch = "x86_64")]
197 use std::arch::x86_64::*;
198
199 let vm_a = _mm_set_ps(m0, m2, m1, m0);
201 let vm_b = _mm_set_ps(m1, m0, m2, m1);
202 let vm_c = _mm_set_ps(m2, m1, m0, m2);
203 let vs_a = _mm_set_ps(s0, s2, s1, s0);
204 let vs_b = _mm_set_ps(s1, s0, s2, s1);
205 let vs_c = _mm_set_ps(s2, s1, s0, s2);
206
207 let full_groups = num_pixels / 4;
208 for g in 0..full_groups {
209 let base = g * 12;
210 let a = _mm_loadu_ps(src_ptr.add(base));
211 let b = _mm_loadu_ps(src_ptr.add(base + 4));
212 let c = _mm_loadu_ps(src_ptr.add(base + 8));
213
214 let ra = _mm_mul_ps(_mm_sub_ps(a, vm_a), vs_a);
215 let rb = _mm_mul_ps(_mm_sub_ps(b, vm_b), vs_b);
216 let rc = _mm_mul_ps(_mm_sub_ps(c, vm_c), vs_c);
217
218 _mm_storeu_ps(dst_ptr.add(base), ra);
219 _mm_storeu_ps(dst_ptr.add(base + 4), rb);
220 _mm_storeu_ps(dst_ptr.add(base + 8), rc);
221 }
222 for i in (full_groups * 4)..num_pixels {
223 let off = i * 3;
224 *dst_ptr.add(off) = (*src_ptr.add(off) - m0) * s0;
225 *dst_ptr.add(off + 1) = (*src_ptr.add(off + 1) - m1) * s1;
226 *dst_ptr.add(off + 2) = (*src_ptr.add(off + 2) - m2) * s2;
227 }
228}
229
230#[allow(unsafe_code)]
231unsafe fn normalize_1ch(
232 src_ptr: *const f32,
233 dst_ptr: *mut f32,
234 mean: f32,
235 inv_std: f32,
236 len: usize,
237) {
238 #[cfg(target_arch = "aarch64")]
239 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
240 use std::arch::aarch64::*;
241 let vm = vdupq_n_f32(mean);
242 let vs = vdupq_n_f32(inv_std);
243 let mut i = 0usize;
244 while i + 4 <= len {
245 let v = vld1q_f32(src_ptr.add(i));
246 let r = vmulq_f32(vsubq_f32(v, vm), vs);
247 vst1q_f32(dst_ptr.add(i), r);
248 i += 4;
249 }
250 while i < len {
251 *dst_ptr.add(i) = (*src_ptr.add(i) - mean) * inv_std;
252 i += 1;
253 }
254 return;
255 }
256
257 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
258 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
259 normalize_1ch_avx(src_ptr, dst_ptr, mean, inv_std, len);
260 return;
261 }
262
263 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
264 if !cfg!(miri) && std::is_x86_feature_detected!("sse2") {
265 normalize_1ch_sse(src_ptr, dst_ptr, mean, inv_std, len);
266 return;
267 }
268
269 for i in 0..len {
271 *dst_ptr.add(i) = (*src_ptr.add(i) - mean) * inv_std;
272 }
273}
274
275#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
276#[target_feature(enable = "avx")]
277#[allow(unsafe_code)]
278unsafe fn normalize_1ch_avx(
279 src_ptr: *const f32,
280 dst_ptr: *mut f32,
281 mean: f32,
282 inv_std: f32,
283 len: usize,
284) {
285 #[cfg(target_arch = "x86")]
286 use std::arch::x86::*;
287 #[cfg(target_arch = "x86_64")]
288 use std::arch::x86_64::*;
289
290 let vm = _mm256_set1_ps(mean);
291 let vs = _mm256_set1_ps(inv_std);
292 let mut i = 0usize;
293 while i + 8 <= len {
294 let v = _mm256_loadu_ps(src_ptr.add(i));
295 let r = _mm256_mul_ps(_mm256_sub_ps(v, vm), vs);
296 _mm256_storeu_ps(dst_ptr.add(i), r);
297 i += 8;
298 }
299 while i < len {
300 *dst_ptr.add(i) = (*src_ptr.add(i) - mean) * inv_std;
301 i += 1;
302 }
303}
304
305#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
306#[target_feature(enable = "sse2")]
307#[allow(unsafe_code)]
308unsafe fn normalize_1ch_sse(
309 src_ptr: *const f32,
310 dst_ptr: *mut f32,
311 mean: f32,
312 inv_std: f32,
313 len: usize,
314) {
315 #[cfg(target_arch = "x86")]
316 use std::arch::x86::*;
317 #[cfg(target_arch = "x86_64")]
318 use std::arch::x86_64::*;
319
320 let vm = _mm_set1_ps(mean);
321 let vs = _mm_set1_ps(inv_std);
322 let mut i = 0usize;
323 while i + 4 <= len {
324 let v = _mm_loadu_ps(src_ptr.add(i));
325 let r = _mm_mul_ps(_mm_sub_ps(v, vm), vs);
326 _mm_storeu_ps(dst_ptr.add(i), r);
327 i += 4;
328 }
329 while i < len {
330 *dst_ptr.add(i) = (*src_ptr.add(i) - mean) * inv_std;
331 i += 1;
332 }
333}
334
335#[allow(unsafe_code)]
336unsafe fn normalize_generic(
337 src_ptr: *const f32,
338 dst_ptr: *mut f32,
339 mean: &[f32],
340 inv_std: &[f32],
341 channels: usize,
342 num_pixels: usize,
343) {
344 #[cfg(target_arch = "aarch64")]
345 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
346 use std::arch::aarch64::*;
347 let simd_end = channels & !3;
348 for px in 0..num_pixels {
349 let base = px * channels;
350 let mut c = 0usize;
351 while c < simd_end {
352 let off = base + c;
353 let v = vld1q_f32(src_ptr.add(off));
354 let vm = vld1q_f32(mean.as_ptr().add(c));
355 let vs = vld1q_f32(inv_std.as_ptr().add(c));
356 let r = vmulq_f32(vsubq_f32(v, vm), vs);
357 vst1q_f32(dst_ptr.add(off), r);
358 c += 4;
359 }
360 while c < channels {
361 let off = base + c;
362 *dst_ptr.add(off) = (*src_ptr.add(off) - mean[c]) * inv_std[c];
363 c += 1;
364 }
365 }
366 return;
367 }
368
369 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
370 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
371 normalize_generic_avx(src_ptr, dst_ptr, mean, inv_std, channels, num_pixels);
372 return;
373 }
374
375 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
376 if !cfg!(miri) && std::is_x86_feature_detected!("sse2") {
377 normalize_generic_sse(src_ptr, dst_ptr, mean, inv_std, channels, num_pixels);
378 return;
379 }
380
381 for px in 0..num_pixels {
383 let base = px * channels;
384 for c in 0..channels {
385 let off = base + c;
386 *dst_ptr.add(off) = (*src_ptr.add(off) - mean[c]) * inv_std[c];
387 }
388 }
389}
390
391#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
392#[target_feature(enable = "avx")]
393#[allow(unsafe_code)]
394unsafe fn normalize_generic_avx(
395 src_ptr: *const f32,
396 dst_ptr: *mut f32,
397 mean: &[f32],
398 inv_std: &[f32],
399 channels: usize,
400 num_pixels: usize,
401) {
402 #[cfg(target_arch = "x86")]
403 use std::arch::x86::*;
404 #[cfg(target_arch = "x86_64")]
405 use std::arch::x86_64::*;
406
407 let simd_end = channels & !7;
408 for px in 0..num_pixels {
409 let base = px * channels;
410 let mut c = 0usize;
411 while c < simd_end {
412 let off = base + c;
413 let v = _mm256_loadu_ps(src_ptr.add(off));
414 let vm = _mm256_loadu_ps(mean.as_ptr().add(c));
415 let vs = _mm256_loadu_ps(inv_std.as_ptr().add(c));
416 let r = _mm256_mul_ps(_mm256_sub_ps(v, vm), vs);
417 _mm256_storeu_ps(dst_ptr.add(off), r);
418 c += 8;
419 }
420 while c < channels {
421 let off = base + c;
422 *dst_ptr.add(off) = (*src_ptr.add(off) - mean[c]) * inv_std[c];
423 c += 1;
424 }
425 }
426}
427
428#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
429#[target_feature(enable = "sse2")]
430#[allow(unsafe_code)]
431unsafe fn normalize_generic_sse(
432 src_ptr: *const f32,
433 dst_ptr: *mut f32,
434 mean: &[f32],
435 inv_std: &[f32],
436 channels: usize,
437 num_pixels: usize,
438) {
439 #[cfg(target_arch = "x86")]
440 use std::arch::x86::*;
441 #[cfg(target_arch = "x86_64")]
442 use std::arch::x86_64::*;
443
444 let simd_end = channels & !3;
445 for px in 0..num_pixels {
446 let base = px * channels;
447 let mut c = 0usize;
448 while c < simd_end {
449 let off = base + c;
450 let v = _mm_loadu_ps(src_ptr.add(off));
451 let vm = _mm_loadu_ps(mean.as_ptr().add(c));
452 let vs = _mm_loadu_ps(inv_std.as_ptr().add(c));
453 let r = _mm_mul_ps(_mm_sub_ps(v, vm), vs);
454 _mm_storeu_ps(dst_ptr.add(off), r);
455 c += 4;
456 }
457 while c < channels {
458 let off = base + c;
459 *dst_ptr.add(off) = (*src_ptr.add(off) - mean[c]) * inv_std[c];
460 c += 1;
461 }
462 }
463}