1use crate::error::{Result, VisionError};
15use scirs2_core::ndarray::Array2;
16
17#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct MotionVector {
24 pub dy: f64,
26 pub dx: f64,
28}
29
30impl MotionVector {
31 pub fn new(dy: f64, dx: f64) -> Self {
33 Self { dy, dx }
34 }
35
36 pub fn zero() -> Self {
38 Self { dy: 0.0, dx: 0.0 }
39 }
40
41 pub fn magnitude(&self) -> f64 {
43 (self.dy * self.dy + self.dx * self.dx).sqrt()
44 }
45
46 pub fn angle(&self) -> f64 {
48 self.dy.atan2(self.dx)
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct MotionField {
55 pub vectors: Vec<Vec<MotionVector>>,
57 pub block_size: usize,
59 pub rows: usize,
61 pub cols: usize,
63 pub frame_height: usize,
65 pub frame_width: usize,
67}
68
69impl MotionField {
70 pub fn average_magnitude(&self) -> f64 {
72 let total: f64 = self
73 .vectors
74 .iter()
75 .flat_map(|row| row.iter())
76 .map(|v| v.magnitude())
77 .sum();
78 let count = (self.rows * self.cols) as f64;
79 if count > 0.0 {
80 total / count
81 } else {
82 0.0
83 }
84 }
85
86 pub fn max_magnitude(&self) -> f64 {
88 self.vectors
89 .iter()
90 .flat_map(|row| row.iter())
91 .map(|v| v.magnitude())
92 .fold(0.0_f64, f64::max)
93 }
94
95 pub fn to_hsv_visualization(&self) -> (Array2<f64>, Array2<f64>, Array2<f64>) {
100 let max_mag = self.max_magnitude().max(1e-9);
101 let h = self.frame_height;
102 let w = self.frame_width;
103 let mut hue = Array2::zeros((h, w));
104 let mut sat = Array2::zeros((h, w));
105 let mut val = Array2::zeros((h, w));
106
107 for br in 0..self.rows {
108 for bc in 0..self.cols {
109 let mv = &self.vectors[br][bc];
110 let mag = mv.magnitude() / max_mag;
111 let ang = (mv.angle() + std::f64::consts::PI) / (2.0 * std::f64::consts::PI);
112
113 let r_start = br * self.block_size;
114 let c_start = bc * self.block_size;
115 let r_end = (r_start + self.block_size).min(h);
116 let c_end = (c_start + self.block_size).min(w);
117
118 for r in r_start..r_end {
119 for c in c_start..c_end {
120 hue[[r, c]] = ang;
121 sat[[r, c]] = mag;
122 val[[r, c]] = mag;
123 }
124 }
125 }
126 }
127 (hue, sat, val)
128 }
129}
130
131pub fn block_match_full(
147 reference: &Array2<f64>,
148 current: &Array2<f64>,
149 block_size: usize,
150 search_range: usize,
151) -> Result<MotionField> {
152 validate_frames(reference, current)?;
153 if block_size == 0 {
154 return Err(VisionError::InvalidParameter(
155 "block_size must be > 0".into(),
156 ));
157 }
158
159 let rows = reference.nrows();
160 let cols = reference.ncols();
161 let grid_rows = rows / block_size;
162 let grid_cols = cols / block_size;
163
164 let mut vectors = Vec::with_capacity(grid_rows);
165
166 for br in 0..grid_rows {
167 let mut row_vecs = Vec::with_capacity(grid_cols);
168 for bc in 0..grid_cols {
169 let r0 = br * block_size;
170 let c0 = bc * block_size;
171
172 let mut best_dy: i64 = 0;
173 let mut best_dx: i64 = 0;
174 let mut best_sad = f64::MAX;
175 let sr = search_range as i64;
176
177 for dy in -sr..=sr {
178 for dx in -sr..=sr {
179 let sad =
180 compute_sad(reference, current, r0, c0, block_size, dy, dx, rows, cols);
181 if sad < best_sad {
182 best_sad = sad;
183 best_dy = dy;
184 best_dx = dx;
185 }
186 }
187 }
188
189 row_vecs.push(MotionVector::new(best_dy as f64, best_dx as f64));
190 }
191 vectors.push(row_vecs);
192 }
193
194 Ok(MotionField {
195 vectors,
196 block_size,
197 rows: grid_rows,
198 cols: grid_cols,
199 frame_height: rows,
200 frame_width: cols,
201 })
202}
203
204pub fn block_match_tss(
210 reference: &Array2<f64>,
211 current: &Array2<f64>,
212 block_size: usize,
213 search_range: usize,
214) -> Result<MotionField> {
215 validate_frames(reference, current)?;
216 if block_size == 0 {
217 return Err(VisionError::InvalidParameter(
218 "block_size must be > 0".into(),
219 ));
220 }
221 if search_range == 0 {
222 return Err(VisionError::InvalidParameter(
223 "search_range must be > 0".into(),
224 ));
225 }
226
227 let rows = reference.nrows();
228 let cols = reference.ncols();
229 let grid_rows = rows / block_size;
230 let grid_cols = cols / block_size;
231
232 let mut vectors = Vec::with_capacity(grid_rows);
233
234 let initial_step = ((search_range as f64) / 2.0).ceil().max(1.0) as i64;
236
237 for br in 0..grid_rows {
238 let mut row_vecs = Vec::with_capacity(grid_cols);
239 for bc in 0..grid_cols {
240 let r0 = br * block_size;
241 let c0 = bc * block_size;
242
243 let mut center_dy: i64 = 0;
244 let mut center_dx: i64 = 0;
245 let mut step = initial_step;
246
247 while step >= 1 {
248 let mut best_dy = center_dy;
249 let mut best_dx = center_dx;
250 let mut best_sad = f64::MAX;
251
252 for ddy in [-step, 0, step] {
253 for ddx in [-step, 0, step] {
254 let dy = center_dy + ddy;
255 let dx = center_dx + ddx;
256 let sad =
257 compute_sad(reference, current, r0, c0, block_size, dy, dx, rows, cols);
258 if sad < best_sad {
259 best_sad = sad;
260 best_dy = dy;
261 best_dx = dx;
262 }
263 }
264 }
265
266 center_dy = best_dy;
267 center_dx = best_dx;
268 step /= 2;
269 }
270
271 row_vecs.push(MotionVector::new(center_dy as f64, center_dx as f64));
272 }
273 vectors.push(row_vecs);
274 }
275
276 Ok(MotionField {
277 vectors,
278 block_size,
279 rows: grid_rows,
280 cols: grid_cols,
281 frame_height: rows,
282 frame_width: cols,
283 })
284}
285
286fn compute_sad(
291 reference: &Array2<f64>,
292 current: &Array2<f64>,
293 r0: usize,
294 c0: usize,
295 block_size: usize,
296 dy: i64,
297 dx: i64,
298 rows: usize,
299 cols: usize,
300) -> f64 {
301 let mut sad = 0.0;
302 for r in 0..block_size {
303 for c in 0..block_size {
304 let cr = r0 + r;
305 let cc = c0 + c;
306 let rr = (cr as i64 + dy) as isize;
307 let rc = (cc as i64 + dx) as isize;
308 if cr < rows
309 && cc < cols
310 && rr >= 0
311 && (rr as usize) < rows
312 && rc >= 0
313 && (rc as usize) < cols
314 {
315 sad += (current[[cr, cc]] - reference[[rr as usize, rc as usize]]).abs();
316 } else {
317 sad += 1.0; }
319 }
320 }
321 sad
322}
323
324pub fn phase_correlation(reference: &Array2<f64>, current: &Array2<f64>) -> Result<MotionVector> {
337 validate_frames(reference, current)?;
338 let rows = reference.nrows();
339 let cols = reference.ncols();
340
341 if rows == 0 || cols == 0 {
342 return Ok(MotionVector::zero());
343 }
344
345 let max_dy = (rows / 4).max(1) as i64;
355 let max_dx = (cols / 4).max(1) as i64;
356
357 let mut best_dy: i64 = 0;
358 let mut best_dx: i64 = 0;
359 let mut best_corr = f64::NEG_INFINITY;
360
361 for dy in -max_dy..=max_dy {
362 for dx in -max_dx..=max_dx {
363 let mut corr = 0.0;
364 let mut count = 0u64;
365 for r in 0..rows {
366 for c in 0..cols {
367 let rr = r as i64 + dy;
368 let rc = c as i64 + dx;
369 if rr >= 0 && (rr as usize) < rows && rc >= 0 && (rc as usize) < cols {
370 corr += reference[[r, c]] * current[[rr as usize, rc as usize]];
371 count += 1;
372 }
373 }
374 }
375 if count > 0 {
376 corr /= count as f64;
377 }
378 if corr > best_corr {
379 best_corr = corr;
380 best_dy = dy;
381 best_dx = dx;
382 }
383 }
384 }
385
386 let refined_dy = subpixel_refine_1d(
388 |d| cross_corr_at(reference, current, d, best_dx, rows, cols),
389 best_dy,
390 max_dy,
391 );
392 let refined_dx = subpixel_refine_1d(
393 |d| cross_corr_at(reference, current, best_dy, d, rows, cols),
394 best_dx,
395 max_dx,
396 );
397
398 Ok(MotionVector::new(refined_dy, refined_dx))
399}
400
401fn cross_corr_at(
402 reference: &Array2<f64>,
403 current: &Array2<f64>,
404 dy: i64,
405 dx: i64,
406 rows: usize,
407 cols: usize,
408) -> f64 {
409 let mut corr = 0.0;
410 let mut count = 0u64;
411 for r in 0..rows {
412 for c in 0..cols {
413 let rr = r as i64 + dy;
414 let rc = c as i64 + dx;
415 if rr >= 0 && (rr as usize) < rows && rc >= 0 && (rc as usize) < cols {
416 corr += reference[[r, c]] * current[[rr as usize, rc as usize]];
417 count += 1;
418 }
419 }
420 }
421 if count > 0 {
422 corr / count as f64
423 } else {
424 0.0
425 }
426}
427
428fn subpixel_refine_1d<F: Fn(i64) -> f64>(corr_fn: F, best: i64, limit: i64) -> f64 {
429 if best <= -limit || best >= limit {
430 return best as f64;
431 }
432 let c_minus = corr_fn(best - 1);
433 let c_center = corr_fn(best);
434 let c_plus = corr_fn(best + 1);
435 let denom = 2.0 * (2.0 * c_center - c_minus - c_plus);
436 if denom.abs() < 1e-12 {
437 return best as f64;
438 }
439 let offset = (c_minus - c_plus) / denom;
440 best as f64 + offset.clamp(-0.5, 0.5)
441}
442
443pub fn motion_compensate(reference: &Array2<f64>, field: &MotionField) -> Result<Array2<f64>> {
453 let rows = reference.nrows();
454 let cols = reference.ncols();
455 if rows != field.frame_height || cols != field.frame_width {
456 return Err(VisionError::DimensionMismatch(format!(
457 "Reference ({}x{}) does not match field frame size ({}x{})",
458 rows, cols, field.frame_height, field.frame_width,
459 )));
460 }
461
462 let bs = field.block_size;
463 let mut output = Array2::zeros((rows, cols));
464
465 for br in 0..field.rows {
466 for bc in 0..field.cols {
467 let mv = &field.vectors[br][bc];
468 let r0 = br * bs;
469 let c0 = bc * bs;
470
471 for r in 0..bs {
472 for c in 0..bs {
473 let dst_r = r0 + r;
474 let dst_c = c0 + c;
475 if dst_r >= rows || dst_c >= cols {
476 continue;
477 }
478 let src_r = (dst_r as f64 + mv.dy).round() as isize;
479 let src_c = (dst_c as f64 + mv.dx).round() as isize;
480 if src_r >= 0
481 && (src_r as usize) < rows
482 && src_c >= 0
483 && (src_c as usize) < cols
484 {
485 output[[dst_r, dst_c]] = reference[[src_r as usize, src_c as usize]];
486 }
487 }
488 }
489 }
490 }
491
492 Ok(output)
493}
494
495pub fn prediction_error(actual: &Array2<f64>, predicted: &Array2<f64>) -> Result<Array2<f64>> {
498 validate_frames(actual, predicted)?;
499 Ok(actual - predicted)
500}
501
502fn validate_frames(a: &Array2<f64>, b: &Array2<f64>) -> Result<()> {
507 if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
508 return Err(VisionError::DimensionMismatch(format!(
509 "Frame dimensions do not match: ({},{}) vs ({},{})",
510 a.nrows(),
511 a.ncols(),
512 b.nrows(),
513 b.ncols(),
514 )));
515 }
516 Ok(())
517}
518
519#[cfg(test)]
524mod tests {
525 use super::*;
526 use scirs2_core::ndarray::Array2;
527
528 fn uniform_frame(val: f64, h: usize, w: usize) -> Array2<f64> {
529 Array2::from_elem((h, w), val)
530 }
531
532 fn frame_with_square(
534 bg: f64,
535 fg: f64,
536 h: usize,
537 w: usize,
538 top: usize,
539 left: usize,
540 size: usize,
541 ) -> Array2<f64> {
542 let mut f = Array2::from_elem((h, w), bg);
543 for r in top..(top + size).min(h) {
544 for c in left..(left + size).min(w) {
545 f[[r, c]] = fg;
546 }
547 }
548 f
549 }
550
551 #[test]
554 fn test_motion_vector_basics() {
555 let v = MotionVector::new(3.0, 4.0);
556 assert!((v.magnitude() - 5.0).abs() < 1e-9);
557 assert!(v.angle().is_finite());
558
559 let z = MotionVector::zero();
560 assert!((z.magnitude()).abs() < 1e-12);
561 }
562
563 #[test]
566 fn test_full_search_no_motion() {
567 let frame = frame_with_square(0.0, 1.0, 16, 16, 2, 2, 8);
569 let field = block_match_full(&frame, &frame, 8, 4).expect("ok");
570 assert_eq!(field.rows, 2);
571 assert_eq!(field.cols, 2);
572 for row in &field.vectors {
573 for v in row {
574 assert!((v.dy).abs() < 1e-9, "expected dy=0, got {}", v.dy);
575 assert!((v.dx).abs() < 1e-9, "expected dx=0, got {}", v.dx);
576 }
577 }
578 }
579
580 #[test]
581 fn test_full_search_detects_horizontal_shift() {
582 let h = 16;
583 let w = 32;
584 let bs = 8;
585 let ref_frame = frame_with_square(0.0, 1.0, h, w, 4, 4, 8);
586 let cur_frame = frame_with_square(0.0, 1.0, h, w, 4, 8, 8); let field = block_match_full(&ref_frame, &cur_frame, bs, 6).expect("ok");
588 let has_shift = field
590 .vectors
591 .iter()
592 .flat_map(|r| r.iter())
593 .any(|v| (v.dx - (-4.0)).abs() < 1.5);
594 assert!(has_shift, "Should detect ~4px horizontal shift");
595 }
596
597 #[test]
598 fn test_full_search_dimension_mismatch() {
599 let a = uniform_frame(0.5, 16, 16);
600 let b = uniform_frame(0.5, 8, 16);
601 assert!(block_match_full(&a, &b, 8, 4).is_err());
602 }
603
604 #[test]
605 fn test_full_search_zero_block() {
606 let f = uniform_frame(0.5, 16, 16);
607 assert!(block_match_full(&f, &f, 0, 4).is_err());
608 }
609
610 #[test]
613 fn test_tss_no_motion() {
614 let frame = frame_with_square(0.0, 1.0, 16, 16, 2, 2, 8);
616 let field = block_match_tss(&frame, &frame, 8, 4).expect("ok");
617 for row in &field.vectors {
618 for v in row {
619 assert!(
620 v.magnitude() < 1e-9,
621 "expected zero motion, got mag={}",
622 v.magnitude()
623 );
624 }
625 }
626 }
627
628 #[test]
629 fn test_tss_detects_vertical_shift() {
630 let h = 16;
631 let w = 16;
632 let ref_frame = frame_with_square(0.0, 1.0, h, w, 2, 2, 4);
633 let cur_frame = frame_with_square(0.0, 1.0, h, w, 5, 2, 4); let field = block_match_tss(&ref_frame, &cur_frame, 4, 8).expect("ok");
635 let has_shift = field
636 .vectors
637 .iter()
638 .flat_map(|r| r.iter())
639 .any(|v| (v.dy - (-3.0)).abs() < 2.0);
640 assert!(has_shift, "Should detect ~3px vertical shift");
641 }
642
643 #[test]
644 fn test_tss_invalid_search_range() {
645 let f = uniform_frame(0.5, 16, 16);
646 assert!(block_match_tss(&f, &f, 8, 0).is_err());
647 }
648
649 #[test]
652 fn test_phase_corr_no_motion() {
653 let frame = frame_with_square(0.0, 1.0, 16, 16, 4, 4, 8);
654 let mv = phase_correlation(&frame, &frame).expect("ok");
655 assert!(
656 mv.magnitude() < 1.0,
657 "No motion expected, got mag={}",
658 mv.magnitude()
659 );
660 }
661
662 #[test]
663 fn test_phase_corr_detects_shift() {
664 let h = 16;
665 let w = 16;
666 let ref_frame = frame_with_square(0.0, 1.0, h, w, 2, 2, 6);
667 let cur_frame = frame_with_square(0.0, 1.0, h, w, 2, 4, 6); let mv = phase_correlation(&ref_frame, &cur_frame).expect("ok");
669 assert!(
673 mv.dx.abs() <= 4.0,
674 "Expected dx magnitude near 2, got {}",
675 mv.dx
676 );
677 }
678
679 #[test]
680 fn test_phase_corr_dimension_mismatch() {
681 let a = uniform_frame(0.5, 8, 8);
682 let b = uniform_frame(0.5, 8, 16);
683 assert!(phase_correlation(&a, &b).is_err());
684 }
685
686 #[test]
689 fn test_motion_compensate_zero_field() {
690 let frame = frame_with_square(0.1, 0.9, 16, 16, 2, 2, 4);
691 let field = MotionField {
692 vectors: vec![vec![MotionVector::zero(); 2]; 2],
693 block_size: 8,
694 rows: 2,
695 cols: 2,
696 frame_height: 16,
697 frame_width: 16,
698 };
699 let comp = motion_compensate(&frame, &field).expect("ok");
700 for r in 0..16 {
701 for c in 0..16 {
702 assert!(
703 (comp[[r, c]] - frame[[r, c]]).abs() < 1e-9,
704 "Zero motion should reproduce the reference"
705 );
706 }
707 }
708 }
709
710 #[test]
711 fn test_motion_compensate_dimension_mismatch() {
712 let frame = uniform_frame(0.5, 8, 8);
713 let field = MotionField {
714 vectors: vec![vec![MotionVector::zero(); 2]; 2],
715 block_size: 8,
716 rows: 2,
717 cols: 2,
718 frame_height: 16,
719 frame_width: 16,
720 };
721 assert!(motion_compensate(&frame, &field).is_err());
722 }
723
724 #[test]
727 fn test_prediction_error_zero() {
728 let frame = uniform_frame(0.5, 8, 8);
729 let err = prediction_error(&frame, &frame).expect("ok");
730 for &v in err.iter() {
731 assert!(v.abs() < 1e-12);
732 }
733 }
734
735 #[test]
736 fn test_prediction_error_nonzero() {
737 let a = uniform_frame(0.8, 4, 4);
738 let b = uniform_frame(0.3, 4, 4);
739 let err = prediction_error(&a, &b).expect("ok");
740 for &v in err.iter() {
741 assert!((v - 0.5).abs() < 1e-9);
742 }
743 }
744
745 #[test]
748 fn test_motion_field_average_and_max() {
749 let field = MotionField {
750 vectors: vec![
751 vec![MotionVector::new(3.0, 4.0), MotionVector::new(0.0, 0.0)],
752 vec![MotionVector::new(1.0, 0.0), MotionVector::new(0.0, 1.0)],
753 ],
754 block_size: 4,
755 rows: 2,
756 cols: 2,
757 frame_height: 8,
758 frame_width: 8,
759 };
760 assert!((field.max_magnitude() - 5.0).abs() < 1e-9);
761 assert!(field.average_magnitude() > 0.0);
762 }
763
764 #[test]
765 fn test_hsv_visualization() {
766 let field = MotionField {
767 vectors: vec![vec![MotionVector::new(1.0, 0.0); 2]; 2],
768 block_size: 4,
769 rows: 2,
770 cols: 2,
771 frame_height: 8,
772 frame_width: 8,
773 };
774 let (hue, sat, val) = field.to_hsv_visualization();
775 assert_eq!(hue.nrows(), 8);
776 assert_eq!(sat.ncols(), 8);
777 let first_s = sat[[0, 0]];
779 for &s in sat.iter() {
780 assert!((s - first_s).abs() < 1e-9);
781 }
782 assert!(val[[0, 0]] > 0.0);
784 }
785}