Skip to main content

scirs2_ndimage/segmentation/
thresholding.rs

1//! Thresholding algorithms for image segmentation
2//!
3//! This module provides functions for thresholding images to create binary masks or segmentations.
4
5use crate::error::{NdimageError, NdimageResult};
6use crate::utils::safe_f64_to_float;
7use scirs2_core::ndarray::{Array, Dimension, Ix2};
8use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
9
10/// Helper function for safe conversion from usize to float
11#[allow(dead_code)]
12fn safe_usize_to_float<T: Float + FromPrimitive>(value: usize) -> NdimageResult<T> {
13    T::from_usize(value).ok_or_else(|| {
14        NdimageError::ComputationError(format!("Failed to convert usize {} to float type", value))
15    })
16}
17
18/// Apply a threshold to an image to create a binary image
19///
20/// # Arguments
21///
22/// * `image` - Input array
23/// * `threshold` - Threshold value
24///
25/// # Returns
26///
27/// * Binary mask where values equal to or above the threshold are set to true
28///
29/// # Example
30///
31/// ```
32/// use scirs2_core::ndarray::array;
33/// use scirs2_ndimage::segmentation::threshold_binary;
34///
35/// let image = array![
36///     [0.0, 0.2, 0.5],
37///     [0.3, 0.8, 0.1],
38///     [0.7, 0.4, 0.6],
39/// ];
40///
41/// let mask = threshold_binary(&image, 0.5).unwrap();
42/// ```
43#[allow(dead_code)]
44pub fn threshold_binary<T, D>(image: &Array<T, D>, threshold: T) -> NdimageResult<Array<T, D>>
45where
46    T: Float + NumAssign + std::fmt::Debug + std::ops::DivAssign + 'static,
47    D: Dimension + 'static,
48{
49    // Apply threshold by mapping over the input array
50    let result = image.mapv(|val| if val > threshold { T::one() } else { T::zero() });
51
52    Ok(result)
53}
54
55/// Apply Otsu's thresholding method
56///
57/// Otsu's method determines an optimal threshold by maximizing
58/// the variance between foreground and background classes.
59///
60/// # Arguments
61///
62/// * `image` - Input array
63/// * `bins` - Number of bins for the histogram
64///
65/// # Returns
66///
67/// * Tuple containing (binary_image, threshold_value)
68///
69/// # Example
70///
71/// ```
72/// use scirs2_core::ndarray::array;
73/// use scirs2_ndimage::segmentation::otsu_threshold;
74///
75/// let image = array![
76///     [0.1, 0.2, 0.3],
77///     [0.4, 0.5, 0.6],
78///     [0.7, 0.8, 0.9],
79/// ];
80///
81/// let (binary, threshold) = otsu_threshold(&image, 256).unwrap();
82/// ```
83#[allow(dead_code)]
84pub fn otsu_threshold<T, D>(image: &Array<T, D>, bins: usize) -> NdimageResult<(Array<T, D>, T)>
85where
86    T: Float + NumAssign + std::fmt::Debug + std::ops::DivAssign + FromPrimitive + 'static,
87    D: Dimension + 'static,
88{
89    let nbins = bins;
90
91    // Get min and max values
92    let mut min_val = Float::infinity();
93    let mut max_val = Float::neg_infinity();
94
95    for &val in image.iter() {
96        if val < min_val {
97            min_val = val;
98        }
99        if val > max_val {
100            max_val = val;
101        }
102    }
103
104    // Handle edge case of flat image
105    if min_val == max_val {
106        // Create a binary image with all zeros (as all values == threshold)
107        let binary = threshold_binary(image, min_val)?;
108        return Ok((binary, min_val));
109    }
110
111    // Calculate histogram
112    let mut hist = vec![0; nbins];
113    let bin_width = (max_val - min_val) / safe_usize_to_float(nbins)?;
114
115    for &val in image.iter() {
116        let bin = ((val - min_val) / bin_width).to_usize().unwrap_or(0);
117        let bin_index = std::cmp::min(bin, nbins - 1);
118        hist[bin_index] += 1;
119    }
120
121    // Calculate total pixels
122    let total_pixels = image.len();
123
124    // Compute cumulative sums
125    let mut cum_sum = vec![0; nbins];
126    cum_sum[0] = hist[0];
127    for i in 1..nbins {
128        cum_sum[i] = cum_sum[i - 1] + hist[i];
129    }
130
131    // Compute cumulative means
132    let mut cum_val = vec![T::zero(); nbins];
133    for i in 0..nbins {
134        if i > 0 {
135            cum_val[i] = cum_val[i - 1] + safe_usize_to_float(i * hist[i])?;
136        } else {
137            cum_val[i] = safe_usize_to_float(i * hist[i])?
138        }
139    }
140
141    // Compute maximum inter-class variance
142    let mut max_var = T::zero();
143    let mut threshold_idx = 0;
144
145    for i in 0..(nbins - 1) {
146        let bg_pixels = cum_sum[i];
147        let fg_pixels = total_pixels - bg_pixels;
148
149        // Skip cases where all pixels are in one class
150        if bg_pixels == 0 || fg_pixels == 0 {
151            continue;
152        }
153
154        let bg_mean = cum_val[i] / safe_usize_to_float::<T>(bg_pixels)?;
155        let fg_mean = (cum_val[nbins - 1] - cum_val[i]) / safe_usize_to_float::<T>(fg_pixels)?;
156
157        // Calculate inter-class variance
158        let variance = safe_usize_to_float::<T>(bg_pixels * fg_pixels)?
159            * (bg_mean - fg_mean)
160            * (bg_mean - fg_mean);
161
162        // Update threshold if variance is higher
163        if variance > max_var {
164            max_var = variance;
165            threshold_idx = i;
166        }
167    }
168
169    // Convert threshold index back to intensity value
170    let threshold = min_val + safe_usize_to_float::<T>(threshold_idx)? * bin_width;
171
172    // Create binary image using the threshold
173    let binary = threshold_binary(image, threshold)?;
174
175    Ok((binary, threshold))
176}
177
178/// Apply adaptive thresholding
179///
180/// Adaptive thresholding computes a local threshold for each pixel based on
181/// the statistics of its neighborhood.
182///
183/// # Arguments
184///
185/// * `image` - Input 2D array
186/// * `block_size` - Size of the neighborhood for calculating local threshold
187/// * `method` - Thresholding method ('mean' or 'gaussian')
188/// * `c` - Constant subtracted from the local threshold
189///
190/// # Returns
191///
192/// * Result containing the binary mask
193///
194/// # Example
195///
196/// ```
197/// use scirs2_core::ndarray::array;
198/// use scirs2_ndimage::segmentation::{adaptive_threshold, AdaptiveMethod};
199///
200/// let image = array![
201///     [0.1, 0.2, 0.7],
202///     [0.3, 0.8, 0.1],
203///     [0.7, 0.4, 0.2],
204/// ];
205///
206/// let mask = adaptive_threshold(&image, 3, AdaptiveMethod::Mean, 0.05).unwrap();
207/// ```
208#[derive(Debug, Clone, Copy)]
209pub enum AdaptiveMethod {
210    Mean,
211    Gaussian,
212}
213
214#[allow(dead_code)]
215pub fn adaptive_threshold<T>(
216    image: &Array<T, Ix2>,
217    block_size: usize,
218    method: AdaptiveMethod,
219    c: T,
220) -> NdimageResult<Array<bool, Ix2>>
221where
222    T: Float + NumAssign + std::fmt::Debug + FromPrimitive,
223{
224    // Check block _size (must be odd)
225    if block_size % 2 == 0 || block_size < 3 {
226        return Err(NdimageError::InvalidInput(
227            "block_size must be odd and at least 3".to_string(),
228        ));
229    }
230
231    let shape = image.raw_dim();
232    let (rows, cols) = (shape[0], shape[1]);
233    let mut result = Array::from_elem(shape, false);
234    let radius = block_size / 2;
235
236    // For each pixel, compute local threshold based on its neighborhood
237    for i in 0..rows {
238        for j in 0..cols {
239            // Define neighborhood bounds with padding at the edges
240            let start_row = i.saturating_sub(radius);
241            let end_row = std::cmp::min(i + radius + 1, rows);
242            let start_col = j.saturating_sub(radius);
243            let end_col = std::cmp::min(j + radius + 1, cols);
244
245            // Slice the neighborhood
246            let neighborhood = image.slice(scirs2_core::ndarray::s![
247                start_row..end_row,
248                start_col..end_col
249            ]);
250
251            // Compute local threshold based on method
252            let threshold = match method {
253                AdaptiveMethod::Mean => {
254                    // Simple mean of neighborhood
255                    let sum = neighborhood.iter().fold(T::zero(), |acc, &x| acc + x);
256                    sum / safe_usize_to_float(neighborhood.len())? - c
257                }
258                AdaptiveMethod::Gaussian => {
259                    // Gaussian weighted mean
260                    // Simplified implementation with distance-based weighting
261                    let center_row = i - start_row;
262                    let center_col = j - start_col;
263
264                    let mut weighted_sum = T::zero();
265                    let mut weight_sum = T::zero();
266
267                    for (idx, &val) in neighborhood.indexed_iter() {
268                        let dist_sq = (idx.0 as isize - center_row as isize).pow(2)
269                            + (idx.1 as isize - center_col as isize).pow(2);
270                        let dist = safe_usize_to_float::<T>(dist_sq as usize)?.sqrt();
271
272                        // Gaussian weight
273                        let sigma =
274                            safe_usize_to_float::<T>(radius)? / safe_f64_to_float::<T>(2.0)?;
275                        let weight =
276                            (-dist * dist / (safe_f64_to_float::<T>(2.0)? * sigma * sigma)).exp();
277
278                        weighted_sum += val * weight;
279                        weight_sum += weight;
280                    }
281
282                    weighted_sum / weight_sum - c
283                }
284            };
285
286            // Apply threshold
287            result[(i, j)] = image[(i, j)] > threshold;
288        }
289    }
290
291    Ok(result)
292}