Skip to main content

yscv_imgproc/ops/
histogram.rs

1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5
6/// Computes a 256-bin histogram for a single-channel `[H, W, 1]` image.
7/// Assumes values in `[0, 1]` range.
8pub fn histogram_256(input: &Tensor) -> Result<[u32; 256], ImgProcError> {
9    let (_h, _w, c) = hwc_shape(input)?;
10    if c != 1 {
11        return Err(ImgProcError::InvalidChannelCount {
12            expected: 1,
13            got: c,
14        });
15    }
16
17    let data = input.data();
18    let len = data.len();
19
20    // Pre-convert to u8 for branch-free binning
21    let mut u8_buf = vec![0u8; len];
22    for (dst, &v) in u8_buf.iter_mut().zip(data.iter()) {
23        *dst = (v.clamp(0.0, 1.0) * 255.0) as u8;
24    }
25
26    // Parallel binning with thread-local histograms for large images
27    if len >= 65536 {
28        let num_chunks = rayon::current_num_threads().clamp(1, 8);
29        let chunk_size = len.div_ceil(num_chunks);
30        let mut local_hists = vec![0u32; num_chunks * 256];
31
32        rayon::scope(|s| {
33            for (chunk_idx, chunk) in u8_buf.chunks(chunk_size).enumerate() {
34                let hist_slice = &mut local_hists[chunk_idx * 256..(chunk_idx + 1) * 256];
35                // SAFETY: each thread writes to its own disjoint slice
36                let hist_ptr = hist_slice.as_mut_ptr() as usize;
37                s.spawn(move |_| {
38                    let hist = unsafe { std::slice::from_raw_parts_mut(hist_ptr as *mut u32, 256) };
39                    histogram_bin_chunk(chunk, hist);
40                });
41            }
42        });
43
44        // Merge
45        let mut hist = [0u32; 256];
46        for chunk_idx in 0..num_chunks {
47            for i in 0..256 {
48                hist[i] += local_hists[chunk_idx * 256 + i];
49            }
50        }
51        Ok(hist)
52    } else {
53        let mut hist = [0u32; 256];
54        histogram_bin_chunk(&u8_buf, &mut hist);
55        Ok(hist)
56    }
57}
58
59/// Bins a slice of `u8` pixel values into a 256-bin histogram.
60///
61/// On aarch64 with NEON, loads 16 bytes per iteration with `vld1q_u8` and
62/// extracts individual lanes for scatter-add, reducing load overhead.
63#[allow(unsafe_code)]
64#[inline]
65fn histogram_bin_chunk(data: &[u8], hist: &mut [u32]) {
66    let mut i = 0;
67
68    #[cfg(target_arch = "aarch64")]
69    if !cfg!(miri) {
70        unsafe {
71            use std::arch::aarch64::*;
72            let ptr = data.as_ptr();
73            let len = data.len();
74            // Use two sub-histograms to reduce store-to-load forwarding stalls
75            // when consecutive pixels map to the same bin.
76            let mut hist2 = [0u32; 256];
77            while i + 16 <= len {
78                let v = vld1q_u8(ptr.add(i));
79                // Extract each byte and increment the corresponding bin.
80                // Alternate between hist and hist2 to reduce dependency stalls.
81                hist[vgetq_lane_u8::<0>(v) as usize] += 1;
82                hist2[vgetq_lane_u8::<1>(v) as usize] += 1;
83                hist[vgetq_lane_u8::<2>(v) as usize] += 1;
84                hist2[vgetq_lane_u8::<3>(v) as usize] += 1;
85                hist[vgetq_lane_u8::<4>(v) as usize] += 1;
86                hist2[vgetq_lane_u8::<5>(v) as usize] += 1;
87                hist[vgetq_lane_u8::<6>(v) as usize] += 1;
88                hist2[vgetq_lane_u8::<7>(v) as usize] += 1;
89                hist[vgetq_lane_u8::<8>(v) as usize] += 1;
90                hist2[vgetq_lane_u8::<9>(v) as usize] += 1;
91                hist[vgetq_lane_u8::<10>(v) as usize] += 1;
92                hist2[vgetq_lane_u8::<11>(v) as usize] += 1;
93                hist[vgetq_lane_u8::<12>(v) as usize] += 1;
94                hist2[vgetq_lane_u8::<13>(v) as usize] += 1;
95                hist[vgetq_lane_u8::<14>(v) as usize] += 1;
96                hist2[vgetq_lane_u8::<15>(v) as usize] += 1;
97                i += 16;
98            }
99            // Merge sub-histogram
100            for j in 0..256 {
101                hist[j] += hist2[j];
102            }
103        }
104    }
105
106    // Scalar tail (also used as fallback on non-aarch64 / Miri)
107    while i < data.len() {
108        hist[data[i] as usize] += 1;
109        i += 1;
110    }
111}
112
113/// Histogram equalization for single-channel `[H, W, 1]` images with values in `[0, 1]`.
114pub fn histogram_equalize(input: &Tensor) -> Result<Tensor, ImgProcError> {
115    let (h, w, c) = hwc_shape(input)?;
116    if c != 1 {
117        return Err(ImgProcError::InvalidChannelCount {
118            expected: 1,
119            got: c,
120        });
121    }
122
123    let hist = histogram_256(input)?;
124    let total = (h * w) as f32;
125
126    // Build CDF
127    let mut cdf = [0.0f32; 256];
128    let mut running = 0u32;
129    for (i, &count) in hist.iter().enumerate() {
130        running += count;
131        cdf[i] = running as f32 / total;
132    }
133
134    let out: Vec<f32> = input
135        .data()
136        .iter()
137        .map(|&v| {
138            let bin = (v.clamp(0.0, 1.0) * 255.0) as usize;
139            cdf[bin.min(255)]
140        })
141        .collect();
142
143    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
144}
145
146/// Computes the integral (summed-area table) image from a `[H, W, 1]` tensor.
147///
148/// Uses a two-pass approach:
149/// 1. Horizontal prefix sums per row (parallelized for large images)
150/// 2. Vertical accumulation using SIMD
151#[allow(unsafe_code)]
152pub fn integral_image(input: &Tensor) -> Result<Tensor, ImgProcError> {
153    let (h, w, c) = hwc_shape(input)?;
154    if c != 1 {
155        return Err(ImgProcError::InvalidChannelCount {
156            expected: 1,
157            got: c,
158        });
159    }
160    let data = input.data();
161    let mut sat = vec![0.0f32; h * w];
162
163    // Pass 1: horizontal prefix sums per row
164    if h > 1 && w * h >= 4096 {
165        // Parallel per-row prefix sums for large images
166        use rayon::prelude::*;
167        sat.par_chunks_mut(w).enumerate().for_each(|(y, row)| {
168            let src_off = y * w;
169            if w > 0 {
170                row[0] = data[src_off];
171                for x in 1..w {
172                    row[x] = row[x - 1] + data[src_off + x];
173                }
174            }
175        });
176    } else {
177        for y in 0..h {
178            let off = y * w;
179            if w > 0 {
180                sat[off] = data[off];
181                for x in 1..w {
182                    sat[off + x] = sat[off + x - 1] + data[off + x];
183                }
184            }
185        }
186    }
187
188    // Pass 2: vertical accumulation (serial across rows, SIMD within row)
189    for y in 1..h {
190        let prev = (y - 1) * w;
191        let cur = y * w;
192        integral_add_row(&mut sat, prev, cur, w);
193    }
194
195    Tensor::from_vec(vec![h, w, 1], sat).map_err(Into::into)
196}
197
198/// Adds `sat[prev_off..prev_off+w]` into `sat[cur_off..cur_off+w]` using SIMD.
199#[allow(unsafe_code)]
200#[inline]
201fn integral_add_row(sat: &mut [f32], prev_off: usize, cur_off: usize, w: usize) {
202    let mut x = 0;
203
204    #[cfg(target_arch = "aarch64")]
205    if !cfg!(miri) {
206        unsafe {
207            use std::arch::aarch64::*;
208            let p = sat.as_mut_ptr();
209            while x + 4 <= w {
210                let a = vld1q_f32(p.add(cur_off + x));
211                let b = vld1q_f32(p.add(prev_off + x));
212                vst1q_f32(p.add(cur_off + x), vaddq_f32(a, b));
213                x += 4;
214            }
215        }
216    }
217
218    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
219    if !cfg!(miri) {
220        unsafe {
221            #[cfg(target_arch = "x86")]
222            use std::arch::x86::*;
223            #[cfg(target_arch = "x86_64")]
224            use std::arch::x86_64::*;
225            if std::is_x86_feature_detected!("sse") {
226                let p = sat.as_mut_ptr();
227                while x + 4 <= w {
228                    let a = _mm_loadu_ps(p.add(cur_off + x));
229                    let b = _mm_loadu_ps(p.add(prev_off + x));
230                    _mm_storeu_ps(p.add(cur_off + x), _mm_add_ps(a, b));
231                    x += 4;
232                }
233            }
234        }
235    }
236
237    // Scalar tail
238    while x < w {
239        sat[cur_off + x] += sat[prev_off + x];
240        x += 1;
241    }
242}
243
244/// Applies CLAHE to a grayscale `[H, W, 1]` image.
245///
246/// `tile_h` and `tile_w` define the grid size. `clip_limit` caps bin counts
247/// before redistribution (typical: 2.0-4.0).
248pub fn clahe(
249    input: &Tensor,
250    tile_h: usize,
251    tile_w: usize,
252    clip_limit: f32,
253) -> Result<Tensor, ImgProcError> {
254    let (h, w, c) = hwc_shape(input)?;
255    if c != 1 {
256        return Err(ImgProcError::InvalidChannelCount {
257            expected: 1,
258            got: c,
259        });
260    }
261    if tile_h == 0 || tile_w == 0 {
262        return Err(ImgProcError::InvalidBlockSize { block_size: 0 });
263    }
264
265    let src = input.data();
266    let mut out = vec![0.0f32; h * w];
267
268    let grid_rows = tile_h.min(h);
269    let grid_cols = tile_w.min(w);
270    let cell_h = h / grid_rows;
271    let cell_w = w / grid_cols;
272    let n_tiles = grid_rows * grid_cols;
273
274    // Pre-convert entire image to u8 once (avoid per-pixel float→u8 in hot loops)
275    let mut src_u8 = vec![0u8; h * w];
276    for (dst, &v) in src_u8.iter_mut().zip(src.iter()) {
277        *dst = (v.clamp(0.0, 1.0) * 255.0) as u8;
278    }
279
280    // Flat map storage: maps[tile_idx * 256 + val]
281    let mut maps = vec![0u8; n_tiles * 256];
282
283    // Compute tile maps — each tile is independent, parallelize across tiles.
284    {
285        use super::u8ops::gcd;
286        let maps_ptr = maps.as_mut_ptr() as usize;
287        let src_u8_ptr = src_u8.as_ptr() as usize;
288        gcd::parallel_for(n_tiles, |tile_idx| {
289            let gr = tile_idx / grid_cols;
290            let gc = tile_idx % grid_cols;
291            let y0 = gr * cell_h;
292            let x0 = gc * cell_w;
293            let y1 = if gr == grid_rows - 1 { h } else { y0 + cell_h };
294            let x1 = if gc == grid_cols - 1 { w } else { x0 + cell_w };
295            let n_pixels = (y1 - y0) * (x1 - x0);
296
297            // SAFETY: each tile writes to its own non-overlapping 256-byte slice.
298            let map = unsafe {
299                std::slice::from_raw_parts_mut((maps_ptr as *mut u8).add(tile_idx * 256), 256)
300            };
301            let src_u8 = unsafe { std::slice::from_raw_parts(src_u8_ptr as *const u8, h * w) };
302
303            let mut hist = [0u32; 256];
304            for y in y0..y1 {
305                for x in x0..x1 {
306                    hist[src_u8[y * w + x] as usize] += 1;
307                }
308            }
309
310            let clip = (clip_limit * n_pixels as f32 / 256.0).max(1.0) as u32;
311            let mut excess = 0u32;
312            for h_bin in hist.iter_mut() {
313                if *h_bin > clip {
314                    excess += *h_bin - clip;
315                    *h_bin = clip;
316                }
317            }
318            let bonus = excess / 256;
319            let leftover = (excess % 256) as usize;
320            for h_bin in hist.iter_mut() {
321                *h_bin += bonus;
322            }
323            for i in 0..leftover {
324                hist[i] += 1;
325            }
326
327            let scale = 255.0 / n_pixels.max(1) as f32;
328            let mut csum = 0u32;
329            for (i, &count) in hist.iter().enumerate() {
330                csum += count;
331                map[i] = (csum as f32 * scale).round().min(255.0) as u8;
332            }
333        });
334    }
335
336    // Pre-compute reciprocals for interpolation
337    let inv_cell_h = 1.0 / cell_h as f32;
338    let inv_cell_w = 1.0 / cell_w as f32;
339    let gr_max = grid_rows.saturating_sub(2);
340    let gc_max = grid_cols.saturating_sub(2);
341
342    // Interpolation pass — rows are independent, parallelize across rows.
343    {
344        use super::u8ops::gcd;
345        let out_ptr = out.as_mut_ptr() as usize;
346        let maps_ptr = maps.as_ptr() as usize;
347        let src_u8_ptr = src_u8.as_ptr() as usize;
348        gcd::parallel_for(h, |y| {
349            // SAFETY: each row writes to non-overlapping out[y*w..(y+1)*w].
350            let out_row =
351                unsafe { std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(y * w), w) };
352            let maps = unsafe { std::slice::from_raw_parts(maps_ptr as *const u8, n_tiles * 256) };
353            let src_u8 = unsafe { std::slice::from_raw_parts(src_u8_ptr as *const u8, h * w) };
354
355            let gy = (y as f32 * inv_cell_h - 0.5).max(0.0);
356            let gr0 = (gy as usize).min(gr_max);
357            let gr1 = (gr0 + 1).min(grid_rows - 1);
358            let fy = gy - gr0 as f32;
359            let m00_base = gr0 * grid_cols;
360            let m10_base = gr1 * grid_cols;
361
362            for x in 0..w {
363                let val = src_u8[y * w + x] as usize;
364                let gx = (x as f32 * inv_cell_w - 0.5).max(0.0);
365                let gc0 = (gx as usize).min(gc_max);
366                let gc1 = (gc0 + 1).min(grid_cols - 1);
367                let fx = gx - gc0 as f32;
368
369                let v00 = maps[(m00_base + gc0) * 256 + val] as f32;
370                let v01 = maps[(m00_base + gc1) * 256 + val] as f32;
371                let v10 = maps[(m10_base + gc0) * 256 + val] as f32;
372                let v11 = maps[(m10_base + gc1) * 256 + val] as f32;
373
374                let top = v00 + fx * (v01 - v00);
375                let bot = v10 + fx * (v11 - v10);
376                out_row[x] = (top + fy * (bot - top)) * (1.0 / 255.0);
377            }
378        });
379    }
380
381    Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
382}