Skip to main content

yscv_imgproc/ops/
normalize.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3use yscv_tensor::Tensor;
4
5use super::super::ImgProcError;
6use super::super::shape::hwc_shape;
7
8/// Per-channel normalize in HWC layout: `(x - mean[c]) / std[c]`.
9///
10/// Optimized: precomputes `inv_std = 1/std` to replace division with multiply,
11/// and iterates by pixel (row-major) to avoid per-element modulo.
12/// Uses SIMD (NEON/AVX/SSE) where available for all channel counts.
13#[allow(unsafe_code)]
14pub fn normalize(input: &Tensor, mean: &[f32], std: &[f32]) -> Result<Tensor, ImgProcError> {
15    let (h, w, channels) = hwc_shape(input)?;
16    if mean.len() != channels || std.len() != channels {
17        return Err(ImgProcError::InvalidNormalizationParams {
18            expected_channels: channels,
19            mean_len: mean.len(),
20            std_len: std.len(),
21        });
22    }
23    for (channel, value) in std.iter().enumerate() {
24        if *value == 0.0 {
25            return Err(ImgProcError::ZeroStdAtChannel { channel });
26        }
27    }
28
29    // Precompute reciprocal of std to replace division with multiplication
30    let inv_std: Vec<f32> = std.iter().map(|&s| 1.0 / s).collect();
31
32    let len = h * w * channels;
33    let mut out = vec![0.0f32; len];
34
35    let src = input.data();
36    let num_pixels = h * w;
37
38    // SAFETY: all pointer arithmetic stays in bounds (validated by shape).
39    unsafe {
40        let src_ptr = src.as_ptr();
41        let dst_ptr = out.as_mut_ptr();
42
43        // Fast path for common channel counts: avoid inner loop overhead
44        match channels {
45            3 => {
46                normalize_3ch(src_ptr, dst_ptr, mean, &inv_std, num_pixels);
47            }
48            1 => {
49                normalize_1ch(src_ptr, dst_ptr, mean[0], inv_std[0], len);
50            }
51            _ => {
52                normalize_generic(src_ptr, dst_ptr, mean, &inv_std, channels, num_pixels);
53            }
54        }
55    }
56
57    Tensor::from_vec(vec![h, w, channels], out).map_err(Into::into)
58}
59
60#[allow(unsafe_code)]
61unsafe fn normalize_3ch(
62    src_ptr: *const f32,
63    dst_ptr: *mut f32,
64    mean: &[f32],
65    inv_std: &[f32],
66    num_pixels: usize,
67) {
68    let (m0, m1, m2) = (mean[0], mean[1], mean[2]);
69    let (s0, s1, s2) = (inv_std[0], inv_std[1], inv_std[2]);
70
71    #[cfg(target_arch = "aarch64")]
72    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
73        use std::arch::aarch64::*;
74        let vm = vld1q_f32([m0, m1, m2, 0.0].as_ptr());
75        let vs = vld1q_f32([s0, s1, s2, 0.0].as_ptr());
76        let full_quads = num_pixels / 4;
77        for q in 0..full_quads {
78            let base = q * 12;
79            for p in 0..4 {
80                let off = base + p * 3;
81                let v = vld1q_f32(
82                    [
83                        *src_ptr.add(off),
84                        *src_ptr.add(off + 1),
85                        *src_ptr.add(off + 2),
86                        0.0,
87                    ]
88                    .as_ptr(),
89                );
90                let r = vmulq_f32(vsubq_f32(v, vm), vs);
91                *dst_ptr.add(off) = vgetq_lane_f32::<0>(r);
92                *dst_ptr.add(off + 1) = vgetq_lane_f32::<1>(r);
93                *dst_ptr.add(off + 2) = vgetq_lane_f32::<2>(r);
94            }
95        }
96        for i in (full_quads * 4)..num_pixels {
97            let off = i * 3;
98            *dst_ptr.add(off) = (*src_ptr.add(off) - m0) * s0;
99            *dst_ptr.add(off + 1) = (*src_ptr.add(off + 1) - m1) * s1;
100            *dst_ptr.add(off + 2) = (*src_ptr.add(off + 2) - m2) * s2;
101        }
102        return;
103    }
104
105    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
106    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
107        normalize_3ch_avx(src_ptr, dst_ptr, m0, m1, m2, s0, s1, s2, num_pixels);
108        return;
109    }
110
111    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
112    if !cfg!(miri) && std::is_x86_feature_detected!("sse2") {
113        normalize_3ch_sse(src_ptr, dst_ptr, m0, m1, m2, s0, s1, s2, num_pixels);
114        return;
115    }
116
117    // Scalar fallback
118    for i in 0..num_pixels {
119        let off = i * 3;
120        *dst_ptr.add(off) = (*src_ptr.add(off) - m0) * s0;
121        *dst_ptr.add(off + 1) = (*src_ptr.add(off + 1) - m1) * s1;
122        *dst_ptr.add(off + 2) = (*src_ptr.add(off + 2) - m2) * s2;
123    }
124}
125
126#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
127#[target_feature(enable = "avx")]
128#[allow(unsafe_code)]
129unsafe fn normalize_3ch_avx(
130    src_ptr: *const f32,
131    dst_ptr: *mut f32,
132    m0: f32,
133    m1: f32,
134    m2: f32,
135    s0: f32,
136    s1: f32,
137    s2: f32,
138    num_pixels: usize,
139) {
140    #[cfg(target_arch = "x86")]
141    use std::arch::x86::*;
142    #[cfg(target_arch = "x86_64")]
143    use std::arch::x86_64::*;
144
145    // Process 8 pixels (24 floats) at a time.
146    // Load 3x8 floats, apply (x - mean) * inv_std per channel.
147    // Pack mean/inv_std as repeating pattern: [m0,m1,m2,m0,m1,m2,m0,m1]
148    let vm_a = _mm256_set_ps(m1, m0, m2, m1, m0, m2, m1, m0);
149    let vm_b = _mm256_set_ps(m2, m1, m0, m2, m1, m0, m2, m1);
150    let vm_c = _mm256_set_ps(m0, m2, m1, m0, m2, m1, m0, m2);
151    let vs_a = _mm256_set_ps(s1, s0, s2, s1, s0, s2, s1, s0);
152    let vs_b = _mm256_set_ps(s2, s1, s0, s2, s1, s0, s2, s1);
153    let vs_c = _mm256_set_ps(s0, s2, s1, s0, s2, s1, s0, s2);
154
155    let full_groups = num_pixels / 8;
156    for g in 0..full_groups {
157        let base = g * 24;
158        // Load 24 floats as 3 x __m256
159        let a = _mm256_loadu_ps(src_ptr.add(base));
160        let b = _mm256_loadu_ps(src_ptr.add(base + 8));
161        let c = _mm256_loadu_ps(src_ptr.add(base + 16));
162
163        let ra = _mm256_mul_ps(_mm256_sub_ps(a, vm_a), vs_a);
164        let rb = _mm256_mul_ps(_mm256_sub_ps(b, vm_b), vs_b);
165        let rc = _mm256_mul_ps(_mm256_sub_ps(c, vm_c), vs_c);
166
167        _mm256_storeu_ps(dst_ptr.add(base), ra);
168        _mm256_storeu_ps(dst_ptr.add(base + 8), rb);
169        _mm256_storeu_ps(dst_ptr.add(base + 16), rc);
170    }
171    // Remainder pixels
172    for i in (full_groups * 8)..num_pixels {
173        let off = i * 3;
174        *dst_ptr.add(off) = (*src_ptr.add(off) - m0) * s0;
175        *dst_ptr.add(off + 1) = (*src_ptr.add(off + 1) - m1) * s1;
176        *dst_ptr.add(off + 2) = (*src_ptr.add(off + 2) - m2) * s2;
177    }
178}
179
180#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
181#[target_feature(enable = "sse2")]
182#[allow(unsafe_code)]
183unsafe fn normalize_3ch_sse(
184    src_ptr: *const f32,
185    dst_ptr: *mut f32,
186    m0: f32,
187    m1: f32,
188    m2: f32,
189    s0: f32,
190    s1: f32,
191    s2: f32,
192    num_pixels: usize,
193) {
194    #[cfg(target_arch = "x86")]
195    use std::arch::x86::*;
196    #[cfg(target_arch = "x86_64")]
197    use std::arch::x86_64::*;
198
199    // Process 4 pixels (12 floats) at a time as 3 x __m128
200    let vm_a = _mm_set_ps(m0, m2, m1, m0);
201    let vm_b = _mm_set_ps(m1, m0, m2, m1);
202    let vm_c = _mm_set_ps(m2, m1, m0, m2);
203    let vs_a = _mm_set_ps(s0, s2, s1, s0);
204    let vs_b = _mm_set_ps(s1, s0, s2, s1);
205    let vs_c = _mm_set_ps(s2, s1, s0, s2);
206
207    let full_groups = num_pixels / 4;
208    for g in 0..full_groups {
209        let base = g * 12;
210        let a = _mm_loadu_ps(src_ptr.add(base));
211        let b = _mm_loadu_ps(src_ptr.add(base + 4));
212        let c = _mm_loadu_ps(src_ptr.add(base + 8));
213
214        let ra = _mm_mul_ps(_mm_sub_ps(a, vm_a), vs_a);
215        let rb = _mm_mul_ps(_mm_sub_ps(b, vm_b), vs_b);
216        let rc = _mm_mul_ps(_mm_sub_ps(c, vm_c), vs_c);
217
218        _mm_storeu_ps(dst_ptr.add(base), ra);
219        _mm_storeu_ps(dst_ptr.add(base + 4), rb);
220        _mm_storeu_ps(dst_ptr.add(base + 8), rc);
221    }
222    for i in (full_groups * 4)..num_pixels {
223        let off = i * 3;
224        *dst_ptr.add(off) = (*src_ptr.add(off) - m0) * s0;
225        *dst_ptr.add(off + 1) = (*src_ptr.add(off + 1) - m1) * s1;
226        *dst_ptr.add(off + 2) = (*src_ptr.add(off + 2) - m2) * s2;
227    }
228}
229
230#[allow(unsafe_code)]
231unsafe fn normalize_1ch(
232    src_ptr: *const f32,
233    dst_ptr: *mut f32,
234    mean: f32,
235    inv_std: f32,
236    len: usize,
237) {
238    #[cfg(target_arch = "aarch64")]
239    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
240        use std::arch::aarch64::*;
241        let vm = vdupq_n_f32(mean);
242        let vs = vdupq_n_f32(inv_std);
243        let mut i = 0usize;
244        while i + 4 <= len {
245            let v = vld1q_f32(src_ptr.add(i));
246            let r = vmulq_f32(vsubq_f32(v, vm), vs);
247            vst1q_f32(dst_ptr.add(i), r);
248            i += 4;
249        }
250        while i < len {
251            *dst_ptr.add(i) = (*src_ptr.add(i) - mean) * inv_std;
252            i += 1;
253        }
254        return;
255    }
256
257    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
258    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
259        normalize_1ch_avx(src_ptr, dst_ptr, mean, inv_std, len);
260        return;
261    }
262
263    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
264    if !cfg!(miri) && std::is_x86_feature_detected!("sse2") {
265        normalize_1ch_sse(src_ptr, dst_ptr, mean, inv_std, len);
266        return;
267    }
268
269    // Scalar fallback
270    for i in 0..len {
271        *dst_ptr.add(i) = (*src_ptr.add(i) - mean) * inv_std;
272    }
273}
274
275#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
276#[target_feature(enable = "avx")]
277#[allow(unsafe_code)]
278unsafe fn normalize_1ch_avx(
279    src_ptr: *const f32,
280    dst_ptr: *mut f32,
281    mean: f32,
282    inv_std: f32,
283    len: usize,
284) {
285    #[cfg(target_arch = "x86")]
286    use std::arch::x86::*;
287    #[cfg(target_arch = "x86_64")]
288    use std::arch::x86_64::*;
289
290    let vm = _mm256_set1_ps(mean);
291    let vs = _mm256_set1_ps(inv_std);
292    let mut i = 0usize;
293    while i + 8 <= len {
294        let v = _mm256_loadu_ps(src_ptr.add(i));
295        let r = _mm256_mul_ps(_mm256_sub_ps(v, vm), vs);
296        _mm256_storeu_ps(dst_ptr.add(i), r);
297        i += 8;
298    }
299    while i < len {
300        *dst_ptr.add(i) = (*src_ptr.add(i) - mean) * inv_std;
301        i += 1;
302    }
303}
304
305#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
306#[target_feature(enable = "sse2")]
307#[allow(unsafe_code)]
308unsafe fn normalize_1ch_sse(
309    src_ptr: *const f32,
310    dst_ptr: *mut f32,
311    mean: f32,
312    inv_std: f32,
313    len: usize,
314) {
315    #[cfg(target_arch = "x86")]
316    use std::arch::x86::*;
317    #[cfg(target_arch = "x86_64")]
318    use std::arch::x86_64::*;
319
320    let vm = _mm_set1_ps(mean);
321    let vs = _mm_set1_ps(inv_std);
322    let mut i = 0usize;
323    while i + 4 <= len {
324        let v = _mm_loadu_ps(src_ptr.add(i));
325        let r = _mm_mul_ps(_mm_sub_ps(v, vm), vs);
326        _mm_storeu_ps(dst_ptr.add(i), r);
327        i += 4;
328    }
329    while i < len {
330        *dst_ptr.add(i) = (*src_ptr.add(i) - mean) * inv_std;
331        i += 1;
332    }
333}
334
335#[allow(unsafe_code)]
336unsafe fn normalize_generic(
337    src_ptr: *const f32,
338    dst_ptr: *mut f32,
339    mean: &[f32],
340    inv_std: &[f32],
341    channels: usize,
342    num_pixels: usize,
343) {
344    #[cfg(target_arch = "aarch64")]
345    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
346        use std::arch::aarch64::*;
347        let simd_end = channels & !3;
348        for px in 0..num_pixels {
349            let base = px * channels;
350            let mut c = 0usize;
351            while c < simd_end {
352                let off = base + c;
353                let v = vld1q_f32(src_ptr.add(off));
354                let vm = vld1q_f32(mean.as_ptr().add(c));
355                let vs = vld1q_f32(inv_std.as_ptr().add(c));
356                let r = vmulq_f32(vsubq_f32(v, vm), vs);
357                vst1q_f32(dst_ptr.add(off), r);
358                c += 4;
359            }
360            while c < channels {
361                let off = base + c;
362                *dst_ptr.add(off) = (*src_ptr.add(off) - mean[c]) * inv_std[c];
363                c += 1;
364            }
365        }
366        return;
367    }
368
369    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
370    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
371        normalize_generic_avx(src_ptr, dst_ptr, mean, inv_std, channels, num_pixels);
372        return;
373    }
374
375    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
376    if !cfg!(miri) && std::is_x86_feature_detected!("sse2") {
377        normalize_generic_sse(src_ptr, dst_ptr, mean, inv_std, channels, num_pixels);
378        return;
379    }
380
381    // Scalar fallback
382    for px in 0..num_pixels {
383        let base = px * channels;
384        for c in 0..channels {
385            let off = base + c;
386            *dst_ptr.add(off) = (*src_ptr.add(off) - mean[c]) * inv_std[c];
387        }
388    }
389}
390
391#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
392#[target_feature(enable = "avx")]
393#[allow(unsafe_code)]
394unsafe fn normalize_generic_avx(
395    src_ptr: *const f32,
396    dst_ptr: *mut f32,
397    mean: &[f32],
398    inv_std: &[f32],
399    channels: usize,
400    num_pixels: usize,
401) {
402    #[cfg(target_arch = "x86")]
403    use std::arch::x86::*;
404    #[cfg(target_arch = "x86_64")]
405    use std::arch::x86_64::*;
406
407    let simd_end = channels & !7;
408    for px in 0..num_pixels {
409        let base = px * channels;
410        let mut c = 0usize;
411        while c < simd_end {
412            let off = base + c;
413            let v = _mm256_loadu_ps(src_ptr.add(off));
414            let vm = _mm256_loadu_ps(mean.as_ptr().add(c));
415            let vs = _mm256_loadu_ps(inv_std.as_ptr().add(c));
416            let r = _mm256_mul_ps(_mm256_sub_ps(v, vm), vs);
417            _mm256_storeu_ps(dst_ptr.add(off), r);
418            c += 8;
419        }
420        while c < channels {
421            let off = base + c;
422            *dst_ptr.add(off) = (*src_ptr.add(off) - mean[c]) * inv_std[c];
423            c += 1;
424        }
425    }
426}
427
428#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
429#[target_feature(enable = "sse2")]
430#[allow(unsafe_code)]
431unsafe fn normalize_generic_sse(
432    src_ptr: *const f32,
433    dst_ptr: *mut f32,
434    mean: &[f32],
435    inv_std: &[f32],
436    channels: usize,
437    num_pixels: usize,
438) {
439    #[cfg(target_arch = "x86")]
440    use std::arch::x86::*;
441    #[cfg(target_arch = "x86_64")]
442    use std::arch::x86_64::*;
443
444    let simd_end = channels & !3;
445    for px in 0..num_pixels {
446        let base = px * channels;
447        let mut c = 0usize;
448        while c < simd_end {
449            let off = base + c;
450            let v = _mm_loadu_ps(src_ptr.add(off));
451            let vm = _mm_loadu_ps(mean.as_ptr().add(c));
452            let vs = _mm_loadu_ps(inv_std.as_ptr().add(c));
453            let r = _mm_mul_ps(_mm_sub_ps(v, vm), vs);
454            _mm_storeu_ps(dst_ptr.add(off), r);
455            c += 4;
456        }
457        while c < channels {
458            let off = base + c;
459            *dst_ptr.add(off) = (*src_ptr.add(off) - mean[c]) * inv_std[c];
460            c += 1;
461        }
462    }
463}