1use ndarray::{s, Array2, Array3, Array4, ArrayD, IxDyn};
13use rand::prelude::*;
14use rand::rngs::SmallRng;
15use scirs2_neural::layers::{
16 AdaptiveMaxPool2D, BatchNorm, Conv2D, Dense, Dropout, MaxPool2D, PaddingMode, Sequential,
17};
18use scirs2_neural::losses::{CrossEntropyLoss, MeanSquaredError};
19use scirs2_neural::prelude::*;
20use std::collections::HashMap;
21
22type StdResult<T> = std::result::Result<T, Box<dyn std::error::Error>>;
24
25#[derive(Debug, Clone)]
27pub struct DetectionConfig {
28 pub num_classes: usize,
29 pub max_objects: usize,
30 pub input_size: (usize, usize),
31 pub anchor_sizes: Vec<f32>,
32 pub feature_map_size: (usize, usize),
33}
34
35impl Default for DetectionConfig {
36 fn default() -> Self {
37 Self {
38 num_classes: 3, max_objects: 5,
40 input_size: (64, 64),
41 anchor_sizes: vec![16.0, 32.0, 48.0],
42 feature_map_size: (8, 8),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct BoundingBox {
50 pub x: f32,
51 pub y: f32,
52 pub width: f32,
53 pub height: f32,
54 pub class_id: usize,
55 pub confidence: f32,
56}
57
58impl BoundingBox {
59 pub fn new(x: f32, y: f32, width: f32, height: f32, class_id: usize, confidence: f32) -> Self {
60 Self {
61 x,
62 y,
63 width,
64 height,
65 class_id,
66 confidence,
67 }
68 }
69
70 pub fn iou(&self, other: &BoundingBox) -> f32 {
72 let x1 = self.x.max(other.x);
73 let y1 = self.y.max(other.y);
74 let x2 = (self.x + self.width).min(other.x + other.width);
75 let y2 = (self.y + self.height).min(other.y + other.height);
76
77 if x2 <= x1 || y2 <= y1 {
78 return 0.0;
79 }
80
81 let intersection = (x2 - x1) * (y2 - y1);
82 let union = self.width * self.height + other.width * other.height - intersection;
83
84 if union <= 0.0 {
85 0.0
86 } else {
87 intersection / union
88 }
89 }
90}
91
92pub struct DetectionDataset {
94 config: DetectionConfig,
95 rng: SmallRng,
96}
97
98impl DetectionDataset {
99 pub fn new(config: DetectionConfig, seed: u64) -> Self {
100 Self {
101 config,
102 rng: SmallRng::seed_from_u64(seed),
103 }
104 }
105
106 pub fn generate_sample(&mut self) -> (Array3<f32>, Vec<BoundingBox>) {
108 let (height, width) = self.config.input_size;
109 let mut image = Array3::<f32>::zeros((3, height, width)); for c in 0..3 {
113 for i in 0..height {
114 for j in 0..width {
115 image[[c, i, j]] = self.rng.random_range(0.0..0.3);
116 }
117 }
118 }
119
120 let mut objects = Vec::new();
121 let num_objects = self.rng.random_range(1..=self.config.max_objects.min(3));
122
123 for _ in 0..num_objects {
124 let obj_width = self.rng.random_range(8..24) as f32;
125 let obj_height = self.rng.random_range(8..24) as f32;
126 let obj_x = self.rng.random_range(0.0..(width as f32 - obj_width));
127 let obj_y = self.rng.random_range(0.0..(height as f32 - obj_height));
128 let class_id = self.rng.random_range(1..self.config.num_classes); let color_intensity = match class_id {
132 1 => [0.8, 0.2, 0.2], 2 => [0.2, 0.8, 0.2], _ => [0.2, 0.2, 0.8], };
136
137 for c in 0..3 {
138 for i in (obj_y as usize)..((obj_y + obj_height) as usize).min(height) {
139 for j in (obj_x as usize)..((obj_x + obj_width) as usize).min(width) {
140 image[[c, i, j]] = color_intensity[c] + self.rng.random_range(-0.1..0.1);
141 }
142 }
143 }
144
145 objects.push(BoundingBox::new(
146 obj_x, obj_y, obj_width, obj_height, class_id, 1.0,
147 ));
148 }
149
150 (image, objects)
151 }
152
153 pub fn generate_batch(&mut self, batch_size: usize) -> (Array4<f32>, Vec<Vec<BoundingBox>>) {
155 let (height, width) = self.config.input_size;
156 let mut images = Array4::<f32>::zeros((batch_size, 3, height, width));
157 let mut all_objects = Vec::new();
158
159 for i in 0..batch_size {
160 let (image, objects) = self.generate_sample();
161 images.slice_mut(s![i, .., .., ..]).assign(&image);
162 all_objects.push(objects);
163 }
164
165 (images, all_objects)
166 }
167}
168
169pub struct ObjectDetectionModel {
171 feature_extractor: Sequential<f32>,
172 classifier_head: Sequential<f32>,
173 bbox_regressor: Sequential<f32>,
174 config: DetectionConfig,
175}
176
177impl ObjectDetectionModel {
178 pub fn new(config: DetectionConfig, rng: &mut SmallRng) -> StdResult<Self> {
179 let mut feature_extractor = Sequential::new();
181
182 feature_extractor.add(Conv2D::new(3, 64, (7, 7), (2, 2), PaddingMode::Same, rng)?);
184 feature_extractor.add(BatchNorm::new(64, 0.1, 1e-5, rng)?);
185 feature_extractor.add(MaxPool2D::new((2, 2), (2, 2), None)?);
186
187 feature_extractor.add(Conv2D::new(
189 64,
190 128,
191 (3, 3),
192 (2, 2),
193 PaddingMode::Same,
194 rng,
195 )?);
196 feature_extractor.add(BatchNorm::new(128, 0.1, 1e-5, rng)?);
197
198 feature_extractor.add(Conv2D::new(
199 128,
200 256,
201 (3, 3),
202 (2, 2),
203 PaddingMode::Same,
204 rng,
205 )?);
206 feature_extractor.add(BatchNorm::new(256, 0.1, 1e-5, rng)?);
207
208 feature_extractor.add(AdaptiveMaxPool2D::new(config.feature_map_size, None)?);
210
211 let mut classifier_head = Sequential::new();
213 let feature_dim = 256 * config.feature_map_size.0 * config.feature_map_size.1;
214 classifier_head.add(Dense::new(feature_dim, 512, Some("relu"), rng)?);
215 classifier_head.add(Dropout::new(0.5, rng)?);
216 classifier_head.add(Dense::new(512, 256, Some("relu"), rng)?);
217 classifier_head.add(Dropout::new(0.3, rng)?);
218 classifier_head.add(Dense::new(
219 256,
220 config.num_classes * config.max_objects,
221 Some("softmax"),
222 rng,
223 )?);
224
225 let mut bbox_regressor = Sequential::new();
227 bbox_regressor.add(Dense::new(feature_dim, 512, Some("relu"), rng)?);
228 bbox_regressor.add(Dropout::new(0.5, rng)?);
229 bbox_regressor.add(Dense::new(512, 256, Some("relu"), rng)?);
230 bbox_regressor.add(Dropout::new(0.3, rng)?);
231 bbox_regressor.add(Dense::new(256, 4 * config.max_objects, None, rng)?); Ok(Self {
234 feature_extractor,
235 classifier_head,
236 bbox_regressor,
237 config,
238 })
239 }
240
241 pub fn forward(&self, input: &ArrayD<f32>) -> StdResult<(ArrayD<f32>, ArrayD<f32>)> {
243 let features = self.feature_extractor.forward(input)?;
245
246 let batch_size = features.shape()[0];
248 let feature_dim = features.len() / batch_size;
249 let flattened = features
250 .into_shape_with_order(IxDyn(&[batch_size, feature_dim]))
251 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
252
253 let classifications = self.classifier_head.forward(&flattened)?;
255 let bbox_predictions = self.bbox_regressor.forward(&flattened)?;
256
257 Ok((classifications, bbox_predictions))
258 }
259
260 pub fn extract_detections(
262 &self,
263 classifications: &ArrayD<f32>,
264 bbox_predictions: &ArrayD<f32>,
265 confidence_threshold: f32,
266 ) -> Vec<Vec<BoundingBox>> {
267 let batch_size = classifications.shape()[0];
268 let mut detections = Vec::new();
269
270 for b in 0..batch_size {
271 let mut batch_detections = Vec::new();
272
273 for obj_idx in 0..self.config.max_objects {
274 let mut best_class = 0;
276 let mut best_score = 0.0f32;
277
278 for class_idx in 0..self.config.num_classes {
279 let score_idx = obj_idx * self.config.num_classes + class_idx;
280 if score_idx < classifications.shape()[1] {
281 let score = classifications[[b, score_idx]];
282 if score > best_score {
283 best_score = score;
284 best_class = class_idx;
285 }
286 }
287 }
288
289 if best_class > 0 && best_score > confidence_threshold {
291 let bbox_start = obj_idx * 4;
293 if bbox_start + 3 < bbox_predictions.shape()[1] {
294 let x = bbox_predictions[[b, bbox_start]].max(0.0);
295 let y = bbox_predictions[[b, bbox_start + 1]].max(0.0);
296 let width = bbox_predictions[[b, bbox_start + 2]].max(1.0);
297 let height = bbox_predictions[[b, bbox_start + 3]].max(1.0);
298
299 batch_detections.push(BoundingBox::new(
300 x, y, width, height, best_class, best_score,
301 ));
302 }
303 }
304 }
305
306 detections.push(batch_detections);
307 }
308
309 detections
310 }
311}
312
313pub struct DetectionLoss {
315 classification_loss: CrossEntropyLoss,
316 regression_loss: MeanSquaredError,
317 classification_weight: f32,
318 regression_weight: f32,
319}
320
321impl DetectionLoss {
322 pub fn new(classification_weight: f32, regression_weight: f32) -> Self {
323 Self {
324 classification_loss: CrossEntropyLoss::new(1e-7),
325 regression_loss: MeanSquaredError,
326 classification_weight,
327 regression_weight,
328 }
329 }
330
331 pub fn compute_loss(
333 &self,
334 pred_classes: &ArrayD<f32>,
335 pred_boxes: &ArrayD<f32>,
336 target_classes: &ArrayD<f32>,
337 target_boxes: &ArrayD<f32>,
338 ) -> StdResult<f32> {
339 let class_loss = self
340 .classification_loss
341 .forward(pred_classes, target_classes)?;
342 let bbox_loss = self.regression_loss.forward(pred_boxes, target_boxes)?;
343
344 Ok(self.classification_weight * class_loss + self.regression_weight * bbox_loss)
345 }
346}
347
348pub struct DetectionMetrics {
350 iou_threshold: f32,
351 confidence_threshold: f32,
352}
353
354impl DetectionMetrics {
355 pub fn new(iou_threshold: f32, confidence_threshold: f32) -> Self {
356 Self {
357 iou_threshold,
358 confidence_threshold,
359 }
360 }
361
362 pub fn calculate_map(
364 &self,
365 predictions: &[Vec<BoundingBox>],
366 ground_truth: &[Vec<BoundingBox>],
367 ) -> f32 {
368 if predictions.is_empty() || ground_truth.is_empty() {
369 return 0.0;
370 }
371
372 let mut total_precision = 0.0;
373 let mut total_samples = 0;
374
375 for (pred_batch, gt_batch) in predictions.iter().zip(ground_truth.iter()) {
376 let precision = self.calculate_precision(pred_batch, gt_batch);
377 total_precision += precision;
378 total_samples += 1;
379 }
380
381 if total_samples > 0 {
382 total_precision / total_samples as f32
383 } else {
384 0.0
385 }
386 }
387
388 fn calculate_precision(
390 &self,
391 predictions: &[BoundingBox],
392 ground_truth: &[BoundingBox],
393 ) -> f32 {
394 if predictions.is_empty() {
395 return if ground_truth.is_empty() { 1.0 } else { 0.0 };
396 }
397
398 let mut true_positives = 0;
399 let mut used_gt = vec![false; ground_truth.len()];
400
401 for pred in predictions {
402 if pred.confidence < self.confidence_threshold {
403 continue;
404 }
405
406 let mut best_iou = 0.0;
407 let mut best_gt_idx = None;
408
409 for (gt_idx, gt) in ground_truth.iter().enumerate() {
410 if used_gt[gt_idx] || pred.class_id != gt.class_id {
411 continue;
412 }
413
414 let iou = pred.iou(gt);
415 if iou > best_iou {
416 best_iou = iou;
417 best_gt_idx = Some(gt_idx);
418 }
419 }
420
421 if let Some(gt_idx) = best_gt_idx {
422 if best_iou >= self.iou_threshold {
423 true_positives += 1;
424 used_gt[gt_idx] = true;
425 }
426 }
427 }
428
429 if predictions.is_empty() {
430 0.0
431 } else {
432 true_positives as f32 / predictions.len() as f32
433 }
434 }
435}
436
437fn prepare_targets(
439 ground_truth: &[Vec<BoundingBox>],
440 config: &DetectionConfig,
441) -> (ArrayD<f32>, ArrayD<f32>) {
442 let batch_size = ground_truth.len();
443
444 let mut class_targets =
446 Array2::<f32>::zeros((batch_size, config.max_objects * config.num_classes));
447
448 let mut bbox_targets = Array2::<f32>::zeros((batch_size, config.max_objects * 4));
450
451 for (batch_idx, objects) in ground_truth.iter().enumerate() {
452 for (obj_idx, obj) in objects.iter().enumerate().take(config.max_objects) {
453 let class_start = obj_idx * config.num_classes;
455 if class_start + obj.class_id < class_targets.shape()[1] {
456 class_targets[[batch_idx, class_start + obj.class_id]] = 1.0;
457 }
458
459 let bbox_start = obj_idx * 4;
461 if bbox_start + 3 < bbox_targets.shape()[1] {
462 bbox_targets[[batch_idx, bbox_start]] = obj.x;
463 bbox_targets[[batch_idx, bbox_start + 1]] = obj.y;
464 bbox_targets[[batch_idx, bbox_start + 2]] = obj.width;
465 bbox_targets[[batch_idx, bbox_start + 3]] = obj.height;
466 }
467 }
468 }
469
470 (class_targets.into_dyn(), bbox_targets.into_dyn())
471}
472
473fn train_detection_model() -> StdResult<()> {
475 println!("šÆ Starting Object Detection Training");
476
477 let mut rng = SmallRng::seed_from_u64(42);
478 let config = DetectionConfig::default();
479
480 println!("š Starting model training...");
481
482 println!("šļø Building object detection model...");
484 let model = ObjectDetectionModel::new(config.clone(), &mut rng)?;
485 println!(
486 "ā
Model created with {} classes and {} max objects",
487 config.num_classes, config.max_objects
488 );
489
490 let mut dataset = DetectionDataset::new(config.clone(), 123);
492
493 let loss_fn = DetectionLoss::new(1.0, 1.0); let metrics = DetectionMetrics::new(0.5, 0.5); println!("š Training configuration:");
500 println!(" - Input size: {:?}", config.input_size);
501 println!(" - Feature map size: {:?}", config.feature_map_size);
502 println!(" - Max objects per image: {}", config.max_objects);
503 println!(" - Number of classes: {}", config.num_classes);
504
505 let num_epochs = 10;
507 let batch_size = 4;
508 let _learning_rate = 0.001;
509
510 for epoch in 0..num_epochs {
511 println!("\nš Epoch {}/{}", epoch + 1, num_epochs);
512
513 let mut epoch_loss = 0.0;
514 let num_batches = 8; for batch_idx in 0..num_batches {
517 let (images, ground_truth) = dataset.generate_batch(batch_size);
519 let images_dyn = images.into_dyn();
520
521 let (pred_classes, pred_boxes) = model.forward(&images_dyn)?;
523
524 let (target_classes, target_boxes) = prepare_targets(&ground_truth, &config);
526
527 let batch_loss =
529 loss_fn.compute_loss(&pred_classes, &pred_boxes, &target_classes, &target_boxes)?;
530 epoch_loss += batch_loss;
531
532 if batch_idx % 4 == 0 {
533 print!(
534 "š Batch {}/{} - Loss: {:.4} \r",
535 batch_idx + 1,
536 num_batches,
537 batch_loss
538 );
539 }
540 }
541
542 let avg_loss = epoch_loss / num_batches as f32;
543 println!(
544 "ā
Epoch {} completed - Average Loss: {:.4}",
545 epoch + 1,
546 avg_loss
547 );
548
549 if (epoch + 1) % 3 == 0 {
551 println!("š Running evaluation...");
552
553 let (val_images, val_ground_truth) = dataset.generate_batch(batch_size);
555 let val_images_dyn = val_images.into_dyn();
556
557 let (pred_classes, pred_boxes) = model.forward(&val_images_dyn)?;
559 let detections = model.extract_detections(&pred_classes, &pred_boxes, 0.5);
560
561 let map = metrics.calculate_map(&detections, &val_ground_truth);
563 println!("š Validation mAP: {:.4}", map);
564
565 if !detections.is_empty() && !detections[0].is_empty() {
567 println!("šÆ Sample detections:");
568 for (i, detection) in detections[0].iter().enumerate().take(3) {
569 println!(
570 " Detection {}: class={}, conf={:.3}, bbox=({:.1}, {:.1}, {:.1}, {:.1})",
571 i + 1,
572 detection.class_id,
573 detection.confidence,
574 detection.x,
575 detection.y,
576 detection.width,
577 detection.height
578 );
579 }
580 }
581 }
582 }
583
584 println!("\nš Object detection training completed!");
585
586 println!("š¬ Final evaluation...");
588 let (test_images, test_ground_truth) = dataset.generate_batch(8);
589 let test_images_dyn = test_images.into_dyn();
590
591 let (pred_classes, pred_boxes) = model.forward(&test_images_dyn)?;
592 let final_detections = model.extract_detections(&pred_classes, &pred_boxes, 0.3);
593
594 let final_map = metrics.calculate_map(&final_detections, &test_ground_truth);
595 println!("š Final mAP: {:.4}", final_map);
596
597 println!("\nš Performance Analysis:");
599 println!(
600 " - Total parameters: ~{:.1}K",
601 (256 * 8 * 8 * 512
602 + 512 * 256
603 + 256 * config.num_classes * config.max_objects
604 + 256 * 4 * config.max_objects)
605 / 1000
606 );
607 println!(" - Memory efficient: ā
(adaptive pooling used)");
608 println!(" - JIT optimized: ā
");
609
610 let mut detection_stats = HashMap::new();
611 for detections in &final_detections {
612 for detection in detections {
613 *detection_stats.entry(detection.class_id).or_insert(0) += 1;
614 }
615 }
616
617 println!(" - Detections by class:");
618 for (class_id, count) in detection_stats {
619 println!(" Class {}: {} detections", class_id, count);
620 }
621
622 Ok(())
623}
624
625fn main() -> StdResult<()> {
626 println!("šÆ Object Detection Complete Example");
627 println!("=====================================");
628 println!();
629 println!("This example demonstrates:");
630 println!("⢠Building an object detection model with CNN backbone");
631 println!("⢠Synthetic dataset generation with multiple objects");
632 println!("⢠Combined classification and bounding box regression");
633 println!("⢠Object detection specific metrics (IoU, mAP)");
634 println!("⢠JIT compilation for performance optimization");
635 println!();
636
637 train_detection_model()?;
638
639 println!("\nš” Key Concepts Demonstrated:");
640 println!(" š¹ Feature extraction with CNN backbone");
641 println!(" š¹ Multi-task learning (classification + regression)");
642 println!(" š¹ Object detection loss functions");
643 println!(" š¹ IoU-based evaluation metrics");
644 println!(" š¹ Non-maximum suppression concepts");
645 println!(" š¹ Bounding box post-processing");
646 println!();
647 println!("š For production use:");
648 println!(" ⢠Implement anchor-based detection (YOLO, SSD)");
649 println!(" ⢠Add data augmentation (rotation, scaling, cropping)");
650 println!(" ⢠Use pre-trained backbones (ResNet, EfficientNet)");
651 println!(" ⢠Implement proper NMS (Non-Maximum Suppression)");
652 println!(" ⢠Add multi-scale training and testing");
653 println!(" ⢠Use real datasets (COCO, Pascal VOC)");
654
655 Ok(())
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661
662 #[test]
663 fn test_detection_config() {
664 let config = DetectionConfig::default();
665 assert_eq!(config.num_classes, 3);
666 assert_eq!(config.max_objects, 5);
667 assert_eq!(config.input_size, (64, 64));
668 }
669
670 #[test]
671 fn test_bounding_box_iou() {
672 let box1 = BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 1.0);
673 let box2 = BoundingBox::new(5.0, 5.0, 10.0, 10.0, 1, 1.0);
674
675 let iou = box1.iou(&box2);
676 assert!(iou > 0.0 && iou < 1.0);
677
678 let box3 = BoundingBox::new(20.0, 20.0, 10.0, 10.0, 1, 1.0);
680 assert_eq!(box1.iou(&box3), 0.0);
681
682 let box4 = BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 1.0);
684 assert_eq!(box1.iou(&box4), 1.0);
685 }
686
687 #[test]
688 fn test_dataset_generation() {
689 let config = DetectionConfig::default();
690 let mut dataset = DetectionDataset::new(config.clone(), 42);
691
692 let (image, objects) = dataset.generate_sample();
693 assert_eq!(
694 image.shape(),
695 &[3, config.input_size.0, config.input_size.1]
696 );
697 assert!(!objects.is_empty());
698 assert!(objects.len() <= config.max_objects);
699
700 for obj in &objects {
701 assert!(obj.class_id > 0 && obj.class_id < config.num_classes);
702 assert!(obj.x >= 0.0 && obj.y >= 0.0);
703 assert!(obj.width > 0.0 && obj.height > 0.0);
704 }
705 }
706
707 #[test]
708 fn test_detection_metrics() {
709 let metrics = DetectionMetrics::new(0.5, 0.5);
710
711 let pred = vec![BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 0.9)];
713 let gt = vec![BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 1.0)];
714
715 let precision = metrics.calculate_precision(&pred, >);
716 assert_eq!(precision, 1.0);
717
718 let pred2 = vec![BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 0.9)];
720 let gt2 = vec![BoundingBox::new(0.0, 0.0, 10.0, 10.0, 2, 1.0)];
721
722 let precision2 = metrics.calculate_precision(&pred2, >2);
723 assert_eq!(precision2, 0.0);
724 }
725
726 #[test]
727 fn test_model_creation() -> StdResult<()> {
728 let mut rng = SmallRng::seed_from_u64(42);
729 let config = DetectionConfig::default();
730
731 let model = ObjectDetectionModel::new(config, &mut rng)?;
732
733 let batch_size = 2;
735 let input = Array4::<f32>::ones((batch_size, 3, 64, 64)).into_dyn();
736 let (classes, boxes) = model.forward(&input)?;
737
738 assert_eq!(classes.shape()[0], batch_size);
739 assert_eq!(boxes.shape()[0], batch_size);
740
741 Ok(())
742 }
743}