1use crate::error::{Result, VisionError};
19use scirs2_core::ndarray::Array2;
20
21#[derive(Debug, Clone, Copy, PartialEq)]
27pub struct BBox {
28 pub top: f64,
30 pub left: f64,
32 pub height: f64,
34 pub width: f64,
36}
37
38impl BBox {
39 pub fn new(top: f64, left: f64, height: f64, width: f64) -> Self {
41 Self {
42 top,
43 left,
44 height,
45 width,
46 }
47 }
48
49 pub fn center_row(&self) -> f64 {
51 self.top + self.height / 2.0
52 }
53
54 pub fn center_col(&self) -> f64 {
56 self.left + self.width / 2.0
57 }
58
59 pub fn area(&self) -> f64 {
61 self.height * self.width
62 }
63
64 pub fn iou(&self, other: &BBox) -> f64 {
66 let r1 = self.top;
67 let r2 = self.top + self.height;
68 let c1 = self.left;
69 let c2 = self.left + self.width;
70
71 let or1 = other.top;
72 let or2 = other.top + other.height;
73 let oc1 = other.left;
74 let oc2 = other.left + other.width;
75
76 let inter_r1 = r1.max(or1);
77 let inter_r2 = r2.min(or2);
78 let inter_c1 = c1.max(oc1);
79 let inter_c2 = c2.min(oc2);
80
81 if inter_r2 <= inter_r1 || inter_c2 <= inter_c1 {
82 return 0.0;
83 }
84 let inter_area = (inter_r2 - inter_r1) * (inter_c2 - inter_c1);
85 let union_area = self.area() + other.area() - inter_area;
86 if union_area > 0.0 {
87 inter_area / union_area
88 } else {
89 0.0
90 }
91 }
92
93 pub fn center_distance(&self, other: &BBox) -> f64 {
95 let dr = self.center_row() - other.center_row();
96 let dc = self.center_col() - other.center_col();
97 (dr * dr + dc * dc).sqrt()
98 }
99}
100
101#[derive(Debug, Clone)]
107pub struct MeanShiftConfig {
108 pub max_iterations: usize,
110 pub epsilon: f64,
112 pub num_bins: usize,
114}
115
116impl Default for MeanShiftConfig {
117 fn default() -> Self {
118 Self {
119 max_iterations: 30,
120 epsilon: 1.0,
121 num_bins: 16,
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
131pub struct MeanShiftTracker {
132 window: BBox,
134 target_hist: Vec<f64>,
136 config: MeanShiftConfig,
138}
139
140impl MeanShiftTracker {
141 pub fn new(frame: &Array2<f64>, window: BBox, config: MeanShiftConfig) -> Result<Self> {
143 if config.num_bins == 0 {
144 return Err(VisionError::InvalidParameter("num_bins must be > 0".into()));
145 }
146 let hist = compute_histogram(frame, &window, config.num_bins);
147 Ok(Self {
148 window,
149 target_hist: hist,
150 config,
151 })
152 }
153
154 pub fn update(&mut self, frame: &Array2<f64>) -> Result<BBox> {
156 let rows = frame.nrows();
157 let cols = frame.ncols();
158 let bins = self.config.num_bins;
159
160 for _ in 0..self.config.max_iterations {
161 let (mean_r, mean_c, _total_w) =
163 compute_mean_shift_center(frame, &self.window, &self.target_hist, bins, rows, cols);
164
165 let old_cr = self.window.center_row();
166 let old_cc = self.window.center_col();
167
168 let new_top = (mean_r - self.window.height / 2.0).max(0.0);
169 let new_left = (mean_c - self.window.width / 2.0).max(0.0);
170
171 self.window.top = new_top.min((rows as f64) - self.window.height);
172 self.window.left = new_left.min((cols as f64) - self.window.width);
173
174 let dr = self.window.center_row() - old_cr;
175 let dc = self.window.center_col() - old_cc;
176 if (dr * dr + dc * dc).sqrt() < self.config.epsilon {
177 break;
178 }
179 }
180
181 Ok(self.window)
182 }
183
184 pub fn window(&self) -> BBox {
186 self.window
187 }
188}
189
190#[derive(Debug, Clone)]
199pub struct CamShiftTracker {
200 window: BBox,
202 target_hist: Vec<f64>,
204 config: MeanShiftConfig,
206 angle: f64,
208}
209
210impl CamShiftTracker {
211 pub fn new(frame: &Array2<f64>, window: BBox, config: MeanShiftConfig) -> Result<Self> {
213 if config.num_bins == 0 {
214 return Err(VisionError::InvalidParameter("num_bins must be > 0".into()));
215 }
216 let hist = compute_histogram(frame, &window, config.num_bins);
217 Ok(Self {
218 window,
219 target_hist: hist,
220 config,
221 angle: 0.0,
222 })
223 }
224
225 pub fn update(&mut self, frame: &Array2<f64>) -> Result<(BBox, f64)> {
228 let rows = frame.nrows();
229 let cols = frame.ncols();
230 let bins = self.config.num_bins;
231
232 for _ in 0..self.config.max_iterations {
234 let (mean_r, mean_c, _) =
235 compute_mean_shift_center(frame, &self.window, &self.target_hist, bins, rows, cols);
236
237 let old_cr = self.window.center_row();
238 let old_cc = self.window.center_col();
239
240 self.window.top = (mean_r - self.window.height / 2.0)
241 .max(0.0)
242 .min((rows as f64) - self.window.height);
243 self.window.left = (mean_c - self.window.width / 2.0)
244 .max(0.0)
245 .min((cols as f64) - self.window.width);
246
247 let dr = self.window.center_row() - old_cr;
248 let dc = self.window.center_col() - old_cc;
249 if (dr * dr + dc * dc).sqrt() < self.config.epsilon {
250 break;
251 }
252 }
253
254 let (m00, m10, m01, m20, m02, m11) =
256 compute_moments(frame, &self.window, &self.target_hist, bins, rows, cols);
257
258 if m00 > 1e-9 {
259 let xc = m10 / m00;
260 let yc = m01 / m00;
261
262 let mu20 = m20 / m00 - xc * xc;
264 let mu02 = m02 / m00 - yc * yc;
265 let mu11 = m11 / m00 - xc * yc;
266
267 self.angle = 0.5 * (2.0 * mu11).atan2(mu20 - mu02);
269
270 let s = (m00 / 256.0).sqrt().max(2.0);
272 self.window.width = s * 2.0;
273 self.window.height = s * 2.0;
274
275 self.window.top = (yc - self.window.height / 2.0)
277 .max(0.0)
278 .min((rows as f64) - self.window.height);
279 self.window.left = (xc - self.window.width / 2.0)
280 .max(0.0)
281 .min((cols as f64) - self.window.width);
282 }
283
284 Ok((self.window, self.angle))
285 }
286
287 pub fn window(&self) -> BBox {
289 self.window
290 }
291
292 pub fn angle(&self) -> f64 {
294 self.angle
295 }
296}
297
298fn compute_histogram(frame: &Array2<f64>, bbox: &BBox, bins: usize) -> Vec<f64> {
303 let mut hist = vec![0.0; bins];
304 let rows = frame.nrows();
305 let cols = frame.ncols();
306 let r_start = (bbox.top as usize).min(rows);
307 let r_end = ((bbox.top + bbox.height) as usize).min(rows);
308 let c_start = (bbox.left as usize).min(cols);
309 let c_end = ((bbox.left + bbox.width) as usize).min(cols);
310 let mut total = 0.0;
311 for r in r_start..r_end {
312 for c in c_start..c_end {
313 let val = frame[[r, c]].clamp(0.0, 1.0);
314 let bin = ((val * (bins as f64 - 1.0)).round() as usize).min(bins - 1);
315 hist[bin] += 1.0;
316 total += 1.0;
317 }
318 }
319 if total > 0.0 {
320 for h in hist.iter_mut() {
321 *h /= total;
322 }
323 }
324 hist
325}
326
327fn compute_mean_shift_center(
328 frame: &Array2<f64>,
329 bbox: &BBox,
330 target_hist: &[f64],
331 bins: usize,
332 rows: usize,
333 cols: usize,
334) -> (f64, f64, f64) {
335 let candidate_hist = compute_histogram(frame, bbox, bins);
337
338 let weights: Vec<f64> = (0..bins)
340 .map(|i| {
341 if candidate_hist[i] > 1e-12 {
342 (target_hist[i] / candidate_hist[i]).sqrt().min(10.0)
343 } else {
344 0.0
345 }
346 })
347 .collect();
348
349 let r_start = (bbox.top as usize).min(rows);
350 let r_end = ((bbox.top + bbox.height) as usize).min(rows);
351 let c_start = (bbox.left as usize).min(cols);
352 let c_end = ((bbox.left + bbox.width) as usize).min(cols);
353
354 let mut sum_r = 0.0;
355 let mut sum_c = 0.0;
356 let mut sum_w = 0.0;
357
358 for r in r_start..r_end {
359 for c in c_start..c_end {
360 let val = frame[[r, c]].clamp(0.0, 1.0);
361 let bin = ((val * (bins as f64 - 1.0)).round() as usize).min(bins - 1);
362 let w = weights[bin];
363 sum_r += r as f64 * w;
364 sum_c += c as f64 * w;
365 sum_w += w;
366 }
367 }
368
369 if sum_w > 0.0 {
370 (sum_r / sum_w, sum_c / sum_w, sum_w)
371 } else {
372 (bbox.center_row(), bbox.center_col(), 0.0)
373 }
374}
375
376fn compute_moments(
377 frame: &Array2<f64>,
378 bbox: &BBox,
379 target_hist: &[f64],
380 bins: usize,
381 rows: usize,
382 cols: usize,
383) -> (f64, f64, f64, f64, f64, f64) {
384 let candidate_hist = compute_histogram(frame, bbox, bins);
385 let weights: Vec<f64> = (0..bins)
386 .map(|i| {
387 if candidate_hist[i] > 1e-12 {
388 (target_hist[i] / candidate_hist[i]).sqrt().min(10.0)
389 } else {
390 0.0
391 }
392 })
393 .collect();
394
395 let r_start = (bbox.top as usize).min(rows);
396 let r_end = ((bbox.top + bbox.height) as usize).min(rows);
397 let c_start = (bbox.left as usize).min(cols);
398 let c_end = ((bbox.left + bbox.width) as usize).min(cols);
399
400 let mut m00 = 0.0;
401 let mut m10 = 0.0; let mut m01 = 0.0; let mut m20 = 0.0;
404 let mut m02 = 0.0;
405 let mut m11 = 0.0;
406
407 for r in r_start..r_end {
408 for c in c_start..c_end {
409 let val = frame[[r, c]].clamp(0.0, 1.0);
410 let bin = ((val * (bins as f64 - 1.0)).round() as usize).min(bins - 1);
411 let w = weights[bin];
412 let x = c as f64;
413 let y = r as f64;
414 m00 += w;
415 m10 += x * w;
416 m01 += y * w;
417 m20 += x * x * w;
418 m02 += y * y * w;
419 m11 += x * y * w;
420 }
421 }
422
423 (m00, m10, m01, m20, m02, m11)
424}
425
426#[derive(Debug, Clone, Copy, PartialEq, Eq)]
432pub enum KalmanModel {
433 ConstantVelocity,
435 ConstantAcceleration,
437}
438
439#[derive(Debug, Clone)]
441pub struct KalmanTracker {
442 state: Vec<f64>,
444 cov: Vec<f64>,
446 dim: usize,
448 process_noise: f64,
450 measurement_noise: f64,
452 model: KalmanModel,
454}
455
456impl KalmanTracker {
457 pub fn new(x: f64, y: f64, model: KalmanModel) -> Self {
459 let dim = match model {
460 KalmanModel::ConstantVelocity => 4,
461 KalmanModel::ConstantAcceleration => 6,
462 };
463 let mut state = vec![0.0; dim];
464 state[0] = x;
465 state[1] = y;
466 let mut cov = vec![0.0; dim * dim];
468 for i in 0..dim {
469 cov[i * dim + i] = 1000.0;
470 }
471 Self {
472 state,
473 cov,
474 dim,
475 process_noise: 1.0,
476 measurement_noise: 1.0,
477 model,
478 }
479 }
480
481 pub fn set_noise(&mut self, process: f64, measurement: f64) {
483 self.process_noise = process;
484 self.measurement_noise = measurement;
485 }
486
487 pub fn predict(&mut self) {
489 let n = self.dim;
490 let f = self.transition_matrix();
492 let new_state = mat_vec_mul(&f, &self.state, n);
493 self.state = new_state;
494
495 let fp = mat_mat_mul(&f, &self.cov, n);
497 let ft = transpose(&f, n);
498 let fp_ft = mat_mat_mul(&fp, &ft, n);
499 let q = self.process_noise_matrix();
500 self.cov = mat_add(&fp_ft, &q, n);
501 }
502
503 pub fn update(&mut self, mx: f64, my: f64) {
505 let n = self.dim;
506 let h = self.observation_matrix();
507 let m = 2; let hx = mat_vec_mul_rect(&h, &self.state, m, n);
511 let z = [mx, my];
512 let innovation = vec![z[0] - hx[0], z[1] - hx[1]];
513
514 let hp = mat_mat_mul_rect(&h, &self.cov, m, n, n);
516 let ht = transpose_rect(&h, m, n);
517 let hp_ht = mat_mat_mul_rect(&hp, &ht, m, n, m);
518 let r = self.measurement_noise_matrix();
519 let s = mat_add_small(&hp_ht, &r, m);
520
521 let p_ht = mat_mat_mul_rect(&self.cov, &ht, n, n, m);
523 let s_inv = invert_2x2(&s);
524 let k = mat_mat_mul_rect(&p_ht, &s_inv, n, m, m);
525
526 let ky = mat_vec_mul_rect(&k, &innovation, n, m);
528 for (i, &ky_i) in ky.iter().enumerate().take(n) {
529 self.state[i] += ky_i;
530 }
531
532 let kh = mat_mat_mul_rect(&k, &h, n, m, n);
534 let mut eye = vec![0.0; n * n];
535 for i in 0..n {
536 eye[i * n + i] = 1.0;
537 }
538 let i_kh = mat_sub(&eye, &kh, n);
539 self.cov = mat_mat_mul(&i_kh, &self.cov, n);
540 }
541
542 pub fn position(&self) -> (f64, f64) {
544 (self.state[0], self.state[1])
545 }
546
547 pub fn velocity(&self) -> (f64, f64) {
549 if self.dim >= 4 {
550 (self.state[2], self.state[3])
551 } else {
552 (0.0, 0.0)
553 }
554 }
555
556 pub fn predicted_position(&self) -> (f64, f64) {
558 let f = self.transition_matrix();
559 let pred = mat_vec_mul(&f, &self.state, self.dim);
560 (pred[0], pred[1])
561 }
562
563 fn transition_matrix(&self) -> Vec<f64> {
566 let n = self.dim;
567 let mut f = vec![0.0; n * n];
568 for i in 0..n {
569 f[i * n + i] = 1.0;
570 }
571 match self.model {
572 KalmanModel::ConstantVelocity => {
573 f[2] = 1.0;
575 f[n + 3] = 1.0;
576 }
577 KalmanModel::ConstantAcceleration => {
578 f[2] = 1.0;
580 f[n + 3] = 1.0;
581 f[2 * n + 4] = 1.0;
582 f[3 * n + 5] = 1.0;
583 f[4] = 0.5;
585 f[n + 5] = 0.5;
586 }
587 }
588 f
589 }
590
591 fn observation_matrix(&self) -> Vec<f64> {
592 let n = self.dim;
593 let m = 2;
594 let mut h = vec![0.0; m * n];
595 h[0] = 1.0; h[n + 1] = 1.0; h
598 }
599
600 fn process_noise_matrix(&self) -> Vec<f64> {
601 let n = self.dim;
602 let mut q = vec![0.0; n * n];
603 for i in 0..n {
604 q[i * n + i] = self.process_noise;
605 }
606 q
607 }
608
609 fn measurement_noise_matrix(&self) -> Vec<f64> {
610 vec![self.measurement_noise, 0.0, 0.0, self.measurement_noise]
611 }
612}
613
614fn mat_vec_mul(mat: &[f64], vec_in: &[f64], n: usize) -> Vec<f64> {
619 let mut out = vec![0.0; n];
620 for i in 0..n {
621 for j in 0..n {
622 out[i] += mat[i * n + j] * vec_in[j];
623 }
624 }
625 out
626}
627
628fn mat_vec_mul_rect(mat: &[f64], vec_in: &[f64], m: usize, n: usize) -> Vec<f64> {
629 let mut out = vec![0.0; m];
630 for i in 0..m {
631 for j in 0..n {
632 out[i] += mat[i * n + j] * vec_in[j];
633 }
634 }
635 out
636}
637
638fn mat_mat_mul(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
639 let mut c = vec![0.0; n * n];
640 for i in 0..n {
641 for k in 0..n {
642 let a_ik = a[i * n + k];
643 for j in 0..n {
644 c[i * n + j] += a_ik * b[k * n + j];
645 }
646 }
647 }
648 c
649}
650
651fn mat_mat_mul_rect(a: &[f64], b: &[f64], m: usize, k: usize, n: usize) -> Vec<f64> {
652 let mut c = vec![0.0; m * n];
653 for i in 0..m {
654 for kk in 0..k {
655 let a_ik = a[i * k + kk];
656 for j in 0..n {
657 c[i * n + j] += a_ik * b[kk * n + j];
658 }
659 }
660 }
661 c
662}
663
664fn transpose(a: &[f64], n: usize) -> Vec<f64> {
665 let mut t = vec![0.0; n * n];
666 for i in 0..n {
667 for j in 0..n {
668 t[j * n + i] = a[i * n + j];
669 }
670 }
671 t
672}
673
674fn transpose_rect(a: &[f64], m: usize, n: usize) -> Vec<f64> {
675 let mut t = vec![0.0; n * m];
676 for i in 0..m {
677 for j in 0..n {
678 t[j * m + i] = a[i * n + j];
679 }
680 }
681 t
682}
683
684fn mat_add(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
685 let mut c = vec![0.0; n * n];
686 for i in 0..(n * n) {
687 c[i] = a[i] + b[i];
688 }
689 c
690}
691
692fn mat_add_small(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
693 let len = n * n;
694 let mut c = vec![0.0; len];
695 for i in 0..len {
696 c[i] = a[i] + b[i];
697 }
698 c
699}
700
701fn mat_sub(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
702 let mut c = vec![0.0; n * n];
703 for i in 0..(n * n) {
704 c[i] = a[i] - b[i];
705 }
706 c
707}
708
709fn invert_2x2(m: &[f64]) -> Vec<f64> {
710 let det = m[0] * m[3] - m[1] * m[2];
711 if det.abs() < 1e-30 {
712 return vec![1.0, 0.0, 0.0, 1.0];
714 }
715 let inv_det = 1.0 / det;
716 vec![
717 m[3] * inv_det,
718 -m[1] * inv_det,
719 -m[2] * inv_det,
720 m[0] * inv_det,
721 ]
722}
723
724#[derive(Debug, Clone, Copy, PartialEq, Eq)]
730pub enum TrackStatus {
731 Tentative,
733 Confirmed,
735 Lost,
737}
738
739#[derive(Debug, Clone)]
741pub struct TrackedObject {
742 pub id: u64,
744 pub bbox: BBox,
746 pub kalman: KalmanTracker,
748 pub hits: usize,
750 pub misses: usize,
752 pub age: usize,
754 pub status: TrackStatus,
756}
757
758#[derive(Debug, Clone)]
760pub struct MultiTrackerConfig {
761 pub iou_threshold: f64,
763 pub min_hits_to_confirm: usize,
765 pub max_misses: usize,
767 pub kalman_model: KalmanModel,
769 pub process_noise: f64,
771 pub measurement_noise: f64,
773}
774
775impl Default for MultiTrackerConfig {
776 fn default() -> Self {
777 Self {
778 iou_threshold: 0.3,
779 min_hits_to_confirm: 3,
780 max_misses: 5,
781 kalman_model: KalmanModel::ConstantVelocity,
782 process_noise: 1.0,
783 measurement_noise: 1.0,
784 }
785 }
786}
787
788#[derive(Debug, Clone)]
790pub struct MultiObjectTracker {
791 tracks: Vec<TrackedObject>,
793 next_id: u64,
795 config: MultiTrackerConfig,
797}
798
799impl MultiObjectTracker {
800 pub fn new(config: MultiTrackerConfig) -> Self {
802 Self {
803 tracks: Vec::new(),
804 next_id: 1,
805 config,
806 }
807 }
808
809 pub fn update(&mut self, detections: &[BBox]) -> Vec<TrackedObject> {
813 for track in self.tracks.iter_mut() {
815 track.kalman.predict();
816 let (px, py) = track.kalman.position();
817 track.bbox.left = px - track.bbox.width / 2.0;
818 track.bbox.top = py - track.bbox.height / 2.0;
819 }
820
821 let n_tracks = self.tracks.len();
823 let n_dets = detections.len();
824
825 let (matched, unmatched_tracks, unmatched_dets) = if n_tracks > 0 && n_dets > 0 {
826 let mut cost = vec![vec![0.0; n_dets]; n_tracks];
827 for (i, cost_row) in cost.iter_mut().enumerate().take(n_tracks) {
828 for (j, cost_val) in cost_row.iter_mut().enumerate().take(n_dets) {
829 *cost_val = 1.0 - self.tracks[i].bbox.iou(&detections[j]);
830 }
831 }
832 hungarian_assignment(&cost, self.config.iou_threshold)
833 } else {
834 (Vec::new(), (0..n_tracks).collect(), (0..n_dets).collect())
835 };
836
837 for &(ti, di) in &matched {
839 let det = &detections[di];
840 let track = &mut self.tracks[ti];
841 let cx = det.left + det.width / 2.0;
842 let cy = det.top + det.height / 2.0;
843 track.kalman.update(cx, cy);
844 track.bbox = *det;
845 track.hits += 1;
846 track.misses = 0;
847 track.age += 1;
848 if track.hits >= self.config.min_hits_to_confirm {
849 track.status = TrackStatus::Confirmed;
850 }
851 }
852
853 for &ti in &unmatched_tracks {
855 self.tracks[ti].misses += 1;
856 self.tracks[ti].age += 1;
857 if self.tracks[ti].misses > self.config.max_misses {
858 self.tracks[ti].status = TrackStatus::Lost;
859 }
860 }
861
862 for &di in &unmatched_dets {
864 let det = &detections[di];
865 let cx = det.left + det.width / 2.0;
866 let cy = det.top + det.height / 2.0;
867 let mut kf = KalmanTracker::new(cx, cy, self.config.kalman_model);
868 kf.set_noise(self.config.process_noise, self.config.measurement_noise);
869 let initial_status = if 1 >= self.config.min_hits_to_confirm {
870 TrackStatus::Confirmed
871 } else {
872 TrackStatus::Tentative
873 };
874 let track = TrackedObject {
875 id: self.next_id,
876 bbox: *det,
877 kalman: kf,
878 hits: 1,
879 misses: 0,
880 age: 1,
881 status: initial_status,
882 };
883 self.tracks.push(track);
884 self.next_id += 1;
885 }
886
887 self.tracks.retain(|t| t.status != TrackStatus::Lost);
889
890 self.tracks
892 .iter()
893 .filter(|t| t.status == TrackStatus::Confirmed)
894 .cloned()
895 .collect()
896 }
897
898 pub fn all_tracks(&self) -> &[TrackedObject] {
900 &self.tracks
901 }
902
903 pub fn num_tracks(&self) -> usize {
905 self.tracks.len()
906 }
907}
908
909fn hungarian_assignment(
918 cost: &[Vec<f64>],
919 iou_threshold: f64,
920) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
921 let n_tracks = cost.len();
922 let n_dets = if n_tracks > 0 { cost[0].len() } else { 0 };
923 let cost_threshold = 1.0 - iou_threshold;
924
925 let assignments = munkres_assign(cost, n_tracks, n_dets);
927
928 let mut matched = Vec::new();
929 let mut matched_tracks = vec![false; n_tracks];
930 let mut matched_dets = vec![false; n_dets];
931
932 for (ti, di) in assignments {
933 if cost[ti][di] <= cost_threshold {
934 matched.push((ti, di));
935 matched_tracks[ti] = true;
936 matched_dets[di] = true;
937 }
938 }
939
940 let unmatched_tracks: Vec<usize> = (0..n_tracks).filter(|&i| !matched_tracks[i]).collect();
941 let unmatched_dets: Vec<usize> = (0..n_dets).filter(|&i| !matched_dets[i]).collect();
942
943 (matched, unmatched_tracks, unmatched_dets)
944}
945
946fn munkres_assign(cost: &[Vec<f64>], n: usize, m: usize) -> Vec<(usize, usize)> {
951 if n == 0 || m == 0 {
952 return Vec::new();
953 }
954
955 let mut used_rows = vec![false; n];
958 let mut used_cols = vec![false; m];
959 let mut assignments = Vec::new();
960
961 let mut entries: Vec<(f64, usize, usize)> = Vec::with_capacity(n * m);
963 for (i, cost_row) in cost.iter().enumerate().take(n) {
964 for (j, &cost_val) in cost_row.iter().enumerate().take(m) {
965 entries.push((cost_val, i, j));
966 }
967 }
968 entries.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
969
970 for (_, row, col) in entries {
971 if !used_rows[row] && !used_cols[col] {
972 assignments.push((row, col));
973 used_rows[row] = true;
974 used_cols[col] = true;
975 }
976 }
977
978 assignments
979}
980
981#[cfg(test)]
986mod tests {
987 use super::*;
988 use scirs2_core::ndarray::Array2;
989
990 fn uniform_frame(val: f64, h: usize, w: usize) -> Array2<f64> {
991 Array2::from_elem((h, w), val)
992 }
993
994 fn frame_with_bright_region(
995 bg: f64,
996 fg: f64,
997 h: usize,
998 w: usize,
999 top: usize,
1000 left: usize,
1001 rh: usize,
1002 rw: usize,
1003 ) -> Array2<f64> {
1004 let mut f = Array2::from_elem((h, w), bg);
1005 for r in top..(top + rh).min(h) {
1006 for c in left..(left + rw).min(w) {
1007 f[[r, c]] = fg;
1008 }
1009 }
1010 f
1011 }
1012
1013 #[test]
1016 fn test_bbox_basics() {
1017 let b = BBox::new(10.0, 20.0, 30.0, 40.0);
1018 assert!((b.center_row() - 25.0).abs() < 1e-9);
1019 assert!((b.center_col() - 40.0).abs() < 1e-9);
1020 assert!((b.area() - 1200.0).abs() < 1e-9);
1021 }
1022
1023 #[test]
1024 fn test_bbox_iou_identical() {
1025 let b = BBox::new(0.0, 0.0, 10.0, 10.0);
1026 assert!((b.iou(&b) - 1.0).abs() < 1e-9);
1027 }
1028
1029 #[test]
1030 fn test_bbox_iou_no_overlap() {
1031 let a = BBox::new(0.0, 0.0, 10.0, 10.0);
1032 let b = BBox::new(20.0, 20.0, 10.0, 10.0);
1033 assert!(a.iou(&b).abs() < 1e-9);
1034 }
1035
1036 #[test]
1037 fn test_bbox_center_distance() {
1038 let a = BBox::new(0.0, 0.0, 10.0, 10.0);
1039 let b = BBox::new(3.0, 4.0, 10.0, 10.0);
1040 let d = a.center_distance(&b);
1041 assert!((d - 5.0).abs() < 1e-9);
1042 }
1043
1044 #[test]
1047 fn test_mean_shift_static_target() {
1048 let frame = frame_with_bright_region(0.0, 1.0, 32, 32, 10, 10, 10, 10);
1049 let window = BBox::new(10.0, 10.0, 10.0, 10.0);
1050 let mut tracker =
1051 MeanShiftTracker::new(&frame, window, MeanShiftConfig::default()).expect("ok");
1052 let result = tracker.update(&frame).expect("ok");
1053 assert!(
1055 (result.center_row() - 15.0).abs() < 5.0,
1056 "centre row should be near 15"
1057 );
1058 assert!(
1059 (result.center_col() - 15.0).abs() < 5.0,
1060 "centre col should be near 15"
1061 );
1062 }
1063
1064 #[test]
1065 fn test_mean_shift_moving_target() {
1066 let h = 32;
1067 let w = 32;
1068 let init = frame_with_bright_region(0.0, 1.0, h, w, 5, 5, 8, 8);
1069 let window = BBox::new(5.0, 5.0, 8.0, 8.0);
1070 let mut tracker =
1071 MeanShiftTracker::new(&init, window, MeanShiftConfig::default()).expect("ok");
1072
1073 let moved = frame_with_bright_region(0.0, 1.0, h, w, 8, 8, 8, 8);
1075 let result = tracker.update(&moved).expect("ok");
1076 assert!(
1077 result.center_row() > 7.0,
1078 "Should track toward new position"
1079 );
1080 }
1081
1082 #[test]
1083 fn test_mean_shift_invalid_bins() {
1084 let frame = uniform_frame(0.5, 16, 16);
1085 let window = BBox::new(0.0, 0.0, 8.0, 8.0);
1086 let config = MeanShiftConfig {
1087 num_bins: 0,
1088 ..Default::default()
1089 };
1090 assert!(MeanShiftTracker::new(&frame, window, config).is_err());
1091 }
1092
1093 #[test]
1096 fn test_camshift_static_target() {
1097 let frame = frame_with_bright_region(0.0, 1.0, 32, 32, 10, 10, 10, 10);
1098 let window = BBox::new(10.0, 10.0, 10.0, 10.0);
1099 let mut tracker =
1100 CamShiftTracker::new(&frame, window, MeanShiftConfig::default()).expect("ok");
1101 let (result, angle) = tracker.update(&frame).expect("ok");
1102 assert!(result.area() > 0.0);
1103 assert!(angle.is_finite());
1104 }
1105
1106 #[test]
1107 fn test_camshift_adapts_size() {
1108 let h = 64;
1109 let w = 64;
1110 let frame = frame_with_bright_region(0.0, 1.0, h, w, 10, 10, 20, 20);
1111 let small_window = BBox::new(12.0, 12.0, 5.0, 5.0);
1112 let mut tracker =
1113 CamShiftTracker::new(&frame, small_window, MeanShiftConfig::default()).expect("ok");
1114 let (result, _) = tracker.update(&frame).expect("ok");
1115 assert!(result.area() > small_window.area() * 0.5);
1117 }
1118
1119 #[test]
1122 fn test_kalman_constant_velocity() {
1123 let mut kf = KalmanTracker::new(0.0, 0.0, KalmanModel::ConstantVelocity);
1124 kf.set_noise(0.1, 1.0);
1125 for t in 1..=10 {
1127 let mx = t as f64 * 2.0;
1128 let my = t as f64 * 1.0;
1129 kf.predict();
1130 kf.update(mx, my);
1131 }
1132 let (x, y) = kf.position();
1133 assert!((x - 20.0).abs() < 2.0, "x should be near 20, got {x}");
1134 assert!((y - 10.0).abs() < 2.0, "y should be near 10, got {y}");
1135 }
1136
1137 #[test]
1138 fn test_kalman_constant_acceleration() {
1139 let mut kf = KalmanTracker::new(0.0, 0.0, KalmanModel::ConstantAcceleration);
1140 kf.set_noise(0.1, 1.0);
1141 for t in 1..=10 {
1142 let mx = 0.5 * (t as f64) * (t as f64); let my = 0.0;
1144 kf.predict();
1145 kf.update(mx, my);
1146 }
1147 let (x, _y) = kf.position();
1148 assert!(
1149 (x - 50.0).abs() < 10.0,
1150 "x should be near 50 (0.5*10^2), got {x}"
1151 );
1152 }
1153
1154 #[test]
1155 fn test_kalman_prediction() {
1156 let mut kf = KalmanTracker::new(10.0, 5.0, KalmanModel::ConstantVelocity);
1157 kf.set_noise(0.01, 0.1);
1158 for t in 1..=5 {
1160 kf.predict();
1161 kf.update(10.0 + t as f64, 5.0 + t as f64 * 0.5);
1162 }
1163 let (px, py) = kf.predicted_position();
1164 assert!(px > 15.0, "predicted x should be >15, got {px}");
1165 assert!(py > 7.0, "predicted y should be >7, got {py}");
1166 }
1167
1168 #[test]
1171 fn test_mot_single_object() {
1172 let mut tracker = MultiObjectTracker::new(MultiTrackerConfig {
1173 min_hits_to_confirm: 2,
1174 ..Default::default()
1175 });
1176 let det = BBox::new(10.0, 20.0, 30.0, 40.0);
1177 let confirmed = tracker.update(&[det]);
1179 assert!(confirmed.is_empty(), "Should be tentative after 1 frame");
1180 let confirmed = tracker.update(&[det]);
1182 assert_eq!(confirmed.len(), 1, "Should be confirmed after 2 frames");
1183 assert_eq!(confirmed[0].id, 1);
1184 }
1185
1186 #[test]
1187 fn test_mot_two_objects() {
1188 let mut tracker = MultiObjectTracker::new(MultiTrackerConfig {
1189 min_hits_to_confirm: 1,
1190 ..Default::default()
1191 });
1192 let d1 = BBox::new(10.0, 10.0, 20.0, 20.0);
1193 let d2 = BBox::new(100.0, 100.0, 20.0, 20.0);
1194 let confirmed = tracker.update(&[d1, d2]);
1195 assert_eq!(confirmed.len(), 2);
1196 assert_ne!(confirmed[0].id, confirmed[1].id);
1197 }
1198
1199 #[test]
1200 fn test_mot_track_deletion() {
1201 let mut tracker = MultiObjectTracker::new(MultiTrackerConfig {
1202 min_hits_to_confirm: 1,
1203 max_misses: 2,
1204 ..Default::default()
1205 });
1206 let det = BBox::new(10.0, 10.0, 20.0, 20.0);
1207 tracker.update(&[det]);
1208 assert_eq!(tracker.num_tracks(), 1);
1209 tracker.update(&[]);
1211 tracker.update(&[]);
1212 tracker.update(&[]);
1213 assert_eq!(
1214 tracker.num_tracks(),
1215 0,
1216 "Track should be deleted after max_misses"
1217 );
1218 }
1219
1220 #[test]
1221 fn test_mot_id_persistence() {
1222 let mut tracker = MultiObjectTracker::new(MultiTrackerConfig {
1223 min_hits_to_confirm: 1,
1224 iou_threshold: 0.1,
1225 ..Default::default()
1226 });
1227 let det = BBox::new(10.0, 10.0, 20.0, 20.0);
1228 let first = tracker.update(&[det]);
1229 let id = first[0].id;
1230 let det_moved = BBox::new(12.0, 12.0, 20.0, 20.0);
1232 let second = tracker.update(&[det_moved]);
1233 assert_eq!(second.len(), 1);
1234 assert_eq!(second[0].id, id, "ID should be preserved");
1235 }
1236
1237 #[test]
1240 fn test_hungarian_simple() {
1241 let cost = vec![vec![0.1, 0.9], vec![0.9, 0.2]];
1242 let (matched, unmatched_t, unmatched_d) = hungarian_assignment(&cost, 0.0);
1243 assert_eq!(matched.len(), 2);
1244 assert!(unmatched_t.is_empty());
1245 assert!(unmatched_d.is_empty());
1246 }
1247
1248 #[test]
1249 fn test_hungarian_threshold_filtering() {
1250 let cost = vec![vec![0.8, 0.9]]; let (matched, _, _) = hungarian_assignment(&cost, 0.5);
1252 assert!(matched.is_empty(), "High cost should be filtered out");
1253 }
1254}