scirs2_ndimage/segmentation/
graph_cuts.rs

1//! Graph cuts segmentation algorithm
2//!
3//! This module implements the graph cuts segmentation algorithm, which formulates
4//! image segmentation as a min-cut/max-flow problem on a graph.
5
6use scirs2_core::ndarray::{Array2, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::{HashMap, VecDeque};
9use std::fmt::Debug;
10
11use crate::error::{NdimageError, NdimageResult};
12
13/// Graph structure for max-flow/min-cut algorithm
14struct Graph {
15    nodes: Vec<Node>,
16    edges: HashMap<(usize, usize), f64>,
17    source: usize,
18    sink: usize,
19}
20
21/// Node in the graph
22#[derive(Clone, Debug)]
23struct Node {
24    id: usize,
25    neighbors: Vec<usize>,
26}
27
28impl Graph {
29    /// Create a new graph with specified number of nodes
30    fn new(_numnodes: usize) -> Self {
31        let mut _nodes = Vec::with_capacity(_numnodes + 2);
32        for i in 0.._numnodes + 2 {
33            _nodes.push(Node {
34                id: i,
35                neighbors: Vec::new(),
36            });
37        }
38
39        Self {
40            nodes: _nodes,
41            edges: HashMap::new(),
42            source: _numnodes,
43            sink: _numnodes + 1,
44        }
45    }
46
47    /// Add an edge between two nodes
48    fn add_edge(&mut self, from: usize, to: usize, capacity: f64) {
49        if from != to && capacity > 0.0 {
50            self.nodes[from].neighbors.push(to);
51            self.nodes[to].neighbors.push(from);
52            self.edges.insert((from, to), capacity);
53            self.edges.insert((to, from), 0.0); // Reverse edge with 0 capacity
54        }
55    }
56
57    /// Find augmenting path using BFS
58    fn bfs(
59        &self,
60        parent: &mut Vec<Option<usize>>,
61        residual: &HashMap<(usize, usize), f64>,
62    ) -> bool {
63        let mut visited = vec![false; self.nodes.len()];
64        let mut queue = VecDeque::new();
65
66        queue.push_back(self.source);
67        visited[self.source] = true;
68        parent[self.source] = None;
69
70        while let Some(u) = queue.pop_front() {
71            for &v in &self.nodes[u].neighbors {
72                let capacity = residual.get(&(u, v)).unwrap_or(&0.0);
73                if !visited[v] && *capacity > 0.0 {
74                    visited[v] = true;
75                    parent[v] = Some(u);
76
77                    if v == self.sink {
78                        return true;
79                    }
80
81                    queue.push_back(v);
82                }
83            }
84        }
85
86        false
87    }
88
89    /// Compute maximum flow using Ford-Fulkerson algorithm
90    fn max_flow(&mut self) -> (f64, Vec<bool>) {
91        let mut residual = self.edges.clone();
92        let mut parent = vec![None; self.nodes.len()];
93        let mut max_flow = 0.0;
94
95        // Find augmenting paths
96        while self.bfs(&mut parent, &residual) {
97            // Find minimum capacity along the path
98            let mut path_flow = f64::INFINITY;
99            let mut v = self.sink;
100
101            while v != self.source {
102                let u = parent[v].unwrap();
103                let capacity = residual.get(&(u, v)).unwrap_or(&0.0);
104                path_flow = path_flow.min(*capacity);
105                v = u;
106            }
107
108            // Update residual capacities
109            v = self.sink;
110            while v != self.source {
111                let u = parent[v].unwrap();
112                *residual.get_mut(&(u, v)).unwrap() -= path_flow;
113                *residual.get_mut(&(v, u)).unwrap() += path_flow;
114                v = u;
115            }
116
117            max_flow += path_flow;
118        }
119
120        // Find minimum cut
121        let mut cut = vec![false; self.nodes.len()];
122        let mut visited = vec![false; self.nodes.len()];
123        let mut queue = VecDeque::new();
124
125        queue.push_back(self.source);
126        visited[self.source] = true;
127        cut[self.source] = true;
128
129        while let Some(u) = queue.pop_front() {
130            for &v in &self.nodes[u].neighbors {
131                let capacity = residual.get(&(u, v)).unwrap_or(&0.0);
132                if !visited[v] && *capacity > 0.0 {
133                    visited[v] = true;
134                    cut[v] = true;
135                    queue.push_back(v);
136                }
137            }
138        }
139
140        (max_flow, cut)
141    }
142}
143
144/// Parameters for graph cuts segmentation
145#[derive(Clone)]
146pub struct GraphCutsParams {
147    /// Weight for smoothness term (pairwise potentials)
148    pub lambda: f64,
149    /// Sigma for Gaussian similarity in smoothness term
150    pub sigma: f64,
151    /// Neighborhood system: 4 or 8 connectivity
152    pub connectivity: u8,
153}
154
155impl Default for GraphCutsParams {
156    fn default() -> Self {
157        Self {
158            lambda: 1.0,
159            sigma: 50.0,
160            connectivity: 8,
161        }
162    }
163}
164
165/// Perform graph cuts segmentation on an image
166///
167/// # Arguments
168/// * `image` - Input image
169/// * `foreground_seeds` - Mask indicating definite foreground pixels
170/// * `background_seeds` - Mask indicating definite background pixels
171/// * `params` - Segmentation parameters
172///
173/// # Returns
174/// Binary segmentation mask where true indicates foreground
175#[allow(dead_code)]
176pub fn graph_cuts<T>(
177    image: &ArrayView2<T>,
178    foreground_seeds: &ArrayView2<bool>,
179    background_seeds: &ArrayView2<bool>,
180    params: Option<GraphCutsParams>,
181) -> NdimageResult<Array2<bool>>
182where
183    T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
184{
185    let params = params.unwrap_or_default();
186    let (height, width) = image.dim();
187    let num_pixels = height * width;
188
189    // Validate inputs
190    if foreground_seeds.dim() != image.dim() || background_seeds.dim() != image.dim() {
191        return Err(NdimageError::DimensionError(
192            "Seed masks must have same dimensions as image".into(),
193        ));
194    }
195
196    // Check for overlapping _seeds
197    for i in 0..height {
198        for j in 0..width {
199            if foreground_seeds[[i, j]] && background_seeds[[i, j]] {
200                return Err(NdimageError::InvalidInput(
201                    "Foreground and background _seeds cannot overlap".into(),
202                ));
203            }
204        }
205    }
206
207    // Create graph
208    let mut graph = Graph::new(num_pixels);
209
210    // Helper function to convert 2D coordinates to node index
211    let coord_to_idx = |y: usize, x: usize| -> usize { y * width + x };
212
213    // Add terminal edges (data term)
214    let k = compute_k_constant(image);
215
216    for i in 0..height {
217        for j in 0..width {
218            let idx = coord_to_idx(i, j);
219
220            if foreground_seeds[[i, j]] {
221                // Definite foreground
222                graph.add_edge(graph.source, idx, k);
223                graph.add_edge(idx, graph.sink, 0.0);
224            } else if background_seeds[[i, j]] {
225                // Definite background
226                graph.add_edge(graph.source, idx, 0.0);
227                graph.add_edge(idx, graph.sink, k);
228            } else {
229                // Unknown - use data-driven weights
230                let (fg_weight, bg_weight) =
231                    compute_data_weights(image, i, j, foreground_seeds, background_seeds);
232                graph.add_edge(graph.source, idx, fg_weight);
233                graph.add_edge(idx, graph.sink, bg_weight);
234            }
235        }
236    }
237
238    // Add neighbor edges (smoothness term)
239    let neighbors = get_neighbors(params.connectivity);
240
241    for i in 0..height {
242        for j in 0..width {
243            let idx1 = coord_to_idx(i, j);
244            let val1 = image[[i, j]];
245
246            for (di, dj) in &neighbors {
247                let ni = i as i32 + di;
248                let nj = j as i32 + dj;
249
250                if ni >= 0 && ni < height as i32 && nj >= 0 && nj < width as i32 {
251                    let ni = ni as usize;
252                    let nj = nj as usize;
253                    let idx2 = coord_to_idx(ni, nj);
254
255                    if idx1 < idx2 {
256                        // Avoid duplicate edges
257                        let val2 = image[[ni, nj]];
258                        let weight =
259                            compute_smoothness_weight(val1, val2, params.lambda, params.sigma);
260                        graph.add_edge(idx1, idx2, weight);
261                    }
262                }
263            }
264        }
265    }
266
267    // Solve max-flow/min-cut
268    let (_, cut) = graph.max_flow();
269
270    // Convert cut to segmentation mask
271    let mut result = Array2::default((height, width));
272    for i in 0..height {
273        for j in 0..width {
274            let idx = coord_to_idx(i, j);
275            result[[i, j]] = cut[idx];
276        }
277    }
278
279    Ok(result)
280}
281
282/// Compute K constant for terminal edges
283#[allow(dead_code)]
284fn compute_k_constant<T: Float>(image: &ArrayView2<T>) -> f64 {
285    // K should be larger than any possible sum of edge weights
286    let max_val = image
287        .iter()
288        .map(|&v| v.to_f64().unwrap_or(0.0))
289        .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
290        .unwrap_or(0.0);
291
292    1.0 + max_val * 8.0 // Conservative estimate
293}
294
295/// Compute data weights for a pixel
296#[allow(dead_code)]
297fn compute_data_weights<T: Float>(
298    image: &ArrayView2<T>,
299    y: usize,
300    x: usize,
301    foreground_seeds: &ArrayView2<bool>,
302    background_seeds: &ArrayView2<bool>,
303) -> (f64, f64) {
304    let pixel_val = image[[y, x]].to_f64().unwrap_or(0.0);
305    let (height, width) = image.dim();
306
307    // Compute mean intensity of seed regions
308    let mut fg_sum = 0.0;
309    let mut fg_count = 0;
310    let mut bg_sum = 0.0;
311    let mut bg_count = 0;
312
313    for i in 0..height {
314        for j in 0..width {
315            if foreground_seeds[[i, j]] {
316                fg_sum += image[[i, j]].to_f64().unwrap_or(0.0);
317                fg_count += 1;
318            } else if background_seeds[[i, j]] {
319                bg_sum += image[[i, j]].to_f64().unwrap_or(0.0);
320                bg_count += 1;
321            }
322        }
323    }
324
325    let fg_mean = if fg_count > 0 {
326        fg_sum / fg_count as f64
327    } else {
328        0.0
329    };
330    let bg_mean = if bg_count > 0 {
331        bg_sum / bg_count as f64
332    } else {
333        255.0
334    };
335
336    // Simple Gaussian model
337    let fg_diff = pixel_val - fg_mean;
338    let bg_diff = pixel_val - bg_mean;
339
340    let fg_prob = (-fg_diff * fg_diff / 100.0).exp();
341    let bg_prob = (-bg_diff * bg_diff / 100.0).exp();
342
343    let epsilon = 1e-10;
344    let fg_weight = -((bg_prob + epsilon).ln());
345    let bg_weight = -((fg_prob + epsilon).ln());
346
347    (fg_weight.max(0.0), bg_weight.max(0.0))
348}
349
350/// Compute smoothness weight between neighboring pixels
351#[allow(dead_code)]
352fn compute_smoothness_weight<T: Float>(val1: T, val2: T, lambda: f64, sigma: f64) -> f64 {
353    let diff = (val1 - val2).to_f64().unwrap_or(0.0);
354    let weight = lambda * (-diff * diff / (2.0 * sigma * sigma)).exp();
355    weight
356}
357
358/// Get neighbor offsets based on connectivity
359#[allow(dead_code)]
360fn get_neighbors(connectivity: u8) -> Vec<(i32, i32)> {
361    match connectivity {
362        4 => vec![(0, 1), (1, 0), (0, -1), (-1, 0)],
363        8 => vec![
364            (0, 1),
365            (1, 0),
366            (0, -1),
367            (-1, 0),
368            (1, 1),
369            (1, -1),
370            (-1, 1),
371            (-1, -1),
372        ],
373        _ => vec![(0, 1), (1, 0), (0, -1), (-1, 0)], // Default to 4-_connectivity
374    }
375}
376
377/// Interactive graph cuts segmentation with iterative refinement
378pub struct InteractiveGraphCuts<T> {
379    image: Array2<T>,
380    foreground_seeds: Array2<bool>,
381    background_seeds: Array2<bool>,
382    current_segmentation: Option<Array2<bool>>,
383    params: GraphCutsParams,
384}
385
386impl<T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static>
387    InteractiveGraphCuts<T>
388{
389    /// Create new interactive segmentation session
390    pub fn new(image: Array2<T>, params: Option<GraphCutsParams>) -> Self {
391        let shape = image.dim();
392        Self {
393            image,
394            foreground_seeds: Array2::default(shape),
395            background_seeds: Array2::default(shape),
396            current_segmentation: None,
397            params: params.unwrap_or_default(),
398        }
399    }
400
401    /// Add foreground seeds
402    pub fn add_foreground_seeds(&mut self, seeds: &[(usize, usize)]) {
403        for &(y, x) in seeds {
404            if y < self.foreground_seeds.dim().0 && x < self.foreground_seeds.dim().1 {
405                self.foreground_seeds[[y, x]] = true;
406                self.background_seeds[[y, x]] = false; // Ensure no overlap
407            }
408        }
409    }
410
411    /// Add background seeds
412    pub fn add_background_seeds(&mut self, seeds: &[(usize, usize)]) {
413        for &(y, x) in seeds {
414            if y < self.background_seeds.dim().0 && x < self.background_seeds.dim().1 {
415                self.background_seeds[[y, x]] = true;
416                self.foreground_seeds[[y, x]] = false; // Ensure no overlap
417            }
418        }
419    }
420
421    /// Clear all seeds
422    pub fn clear_seeds(&mut self) {
423        self.foreground_seeds.fill(false);
424        self.background_seeds.fill(false);
425    }
426
427    /// Run segmentation with current seeds
428    pub fn segment(&mut self) -> NdimageResult<&Array2<bool>> {
429        let result = graph_cuts(
430            &self.image.view(),
431            &self.foreground_seeds.view(),
432            &self.background_seeds.view(),
433            Some(self.params.clone()),
434        )?;
435
436        self.current_segmentation = Some(result);
437        Ok(self.current_segmentation.as_ref().unwrap())
438    }
439
440    /// Get current segmentation result
441    pub fn get_segmentation(&self) -> Option<&Array2<bool>> {
442        self.current_segmentation.as_ref()
443    }
444}
445
446impl GraphCutsParams {
447    /// Create parameters optimized for grayscale images
448    pub fn for_grayscale() -> Self {
449        Self {
450            lambda: 10.0,
451            sigma: 30.0,
452            connectivity: 8,
453        }
454    }
455
456    /// Create parameters optimized for color images
457    pub fn for_color() -> Self {
458        Self {
459            lambda: 5.0,
460            sigma: 50.0,
461            connectivity: 8,
462        }
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use scirs2_core::ndarray::arr2;
470
471    #[test]
472    fn test_graph_cuts_simple() {
473        // Create simple test image
474        let image = arr2(&[
475            [0.0, 0.0, 100.0, 100.0],
476            [0.0, 0.0, 100.0, 100.0],
477            [0.0, 0.0, 100.0, 100.0],
478            [0.0, 0.0, 100.0, 100.0],
479        ]);
480
481        // Create seed masks
482        let mut fg_seeds = Array2::default((4, 4));
483        let mut bg_seeds = Array2::default((4, 4));
484
485        // Mark some foreground seeds (right side)
486        fg_seeds[[1, 2]] = true;
487        fg_seeds[[2, 3]] = true;
488
489        // Mark some background seeds (left side)
490        bg_seeds[[1, 0]] = true;
491        bg_seeds[[2, 1]] = true;
492
493        // Run segmentation
494        let result = graph_cuts(&image.view(), &fg_seeds.view(), &bg_seeds.view(), None).unwrap();
495
496        // Check that right side is segmented as foreground
497        assert!(result[[0, 2]] || result[[0, 3]]);
498        assert!(result[[1, 2]] || result[[1, 3]]);
499
500        // Check that left side is segmented as background
501        assert!(!result[[0, 0]] && !result[[0, 1]]);
502        assert!(!result[[1, 0]] && !result[[1, 1]]);
503    }
504
505    #[test]
506    fn test_interactive_graph_cuts() {
507        let image = arr2(&[
508            [10.0, 20.0, 80.0, 90.0],
509            [15.0, 25.0, 85.0, 95.0],
510            [12.0, 22.0, 82.0, 92.0],
511            [18.0, 28.0, 88.0, 98.0],
512        ]);
513
514        let mut interactive = InteractiveGraphCuts::new(image, None);
515
516        // Add seeds
517        interactive.add_foreground_seeds(&[(0, 3), (1, 2)]);
518        interactive.add_background_seeds(&[(0, 0), (1, 1)]);
519
520        // Segment
521        let result = interactive.segment().unwrap();
522        assert_eq!(result.dim(), (4, 4));
523
524        // Add more seeds and re-segment
525        interactive.add_foreground_seeds(&[(2, 3)]);
526        let result2 = interactive.segment().unwrap();
527        assert_eq!(result2.dim(), (4, 4));
528    }
529}