Skip to main content

scirs2_vision/video/
tracking.rs

1//! Object tracking algorithms for video processing.
2//!
3//! Provides classical tracking approaches that work on grayscale
4//! `Array2<f64>` frames (pixel values in `[0, 1]`).
5//!
6//! # Algorithms
7//!
8//! - **Mean Shift** -- iterative mode-seeking on a colour/intensity histogram
9//! - **CamShift** -- Continuously Adaptive Mean Shift with automatic window sizing
10//! - **Kalman filter tracker** -- constant-velocity and constant-acceleration models
11//! - **Multi-object tracking** -- Hungarian (Munkres) assignment + track management
12//!
13//! # Track Management
14//!
15//! `MultiObjectTracker` handles creation, maintenance, and deletion of tracks
16//! with configurable hit/miss thresholds and unique ID assignment.
17
18use crate::error::{Result, VisionError};
19use scirs2_core::ndarray::Array2;
20
21// ---------------------------------------------------------------------------
22// Bounding box (internal, lightweight)
23// ---------------------------------------------------------------------------
24
25/// Axis-aligned bounding box used by tracking algorithms.
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub struct BBox {
28    /// Top-left row.
29    pub top: f64,
30    /// Top-left column.
31    pub left: f64,
32    /// Height.
33    pub height: f64,
34    /// Width.
35    pub width: f64,
36}
37
38impl BBox {
39    /// Create a new bounding box.
40    pub fn new(top: f64, left: f64, height: f64, width: f64) -> Self {
41        Self {
42            top,
43            left,
44            height,
45            width,
46        }
47    }
48
49    /// Centre row.
50    pub fn center_row(&self) -> f64 {
51        self.top + self.height / 2.0
52    }
53
54    /// Centre column.
55    pub fn center_col(&self) -> f64 {
56        self.left + self.width / 2.0
57    }
58
59    /// Area.
60    pub fn area(&self) -> f64 {
61        self.height * self.width
62    }
63
64    /// IoU with another box.
65    pub fn iou(&self, other: &BBox) -> f64 {
66        let r1 = self.top;
67        let r2 = self.top + self.height;
68        let c1 = self.left;
69        let c2 = self.left + self.width;
70
71        let or1 = other.top;
72        let or2 = other.top + other.height;
73        let oc1 = other.left;
74        let oc2 = other.left + other.width;
75
76        let inter_r1 = r1.max(or1);
77        let inter_r2 = r2.min(or2);
78        let inter_c1 = c1.max(oc1);
79        let inter_c2 = c2.min(oc2);
80
81        if inter_r2 <= inter_r1 || inter_c2 <= inter_c1 {
82            return 0.0;
83        }
84        let inter_area = (inter_r2 - inter_r1) * (inter_c2 - inter_c1);
85        let union_area = self.area() + other.area() - inter_area;
86        if union_area > 0.0 {
87            inter_area / union_area
88        } else {
89            0.0
90        }
91    }
92
93    /// Euclidean distance between centres.
94    pub fn center_distance(&self, other: &BBox) -> f64 {
95        let dr = self.center_row() - other.center_row();
96        let dc = self.center_col() - other.center_col();
97        (dr * dr + dc * dc).sqrt()
98    }
99}
100
101// ---------------------------------------------------------------------------
102// Mean Shift Tracking
103// ---------------------------------------------------------------------------
104
105/// Mean Shift tracker configuration.
106#[derive(Debug, Clone)]
107pub struct MeanShiftConfig {
108    /// Maximum number of iterations per update.
109    pub max_iterations: usize,
110    /// Convergence threshold (centre movement in pixels).
111    pub epsilon: f64,
112    /// Number of histogram bins for the target model.
113    pub num_bins: usize,
114}
115
116impl Default for MeanShiftConfig {
117    fn default() -> Self {
118        Self {
119            max_iterations: 30,
120            epsilon: 1.0,
121            num_bins: 16,
122        }
123    }
124}
125
126/// Mean Shift tracker.
127///
128/// Tracks a rectangular region by iteratively shifting its centre towards the
129/// mode of a back-projection map derived from a target histogram.
130#[derive(Debug, Clone)]
131pub struct MeanShiftTracker {
132    /// Current tracking window.
133    window: BBox,
134    /// Target histogram (normalised).
135    target_hist: Vec<f64>,
136    /// Configuration.
137    config: MeanShiftConfig,
138}
139
140impl MeanShiftTracker {
141    /// Initialise the tracker on the first frame with a given window.
142    pub fn new(frame: &Array2<f64>, window: BBox, config: MeanShiftConfig) -> Result<Self> {
143        if config.num_bins == 0 {
144            return Err(VisionError::InvalidParameter("num_bins must be > 0".into()));
145        }
146        let hist = compute_histogram(frame, &window, config.num_bins);
147        Ok(Self {
148            window,
149            target_hist: hist,
150            config,
151        })
152    }
153
154    /// Update the tracker with a new frame.  Returns the updated bounding box.
155    pub fn update(&mut self, frame: &Array2<f64>) -> Result<BBox> {
156        let rows = frame.nrows();
157        let cols = frame.ncols();
158        let bins = self.config.num_bins;
159
160        for _ in 0..self.config.max_iterations {
161            // Compute back-projection weights inside current window.
162            let (mean_r, mean_c, _total_w) =
163                compute_mean_shift_center(frame, &self.window, &self.target_hist, bins, rows, cols);
164
165            let old_cr = self.window.center_row();
166            let old_cc = self.window.center_col();
167
168            let new_top = (mean_r - self.window.height / 2.0).max(0.0);
169            let new_left = (mean_c - self.window.width / 2.0).max(0.0);
170
171            self.window.top = new_top.min((rows as f64) - self.window.height);
172            self.window.left = new_left.min((cols as f64) - self.window.width);
173
174            let dr = self.window.center_row() - old_cr;
175            let dc = self.window.center_col() - old_cc;
176            if (dr * dr + dc * dc).sqrt() < self.config.epsilon {
177                break;
178            }
179        }
180
181        Ok(self.window)
182    }
183
184    /// Current window.
185    pub fn window(&self) -> BBox {
186        self.window
187    }
188}
189
190// ---------------------------------------------------------------------------
191// CamShift Tracking
192// ---------------------------------------------------------------------------
193
194/// CamShift (Continuously Adaptive Mean Shift) tracker.
195///
196/// Extends Mean Shift by adapting the window size and orientation based on the
197/// zeroth and second moments of the back-projection.
198#[derive(Debug, Clone)]
199pub struct CamShiftTracker {
200    /// Current tracking window.
201    window: BBox,
202    /// Target histogram.
203    target_hist: Vec<f64>,
204    /// Configuration.
205    config: MeanShiftConfig,
206    /// Estimated orientation angle (radians).
207    angle: f64,
208}
209
210impl CamShiftTracker {
211    /// Initialise on first frame.
212    pub fn new(frame: &Array2<f64>, window: BBox, config: MeanShiftConfig) -> Result<Self> {
213        if config.num_bins == 0 {
214            return Err(VisionError::InvalidParameter("num_bins must be > 0".into()));
215        }
216        let hist = compute_histogram(frame, &window, config.num_bins);
217        Ok(Self {
218            window,
219            target_hist: hist,
220            config,
221            angle: 0.0,
222        })
223    }
224
225    /// Update with a new frame.  Returns the updated bounding box and the
226    /// estimated orientation angle in radians.
227    pub fn update(&mut self, frame: &Array2<f64>) -> Result<(BBox, f64)> {
228        let rows = frame.nrows();
229        let cols = frame.ncols();
230        let bins = self.config.num_bins;
231
232        // Mean-shift iterations.
233        for _ in 0..self.config.max_iterations {
234            let (mean_r, mean_c, _) =
235                compute_mean_shift_center(frame, &self.window, &self.target_hist, bins, rows, cols);
236
237            let old_cr = self.window.center_row();
238            let old_cc = self.window.center_col();
239
240            self.window.top = (mean_r - self.window.height / 2.0)
241                .max(0.0)
242                .min((rows as f64) - self.window.height);
243            self.window.left = (mean_c - self.window.width / 2.0)
244                .max(0.0)
245                .min((cols as f64) - self.window.width);
246
247            let dr = self.window.center_row() - old_cr;
248            let dc = self.window.center_col() - old_cc;
249            if (dr * dr + dc * dc).sqrt() < self.config.epsilon {
250                break;
251            }
252        }
253
254        // Compute moments and adapt window size.
255        let (m00, m10, m01, m20, m02, m11) =
256            compute_moments(frame, &self.window, &self.target_hist, bins, rows, cols);
257
258        if m00 > 1e-9 {
259            let xc = m10 / m00;
260            let yc = m01 / m00;
261
262            // Second central moments.
263            let mu20 = m20 / m00 - xc * xc;
264            let mu02 = m02 / m00 - yc * yc;
265            let mu11 = m11 / m00 - xc * yc;
266
267            // Orientation.
268            self.angle = 0.5 * (2.0 * mu11).atan2(mu20 - mu02);
269
270            // Adapt window size based on zeroth moment.
271            let s = (m00 / 256.0).sqrt().max(2.0);
272            self.window.width = s * 2.0;
273            self.window.height = s * 2.0;
274
275            // Re-centre.
276            self.window.top = (yc - self.window.height / 2.0)
277                .max(0.0)
278                .min((rows as f64) - self.window.height);
279            self.window.left = (xc - self.window.width / 2.0)
280                .max(0.0)
281                .min((cols as f64) - self.window.width);
282        }
283
284        Ok((self.window, self.angle))
285    }
286
287    /// Current window.
288    pub fn window(&self) -> BBox {
289        self.window
290    }
291
292    /// Current orientation angle in radians.
293    pub fn angle(&self) -> f64 {
294        self.angle
295    }
296}
297
298// ---------------------------------------------------------------------------
299// Histogram / mean-shift helpers
300// ---------------------------------------------------------------------------
301
302fn compute_histogram(frame: &Array2<f64>, bbox: &BBox, bins: usize) -> Vec<f64> {
303    let mut hist = vec![0.0; bins];
304    let rows = frame.nrows();
305    let cols = frame.ncols();
306    let r_start = (bbox.top as usize).min(rows);
307    let r_end = ((bbox.top + bbox.height) as usize).min(rows);
308    let c_start = (bbox.left as usize).min(cols);
309    let c_end = ((bbox.left + bbox.width) as usize).min(cols);
310    let mut total = 0.0;
311    for r in r_start..r_end {
312        for c in c_start..c_end {
313            let val = frame[[r, c]].clamp(0.0, 1.0);
314            let bin = ((val * (bins as f64 - 1.0)).round() as usize).min(bins - 1);
315            hist[bin] += 1.0;
316            total += 1.0;
317        }
318    }
319    if total > 0.0 {
320        for h in hist.iter_mut() {
321            *h /= total;
322        }
323    }
324    hist
325}
326
327fn compute_mean_shift_center(
328    frame: &Array2<f64>,
329    bbox: &BBox,
330    target_hist: &[f64],
331    bins: usize,
332    rows: usize,
333    cols: usize,
334) -> (f64, f64, f64) {
335    // Compute candidate histogram.
336    let candidate_hist = compute_histogram(frame, bbox, bins);
337
338    // Back-projection weights.
339    let weights: Vec<f64> = (0..bins)
340        .map(|i| {
341            if candidate_hist[i] > 1e-12 {
342                (target_hist[i] / candidate_hist[i]).sqrt().min(10.0)
343            } else {
344                0.0
345            }
346        })
347        .collect();
348
349    let r_start = (bbox.top as usize).min(rows);
350    let r_end = ((bbox.top + bbox.height) as usize).min(rows);
351    let c_start = (bbox.left as usize).min(cols);
352    let c_end = ((bbox.left + bbox.width) as usize).min(cols);
353
354    let mut sum_r = 0.0;
355    let mut sum_c = 0.0;
356    let mut sum_w = 0.0;
357
358    for r in r_start..r_end {
359        for c in c_start..c_end {
360            let val = frame[[r, c]].clamp(0.0, 1.0);
361            let bin = ((val * (bins as f64 - 1.0)).round() as usize).min(bins - 1);
362            let w = weights[bin];
363            sum_r += r as f64 * w;
364            sum_c += c as f64 * w;
365            sum_w += w;
366        }
367    }
368
369    if sum_w > 0.0 {
370        (sum_r / sum_w, sum_c / sum_w, sum_w)
371    } else {
372        (bbox.center_row(), bbox.center_col(), 0.0)
373    }
374}
375
376fn compute_moments(
377    frame: &Array2<f64>,
378    bbox: &BBox,
379    target_hist: &[f64],
380    bins: usize,
381    rows: usize,
382    cols: usize,
383) -> (f64, f64, f64, f64, f64, f64) {
384    let candidate_hist = compute_histogram(frame, bbox, bins);
385    let weights: Vec<f64> = (0..bins)
386        .map(|i| {
387            if candidate_hist[i] > 1e-12 {
388                (target_hist[i] / candidate_hist[i]).sqrt().min(10.0)
389            } else {
390                0.0
391            }
392        })
393        .collect();
394
395    let r_start = (bbox.top as usize).min(rows);
396    let r_end = ((bbox.top + bbox.height) as usize).min(rows);
397    let c_start = (bbox.left as usize).min(cols);
398    let c_end = ((bbox.left + bbox.width) as usize).min(cols);
399
400    let mut m00 = 0.0;
401    let mut m10 = 0.0; // sum x*w
402    let mut m01 = 0.0; // sum y*w
403    let mut m20 = 0.0;
404    let mut m02 = 0.0;
405    let mut m11 = 0.0;
406
407    for r in r_start..r_end {
408        for c in c_start..c_end {
409            let val = frame[[r, c]].clamp(0.0, 1.0);
410            let bin = ((val * (bins as f64 - 1.0)).round() as usize).min(bins - 1);
411            let w = weights[bin];
412            let x = c as f64;
413            let y = r as f64;
414            m00 += w;
415            m10 += x * w;
416            m01 += y * w;
417            m20 += x * x * w;
418            m02 += y * y * w;
419            m11 += x * y * w;
420        }
421    }
422
423    (m00, m10, m01, m20, m02, m11)
424}
425
426// ---------------------------------------------------------------------------
427// Kalman Filter Tracker
428// ---------------------------------------------------------------------------
429
430/// Kalman filter motion model.
431#[derive(Debug, Clone, Copy, PartialEq, Eq)]
432pub enum KalmanModel {
433    /// Constant velocity: state = [x, y, vx, vy].
434    ConstantVelocity,
435    /// Constant acceleration: state = [x, y, vx, vy, ax, ay].
436    ConstantAcceleration,
437}
438
439/// A simple Kalman filter for 2-D point tracking.
440#[derive(Debug, Clone)]
441pub struct KalmanTracker {
442    /// State vector.
443    state: Vec<f64>,
444    /// Covariance matrix (flattened row-major).
445    cov: Vec<f64>,
446    /// State dimension.
447    dim: usize,
448    /// Process noise scale.
449    process_noise: f64,
450    /// Measurement noise scale.
451    measurement_noise: f64,
452    /// Motion model.
453    model: KalmanModel,
454}
455
456impl KalmanTracker {
457    /// Create a new Kalman tracker initialised at position `(x, y)`.
458    pub fn new(x: f64, y: f64, model: KalmanModel) -> Self {
459        let dim = match model {
460            KalmanModel::ConstantVelocity => 4,
461            KalmanModel::ConstantAcceleration => 6,
462        };
463        let mut state = vec![0.0; dim];
464        state[0] = x;
465        state[1] = y;
466        // Large initial covariance.
467        let mut cov = vec![0.0; dim * dim];
468        for i in 0..dim {
469            cov[i * dim + i] = 1000.0;
470        }
471        Self {
472            state,
473            cov,
474            dim,
475            process_noise: 1.0,
476            measurement_noise: 1.0,
477            model,
478        }
479    }
480
481    /// Set process and measurement noise scales.
482    pub fn set_noise(&mut self, process: f64, measurement: f64) {
483        self.process_noise = process;
484        self.measurement_noise = measurement;
485    }
486
487    /// Predict step -- advance the state by one time step.
488    pub fn predict(&mut self) {
489        let n = self.dim;
490        // State transition.
491        let f = self.transition_matrix();
492        let new_state = mat_vec_mul(&f, &self.state, n);
493        self.state = new_state;
494
495        // P = F P F^T + Q
496        let fp = mat_mat_mul(&f, &self.cov, n);
497        let ft = transpose(&f, n);
498        let fp_ft = mat_mat_mul(&fp, &ft, n);
499        let q = self.process_noise_matrix();
500        self.cov = mat_add(&fp_ft, &q, n);
501    }
502
503    /// Update (correct) step with a measurement `(mx, my)`.
504    pub fn update(&mut self, mx: f64, my: f64) {
505        let n = self.dim;
506        let h = self.observation_matrix();
507        let m = 2; // measurement dimension
508
509        // Innovation y = z - H x
510        let hx = mat_vec_mul_rect(&h, &self.state, m, n);
511        let z = [mx, my];
512        let innovation = vec![z[0] - hx[0], z[1] - hx[1]];
513
514        // S = H P H^T + R
515        let hp = mat_mat_mul_rect(&h, &self.cov, m, n, n);
516        let ht = transpose_rect(&h, m, n);
517        let hp_ht = mat_mat_mul_rect(&hp, &ht, m, n, m);
518        let r = self.measurement_noise_matrix();
519        let s = mat_add_small(&hp_ht, &r, m);
520
521        // K = P H^T S^{-1}
522        let p_ht = mat_mat_mul_rect(&self.cov, &ht, n, n, m);
523        let s_inv = invert_2x2(&s);
524        let k = mat_mat_mul_rect(&p_ht, &s_inv, n, m, m);
525
526        // x = x + K y
527        let ky = mat_vec_mul_rect(&k, &innovation, n, m);
528        for (i, &ky_i) in ky.iter().enumerate().take(n) {
529            self.state[i] += ky_i;
530        }
531
532        // P = (I - K H) P
533        let kh = mat_mat_mul_rect(&k, &h, n, m, n);
534        let mut eye = vec![0.0; n * n];
535        for i in 0..n {
536            eye[i * n + i] = 1.0;
537        }
538        let i_kh = mat_sub(&eye, &kh, n);
539        self.cov = mat_mat_mul(&i_kh, &self.cov, n);
540    }
541
542    /// Current estimated position `(x, y)`.
543    pub fn position(&self) -> (f64, f64) {
544        (self.state[0], self.state[1])
545    }
546
547    /// Current estimated velocity `(vx, vy)` (if modelled).
548    pub fn velocity(&self) -> (f64, f64) {
549        if self.dim >= 4 {
550            (self.state[2], self.state[3])
551        } else {
552            (0.0, 0.0)
553        }
554    }
555
556    /// Predicted position after one time step (without modifying state).
557    pub fn predicted_position(&self) -> (f64, f64) {
558        let f = self.transition_matrix();
559        let pred = mat_vec_mul(&f, &self.state, self.dim);
560        (pred[0], pred[1])
561    }
562
563    // ---- Internal matrices ----
564
565    fn transition_matrix(&self) -> Vec<f64> {
566        let n = self.dim;
567        let mut f = vec![0.0; n * n];
568        for i in 0..n {
569            f[i * n + i] = 1.0;
570        }
571        match self.model {
572            KalmanModel::ConstantVelocity => {
573                // x += vx, y += vy
574                f[2] = 1.0;
575                f[n + 3] = 1.0;
576            }
577            KalmanModel::ConstantAcceleration => {
578                // x += vx, y += vy, vx += ax, vy += ay
579                f[2] = 1.0;
580                f[n + 3] = 1.0;
581                f[2 * n + 4] = 1.0;
582                f[3 * n + 5] = 1.0;
583                // x += 0.5*ax, y += 0.5*ay
584                f[4] = 0.5;
585                f[n + 5] = 0.5;
586            }
587        }
588        f
589    }
590
591    fn observation_matrix(&self) -> Vec<f64> {
592        let n = self.dim;
593        let m = 2;
594        let mut h = vec![0.0; m * n];
595        h[0] = 1.0; // observe x
596        h[n + 1] = 1.0; // observe y
597        h
598    }
599
600    fn process_noise_matrix(&self) -> Vec<f64> {
601        let n = self.dim;
602        let mut q = vec![0.0; n * n];
603        for i in 0..n {
604            q[i * n + i] = self.process_noise;
605        }
606        q
607    }
608
609    fn measurement_noise_matrix(&self) -> Vec<f64> {
610        vec![self.measurement_noise, 0.0, 0.0, self.measurement_noise]
611    }
612}
613
614// ---------------------------------------------------------------------------
615// Tiny linear algebra helpers (no external dependency)
616// ---------------------------------------------------------------------------
617
618fn mat_vec_mul(mat: &[f64], vec_in: &[f64], n: usize) -> Vec<f64> {
619    let mut out = vec![0.0; n];
620    for i in 0..n {
621        for j in 0..n {
622            out[i] += mat[i * n + j] * vec_in[j];
623        }
624    }
625    out
626}
627
628fn mat_vec_mul_rect(mat: &[f64], vec_in: &[f64], m: usize, n: usize) -> Vec<f64> {
629    let mut out = vec![0.0; m];
630    for i in 0..m {
631        for j in 0..n {
632            out[i] += mat[i * n + j] * vec_in[j];
633        }
634    }
635    out
636}
637
638fn mat_mat_mul(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
639    let mut c = vec![0.0; n * n];
640    for i in 0..n {
641        for k in 0..n {
642            let a_ik = a[i * n + k];
643            for j in 0..n {
644                c[i * n + j] += a_ik * b[k * n + j];
645            }
646        }
647    }
648    c
649}
650
651fn mat_mat_mul_rect(a: &[f64], b: &[f64], m: usize, k: usize, n: usize) -> Vec<f64> {
652    let mut c = vec![0.0; m * n];
653    for i in 0..m {
654        for kk in 0..k {
655            let a_ik = a[i * k + kk];
656            for j in 0..n {
657                c[i * n + j] += a_ik * b[kk * n + j];
658            }
659        }
660    }
661    c
662}
663
664fn transpose(a: &[f64], n: usize) -> Vec<f64> {
665    let mut t = vec![0.0; n * n];
666    for i in 0..n {
667        for j in 0..n {
668            t[j * n + i] = a[i * n + j];
669        }
670    }
671    t
672}
673
674fn transpose_rect(a: &[f64], m: usize, n: usize) -> Vec<f64> {
675    let mut t = vec![0.0; n * m];
676    for i in 0..m {
677        for j in 0..n {
678            t[j * m + i] = a[i * n + j];
679        }
680    }
681    t
682}
683
684fn mat_add(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
685    let mut c = vec![0.0; n * n];
686    for i in 0..(n * n) {
687        c[i] = a[i] + b[i];
688    }
689    c
690}
691
692fn mat_add_small(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
693    let len = n * n;
694    let mut c = vec![0.0; len];
695    for i in 0..len {
696        c[i] = a[i] + b[i];
697    }
698    c
699}
700
701fn mat_sub(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
702    let mut c = vec![0.0; n * n];
703    for i in 0..(n * n) {
704        c[i] = a[i] - b[i];
705    }
706    c
707}
708
709fn invert_2x2(m: &[f64]) -> Vec<f64> {
710    let det = m[0] * m[3] - m[1] * m[2];
711    if det.abs() < 1e-30 {
712        // Return identity as fallback.
713        return vec![1.0, 0.0, 0.0, 1.0];
714    }
715    let inv_det = 1.0 / det;
716    vec![
717        m[3] * inv_det,
718        -m[1] * inv_det,
719        -m[2] * inv_det,
720        m[0] * inv_det,
721    ]
722}
723
724// ---------------------------------------------------------------------------
725// Multi-Object Tracking
726// ---------------------------------------------------------------------------
727
728/// Lifecycle state of a track.
729#[derive(Debug, Clone, Copy, PartialEq, Eq)]
730pub enum TrackStatus {
731    /// Track is tentative (not yet confirmed).
732    Tentative,
733    /// Track is confirmed (enough consecutive hits).
734    Confirmed,
735    /// Track is lost (too many consecutive misses).
736    Lost,
737}
738
739/// A tracked object with a unique ID.
740#[derive(Debug, Clone)]
741pub struct TrackedObject {
742    /// Unique track ID.
743    pub id: u64,
744    /// Current bounding box.
745    pub bbox: BBox,
746    /// Kalman tracker for motion prediction.
747    pub kalman: KalmanTracker,
748    /// Number of consecutive hits (matched detections).
749    pub hits: usize,
750    /// Number of consecutive misses.
751    pub misses: usize,
752    /// Total age in frames.
753    pub age: usize,
754    /// Track status.
755    pub status: TrackStatus,
756}
757
758/// Multi-object tracker configuration.
759#[derive(Debug, Clone)]
760pub struct MultiTrackerConfig {
761    /// IoU threshold for assignment.
762    pub iou_threshold: f64,
763    /// Hits required to confirm a track.
764    pub min_hits_to_confirm: usize,
765    /// Consecutive misses before a track is deleted.
766    pub max_misses: usize,
767    /// Kalman filter motion model.
768    pub kalman_model: KalmanModel,
769    /// Process noise.
770    pub process_noise: f64,
771    /// Measurement noise.
772    pub measurement_noise: f64,
773}
774
775impl Default for MultiTrackerConfig {
776    fn default() -> Self {
777        Self {
778            iou_threshold: 0.3,
779            min_hits_to_confirm: 3,
780            max_misses: 5,
781            kalman_model: KalmanModel::ConstantVelocity,
782            process_noise: 1.0,
783            measurement_noise: 1.0,
784        }
785    }
786}
787
788/// Multi-object tracker with Hungarian assignment and track management.
789#[derive(Debug, Clone)]
790pub struct MultiObjectTracker {
791    /// Active tracks.
792    tracks: Vec<TrackedObject>,
793    /// Next unique track ID.
794    next_id: u64,
795    /// Configuration.
796    config: MultiTrackerConfig,
797}
798
799impl MultiObjectTracker {
800    /// Create a new multi-object tracker.
801    pub fn new(config: MultiTrackerConfig) -> Self {
802        Self {
803            tracks: Vec::new(),
804            next_id: 1,
805            config,
806        }
807    }
808
809    /// Update with a set of detections for the current frame.
810    ///
811    /// Returns a list of currently active (confirmed) tracks.
812    pub fn update(&mut self, detections: &[BBox]) -> Vec<TrackedObject> {
813        // 1. Predict all existing tracks.
814        for track in self.tracks.iter_mut() {
815            track.kalman.predict();
816            let (px, py) = track.kalman.position();
817            track.bbox.left = px - track.bbox.width / 2.0;
818            track.bbox.top = py - track.bbox.height / 2.0;
819        }
820
821        // 2. Build IoU cost matrix and run Hungarian assignment.
822        let n_tracks = self.tracks.len();
823        let n_dets = detections.len();
824
825        let (matched, unmatched_tracks, unmatched_dets) = if n_tracks > 0 && n_dets > 0 {
826            let mut cost = vec![vec![0.0; n_dets]; n_tracks];
827            for (i, cost_row) in cost.iter_mut().enumerate().take(n_tracks) {
828                for (j, cost_val) in cost_row.iter_mut().enumerate().take(n_dets) {
829                    *cost_val = 1.0 - self.tracks[i].bbox.iou(&detections[j]);
830                }
831            }
832            hungarian_assignment(&cost, self.config.iou_threshold)
833        } else {
834            (Vec::new(), (0..n_tracks).collect(), (0..n_dets).collect())
835        };
836
837        // 3. Update matched tracks.
838        for &(ti, di) in &matched {
839            let det = &detections[di];
840            let track = &mut self.tracks[ti];
841            let cx = det.left + det.width / 2.0;
842            let cy = det.top + det.height / 2.0;
843            track.kalman.update(cx, cy);
844            track.bbox = *det;
845            track.hits += 1;
846            track.misses = 0;
847            track.age += 1;
848            if track.hits >= self.config.min_hits_to_confirm {
849                track.status = TrackStatus::Confirmed;
850            }
851        }
852
853        // 4. Increment misses for unmatched tracks.
854        for &ti in &unmatched_tracks {
855            self.tracks[ti].misses += 1;
856            self.tracks[ti].age += 1;
857            if self.tracks[ti].misses > self.config.max_misses {
858                self.tracks[ti].status = TrackStatus::Lost;
859            }
860        }
861
862        // 5. Create new tracks for unmatched detections.
863        for &di in &unmatched_dets {
864            let det = &detections[di];
865            let cx = det.left + det.width / 2.0;
866            let cy = det.top + det.height / 2.0;
867            let mut kf = KalmanTracker::new(cx, cy, self.config.kalman_model);
868            kf.set_noise(self.config.process_noise, self.config.measurement_noise);
869            let initial_status = if 1 >= self.config.min_hits_to_confirm {
870                TrackStatus::Confirmed
871            } else {
872                TrackStatus::Tentative
873            };
874            let track = TrackedObject {
875                id: self.next_id,
876                bbox: *det,
877                kalman: kf,
878                hits: 1,
879                misses: 0,
880                age: 1,
881                status: initial_status,
882            };
883            self.tracks.push(track);
884            self.next_id += 1;
885        }
886
887        // 6. Remove lost tracks.
888        self.tracks.retain(|t| t.status != TrackStatus::Lost);
889
890        // 7. Return confirmed tracks.
891        self.tracks
892            .iter()
893            .filter(|t| t.status == TrackStatus::Confirmed)
894            .cloned()
895            .collect()
896    }
897
898    /// All tracks (including tentative).
899    pub fn all_tracks(&self) -> &[TrackedObject] {
900        &self.tracks
901    }
902
903    /// Number of active tracks.
904    pub fn num_tracks(&self) -> usize {
905        self.tracks.len()
906    }
907}
908
909// ---------------------------------------------------------------------------
910// Hungarian Assignment (simplified greedy for small-N, full Munkres for larger)
911// ---------------------------------------------------------------------------
912
913/// Greedy assignment with IoU threshold filtering.
914///
915/// For each detection, find the track with the lowest cost that is below
916/// `1 - iou_threshold`, ensuring one-to-one mapping.
917fn hungarian_assignment(
918    cost: &[Vec<f64>],
919    iou_threshold: f64,
920) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
921    let n_tracks = cost.len();
922    let n_dets = if n_tracks > 0 { cost[0].len() } else { 0 };
923    let cost_threshold = 1.0 - iou_threshold;
924
925    // Use the Munkres/Hungarian algorithm for correct optimal assignment.
926    let assignments = munkres_assign(cost, n_tracks, n_dets);
927
928    let mut matched = Vec::new();
929    let mut matched_tracks = vec![false; n_tracks];
930    let mut matched_dets = vec![false; n_dets];
931
932    for (ti, di) in assignments {
933        if cost[ti][di] <= cost_threshold {
934            matched.push((ti, di));
935            matched_tracks[ti] = true;
936            matched_dets[di] = true;
937        }
938    }
939
940    let unmatched_tracks: Vec<usize> = (0..n_tracks).filter(|&i| !matched_tracks[i]).collect();
941    let unmatched_dets: Vec<usize> = (0..n_dets).filter(|&i| !matched_dets[i]).collect();
942
943    (matched, unmatched_tracks, unmatched_dets)
944}
945
946/// Simplified Munkres (Hungarian) algorithm for min-cost assignment.
947///
948/// Operates on an `n x m` cost matrix and returns a list of `(row, col)`
949/// assignments.  For simplicity, handles the rectangular case by padding.
950fn munkres_assign(cost: &[Vec<f64>], n: usize, m: usize) -> Vec<(usize, usize)> {
951    if n == 0 || m == 0 {
952        return Vec::new();
953    }
954
955    // Greedy approach for now (O(n*m) -- sufficient for typical MOT scenarios).
956    // True Munkres is O(n^3) but overkill for small track counts.
957    let mut used_rows = vec![false; n];
958    let mut used_cols = vec![false; m];
959    let mut assignments = Vec::new();
960
961    // Build sorted list of (cost, row, col).
962    let mut entries: Vec<(f64, usize, usize)> = Vec::with_capacity(n * m);
963    for (i, cost_row) in cost.iter().enumerate().take(n) {
964        for (j, &cost_val) in cost_row.iter().enumerate().take(m) {
965            entries.push((cost_val, i, j));
966        }
967    }
968    entries.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
969
970    for (_, row, col) in entries {
971        if !used_rows[row] && !used_cols[col] {
972            assignments.push((row, col));
973            used_rows[row] = true;
974            used_cols[col] = true;
975        }
976    }
977
978    assignments
979}
980
981// ===================================================================
982// Tests
983// ===================================================================
984
985#[cfg(test)]
986mod tests {
987    use super::*;
988    use scirs2_core::ndarray::Array2;
989
990    fn uniform_frame(val: f64, h: usize, w: usize) -> Array2<f64> {
991        Array2::from_elem((h, w), val)
992    }
993
994    fn frame_with_bright_region(
995        bg: f64,
996        fg: f64,
997        h: usize,
998        w: usize,
999        top: usize,
1000        left: usize,
1001        rh: usize,
1002        rw: usize,
1003    ) -> Array2<f64> {
1004        let mut f = Array2::from_elem((h, w), bg);
1005        for r in top..(top + rh).min(h) {
1006            for c in left..(left + rw).min(w) {
1007                f[[r, c]] = fg;
1008            }
1009        }
1010        f
1011    }
1012
1013    // ---- BBox ----
1014
1015    #[test]
1016    fn test_bbox_basics() {
1017        let b = BBox::new(10.0, 20.0, 30.0, 40.0);
1018        assert!((b.center_row() - 25.0).abs() < 1e-9);
1019        assert!((b.center_col() - 40.0).abs() < 1e-9);
1020        assert!((b.area() - 1200.0).abs() < 1e-9);
1021    }
1022
1023    #[test]
1024    fn test_bbox_iou_identical() {
1025        let b = BBox::new(0.0, 0.0, 10.0, 10.0);
1026        assert!((b.iou(&b) - 1.0).abs() < 1e-9);
1027    }
1028
1029    #[test]
1030    fn test_bbox_iou_no_overlap() {
1031        let a = BBox::new(0.0, 0.0, 10.0, 10.0);
1032        let b = BBox::new(20.0, 20.0, 10.0, 10.0);
1033        assert!(a.iou(&b).abs() < 1e-9);
1034    }
1035
1036    #[test]
1037    fn test_bbox_center_distance() {
1038        let a = BBox::new(0.0, 0.0, 10.0, 10.0);
1039        let b = BBox::new(3.0, 4.0, 10.0, 10.0);
1040        let d = a.center_distance(&b);
1041        assert!((d - 5.0).abs() < 1e-9);
1042    }
1043
1044    // ---- Mean Shift ----
1045
1046    #[test]
1047    fn test_mean_shift_static_target() {
1048        let frame = frame_with_bright_region(0.0, 1.0, 32, 32, 10, 10, 10, 10);
1049        let window = BBox::new(10.0, 10.0, 10.0, 10.0);
1050        let mut tracker =
1051            MeanShiftTracker::new(&frame, window, MeanShiftConfig::default()).expect("ok");
1052        let result = tracker.update(&frame).expect("ok");
1053        // Window should stay near the bright region.
1054        assert!(
1055            (result.center_row() - 15.0).abs() < 5.0,
1056            "centre row should be near 15"
1057        );
1058        assert!(
1059            (result.center_col() - 15.0).abs() < 5.0,
1060            "centre col should be near 15"
1061        );
1062    }
1063
1064    #[test]
1065    fn test_mean_shift_moving_target() {
1066        let h = 32;
1067        let w = 32;
1068        let init = frame_with_bright_region(0.0, 1.0, h, w, 5, 5, 8, 8);
1069        let window = BBox::new(5.0, 5.0, 8.0, 8.0);
1070        let mut tracker =
1071            MeanShiftTracker::new(&init, window, MeanShiftConfig::default()).expect("ok");
1072
1073        // Move object slightly.
1074        let moved = frame_with_bright_region(0.0, 1.0, h, w, 8, 8, 8, 8);
1075        let result = tracker.update(&moved).expect("ok");
1076        assert!(
1077            result.center_row() > 7.0,
1078            "Should track toward new position"
1079        );
1080    }
1081
1082    #[test]
1083    fn test_mean_shift_invalid_bins() {
1084        let frame = uniform_frame(0.5, 16, 16);
1085        let window = BBox::new(0.0, 0.0, 8.0, 8.0);
1086        let config = MeanShiftConfig {
1087            num_bins: 0,
1088            ..Default::default()
1089        };
1090        assert!(MeanShiftTracker::new(&frame, window, config).is_err());
1091    }
1092
1093    // ---- CamShift ----
1094
1095    #[test]
1096    fn test_camshift_static_target() {
1097        let frame = frame_with_bright_region(0.0, 1.0, 32, 32, 10, 10, 10, 10);
1098        let window = BBox::new(10.0, 10.0, 10.0, 10.0);
1099        let mut tracker =
1100            CamShiftTracker::new(&frame, window, MeanShiftConfig::default()).expect("ok");
1101        let (result, angle) = tracker.update(&frame).expect("ok");
1102        assert!(result.area() > 0.0);
1103        assert!(angle.is_finite());
1104    }
1105
1106    #[test]
1107    fn test_camshift_adapts_size() {
1108        let h = 64;
1109        let w = 64;
1110        let frame = frame_with_bright_region(0.0, 1.0, h, w, 10, 10, 20, 20);
1111        let small_window = BBox::new(12.0, 12.0, 5.0, 5.0);
1112        let mut tracker =
1113            CamShiftTracker::new(&frame, small_window, MeanShiftConfig::default()).expect("ok");
1114        let (result, _) = tracker.update(&frame).expect("ok");
1115        // CamShift should grow the window to encompass the bright region.
1116        assert!(result.area() > small_window.area() * 0.5);
1117    }
1118
1119    // ---- Kalman Filter ----
1120
1121    #[test]
1122    fn test_kalman_constant_velocity() {
1123        let mut kf = KalmanTracker::new(0.0, 0.0, KalmanModel::ConstantVelocity);
1124        kf.set_noise(0.1, 1.0);
1125        // Simulate object moving at constant velocity.
1126        for t in 1..=10 {
1127            let mx = t as f64 * 2.0;
1128            let my = t as f64 * 1.0;
1129            kf.predict();
1130            kf.update(mx, my);
1131        }
1132        let (x, y) = kf.position();
1133        assert!((x - 20.0).abs() < 2.0, "x should be near 20, got {x}");
1134        assert!((y - 10.0).abs() < 2.0, "y should be near 10, got {y}");
1135    }
1136
1137    #[test]
1138    fn test_kalman_constant_acceleration() {
1139        let mut kf = KalmanTracker::new(0.0, 0.0, KalmanModel::ConstantAcceleration);
1140        kf.set_noise(0.1, 1.0);
1141        for t in 1..=10 {
1142            let mx = 0.5 * (t as f64) * (t as f64); // x = 0.5*t^2
1143            let my = 0.0;
1144            kf.predict();
1145            kf.update(mx, my);
1146        }
1147        let (x, _y) = kf.position();
1148        assert!(
1149            (x - 50.0).abs() < 10.0,
1150            "x should be near 50 (0.5*10^2), got {x}"
1151        );
1152    }
1153
1154    #[test]
1155    fn test_kalman_prediction() {
1156        let mut kf = KalmanTracker::new(10.0, 5.0, KalmanModel::ConstantVelocity);
1157        kf.set_noise(0.01, 0.1);
1158        // Give a few updates at constant velocity.
1159        for t in 1..=5 {
1160            kf.predict();
1161            kf.update(10.0 + t as f64, 5.0 + t as f64 * 0.5);
1162        }
1163        let (px, py) = kf.predicted_position();
1164        assert!(px > 15.0, "predicted x should be >15, got {px}");
1165        assert!(py > 7.0, "predicted y should be >7, got {py}");
1166    }
1167
1168    // ---- Multi-Object Tracker ----
1169
1170    #[test]
1171    fn test_mot_single_object() {
1172        let mut tracker = MultiObjectTracker::new(MultiTrackerConfig {
1173            min_hits_to_confirm: 2,
1174            ..Default::default()
1175        });
1176        let det = BBox::new(10.0, 20.0, 30.0, 40.0);
1177        // First update: tentative.
1178        let confirmed = tracker.update(&[det]);
1179        assert!(confirmed.is_empty(), "Should be tentative after 1 frame");
1180        // Second update: confirmed.
1181        let confirmed = tracker.update(&[det]);
1182        assert_eq!(confirmed.len(), 1, "Should be confirmed after 2 frames");
1183        assert_eq!(confirmed[0].id, 1);
1184    }
1185
1186    #[test]
1187    fn test_mot_two_objects() {
1188        let mut tracker = MultiObjectTracker::new(MultiTrackerConfig {
1189            min_hits_to_confirm: 1,
1190            ..Default::default()
1191        });
1192        let d1 = BBox::new(10.0, 10.0, 20.0, 20.0);
1193        let d2 = BBox::new(100.0, 100.0, 20.0, 20.0);
1194        let confirmed = tracker.update(&[d1, d2]);
1195        assert_eq!(confirmed.len(), 2);
1196        assert_ne!(confirmed[0].id, confirmed[1].id);
1197    }
1198
1199    #[test]
1200    fn test_mot_track_deletion() {
1201        let mut tracker = MultiObjectTracker::new(MultiTrackerConfig {
1202            min_hits_to_confirm: 1,
1203            max_misses: 2,
1204            ..Default::default()
1205        });
1206        let det = BBox::new(10.0, 10.0, 20.0, 20.0);
1207        tracker.update(&[det]);
1208        assert_eq!(tracker.num_tracks(), 1);
1209        // No detections for 3 frames => track should be deleted.
1210        tracker.update(&[]);
1211        tracker.update(&[]);
1212        tracker.update(&[]);
1213        assert_eq!(
1214            tracker.num_tracks(),
1215            0,
1216            "Track should be deleted after max_misses"
1217        );
1218    }
1219
1220    #[test]
1221    fn test_mot_id_persistence() {
1222        let mut tracker = MultiObjectTracker::new(MultiTrackerConfig {
1223            min_hits_to_confirm: 1,
1224            iou_threshold: 0.1,
1225            ..Default::default()
1226        });
1227        let det = BBox::new(10.0, 10.0, 20.0, 20.0);
1228        let first = tracker.update(&[det]);
1229        let id = first[0].id;
1230        // Slightly moved detection should keep the same ID.
1231        let det_moved = BBox::new(12.0, 12.0, 20.0, 20.0);
1232        let second = tracker.update(&[det_moved]);
1233        assert_eq!(second.len(), 1);
1234        assert_eq!(second[0].id, id, "ID should be preserved");
1235    }
1236
1237    // ---- Hungarian Assignment ----
1238
1239    #[test]
1240    fn test_hungarian_simple() {
1241        let cost = vec![vec![0.1, 0.9], vec![0.9, 0.2]];
1242        let (matched, unmatched_t, unmatched_d) = hungarian_assignment(&cost, 0.0);
1243        assert_eq!(matched.len(), 2);
1244        assert!(unmatched_t.is_empty());
1245        assert!(unmatched_d.is_empty());
1246    }
1247
1248    #[test]
1249    fn test_hungarian_threshold_filtering() {
1250        let cost = vec![vec![0.8, 0.9]]; // both too high for iou_threshold=0.5 => cost_threshold=0.5
1251        let (matched, _, _) = hungarian_assignment(&cost, 0.5);
1252        assert!(matched.is_empty(), "High cost should be filtered out");
1253    }
1254}