1use crate::error::{Result, VisionError};
17use scirs2_core::ndarray::{Array2, Ix2};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum ForegroundLabel {
29 Background,
31 Foreground,
33 Shadow,
35}
36
37pub fn mask_to_binary(mask: &Array2<ForegroundLabel>) -> Array2<f64> {
39 mask.mapv(|l| match l {
40 ForegroundLabel::Foreground => 1.0,
41 _ => 0.0,
42 })
43}
44
45#[derive(Debug, Clone)]
47pub struct BackgroundConfig {
48 pub learning_rate: f64,
50 pub fg_threshold: f64,
53 pub detect_shadows: bool,
55 pub shadow_params: ShadowParams,
57}
58
59impl Default for BackgroundConfig {
60 fn default() -> Self {
61 Self {
62 learning_rate: 0.05,
63 fg_threshold: 0.15,
64 detect_shadows: false,
65 shadow_params: ShadowParams::default(),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
77pub struct ShadowParams {
78 pub tau_lo: f64,
80 pub tau_hi: f64,
82}
83
84impl Default for ShadowParams {
85 fn default() -> Self {
86 Self {
87 tau_lo: 0.4,
88 tau_hi: 0.9,
89 }
90 }
91}
92
93fn classify_pixel(
98 pixel: f64,
99 bg_value: f64,
100 threshold: f64,
101 detect_shadows: bool,
102 shadow_params: &ShadowParams,
103) -> ForegroundLabel {
104 let diff = (pixel - bg_value).abs();
105 if diff < threshold {
106 return ForegroundLabel::Background;
107 }
108 if detect_shadows && bg_value > 1e-9 {
109 let ratio = pixel / bg_value;
110 if ratio >= shadow_params.tau_lo && ratio <= shadow_params.tau_hi {
111 return ForegroundLabel::Shadow;
112 }
113 }
114 ForegroundLabel::Foreground
115}
116
117#[derive(Debug, Clone)]
132pub struct RunningAverageBackground {
133 background: Option<Array2<f64>>,
135 config: BackgroundConfig,
137 frame_count: u64,
139}
140
141impl RunningAverageBackground {
142 pub fn new(config: BackgroundConfig) -> Result<Self> {
144 if config.learning_rate <= 0.0 || config.learning_rate > 1.0 {
145 return Err(VisionError::InvalidParameter(
146 "learning_rate must be in (0, 1]".into(),
147 ));
148 }
149 if config.fg_threshold <= 0.0 {
150 return Err(VisionError::InvalidParameter(
151 "fg_threshold must be positive".into(),
152 ));
153 }
154 Ok(Self {
155 background: None,
156 config,
157 frame_count: 0,
158 })
159 }
160
161 pub fn default_config() -> Result<Self> {
163 Self::new(BackgroundConfig::default())
164 }
165
166 pub fn apply(&mut self, frame: &Array2<f64>) -> Result<Array2<ForegroundLabel>> {
168 let shape = frame.raw_dim();
169 match &mut self.background {
170 None => {
171 self.background = Some(frame.clone());
173 self.frame_count = 1;
174 Ok(Array2::from_elem(shape, ForegroundLabel::Background))
175 }
176 Some(bg) => {
177 if bg.raw_dim() != shape {
178 return Err(VisionError::DimensionMismatch(format!(
179 "Frame shape {:?} does not match background {:?}",
180 shape,
181 bg.raw_dim()
182 )));
183 }
184 self.frame_count += 1;
185 let alpha = self.config.learning_rate;
186 let threshold = self.config.fg_threshold;
187 let detect_shadows = self.config.detect_shadows;
188 let shadow_params = &self.config.shadow_params;
189
190 let rows = shape[0];
191 let cols = shape[1];
192 let mut mask = Array2::from_elem(shape, ForegroundLabel::Background);
193
194 for r in 0..rows {
195 for c in 0..cols {
196 let p = frame[[r, c]];
197 let b = bg[[r, c]];
198 mask[[r, c]] =
199 classify_pixel(p, b, threshold, detect_shadows, shadow_params);
200 bg[[r, c]] = (1.0 - alpha) * b + alpha * p;
202 }
203 }
204 Ok(mask)
205 }
206 }
207 }
208
209 pub fn background(&self) -> Option<&Array2<f64>> {
211 self.background.as_ref()
212 }
213
214 pub fn frame_count(&self) -> u64 {
216 self.frame_count
217 }
218
219 pub fn set_learning_rate(&mut self, rate: f64) -> Result<()> {
221 if rate <= 0.0 || rate > 1.0 {
222 return Err(VisionError::InvalidParameter(
223 "learning_rate must be in (0, 1]".into(),
224 ));
225 }
226 self.config.learning_rate = rate;
227 Ok(())
228 }
229}
230
231#[derive(Debug, Clone)]
237struct GaussianComponent {
238 mean: f64,
239 variance: f64,
240 weight: f64,
241}
242
243#[derive(Debug, Clone)]
251pub struct GmmBackground {
252 max_components: usize,
254 models: Option<Vec<Vec<Vec<GaussianComponent>>>>,
256 config: BackgroundConfig,
258 match_threshold: f64,
260 bg_weight_threshold: f64,
263 min_variance: f64,
265 frame_count: u64,
267 frame_rows: usize,
269 frame_cols: usize,
271}
272
273impl GmmBackground {
274 pub fn new(max_components: usize, config: BackgroundConfig) -> Result<Self> {
280 if max_components == 0 {
281 return Err(VisionError::InvalidParameter(
282 "max_components must be >= 1".into(),
283 ));
284 }
285 if config.learning_rate <= 0.0 || config.learning_rate > 1.0 {
286 return Err(VisionError::InvalidParameter(
287 "learning_rate must be in (0, 1]".into(),
288 ));
289 }
290 Ok(Self {
291 max_components,
292 models: None,
293 config,
294 match_threshold: 2.5,
295 bg_weight_threshold: 0.7,
296 min_variance: 0.001,
297 frame_count: 0,
298 frame_rows: 0,
299 frame_cols: 0,
300 })
301 }
302
303 pub fn default_config() -> Result<Self> {
305 Self::new(5, BackgroundConfig::default())
306 }
307
308 pub fn set_match_threshold(&mut self, t: f64) -> Result<()> {
310 if t <= 0.0 {
311 return Err(VisionError::InvalidParameter(
312 "match_threshold must be positive".into(),
313 ));
314 }
315 self.match_threshold = t;
316 Ok(())
317 }
318
319 pub fn apply(&mut self, frame: &Array2<f64>) -> Result<Array2<ForegroundLabel>> {
321 let rows = frame.nrows();
322 let cols = frame.ncols();
323 let shape: Ix2 = frame.raw_dim();
324
325 if self.models.is_none() {
326 let mut models = Vec::with_capacity(rows);
328 for r in 0..rows {
329 let mut row_models = Vec::with_capacity(cols);
330 for c in 0..cols {
331 let comp = GaussianComponent {
332 mean: frame[[r, c]],
333 variance: 0.02,
334 weight: 1.0,
335 };
336 row_models.push(vec![comp]);
337 }
338 models.push(row_models);
339 }
340 self.models = Some(models);
341 self.frame_rows = rows;
342 self.frame_cols = cols;
343 self.frame_count = 1;
344 return Ok(Array2::from_elem(shape, ForegroundLabel::Background));
345 }
346
347 if rows != self.frame_rows || cols != self.frame_cols {
348 return Err(VisionError::DimensionMismatch(format!(
349 "Frame ({rows}x{cols}) differs from model ({}x{})",
350 self.frame_rows, self.frame_cols,
351 )));
352 }
353
354 self.frame_count += 1;
355 let alpha = self.config.learning_rate;
356 let fg_thresh = self.config.fg_threshold;
357 let detect_shadows = self.config.detect_shadows;
358 let shadow_params = self.config.shadow_params.clone();
359 let match_thresh = self.match_threshold;
360 let max_k = self.max_components;
361 let bg_wt = self.bg_weight_threshold;
362 let min_var = self.min_variance;
363
364 let models = self
365 .models
366 .as_mut()
367 .ok_or_else(|| VisionError::OperationError("Models not initialised".into()))?;
368
369 let mut mask = Array2::from_elem(shape, ForegroundLabel::Background);
370
371 for r in 0..rows {
372 for c in 0..cols {
373 let pixel = frame[[r, c]];
374 let comps = &mut models[r][c];
375
376 let mut matched = false;
378 let mut matched_bg = false;
379
380 comps.sort_by(|a, b| {
382 let ra = a.weight / a.variance.sqrt().max(1e-12);
383 let rb = b.weight / b.variance.sqrt().max(1e-12);
384 rb.partial_cmp(&ra).unwrap_or(std::cmp::Ordering::Equal)
385 });
386
387 let mut cum_weight = 0.0;
389 let mut bg_count = 0;
390 for comp in comps.iter() {
391 cum_weight += comp.weight;
392 bg_count += 1;
393 if cum_weight >= bg_wt {
394 break;
395 }
396 }
397
398 for (i, comp) in comps.iter_mut().enumerate() {
399 let sigma = comp.variance.sqrt().max(1e-12);
400 let d = (pixel - comp.mean).abs() / sigma;
401 if d < match_thresh {
402 comp.mean = (1.0 - alpha) * comp.mean + alpha * pixel;
404 let diff = pixel - comp.mean;
405 comp.variance =
406 ((1.0 - alpha) * comp.variance + alpha * diff * diff).max(min_var);
407 comp.weight = (1.0 - alpha) * comp.weight + alpha;
408 matched = true;
409 if i < bg_count {
410 matched_bg = true;
411 }
412 break;
413 }
414 }
415
416 let mut total_w = 0.0;
418 for comp in comps.iter_mut() {
419 comp.weight *= 1.0 - alpha;
420 total_w += comp.weight;
421 }
422
423 if !matched {
424 let new_comp = GaussianComponent {
426 mean: pixel,
427 variance: 0.02,
428 weight: alpha,
429 };
430 total_w += alpha;
431 if comps.len() < max_k {
432 comps.push(new_comp);
433 } else {
434 if let Some(last) = comps.last_mut() {
436 total_w -= last.weight;
437 *last = new_comp;
438 total_w += last.weight;
439 }
440 }
441 }
442
443 if total_w > 0.0 {
445 for comp in comps.iter_mut() {
446 comp.weight /= total_w;
447 }
448 }
449
450 if matched_bg {
452 if detect_shadows {
454 let bg_mean = comps.first().map(|c| c.mean).unwrap_or(0.0);
456 mask[[r, c]] = classify_pixel(
457 pixel,
458 bg_mean,
459 fg_thresh,
460 detect_shadows,
461 &shadow_params,
462 );
463 } else {
464 mask[[r, c]] = ForegroundLabel::Background;
465 }
466 } else {
467 mask[[r, c]] = ForegroundLabel::Foreground;
468 }
469 }
470 }
471
472 Ok(mask)
473 }
474
475 pub fn background_image(&self) -> Option<Array2<f64>> {
477 let models = self.models.as_ref()?;
478 let mut bg = Array2::zeros((self.frame_rows, self.frame_cols));
479 for r in 0..self.frame_rows {
480 for c in 0..self.frame_cols {
481 if let Some(comp) = models[r][c].first() {
482 bg[[r, c]] = comp.mean;
483 }
484 }
485 }
486 Some(bg)
487 }
488
489 pub fn frame_count(&self) -> u64 {
491 self.frame_count
492 }
493}
494
495#[derive(Debug, Clone)]
506pub struct MedianBackground {
507 median: Option<Array2<f64>>,
509 step: f64,
511 config: BackgroundConfig,
513 frame_count: u64,
515}
516
517impl MedianBackground {
518 pub fn new(step: f64, config: BackgroundConfig) -> Result<Self> {
522 if step <= 0.0 {
523 return Err(VisionError::InvalidParameter(
524 "step must be positive".into(),
525 ));
526 }
527 Ok(Self {
528 median: None,
529 step,
530 config,
531 frame_count: 0,
532 })
533 }
534
535 pub fn default_config() -> Result<Self> {
537 Self::new(0.005, BackgroundConfig::default())
538 }
539
540 pub fn apply(&mut self, frame: &Array2<f64>) -> Result<Array2<ForegroundLabel>> {
542 let shape = frame.raw_dim();
543 match &mut self.median {
544 None => {
545 self.median = Some(frame.clone());
546 self.frame_count = 1;
547 Ok(Array2::from_elem(shape, ForegroundLabel::Background))
548 }
549 Some(med) => {
550 if med.raw_dim() != shape {
551 return Err(VisionError::DimensionMismatch(format!(
552 "Frame shape {:?} vs median {:?}",
553 shape,
554 med.raw_dim()
555 )));
556 }
557 self.frame_count += 1;
558 let step = self.step;
559 let threshold = self.config.fg_threshold;
560 let detect_shadows = self.config.detect_shadows;
561 let shadow_params = &self.config.shadow_params;
562
563 let rows = shape[0];
564 let cols = shape[1];
565 let mut mask = Array2::from_elem(shape, ForegroundLabel::Background);
566
567 for r in 0..rows {
568 for c in 0..cols {
569 let p = frame[[r, c]];
570 let m = med[[r, c]];
571 mask[[r, c]] =
572 classify_pixel(p, m, threshold, detect_shadows, shadow_params);
573
574 if p > m {
576 med[[r, c]] = (m + step).min(1.0);
577 } else if p < m {
578 med[[r, c]] = (m - step).max(0.0);
579 }
580 }
581 }
582 Ok(mask)
583 }
584 }
585 }
586
587 pub fn background(&self) -> Option<&Array2<f64>> {
589 self.median.as_ref()
590 }
591
592 pub fn frame_count(&self) -> u64 {
594 self.frame_count
595 }
596}
597
598pub fn adaptive_learning_rate(
612 mask: &Array2<ForegroundLabel>,
613 base_rate: f64,
614 min_rate: f64,
615) -> f64 {
616 let total = mask.len() as f64;
617 if total == 0.0 {
618 return base_rate;
619 }
620 let fg_count = mask
621 .iter()
622 .filter(|&&l| l == ForegroundLabel::Foreground)
623 .count() as f64;
624 let fg_fraction = fg_count / total;
625 let rate = base_rate * (1.0 - fg_fraction);
627 rate.max(min_rate)
628}
629
630#[cfg(test)]
635mod tests {
636 use super::*;
637 use scirs2_core::ndarray::Array2;
638
639 fn uniform_frame(val: f64) -> Array2<f64> {
641 Array2::from_elem((8, 8), val)
642 }
643
644 fn frame_with_object(bg: f64, fg: f64, top: usize, left: usize, size: usize) -> Array2<f64> {
646 let mut f = Array2::from_elem((8, 8), bg);
647 for r in top..(top + size).min(8) {
648 for c in left..(left + size).min(8) {
649 f[[r, c]] = fg;
650 }
651 }
652 f
653 }
654
655 #[test]
658 fn test_running_avg_first_frame_all_bg() {
659 let mut model =
660 RunningAverageBackground::default_config().expect("default config should succeed");
661 let frame = uniform_frame(0.5);
662 let mask = model.apply(&frame).expect("apply should succeed");
663 assert!(mask.iter().all(|&l| l == ForegroundLabel::Background));
664 assert_eq!(model.frame_count(), 1);
665 }
666
667 #[test]
668 fn test_running_avg_detects_foreground() {
669 let mut model = RunningAverageBackground::new(BackgroundConfig {
670 learning_rate: 0.01,
671 fg_threshold: 0.1,
672 ..Default::default()
673 })
674 .expect("config ok");
675 let blank = uniform_frame(0.2);
677 for _ in 0..10 {
678 model.apply(&blank).expect("apply");
679 }
680 let obj = frame_with_object(0.2, 0.9, 2, 2, 3);
682 let mask = model.apply(&obj).expect("apply");
683 for r in 2..5 {
685 for c in 2..5 {
686 assert_eq!(mask[[r, c]], ForegroundLabel::Foreground);
687 }
688 }
689 assert_eq!(mask[[0, 0]], ForegroundLabel::Background);
691 }
692
693 #[test]
694 fn test_running_avg_shape_mismatch() {
695 let mut model = RunningAverageBackground::default_config().expect("ok");
696 model.apply(&uniform_frame(0.5)).expect("first apply");
697 let wrong = Array2::from_elem((4, 4), 0.5);
698 let res = model.apply(&wrong);
699 assert!(res.is_err());
700 }
701
702 #[test]
703 fn test_running_avg_invalid_lr() {
704 let res = RunningAverageBackground::new(BackgroundConfig {
705 learning_rate: 0.0,
706 ..Default::default()
707 });
708 assert!(res.is_err());
709 let res2 = RunningAverageBackground::new(BackgroundConfig {
710 learning_rate: 1.5,
711 ..Default::default()
712 });
713 assert!(res2.is_err());
714 }
715
716 #[test]
717 fn test_running_avg_shadow_detection() {
718 let mut model = RunningAverageBackground::new(BackgroundConfig {
719 learning_rate: 0.01,
720 fg_threshold: 0.05,
721 detect_shadows: true,
722 shadow_params: ShadowParams {
723 tau_lo: 0.4,
724 tau_hi: 0.9,
725 },
726 })
727 .expect("ok");
728 let bg_val = 0.8;
729 let blank = uniform_frame(bg_val);
730 for _ in 0..20 {
731 model.apply(&blank).expect("ok");
732 }
733 let shadow_val = 0.55; let shadow_frame = frame_with_object(bg_val, shadow_val, 1, 1, 2);
736 let mask = model.apply(&shadow_frame).expect("ok");
737 for r in 1..3 {
738 for c in 1..3 {
739 assert_eq!(mask[[r, c]], ForegroundLabel::Shadow);
740 }
741 }
742 }
743
744 #[test]
745 fn test_running_avg_background_converges() {
746 let mut model = RunningAverageBackground::new(BackgroundConfig {
747 learning_rate: 0.5,
748 fg_threshold: 0.05,
749 ..Default::default()
750 })
751 .expect("ok");
752 let frame = uniform_frame(0.6);
753 for _ in 0..50 {
754 model.apply(&frame).expect("ok");
755 }
756 let bg = model.background().expect("should exist");
757 for &val in bg.iter() {
758 assert!((val - 0.6).abs() < 0.01, "bg should converge to 0.6");
759 }
760 }
761
762 #[test]
763 fn test_running_avg_set_lr() {
764 let mut model = RunningAverageBackground::default_config().expect("ok");
765 assert!(model.set_learning_rate(0.1).is_ok());
766 assert!(model.set_learning_rate(0.0).is_err());
767 assert!(model.set_learning_rate(1.5).is_err());
768 }
769
770 #[test]
773 fn test_gmm_first_frame() {
774 let mut model = GmmBackground::default_config().expect("ok");
775 let frame = uniform_frame(0.5);
776 let mask = model.apply(&frame).expect("apply");
777 assert!(mask.iter().all(|&l| l == ForegroundLabel::Background));
778 }
779
780 #[test]
781 fn test_gmm_detects_foreground() {
782 let mut model = GmmBackground::new(
783 3,
784 BackgroundConfig {
785 learning_rate: 0.1,
786 fg_threshold: 0.1,
787 ..Default::default()
788 },
789 )
790 .expect("ok");
791 let blank = uniform_frame(0.3);
792 for _ in 0..15 {
793 model.apply(&blank).expect("ok");
794 }
795 let obj = frame_with_object(0.3, 0.95, 3, 3, 2);
796 let mask = model.apply(&obj).expect("ok");
797 for r in 3..5 {
798 for c in 3..5 {
799 assert_eq!(mask[[r, c]], ForegroundLabel::Foreground);
800 }
801 }
802 }
803
804 #[test]
805 fn test_gmm_shape_mismatch() {
806 let mut model = GmmBackground::default_config().expect("ok");
807 model.apply(&uniform_frame(0.5)).expect("first ok");
808 let wrong = Array2::from_elem((4, 4), 0.5);
809 assert!(model.apply(&wrong).is_err());
810 }
811
812 #[test]
813 fn test_gmm_background_image() {
814 let mut model = GmmBackground::default_config().expect("ok");
815 let frame = uniform_frame(0.5);
816 model.apply(&frame).expect("ok");
817 let bg = model.background_image().expect("should have bg");
818 assert_eq!(bg.nrows(), 8);
819 assert_eq!(bg.ncols(), 8);
820 }
821
822 #[test]
823 fn test_gmm_invalid_params() {
824 assert!(GmmBackground::new(0, BackgroundConfig::default()).is_err());
825 assert!(GmmBackground::new(
826 3,
827 BackgroundConfig {
828 learning_rate: -0.1,
829 ..Default::default()
830 }
831 )
832 .is_err());
833 }
834
835 #[test]
836 fn test_gmm_shadow_mode() {
837 let mut model = GmmBackground::new(
838 3,
839 BackgroundConfig {
840 learning_rate: 0.1,
841 fg_threshold: 0.05,
842 detect_shadows: true,
843 shadow_params: ShadowParams {
844 tau_lo: 0.4,
845 tau_hi: 0.9,
846 },
847 },
848 )
849 .expect("ok");
850 let blank = uniform_frame(0.8);
851 for _ in 0..20 {
852 model.apply(&blank).expect("ok");
853 }
854 let shadow_frame = frame_with_object(0.8, 0.55, 0, 0, 2);
856 let mask = model.apply(&shadow_frame).expect("ok");
857 let non_bg: usize = mask
860 .iter()
861 .filter(|&&l| l != ForegroundLabel::Background)
862 .count();
863 assert!(non_bg > 0, "Expected some non-background pixels");
864 }
865
866 #[test]
869 fn test_median_first_frame() {
870 let mut model = MedianBackground::default_config().expect("ok");
871 let frame = uniform_frame(0.5);
872 let mask = model.apply(&frame).expect("ok");
873 assert!(mask.iter().all(|&l| l == ForegroundLabel::Background));
874 }
875
876 #[test]
877 fn test_median_detects_foreground() {
878 let mut model = MedianBackground::new(
879 0.01,
880 BackgroundConfig {
881 fg_threshold: 0.1,
882 ..Default::default()
883 },
884 )
885 .expect("ok");
886 let blank = uniform_frame(0.3);
887 for _ in 0..20 {
888 model.apply(&blank).expect("ok");
889 }
890 let obj = frame_with_object(0.3, 0.9, 1, 1, 3);
891 let mask = model.apply(&obj).expect("ok");
892 for r in 1..4 {
893 for c in 1..4 {
894 assert_eq!(mask[[r, c]], ForegroundLabel::Foreground);
895 }
896 }
897 }
898
899 #[test]
900 fn test_median_shape_mismatch() {
901 let mut model = MedianBackground::default_config().expect("ok");
902 model.apply(&uniform_frame(0.5)).expect("ok");
903 assert!(model.apply(&Array2::from_elem((4, 4), 0.5)).is_err());
904 }
905
906 #[test]
907 fn test_median_invalid_step() {
908 assert!(MedianBackground::new(0.0, BackgroundConfig::default()).is_err());
909 assert!(MedianBackground::new(-1.0, BackgroundConfig::default()).is_err());
910 }
911
912 #[test]
913 fn test_median_background_converges() {
914 let mut model = MedianBackground::new(
915 0.05,
916 BackgroundConfig {
917 fg_threshold: 0.1,
918 ..Default::default()
919 },
920 )
921 .expect("ok");
922 let frame = uniform_frame(0.7);
923 for _ in 0..100 {
924 model.apply(&frame).expect("ok");
925 }
926 let bg = model.background().expect("bg");
927 for &v in bg.iter() {
928 assert!(
929 (v - 0.7).abs() < 0.06,
930 "median should approach 0.7, got {v}"
931 );
932 }
933 }
934
935 #[test]
938 fn test_mask_to_binary() {
939 let mut mask = Array2::from_elem((3, 3), ForegroundLabel::Background);
940 mask[[0, 0]] = ForegroundLabel::Foreground;
941 mask[[1, 1]] = ForegroundLabel::Shadow;
942 let bin = mask_to_binary(&mask);
943 assert!((bin[[0, 0]] - 1.0).abs() < 1e-12);
944 assert!(bin[[1, 1]].abs() < 1e-12);
945 assert!(bin[[2, 2]].abs() < 1e-12);
946 }
947
948 #[test]
951 fn test_adaptive_lr_static_scene() {
952 let mask = Array2::from_elem((4, 4), ForegroundLabel::Background);
953 let rate = adaptive_learning_rate(&mask, 0.05, 0.001);
954 assert!((rate - 0.05).abs() < 1e-9, "all bg => full rate");
955 }
956
957 #[test]
958 fn test_adaptive_lr_high_motion() {
959 let mask = Array2::from_elem((4, 4), ForegroundLabel::Foreground);
960 let rate = adaptive_learning_rate(&mask, 0.05, 0.001);
961 assert!(
962 (rate - 0.001).abs() < 1e-9,
963 "all fg => rate should be at min_rate"
964 );
965 }
966
967 #[test]
968 fn test_adaptive_lr_partial_fg() {
969 let mut mask = Array2::from_elem((4, 4), ForegroundLabel::Background);
970 mask[[0, 0]] = ForegroundLabel::Foreground;
972 mask[[0, 1]] = ForegroundLabel::Foreground;
973 mask[[1, 0]] = ForegroundLabel::Foreground;
974 mask[[1, 1]] = ForegroundLabel::Foreground;
975 let rate = adaptive_learning_rate(&mask, 0.05, 0.001);
976 let expected = 0.05 * (1.0 - 0.25);
977 assert!(
978 (rate - expected).abs() < 1e-9,
979 "25% fg => rate = {expected}"
980 );
981 }
982}