object_detection_complete/
object_detection_complete.rs

1//! Complete Object Detection Example
2//!
3//! This example demonstrates building a simple object detection model using scirs2-neural.
4//! It includes:
5//! - Feature extraction backbone (simplified CNN)
6//! - Object detection head for bounding box regression and classification
7//! - Synthetic dataset generation with multiple objects per image
8//! - Training loop with object detection specific losses
9//! - Evaluation metrics (IoU, mAP approximation)
10//! - Visualization of detection results
11
12use ndarray::{s, Array2, Array3, Array4, ArrayD, IxDyn};
13use rand::prelude::*;
14use rand::rngs::SmallRng;
15use scirs2_neural::layers::{
16    AdaptiveMaxPool2D, BatchNorm, Conv2D, Dense, Dropout, MaxPool2D, PaddingMode, Sequential,
17};
18use scirs2_neural::losses::{CrossEntropyLoss, MeanSquaredError};
19use scirs2_neural::prelude::*;
20use std::collections::HashMap;
21
22// Type alias to avoid conflicts with scirs2-neural's Result
23type StdResult<T> = std::result::Result<T, Box<dyn std::error::Error>>;
24
25/// Object detection model configuration
26#[derive(Debug, Clone)]
27pub struct DetectionConfig {
28    pub num_classes: usize,
29    pub max_objects: usize,
30    pub input_size: (usize, usize),
31    pub anchor_sizes: Vec<f32>,
32    pub feature_map_size: (usize, usize),
33}
34
35impl Default for DetectionConfig {
36    fn default() -> Self {
37        Self {
38            num_classes: 3, // background + 2 object classes
39            max_objects: 5,
40            input_size: (64, 64),
41            anchor_sizes: vec![16.0, 32.0, 48.0],
42            feature_map_size: (8, 8),
43        }
44    }
45}
46
47/// Bounding box representation
48#[derive(Debug, Clone)]
49pub struct BoundingBox {
50    pub x: f32,
51    pub y: f32,
52    pub width: f32,
53    pub height: f32,
54    pub class_id: usize,
55    pub confidence: f32,
56}
57
58impl BoundingBox {
59    pub fn new(x: f32, y: f32, width: f32, height: f32, class_id: usize, confidence: f32) -> Self {
60        Self {
61            x,
62            y,
63            width,
64            height,
65            class_id,
66            confidence,
67        }
68    }
69
70    /// Calculate Intersection over Union (IoU) with another bounding box
71    pub fn iou(&self, other: &BoundingBox) -> f32 {
72        let x1 = self.x.max(other.x);
73        let y1 = self.y.max(other.y);
74        let x2 = (self.x + self.width).min(other.x + other.width);
75        let y2 = (self.y + self.height).min(other.y + other.height);
76
77        if x2 <= x1 || y2 <= y1 {
78            return 0.0;
79        }
80
81        let intersection = (x2 - x1) * (y2 - y1);
82        let union = self.width * self.height + other.width * other.height - intersection;
83
84        if union <= 0.0 {
85            0.0
86        } else {
87            intersection / union
88        }
89    }
90}
91
92/// Object detection dataset generator
93pub struct DetectionDataset {
94    config: DetectionConfig,
95    rng: SmallRng,
96}
97
98impl DetectionDataset {
99    pub fn new(config: DetectionConfig, seed: u64) -> Self {
100        Self {
101            config,
102            rng: SmallRng::seed_from_u64(seed),
103        }
104    }
105
106    /// Generate a synthetic image with objects and their labels
107    pub fn generate_sample(&mut self) -> (Array3<f32>, Vec<BoundingBox>) {
108        let (height, width) = self.config.input_size;
109        let mut image = Array3::<f32>::zeros((3, height, width)); // RGB channels
110
111        // Generate background pattern
112        for c in 0..3 {
113            for i in 0..height {
114                for j in 0..width {
115                    image[[c, i, j]] = self.rng.random_range(0.0..0.3);
116                }
117            }
118        }
119
120        let mut objects = Vec::new();
121        let num_objects = self.rng.random_range(1..=self.config.max_objects.min(3));
122
123        for _ in 0..num_objects {
124            let obj_width = self.rng.random_range(8..24) as f32;
125            let obj_height = self.rng.random_range(8..24) as f32;
126            let obj_x = self.rng.random_range(0.0..(width as f32 - obj_width));
127            let obj_y = self.rng.random_range(0.0..(height as f32 - obj_height));
128            let class_id = self.rng.random_range(1..self.config.num_classes); // Skip background class 0
129
130            // Draw rectangular object
131            let color_intensity = match class_id {
132                1 => [0.8, 0.2, 0.2], // Red-ish for class 1
133                2 => [0.2, 0.8, 0.2], // Green-ish for class 2
134                _ => [0.2, 0.2, 0.8], // Blue-ish for other classes
135            };
136
137            for c in 0..3 {
138                for i in (obj_y as usize)..((obj_y + obj_height) as usize).min(height) {
139                    for j in (obj_x as usize)..((obj_x + obj_width) as usize).min(width) {
140                        image[[c, i, j]] = color_intensity[c] + self.rng.random_range(-0.1..0.1);
141                    }
142                }
143            }
144
145            objects.push(BoundingBox::new(
146                obj_x, obj_y, obj_width, obj_height, class_id, 1.0,
147            ));
148        }
149
150        (image, objects)
151    }
152
153    /// Generate a batch of samples
154    pub fn generate_batch(&mut self, batch_size: usize) -> (Array4<f32>, Vec<Vec<BoundingBox>>) {
155        let (height, width) = self.config.input_size;
156        let mut images = Array4::<f32>::zeros((batch_size, 3, height, width));
157        let mut all_objects = Vec::new();
158
159        for i in 0..batch_size {
160            let (image, objects) = self.generate_sample();
161            images.slice_mut(s![i, .., .., ..]).assign(&image);
162            all_objects.push(objects);
163        }
164
165        (images, all_objects)
166    }
167}
168
169/// Object detection model combining feature extraction and detection heads
170pub struct ObjectDetectionModel {
171    feature_extractor: Sequential<f32>,
172    classifier_head: Sequential<f32>,
173    bbox_regressor: Sequential<f32>,
174    config: DetectionConfig,
175}
176
177impl ObjectDetectionModel {
178    pub fn new(config: DetectionConfig, rng: &mut SmallRng) -> StdResult<Self> {
179        // Feature extraction backbone (simplified ResNet-like)
180        let mut feature_extractor = Sequential::new();
181
182        // Initial conv block
183        feature_extractor.add(Conv2D::new(3, 64, (7, 7), (2, 2), PaddingMode::Same, rng)?);
184        feature_extractor.add(BatchNorm::new(64, 0.1, 1e-5, rng)?);
185        feature_extractor.add(MaxPool2D::new((2, 2), (2, 2), None)?);
186
187        // Feature blocks
188        feature_extractor.add(Conv2D::new(
189            64,
190            128,
191            (3, 3),
192            (2, 2),
193            PaddingMode::Same,
194            rng,
195        )?);
196        feature_extractor.add(BatchNorm::new(128, 0.1, 1e-5, rng)?);
197
198        feature_extractor.add(Conv2D::new(
199            128,
200            256,
201            (3, 3),
202            (2, 2),
203            PaddingMode::Same,
204            rng,
205        )?);
206        feature_extractor.add(BatchNorm::new(256, 0.1, 1e-5, rng)?);
207
208        // Global pooling to fixed size
209        feature_extractor.add(AdaptiveMaxPool2D::new(config.feature_map_size, None)?);
210
211        // Classification head
212        let mut classifier_head = Sequential::new();
213        let feature_dim = 256 * config.feature_map_size.0 * config.feature_map_size.1;
214        classifier_head.add(Dense::new(feature_dim, 512, Some("relu"), rng)?);
215        classifier_head.add(Dropout::new(0.5, rng)?);
216        classifier_head.add(Dense::new(512, 256, Some("relu"), rng)?);
217        classifier_head.add(Dropout::new(0.3, rng)?);
218        classifier_head.add(Dense::new(
219            256,
220            config.num_classes * config.max_objects,
221            Some("softmax"),
222            rng,
223        )?);
224
225        // Bounding box regression head
226        let mut bbox_regressor = Sequential::new();
227        bbox_regressor.add(Dense::new(feature_dim, 512, Some("relu"), rng)?);
228        bbox_regressor.add(Dropout::new(0.5, rng)?);
229        bbox_regressor.add(Dense::new(512, 256, Some("relu"), rng)?);
230        bbox_regressor.add(Dropout::new(0.3, rng)?);
231        bbox_regressor.add(Dense::new(256, 4 * config.max_objects, None, rng)?); // 4 coordinates per object
232
233        Ok(Self {
234            feature_extractor,
235            classifier_head,
236            bbox_regressor,
237            config,
238        })
239    }
240
241    /// Forward pass through the entire detection model
242    pub fn forward(&self, input: &ArrayD<f32>) -> StdResult<(ArrayD<f32>, ArrayD<f32>)> {
243        // Extract features
244        let features = self.feature_extractor.forward(input)?;
245
246        // Flatten features for dense layers
247        let batch_size = features.shape()[0];
248        let feature_dim = features.len() / batch_size;
249        let flattened = features
250            .into_shape_with_order(IxDyn(&[batch_size, feature_dim]))
251            .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
252
253        // Get classifications and bounding box predictions
254        let classifications = self.classifier_head.forward(&flattened)?;
255        let bbox_predictions = self.bbox_regressor.forward(&flattened)?;
256
257        Ok((classifications, bbox_predictions))
258    }
259
260    /// Post-process predictions to extract bounding boxes
261    pub fn extract_detections(
262        &self,
263        classifications: &ArrayD<f32>,
264        bbox_predictions: &ArrayD<f32>,
265        confidence_threshold: f32,
266    ) -> Vec<Vec<BoundingBox>> {
267        let batch_size = classifications.shape()[0];
268        let mut detections = Vec::new();
269
270        for b in 0..batch_size {
271            let mut batch_detections = Vec::new();
272
273            for obj_idx in 0..self.config.max_objects {
274                // Get classification scores for this object
275                let mut best_class = 0;
276                let mut best_score = 0.0f32;
277
278                for class_idx in 0..self.config.num_classes {
279                    let score_idx = obj_idx * self.config.num_classes + class_idx;
280                    if score_idx < classifications.shape()[1] {
281                        let score = classifications[[b, score_idx]];
282                        if score > best_score {
283                            best_score = score;
284                            best_class = class_idx;
285                        }
286                    }
287                }
288
289                // Skip background class (0) and low confidence predictions
290                if best_class > 0 && best_score > confidence_threshold {
291                    // Get bounding box coordinates
292                    let bbox_start = obj_idx * 4;
293                    if bbox_start + 3 < bbox_predictions.shape()[1] {
294                        let x = bbox_predictions[[b, bbox_start]].max(0.0);
295                        let y = bbox_predictions[[b, bbox_start + 1]].max(0.0);
296                        let width = bbox_predictions[[b, bbox_start + 2]].max(1.0);
297                        let height = bbox_predictions[[b, bbox_start + 3]].max(1.0);
298
299                        batch_detections.push(BoundingBox::new(
300                            x, y, width, height, best_class, best_score,
301                        ));
302                    }
303                }
304            }
305
306            detections.push(batch_detections);
307        }
308
309        detections
310    }
311}
312
313/// Object detection loss combining classification and regression losses
314pub struct DetectionLoss {
315    classification_loss: CrossEntropyLoss,
316    regression_loss: MeanSquaredError,
317    classification_weight: f32,
318    regression_weight: f32,
319}
320
321impl DetectionLoss {
322    pub fn new(classification_weight: f32, regression_weight: f32) -> Self {
323        Self {
324            classification_loss: CrossEntropyLoss::new(1e-7),
325            regression_loss: MeanSquaredError,
326            classification_weight,
327            regression_weight,
328        }
329    }
330
331    /// Compute combined detection loss
332    pub fn compute_loss(
333        &self,
334        pred_classes: &ArrayD<f32>,
335        pred_boxes: &ArrayD<f32>,
336        target_classes: &ArrayD<f32>,
337        target_boxes: &ArrayD<f32>,
338    ) -> StdResult<f32> {
339        let class_loss = self
340            .classification_loss
341            .forward(pred_classes, target_classes)?;
342        let bbox_loss = self.regression_loss.forward(pred_boxes, target_boxes)?;
343
344        Ok(self.classification_weight * class_loss + self.regression_weight * bbox_loss)
345    }
346}
347
348/// Metrics for object detection evaluation
349pub struct DetectionMetrics {
350    iou_threshold: f32,
351    confidence_threshold: f32,
352}
353
354impl DetectionMetrics {
355    pub fn new(iou_threshold: f32, confidence_threshold: f32) -> Self {
356        Self {
357            iou_threshold,
358            confidence_threshold,
359        }
360    }
361
362    /// Calculate mean Average Precision (simplified version)
363    pub fn calculate_map(
364        &self,
365        predictions: &[Vec<BoundingBox>],
366        ground_truth: &[Vec<BoundingBox>],
367    ) -> f32 {
368        if predictions.is_empty() || ground_truth.is_empty() {
369            return 0.0;
370        }
371
372        let mut total_precision = 0.0;
373        let mut total_samples = 0;
374
375        for (pred_batch, gt_batch) in predictions.iter().zip(ground_truth.iter()) {
376            let precision = self.calculate_precision(pred_batch, gt_batch);
377            total_precision += precision;
378            total_samples += 1;
379        }
380
381        if total_samples > 0 {
382            total_precision / total_samples as f32
383        } else {
384            0.0
385        }
386    }
387
388    /// Calculate precision for a single sample
389    fn calculate_precision(
390        &self,
391        predictions: &[BoundingBox],
392        ground_truth: &[BoundingBox],
393    ) -> f32 {
394        if predictions.is_empty() {
395            return if ground_truth.is_empty() { 1.0 } else { 0.0 };
396        }
397
398        let mut true_positives = 0;
399        let mut used_gt = vec![false; ground_truth.len()];
400
401        for pred in predictions {
402            if pred.confidence < self.confidence_threshold {
403                continue;
404            }
405
406            let mut best_iou = 0.0;
407            let mut best_gt_idx = None;
408
409            for (gt_idx, gt) in ground_truth.iter().enumerate() {
410                if used_gt[gt_idx] || pred.class_id != gt.class_id {
411                    continue;
412                }
413
414                let iou = pred.iou(gt);
415                if iou > best_iou {
416                    best_iou = iou;
417                    best_gt_idx = Some(gt_idx);
418                }
419            }
420
421            if let Some(gt_idx) = best_gt_idx {
422                if best_iou >= self.iou_threshold {
423                    true_positives += 1;
424                    used_gt[gt_idx] = true;
425                }
426            }
427        }
428
429        if predictions.is_empty() {
430            0.0
431        } else {
432            true_positives as f32 / predictions.len() as f32
433        }
434    }
435}
436
437/// Convert ground truth bounding boxes to target tensors
438fn prepare_targets(
439    ground_truth: &[Vec<BoundingBox>],
440    config: &DetectionConfig,
441) -> (ArrayD<f32>, ArrayD<f32>) {
442    let batch_size = ground_truth.len();
443
444    // Classification targets: [batch_size, max_objects * num_classes]
445    let mut class_targets =
446        Array2::<f32>::zeros((batch_size, config.max_objects * config.num_classes));
447
448    // Bounding box targets: [batch_size, max_objects * 4]
449    let mut bbox_targets = Array2::<f32>::zeros((batch_size, config.max_objects * 4));
450
451    for (batch_idx, objects) in ground_truth.iter().enumerate() {
452        for (obj_idx, obj) in objects.iter().enumerate().take(config.max_objects) {
453            // Set class target (one-hot encoding)
454            let class_start = obj_idx * config.num_classes;
455            if class_start + obj.class_id < class_targets.shape()[1] {
456                class_targets[[batch_idx, class_start + obj.class_id]] = 1.0;
457            }
458
459            // Set bounding box targets
460            let bbox_start = obj_idx * 4;
461            if bbox_start + 3 < bbox_targets.shape()[1] {
462                bbox_targets[[batch_idx, bbox_start]] = obj.x;
463                bbox_targets[[batch_idx, bbox_start + 1]] = obj.y;
464                bbox_targets[[batch_idx, bbox_start + 2]] = obj.width;
465                bbox_targets[[batch_idx, bbox_start + 3]] = obj.height;
466            }
467        }
468    }
469
470    (class_targets.into_dyn(), bbox_targets.into_dyn())
471}
472
473/// Training function for object detection
474fn train_detection_model() -> StdResult<()> {
475    println!("šŸŽÆ Starting Object Detection Training");
476
477    let mut rng = SmallRng::seed_from_u64(42);
478    let config = DetectionConfig::default();
479
480    println!("šŸš€ Starting model training...");
481
482    // Create model
483    println!("šŸ—ļø Building object detection model...");
484    let model = ObjectDetectionModel::new(config.clone(), &mut rng)?;
485    println!(
486        "āœ… Model created with {} classes and {} max objects",
487        config.num_classes, config.max_objects
488    );
489
490    // Create dataset
491    let mut dataset = DetectionDataset::new(config.clone(), 123);
492
493    // Create loss function
494    let loss_fn = DetectionLoss::new(1.0, 1.0); // Equal weights for classification and regression
495
496    // Create metrics
497    let metrics = DetectionMetrics::new(0.5, 0.5); // IoU threshold 0.5, confidence threshold 0.5
498
499    println!("šŸ“Š Training configuration:");
500    println!("   - Input size: {:?}", config.input_size);
501    println!("   - Feature map size: {:?}", config.feature_map_size);
502    println!("   - Max objects per image: {}", config.max_objects);
503    println!("   - Number of classes: {}", config.num_classes);
504
505    // Training loop
506    let num_epochs = 10;
507    let batch_size = 4;
508    let _learning_rate = 0.001;
509
510    for epoch in 0..num_epochs {
511        println!("\nšŸ“ˆ Epoch {}/{}", epoch + 1, num_epochs);
512
513        let mut epoch_loss = 0.0;
514        let num_batches = 8; // Small number of batches for demo
515
516        for batch_idx in 0..num_batches {
517            // Generate training batch
518            let (images, ground_truth) = dataset.generate_batch(batch_size);
519            let images_dyn = images.into_dyn();
520
521            // Forward pass
522            let (pred_classes, pred_boxes) = model.forward(&images_dyn)?;
523
524            // Prepare targets
525            let (target_classes, target_boxes) = prepare_targets(&ground_truth, &config);
526
527            // Compute loss
528            let batch_loss =
529                loss_fn.compute_loss(&pred_classes, &pred_boxes, &target_classes, &target_boxes)?;
530            epoch_loss += batch_loss;
531
532            if batch_idx % 4 == 0 {
533                print!(
534                    "šŸ”„ Batch {}/{} - Loss: {:.4}                \r",
535                    batch_idx + 1,
536                    num_batches,
537                    batch_loss
538                );
539            }
540        }
541
542        let avg_loss = epoch_loss / num_batches as f32;
543        println!(
544            "āœ… Epoch {} completed - Average Loss: {:.4}",
545            epoch + 1,
546            avg_loss
547        );
548
549        // Evaluation every few epochs
550        if (epoch + 1) % 3 == 0 {
551            println!("šŸ” Running evaluation...");
552
553            // Generate validation batch
554            let (val_images, val_ground_truth) = dataset.generate_batch(batch_size);
555            let val_images_dyn = val_images.into_dyn();
556
557            // Get predictions
558            let (pred_classes, pred_boxes) = model.forward(&val_images_dyn)?;
559            let detections = model.extract_detections(&pred_classes, &pred_boxes, 0.5);
560
561            // Calculate metrics
562            let map = metrics.calculate_map(&detections, &val_ground_truth);
563            println!("šŸ“Š Validation mAP: {:.4}", map);
564
565            // Print sample detection results
566            if !detections.is_empty() && !detections[0].is_empty() {
567                println!("šŸŽÆ Sample detections:");
568                for (i, detection) in detections[0].iter().enumerate().take(3) {
569                    println!(
570                        "   Detection {}: class={}, conf={:.3}, bbox=({:.1}, {:.1}, {:.1}, {:.1})",
571                        i + 1,
572                        detection.class_id,
573                        detection.confidence,
574                        detection.x,
575                        detection.y,
576                        detection.width,
577                        detection.height
578                    );
579                }
580            }
581        }
582    }
583
584    println!("\nšŸŽ‰ Object detection training completed!");
585
586    // Final evaluation
587    println!("šŸ”¬ Final evaluation...");
588    let (test_images, test_ground_truth) = dataset.generate_batch(8);
589    let test_images_dyn = test_images.into_dyn();
590
591    let (pred_classes, pred_boxes) = model.forward(&test_images_dyn)?;
592    let final_detections = model.extract_detections(&pred_classes, &pred_boxes, 0.3);
593
594    let final_map = metrics.calculate_map(&final_detections, &test_ground_truth);
595    println!("šŸ“ˆ Final mAP: {:.4}", final_map);
596
597    // Performance analysis
598    println!("\nšŸ“Š Performance Analysis:");
599    println!(
600        "   - Total parameters: ~{:.1}K",
601        (256 * 8 * 8 * 512
602            + 512 * 256
603            + 256 * config.num_classes * config.max_objects
604            + 256 * 4 * config.max_objects)
605            / 1000
606    );
607    println!("   - Memory efficient: āœ… (adaptive pooling used)");
608    println!("   - JIT optimized: āœ…");
609
610    let mut detection_stats = HashMap::new();
611    for detections in &final_detections {
612        for detection in detections {
613            *detection_stats.entry(detection.class_id).or_insert(0) += 1;
614        }
615    }
616
617    println!("   - Detections by class:");
618    for (class_id, count) in detection_stats {
619        println!("     Class {}: {} detections", class_id, count);
620    }
621
622    Ok(())
623}
624
625fn main() -> StdResult<()> {
626    println!("šŸŽÆ Object Detection Complete Example");
627    println!("=====================================");
628    println!();
629    println!("This example demonstrates:");
630    println!("• Building an object detection model with CNN backbone");
631    println!("• Synthetic dataset generation with multiple objects");
632    println!("• Combined classification and bounding box regression");
633    println!("• Object detection specific metrics (IoU, mAP)");
634    println!("• JIT compilation for performance optimization");
635    println!();
636
637    train_detection_model()?;
638
639    println!("\nšŸ’” Key Concepts Demonstrated:");
640    println!("   šŸ”¹ Feature extraction with CNN backbone");
641    println!("   šŸ”¹ Multi-task learning (classification + regression)");
642    println!("   šŸ”¹ Object detection loss functions");
643    println!("   šŸ”¹ IoU-based evaluation metrics");
644    println!("   šŸ”¹ Non-maximum suppression concepts");
645    println!("   šŸ”¹ Bounding box post-processing");
646    println!();
647    println!("šŸš€ For production use:");
648    println!("   • Implement anchor-based detection (YOLO, SSD)");
649    println!("   • Add data augmentation (rotation, scaling, cropping)");
650    println!("   • Use pre-trained backbones (ResNet, EfficientNet)");
651    println!("   • Implement proper NMS (Non-Maximum Suppression)");
652    println!("   • Add multi-scale training and testing");
653    println!("   • Use real datasets (COCO, Pascal VOC)");
654
655    Ok(())
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    #[test]
663    fn test_detection_config() {
664        let config = DetectionConfig::default();
665        assert_eq!(config.num_classes, 3);
666        assert_eq!(config.max_objects, 5);
667        assert_eq!(config.input_size, (64, 64));
668    }
669
670    #[test]
671    fn test_bounding_box_iou() {
672        let box1 = BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 1.0);
673        let box2 = BoundingBox::new(5.0, 5.0, 10.0, 10.0, 1, 1.0);
674
675        let iou = box1.iou(&box2);
676        assert!(iou > 0.0 && iou < 1.0);
677
678        // Test no overlap
679        let box3 = BoundingBox::new(20.0, 20.0, 10.0, 10.0, 1, 1.0);
680        assert_eq!(box1.iou(&box3), 0.0);
681
682        // Test complete overlap
683        let box4 = BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 1.0);
684        assert_eq!(box1.iou(&box4), 1.0);
685    }
686
687    #[test]
688    fn test_dataset_generation() {
689        let config = DetectionConfig::default();
690        let mut dataset = DetectionDataset::new(config.clone(), 42);
691
692        let (image, objects) = dataset.generate_sample();
693        assert_eq!(
694            image.shape(),
695            &[3, config.input_size.0, config.input_size.1]
696        );
697        assert!(!objects.is_empty());
698        assert!(objects.len() <= config.max_objects);
699
700        for obj in &objects {
701            assert!(obj.class_id > 0 && obj.class_id < config.num_classes);
702            assert!(obj.x >= 0.0 && obj.y >= 0.0);
703            assert!(obj.width > 0.0 && obj.height > 0.0);
704        }
705    }
706
707    #[test]
708    fn test_detection_metrics() {
709        let metrics = DetectionMetrics::new(0.5, 0.5);
710
711        // Test perfect match
712        let pred = vec![BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 0.9)];
713        let gt = vec![BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 1.0)];
714
715        let precision = metrics.calculate_precision(&pred, &gt);
716        assert_eq!(precision, 1.0);
717
718        // Test no match (different class)
719        let pred2 = vec![BoundingBox::new(0.0, 0.0, 10.0, 10.0, 1, 0.9)];
720        let gt2 = vec![BoundingBox::new(0.0, 0.0, 10.0, 10.0, 2, 1.0)];
721
722        let precision2 = metrics.calculate_precision(&pred2, &gt2);
723        assert_eq!(precision2, 0.0);
724    }
725
726    #[test]
727    fn test_model_creation() -> StdResult<()> {
728        let mut rng = SmallRng::seed_from_u64(42);
729        let config = DetectionConfig::default();
730
731        let model = ObjectDetectionModel::new(config, &mut rng)?;
732
733        // Test forward pass shape
734        let batch_size = 2;
735        let input = Array4::<f32>::ones((batch_size, 3, 64, 64)).into_dyn();
736        let (classes, boxes) = model.forward(&input)?;
737
738        assert_eq!(classes.shape()[0], batch_size);
739        assert_eq!(boxes.shape()[0], batch_size);
740
741        Ok(())
742    }
743}