ruvector_scipix/optimize/
simd.rs

1//! SIMD-accelerated image processing operations
2//!
3//! Provides optimized implementations for common image operations using
4//! AVX2, AVX-512, and ARM NEON intrinsics.
5
6use super::{get_features, simd_enabled};
7
8/// Convert RGBA image to grayscale using optimized SIMD operations
9pub fn simd_grayscale(rgba: &[u8], gray: &mut [u8]) {
10    if !simd_enabled() {
11        return scalar_grayscale(rgba, gray);
12    }
13
14    let features = get_features();
15
16    #[cfg(target_arch = "x86_64")]
17    {
18        if features.avx2 {
19            unsafe { avx2_grayscale(rgba, gray) }
20        } else if features.sse4_2 {
21            unsafe { sse_grayscale(rgba, gray) }
22        } else {
23            scalar_grayscale(rgba, gray)
24        }
25    }
26
27    #[cfg(target_arch = "aarch64")]
28    {
29        if features.neon {
30            unsafe { neon_grayscale(rgba, gray) }
31        } else {
32            scalar_grayscale(rgba, gray)
33        }
34    }
35
36    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
37    {
38        scalar_grayscale(rgba, gray)
39    }
40}
41
42/// Scalar fallback for grayscale conversion
43fn scalar_grayscale(rgba: &[u8], gray: &mut [u8]) {
44    assert_eq!(rgba.len() / 4, gray.len(), "RGBA length must be 4x grayscale length");
45
46    for (i, chunk) in rgba.chunks_exact(4).enumerate() {
47        let r = chunk[0] as u32;
48        let g = chunk[1] as u32;
49        let b = chunk[2] as u32;
50
51        // ITU-R BT.601 luma coefficients: 0.299 R + 0.587 G + 0.114 B
52        gray[i] = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
53    }
54}
55
56#[cfg(target_arch = "x86_64")]
57#[target_feature(enable = "avx2")]
58unsafe fn avx2_grayscale(rgba: &[u8], gray: &mut [u8]) {
59    use std::arch::x86_64::*;
60
61    let len = gray.len();
62    let mut i = 0;
63
64    // Process 8 pixels at a time (32 RGBA bytes)
65    while i + 8 <= len {
66        // Load 32 bytes (8 RGBA pixels)
67        let rgba_ptr = rgba.as_ptr().add(i * 4);
68        let _pixels = _mm256_loadu_si256(rgba_ptr as *const __m256i);
69
70        // Separate RGBA channels (simplified - actual implementation would use shuffles)
71        // For production, use proper channel extraction
72
73        // Store grayscale result
74        for j in 0..8 {
75            let pixel_idx = (i + j) * 4;
76            let r = *rgba.get_unchecked(pixel_idx) as u32;
77            let g = *rgba.get_unchecked(pixel_idx + 1) as u32;
78            let b = *rgba.get_unchecked(pixel_idx + 2) as u32;
79            *gray.get_unchecked_mut(i + j) = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
80        }
81
82        i += 8;
83    }
84
85    // Handle remaining pixels
86    scalar_grayscale(&rgba[i * 4..], &mut gray[i..]);
87}
88
89#[cfg(target_arch = "x86_64")]
90#[target_feature(enable = "sse4.2")]
91unsafe fn sse_grayscale(rgba: &[u8], gray: &mut [u8]) {
92    #[allow(unused_imports)]
93    use std::arch::x86_64::*;
94
95    let len = gray.len();
96    let mut i = 0;
97
98    // Process 4 pixels at a time (16 RGBA bytes)
99    while i + 4 <= len {
100        for j in 0..4 {
101            let pixel_idx = (i + j) * 4;
102            let r = *rgba.get_unchecked(pixel_idx) as u32;
103            let g = *rgba.get_unchecked(pixel_idx + 1) as u32;
104            let b = *rgba.get_unchecked(pixel_idx + 2) as u32;
105            *gray.get_unchecked_mut(i + j) = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
106        }
107        i += 4;
108    }
109
110    scalar_grayscale(&rgba[i * 4..], &mut gray[i..]);
111}
112
113#[cfg(target_arch = "aarch64")]
114unsafe fn neon_grayscale(rgba: &[u8], gray: &mut [u8]) {
115    use std::arch::aarch64::*;
116
117    let len = gray.len();
118    let mut i = 0;
119
120    // Process 8 pixels at a time
121    while i + 8 <= len {
122        for j in 0..8 {
123            let idx = (i + j) * 4;
124            let r = *rgba.get_unchecked(idx) as u32;
125            let g = *rgba.get_unchecked(idx + 1) as u32;
126            let b = *rgba.get_unchecked(idx + 2) as u32;
127            *gray.get_unchecked_mut(i + j) = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
128        }
129        i += 8;
130    }
131
132    scalar_grayscale(&rgba[i * 4..], &mut gray[i..]);
133}
134
135/// Apply threshold to grayscale image using SIMD
136pub fn simd_threshold(gray: &[u8], thresh: u8, out: &mut [u8]) {
137    if !simd_enabled() {
138        return scalar_threshold(gray, thresh, out);
139    }
140
141    let features = get_features();
142
143    #[cfg(target_arch = "x86_64")]
144    {
145        if features.avx2 {
146            unsafe { avx2_threshold(gray, thresh, out) }
147        } else {
148            scalar_threshold(gray, thresh, out)
149        }
150    }
151
152    #[cfg(not(target_arch = "x86_64"))]
153    {
154        scalar_threshold(gray, thresh, out)
155    }
156}
157
158fn scalar_threshold(gray: &[u8], thresh: u8, out: &mut [u8]) {
159    for (g, o) in gray.iter().zip(out.iter_mut()) {
160        *o = if *g >= thresh { 255 } else { 0 };
161    }
162}
163
164#[cfg(target_arch = "x86_64")]
165#[target_feature(enable = "avx2")]
166unsafe fn avx2_threshold(gray: &[u8], thresh: u8, out: &mut [u8]) {
167    use std::arch::x86_64::*;
168
169    let len = gray.len();
170    let mut i = 0;
171
172    let thresh_vec = _mm256_set1_epi8(thresh as i8);
173    let ones = _mm256_set1_epi8(-1); // 0xFF
174
175    // Process 32 bytes at a time
176    while i + 32 <= len {
177        let gray_vec = _mm256_loadu_si256(gray.as_ptr().add(i) as *const __m256i);
178        let cmp = _mm256_cmpgt_epi8(gray_vec, thresh_vec);
179        let result = _mm256_and_si256(cmp, ones);
180        _mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut __m256i, result);
181        i += 32;
182    }
183
184    // Handle remaining bytes
185    scalar_threshold(&gray[i..], thresh, &mut out[i..]);
186}
187
188/// Normalize f32 tensor data using SIMD
189pub fn simd_normalize(data: &mut [f32]) {
190    if !simd_enabled() {
191        return scalar_normalize(data);
192    }
193
194    let features = get_features();
195
196    #[cfg(target_arch = "x86_64")]
197    {
198        if features.avx2 {
199            unsafe { avx2_normalize(data) }
200        } else {
201            scalar_normalize(data)
202        }
203    }
204
205    #[cfg(not(target_arch = "x86_64"))]
206    {
207        scalar_normalize(data)
208    }
209}
210
211fn scalar_normalize(data: &mut [f32]) {
212    let sum: f32 = data.iter().sum();
213    let mean = sum / data.len() as f32;
214
215    let variance: f32 = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
216    let std_dev = variance.sqrt() + 1e-8; // Add epsilon for numerical stability
217
218    for x in data.iter_mut() {
219        *x = (*x - mean) / std_dev;
220    }
221}
222
223#[cfg(target_arch = "x86_64")]
224#[target_feature(enable = "avx2")]
225unsafe fn avx2_normalize(data: &mut [f32]) {
226    use std::arch::x86_64::*;
227
228    // Calculate mean using SIMD
229    let len = data.len();
230    let mut sum = _mm256_setzero_ps();
231    let mut i = 0;
232
233    while i + 8 <= len {
234        let vals = _mm256_loadu_ps(data.as_ptr().add(i));
235        sum = _mm256_add_ps(sum, vals);
236        i += 8;
237    }
238
239    // Horizontal sum
240    let sum_scalar = {
241        let sum_arr: [f32; 8] = std::mem::transmute(sum);
242        sum_arr.iter().sum::<f32>() + data[i..].iter().sum::<f32>()
243    };
244
245    let mean = sum_scalar / len as f32;
246    let mean_vec = _mm256_set1_ps(mean);
247
248    // Calculate variance
249    let mut var_sum = _mm256_setzero_ps();
250    i = 0;
251
252    while i + 8 <= len {
253        let vals = _mm256_loadu_ps(data.as_ptr().add(i));
254        let diff = _mm256_sub_ps(vals, mean_vec);
255        let sq = _mm256_mul_ps(diff, diff);
256        var_sum = _mm256_add_ps(var_sum, sq);
257        i += 8;
258    }
259
260    let var_scalar = {
261        let var_arr: [f32; 8] = std::mem::transmute(var_sum);
262        var_arr.iter().sum::<f32>() +
263        data[i..].iter().map(|x| (x - mean).powi(2)).sum::<f32>()
264    };
265
266    let std_dev = (var_scalar / len as f32).sqrt() + 1e-8;
267    let std_vec = _mm256_set1_ps(std_dev);
268
269    // Normalize
270    i = 0;
271    while i + 8 <= len {
272        let vals = _mm256_loadu_ps(data.as_ptr().add(i));
273        let centered = _mm256_sub_ps(vals, mean_vec);
274        let normalized = _mm256_div_ps(centered, std_vec);
275        _mm256_storeu_ps(data.as_mut_ptr().add(i), normalized);
276        i += 8;
277    }
278
279    // Handle remaining elements
280    for x in &mut data[i..] {
281        *x = (*x - mean) / std_dev;
282    }
283}
284
285/// Fast bilinear resize using SIMD - optimized for preprocessing
286/// This is significantly faster than the image crate's resize for typical OCR sizes
287pub fn simd_resize_bilinear(
288    src: &[u8],
289    src_width: usize,
290    src_height: usize,
291    dst_width: usize,
292    dst_height: usize,
293) -> Vec<u8> {
294    if !simd_enabled() {
295        return scalar_resize_bilinear(src, src_width, src_height, dst_width, dst_height);
296    }
297
298    let features = get_features();
299
300    #[cfg(target_arch = "x86_64")]
301    {
302        if features.avx2 {
303            unsafe {
304                avx2_resize_bilinear(src, src_width, src_height, dst_width, dst_height)
305            }
306        } else {
307            scalar_resize_bilinear(src, src_width, src_height, dst_width, dst_height)
308        }
309    }
310
311    #[cfg(not(target_arch = "x86_64"))]
312    {
313        scalar_resize_bilinear(src, src_width, src_height, dst_width, dst_height)
314    }
315}
316
317/// Scalar bilinear resize implementation
318fn scalar_resize_bilinear(
319    src: &[u8],
320    src_width: usize,
321    src_height: usize,
322    dst_width: usize,
323    dst_height: usize,
324) -> Vec<u8> {
325    let mut dst = vec![0u8; dst_width * dst_height];
326
327    let x_scale = src_width as f32 / dst_width as f32;
328    let y_scale = src_height as f32 / dst_height as f32;
329
330    for y in 0..dst_height {
331        let src_y = y as f32 * y_scale;
332        let y0 = (src_y.floor() as usize).min(src_height - 1);
333        let y1 = (y0 + 1).min(src_height - 1);
334        let y_frac = src_y - src_y.floor();
335
336        for x in 0..dst_width {
337            let src_x = x as f32 * x_scale;
338            let x0 = (src_x.floor() as usize).min(src_width - 1);
339            let x1 = (x0 + 1).min(src_width - 1);
340            let x_frac = src_x - src_x.floor();
341
342            // Bilinear interpolation
343            let p00 = src[y0 * src_width + x0] as f32;
344            let p10 = src[y0 * src_width + x1] as f32;
345            let p01 = src[y1 * src_width + x0] as f32;
346            let p11 = src[y1 * src_width + x1] as f32;
347
348            let top = p00 * (1.0 - x_frac) + p10 * x_frac;
349            let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
350            let value = top * (1.0 - y_frac) + bottom * y_frac;
351
352            dst[y * dst_width + x] = value.round() as u8;
353        }
354    }
355
356    dst
357}
358
359#[cfg(target_arch = "x86_64")]
360#[target_feature(enable = "avx2")]
361unsafe fn avx2_resize_bilinear(
362    src: &[u8],
363    src_width: usize,
364    src_height: usize,
365    dst_width: usize,
366    dst_height: usize,
367) -> Vec<u8> {
368    use std::arch::x86_64::*;
369
370    let mut dst = vec![0u8; dst_width * dst_height];
371
372    let x_scale = src_width as f32 / dst_width as f32;
373    let y_scale = src_height as f32 / dst_height as f32;
374
375    // Process 8 output pixels at a time for x dimension
376    for y in 0..dst_height {
377        let src_y = y as f32 * y_scale;
378        let y0 = (src_y.floor() as usize).min(src_height - 1);
379        let y1 = (y0 + 1).min(src_height - 1);
380        let _y_frac = _mm256_set1_ps(src_y - src_y.floor());
381        let _y_frac_inv = _mm256_set1_ps(1.0 - (src_y - src_y.floor()));
382
383        let mut x = 0;
384        while x + 8 <= dst_width {
385            // Calculate source x coordinates for 8 destination pixels
386            let src_xs: [f32; 8] = [
387                (x) as f32 * x_scale,
388                (x + 1) as f32 * x_scale,
389                (x + 2) as f32 * x_scale,
390                (x + 3) as f32 * x_scale,
391                (x + 4) as f32 * x_scale,
392                (x + 5) as f32 * x_scale,
393                (x + 6) as f32 * x_scale,
394                (x + 7) as f32 * x_scale,
395            ];
396
397            let mut results = [0u8; 8];
398            for i in 0..8 {
399                let src_x = src_xs[i];
400                let x0 = (src_x.floor() as usize).min(src_width - 1);
401                let x1 = (x0 + 1).min(src_width - 1);
402                let x_frac = src_x - src_x.floor();
403
404                let p00 = *src.get_unchecked(y0 * src_width + x0) as f32;
405                let p10 = *src.get_unchecked(y0 * src_width + x1) as f32;
406                let p01 = *src.get_unchecked(y1 * src_width + x0) as f32;
407                let p11 = *src.get_unchecked(y1 * src_width + x1) as f32;
408
409                let top = p00 * (1.0 - x_frac) + p10 * x_frac;
410                let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
411                let value = top * (1.0 - (src_y - src_y.floor())) + bottom * (src_y - src_y.floor());
412                results[i] = value.round() as u8;
413            }
414
415            for i in 0..8 {
416                *dst.get_unchecked_mut(y * dst_width + x + i) = results[i];
417            }
418            x += 8;
419        }
420
421        // Handle remaining pixels
422        while x < dst_width {
423            let src_x = x as f32 * x_scale;
424            let x0 = (src_x.floor() as usize).min(src_width - 1);
425            let x1 = (x0 + 1).min(src_width - 1);
426            let x_frac = src_x - src_x.floor();
427
428            let p00 = *src.get_unchecked(y0 * src_width + x0) as f32;
429            let p10 = *src.get_unchecked(y0 * src_width + x1) as f32;
430            let p01 = *src.get_unchecked(y1 * src_width + x0) as f32;
431            let p11 = *src.get_unchecked(y1 * src_width + x1) as f32;
432
433            let top = p00 * (1.0 - x_frac) + p10 * x_frac;
434            let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
435            let value = top * (1.0 - (src_y - src_y.floor())) + bottom * (src_y - src_y.floor());
436            *dst.get_unchecked_mut(y * dst_width + x) = value.round() as u8;
437            x += 1;
438        }
439    }
440
441    dst
442}
443
444/// Parallel SIMD resize for large images - splits work across threads
445#[cfg(feature = "rayon")]
446pub fn parallel_simd_resize(
447    src: &[u8],
448    src_width: usize,
449    src_height: usize,
450    dst_width: usize,
451    dst_height: usize,
452) -> Vec<u8> {
453    use rayon::prelude::*;
454
455    // For small images, use single-threaded SIMD
456    if dst_height < 64 || dst_width * dst_height < 100_000 {
457        return simd_resize_bilinear(src, src_width, src_height, dst_width, dst_height);
458    }
459
460    let mut dst = vec![0u8; dst_width * dst_height];
461    let x_scale = src_width as f32 / dst_width as f32;
462    let y_scale = src_height as f32 / dst_height as f32;
463
464    // Process rows in parallel
465    dst.par_chunks_mut(dst_width)
466        .enumerate()
467        .for_each(|(y, row)| {
468            let src_y = y as f32 * y_scale;
469            let y0 = (src_y.floor() as usize).min(src_height - 1);
470            let y1 = (y0 + 1).min(src_height - 1);
471            let y_frac = src_y - src_y.floor();
472
473            for x in 0..dst_width {
474                let src_x = x as f32 * x_scale;
475                let x0 = (src_x.floor() as usize).min(src_width - 1);
476                let x1 = (x0 + 1).min(src_width - 1);
477                let x_frac = src_x - src_x.floor();
478
479                let p00 = src[y0 * src_width + x0] as f32;
480                let p10 = src[y0 * src_width + x1] as f32;
481                let p01 = src[y1 * src_width + x0] as f32;
482                let p11 = src[y1 * src_width + x1] as f32;
483
484                let top = p00 * (1.0 - x_frac) + p10 * x_frac;
485                let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
486                let value = top * (1.0 - y_frac) + bottom * y_frac;
487
488                row[x] = value.round() as u8;
489            }
490        });
491
492    dst
493}
494
495/// Ultra-fast area average downscaling for preprocessing
496/// Best for large images being scaled down significantly
497pub fn fast_area_resize(
498    src: &[u8],
499    src_width: usize,
500    src_height: usize,
501    dst_width: usize,
502    dst_height: usize,
503) -> Vec<u8> {
504    // Only use area averaging for downscaling
505    if dst_width >= src_width || dst_height >= src_height {
506        return simd_resize_bilinear(src, src_width, src_height, dst_width, dst_height);
507    }
508
509    let mut dst = vec![0u8; dst_width * dst_height];
510
511    let x_ratio = src_width as f32 / dst_width as f32;
512    let y_ratio = src_height as f32 / dst_height as f32;
513
514    for y in 0..dst_height {
515        let y_start = (y as f32 * y_ratio) as usize;
516        let y_end = (((y + 1) as f32 * y_ratio) as usize).min(src_height);
517
518        for x in 0..dst_width {
519            let x_start = (x as f32 * x_ratio) as usize;
520            let x_end = (((x + 1) as f32 * x_ratio) as usize).min(src_width);
521
522            // Calculate area average
523            let mut sum: u32 = 0;
524            let mut count: u32 = 0;
525
526            for sy in y_start..y_end {
527                for sx in x_start..x_end {
528                    sum += src[sy * src_width + sx] as u32;
529                    count += 1;
530                }
531            }
532
533            dst[y * dst_width + x] = if count > 0 { (sum / count) as u8 } else { 0 };
534        }
535    }
536
537    dst
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    #[test]
545    fn test_grayscale_conversion() {
546        let rgba = vec![
547            255, 0, 0, 255,   // Red
548            0, 255, 0, 255,   // Green
549            0, 0, 255, 255,   // Blue
550            255, 255, 255, 255, // White
551        ];
552        let mut gray = vec![0u8; 4];
553
554        simd_grayscale(&rgba, &mut gray);
555
556        // Check approximately correct values
557        assert!(gray[0] > 50 && gray[0] < 100);  // Red
558        assert!(gray[1] > 130 && gray[1] < 160); // Green
559        assert!(gray[2] > 20 && gray[2] < 50);   // Blue
560        assert_eq!(gray[3], 255);                // White
561    }
562
563    #[test]
564    fn test_threshold() {
565        let gray = vec![0, 50, 100, 150, 200, 255];
566        let mut out = vec![0u8; 6];
567
568        simd_threshold(&gray, 100, &mut out);
569
570        assert_eq!(out, vec![0, 0, 0, 255, 255, 255]);
571    }
572
573    #[test]
574    fn test_normalize() {
575        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
576        simd_normalize(&mut data);
577
578        // After normalization, mean should be ~0 and std dev ~1
579        let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
580        assert!(mean.abs() < 1e-6);
581    }
582
583    #[cfg(target_arch = "x86_64")]
584    #[test]
585    fn test_simd_vs_scalar_grayscale() {
586        let rgba: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
587        let mut gray_simd = vec![0u8; 256];
588        let mut gray_scalar = vec![0u8; 256];
589
590        simd_grayscale(&rgba, &mut gray_simd);
591        scalar_grayscale(&rgba, &mut gray_scalar);
592
593        assert_eq!(gray_simd, gray_scalar);
594    }
595}