1use scirs2_core::ndarray::{Array2, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::{HashMap, VecDeque};
9use std::fmt::Debug;
10
11use crate::error::{NdimageError, NdimageResult};
12
13struct Graph {
15 nodes: Vec<Node>,
16 edges: HashMap<(usize, usize), f64>,
17 source: usize,
18 sink: usize,
19}
20
21#[derive(Clone, Debug)]
23struct Node {
24 id: usize,
25 neighbors: Vec<usize>,
26}
27
28impl Graph {
29 fn new(_numnodes: usize) -> Self {
31 let mut _nodes = Vec::with_capacity(_numnodes + 2);
32 for i in 0.._numnodes + 2 {
33 _nodes.push(Node {
34 id: i,
35 neighbors: Vec::new(),
36 });
37 }
38
39 Self {
40 nodes: _nodes,
41 edges: HashMap::new(),
42 source: _numnodes,
43 sink: _numnodes + 1,
44 }
45 }
46
47 fn add_edge(&mut self, from: usize, to: usize, capacity: f64) {
49 if from != to && capacity > 0.0 {
50 self.nodes[from].neighbors.push(to);
51 self.nodes[to].neighbors.push(from);
52 self.edges.insert((from, to), capacity);
53 self.edges.insert((to, from), 0.0); }
55 }
56
57 fn bfs(
59 &self,
60 parent: &mut Vec<Option<usize>>,
61 residual: &HashMap<(usize, usize), f64>,
62 ) -> bool {
63 let mut visited = vec![false; self.nodes.len()];
64 let mut queue = VecDeque::new();
65
66 queue.push_back(self.source);
67 visited[self.source] = true;
68 parent[self.source] = None;
69
70 while let Some(u) = queue.pop_front() {
71 for &v in &self.nodes[u].neighbors {
72 let capacity = residual.get(&(u, v)).unwrap_or(&0.0);
73 if !visited[v] && *capacity > 0.0 {
74 visited[v] = true;
75 parent[v] = Some(u);
76
77 if v == self.sink {
78 return true;
79 }
80
81 queue.push_back(v);
82 }
83 }
84 }
85
86 false
87 }
88
89 fn max_flow(&mut self) -> (f64, Vec<bool>) {
91 let mut residual = self.edges.clone();
92 let mut parent = vec![None; self.nodes.len()];
93 let mut max_flow = 0.0;
94
95 while self.bfs(&mut parent, &residual) {
97 let mut path_flow = f64::INFINITY;
99 let mut v = self.sink;
100
101 while v != self.source {
102 let u = parent[v].unwrap();
103 let capacity = residual.get(&(u, v)).unwrap_or(&0.0);
104 path_flow = path_flow.min(*capacity);
105 v = u;
106 }
107
108 v = self.sink;
110 while v != self.source {
111 let u = parent[v].unwrap();
112 *residual.get_mut(&(u, v)).unwrap() -= path_flow;
113 *residual.get_mut(&(v, u)).unwrap() += path_flow;
114 v = u;
115 }
116
117 max_flow += path_flow;
118 }
119
120 let mut cut = vec![false; self.nodes.len()];
122 let mut visited = vec![false; self.nodes.len()];
123 let mut queue = VecDeque::new();
124
125 queue.push_back(self.source);
126 visited[self.source] = true;
127 cut[self.source] = true;
128
129 while let Some(u) = queue.pop_front() {
130 for &v in &self.nodes[u].neighbors {
131 let capacity = residual.get(&(u, v)).unwrap_or(&0.0);
132 if !visited[v] && *capacity > 0.0 {
133 visited[v] = true;
134 cut[v] = true;
135 queue.push_back(v);
136 }
137 }
138 }
139
140 (max_flow, cut)
141 }
142}
143
144#[derive(Clone)]
146pub struct GraphCutsParams {
147 pub lambda: f64,
149 pub sigma: f64,
151 pub connectivity: u8,
153}
154
155impl Default for GraphCutsParams {
156 fn default() -> Self {
157 Self {
158 lambda: 1.0,
159 sigma: 50.0,
160 connectivity: 8,
161 }
162 }
163}
164
165#[allow(dead_code)]
176pub fn graph_cuts<T>(
177 image: &ArrayView2<T>,
178 foreground_seeds: &ArrayView2<bool>,
179 background_seeds: &ArrayView2<bool>,
180 params: Option<GraphCutsParams>,
181) -> NdimageResult<Array2<bool>>
182where
183 T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
184{
185 let params = params.unwrap_or_default();
186 let (height, width) = image.dim();
187 let num_pixels = height * width;
188
189 if foreground_seeds.dim() != image.dim() || background_seeds.dim() != image.dim() {
191 return Err(NdimageError::DimensionError(
192 "Seed masks must have same dimensions as image".into(),
193 ));
194 }
195
196 for i in 0..height {
198 for j in 0..width {
199 if foreground_seeds[[i, j]] && background_seeds[[i, j]] {
200 return Err(NdimageError::InvalidInput(
201 "Foreground and background _seeds cannot overlap".into(),
202 ));
203 }
204 }
205 }
206
207 let mut graph = Graph::new(num_pixels);
209
210 let coord_to_idx = |y: usize, x: usize| -> usize { y * width + x };
212
213 let k = compute_k_constant(image);
215
216 for i in 0..height {
217 for j in 0..width {
218 let idx = coord_to_idx(i, j);
219
220 if foreground_seeds[[i, j]] {
221 graph.add_edge(graph.source, idx, k);
223 graph.add_edge(idx, graph.sink, 0.0);
224 } else if background_seeds[[i, j]] {
225 graph.add_edge(graph.source, idx, 0.0);
227 graph.add_edge(idx, graph.sink, k);
228 } else {
229 let (fg_weight, bg_weight) =
231 compute_data_weights(image, i, j, foreground_seeds, background_seeds);
232 graph.add_edge(graph.source, idx, fg_weight);
233 graph.add_edge(idx, graph.sink, bg_weight);
234 }
235 }
236 }
237
238 let neighbors = get_neighbors(params.connectivity);
240
241 for i in 0..height {
242 for j in 0..width {
243 let idx1 = coord_to_idx(i, j);
244 let val1 = image[[i, j]];
245
246 for (di, dj) in &neighbors {
247 let ni = i as i32 + di;
248 let nj = j as i32 + dj;
249
250 if ni >= 0 && ni < height as i32 && nj >= 0 && nj < width as i32 {
251 let ni = ni as usize;
252 let nj = nj as usize;
253 let idx2 = coord_to_idx(ni, nj);
254
255 if idx1 < idx2 {
256 let val2 = image[[ni, nj]];
258 let weight =
259 compute_smoothness_weight(val1, val2, params.lambda, params.sigma);
260 graph.add_edge(idx1, idx2, weight);
261 }
262 }
263 }
264 }
265 }
266
267 let (_, cut) = graph.max_flow();
269
270 let mut result = Array2::default((height, width));
272 for i in 0..height {
273 for j in 0..width {
274 let idx = coord_to_idx(i, j);
275 result[[i, j]] = cut[idx];
276 }
277 }
278
279 Ok(result)
280}
281
282#[allow(dead_code)]
284fn compute_k_constant<T: Float>(image: &ArrayView2<T>) -> f64 {
285 let max_val = image
287 .iter()
288 .map(|&v| v.to_f64().unwrap_or(0.0))
289 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
290 .unwrap_or(0.0);
291
292 1.0 + max_val * 8.0 }
294
295#[allow(dead_code)]
297fn compute_data_weights<T: Float>(
298 image: &ArrayView2<T>,
299 y: usize,
300 x: usize,
301 foreground_seeds: &ArrayView2<bool>,
302 background_seeds: &ArrayView2<bool>,
303) -> (f64, f64) {
304 let pixel_val = image[[y, x]].to_f64().unwrap_or(0.0);
305 let (height, width) = image.dim();
306
307 let mut fg_sum = 0.0;
309 let mut fg_count = 0;
310 let mut bg_sum = 0.0;
311 let mut bg_count = 0;
312
313 for i in 0..height {
314 for j in 0..width {
315 if foreground_seeds[[i, j]] {
316 fg_sum += image[[i, j]].to_f64().unwrap_or(0.0);
317 fg_count += 1;
318 } else if background_seeds[[i, j]] {
319 bg_sum += image[[i, j]].to_f64().unwrap_or(0.0);
320 bg_count += 1;
321 }
322 }
323 }
324
325 let fg_mean = if fg_count > 0 {
326 fg_sum / fg_count as f64
327 } else {
328 0.0
329 };
330 let bg_mean = if bg_count > 0 {
331 bg_sum / bg_count as f64
332 } else {
333 255.0
334 };
335
336 let fg_diff = pixel_val - fg_mean;
338 let bg_diff = pixel_val - bg_mean;
339
340 let fg_prob = (-fg_diff * fg_diff / 100.0).exp();
341 let bg_prob = (-bg_diff * bg_diff / 100.0).exp();
342
343 let epsilon = 1e-10;
344 let fg_weight = -((bg_prob + epsilon).ln());
345 let bg_weight = -((fg_prob + epsilon).ln());
346
347 (fg_weight.max(0.0), bg_weight.max(0.0))
348}
349
350#[allow(dead_code)]
352fn compute_smoothness_weight<T: Float>(val1: T, val2: T, lambda: f64, sigma: f64) -> f64 {
353 let diff = (val1 - val2).to_f64().unwrap_or(0.0);
354 let weight = lambda * (-diff * diff / (2.0 * sigma * sigma)).exp();
355 weight
356}
357
358#[allow(dead_code)]
360fn get_neighbors(connectivity: u8) -> Vec<(i32, i32)> {
361 match connectivity {
362 4 => vec![(0, 1), (1, 0), (0, -1), (-1, 0)],
363 8 => vec![
364 (0, 1),
365 (1, 0),
366 (0, -1),
367 (-1, 0),
368 (1, 1),
369 (1, -1),
370 (-1, 1),
371 (-1, -1),
372 ],
373 _ => vec![(0, 1), (1, 0), (0, -1), (-1, 0)], }
375}
376
377pub struct InteractiveGraphCuts<T> {
379 image: Array2<T>,
380 foreground_seeds: Array2<bool>,
381 background_seeds: Array2<bool>,
382 current_segmentation: Option<Array2<bool>>,
383 params: GraphCutsParams,
384}
385
386impl<T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static>
387 InteractiveGraphCuts<T>
388{
389 pub fn new(image: Array2<T>, params: Option<GraphCutsParams>) -> Self {
391 let shape = image.dim();
392 Self {
393 image,
394 foreground_seeds: Array2::default(shape),
395 background_seeds: Array2::default(shape),
396 current_segmentation: None,
397 params: params.unwrap_or_default(),
398 }
399 }
400
401 pub fn add_foreground_seeds(&mut self, seeds: &[(usize, usize)]) {
403 for &(y, x) in seeds {
404 if y < self.foreground_seeds.dim().0 && x < self.foreground_seeds.dim().1 {
405 self.foreground_seeds[[y, x]] = true;
406 self.background_seeds[[y, x]] = false; }
408 }
409 }
410
411 pub fn add_background_seeds(&mut self, seeds: &[(usize, usize)]) {
413 for &(y, x) in seeds {
414 if y < self.background_seeds.dim().0 && x < self.background_seeds.dim().1 {
415 self.background_seeds[[y, x]] = true;
416 self.foreground_seeds[[y, x]] = false; }
418 }
419 }
420
421 pub fn clear_seeds(&mut self) {
423 self.foreground_seeds.fill(false);
424 self.background_seeds.fill(false);
425 }
426
427 pub fn segment(&mut self) -> NdimageResult<&Array2<bool>> {
429 let result = graph_cuts(
430 &self.image.view(),
431 &self.foreground_seeds.view(),
432 &self.background_seeds.view(),
433 Some(self.params.clone()),
434 )?;
435
436 self.current_segmentation = Some(result);
437 Ok(self.current_segmentation.as_ref().unwrap())
438 }
439
440 pub fn get_segmentation(&self) -> Option<&Array2<bool>> {
442 self.current_segmentation.as_ref()
443 }
444}
445
446impl GraphCutsParams {
447 pub fn for_grayscale() -> Self {
449 Self {
450 lambda: 10.0,
451 sigma: 30.0,
452 connectivity: 8,
453 }
454 }
455
456 pub fn for_color() -> Self {
458 Self {
459 lambda: 5.0,
460 sigma: 50.0,
461 connectivity: 8,
462 }
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use scirs2_core::ndarray::arr2;
470
471 #[test]
472 fn test_graph_cuts_simple() {
473 let image = arr2(&[
475 [0.0, 0.0, 100.0, 100.0],
476 [0.0, 0.0, 100.0, 100.0],
477 [0.0, 0.0, 100.0, 100.0],
478 [0.0, 0.0, 100.0, 100.0],
479 ]);
480
481 let mut fg_seeds = Array2::default((4, 4));
483 let mut bg_seeds = Array2::default((4, 4));
484
485 fg_seeds[[1, 2]] = true;
487 fg_seeds[[2, 3]] = true;
488
489 bg_seeds[[1, 0]] = true;
491 bg_seeds[[2, 1]] = true;
492
493 let result = graph_cuts(&image.view(), &fg_seeds.view(), &bg_seeds.view(), None).unwrap();
495
496 assert!(result[[0, 2]] || result[[0, 3]]);
498 assert!(result[[1, 2]] || result[[1, 3]]);
499
500 assert!(!result[[0, 0]] && !result[[0, 1]]);
502 assert!(!result[[1, 0]] && !result[[1, 1]]);
503 }
504
505 #[test]
506 fn test_interactive_graph_cuts() {
507 let image = arr2(&[
508 [10.0, 20.0, 80.0, 90.0],
509 [15.0, 25.0, 85.0, 95.0],
510 [12.0, 22.0, 82.0, 92.0],
511 [18.0, 28.0, 88.0, 98.0],
512 ]);
513
514 let mut interactive = InteractiveGraphCuts::new(image, None);
515
516 interactive.add_foreground_seeds(&[(0, 3), (1, 2)]);
518 interactive.add_background_seeds(&[(0, 0), (1, 1)]);
519
520 let result = interactive.segment().unwrap();
522 assert_eq!(result.dim(), (4, 4));
523
524 interactive.add_foreground_seeds(&[(2, 3)]);
526 let result2 = interactive.segment().unwrap();
527 assert_eq!(result2.dim(), (4, 4));
528 }
529}