Skip to main content

scirs2_ndimage/features/
ml_detection.rs

1//! Machine learning-based feature detection
2//!
3//! This module provides ML-powered feature detection algorithms including
4//! learned edge detectors, keypoint detectors, and semantic feature extraction.
5
6use scirs2_core::ndarray::{Array1, Array2, Array3, Array4, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::HashMap;
9use std::fmt::Debug;
10
11use crate::error::{NdimageError, NdimageResult};
12use crate::filters::{convolve, gaussian_filter};
13use crate::interpolation::{zoom, InterpolationOrder};
14use statrs::statistics::Statistics;
15
16/// Pre-trained model weights for feature detection
17#[derive(Clone, Debug)]
18pub struct FeatureDetectorWeights {
19    /// Convolutional kernels for each layer
20    pub conv_kernels: Vec<Array4<f64>>,
21    /// Bias terms for each layer
22    pub biases: Vec<Array1<f64>>,
23    /// Batch normalization parameters
24    pub bn_params: Option<Vec<BatchNormParams>>,
25}
26
27/// Batch normalization parameters
28#[derive(Clone, Debug)]
29pub struct BatchNormParams {
30    pub mean: Array1<f64>,
31    pub variance: Array1<f64>,
32    pub gamma: Array1<f64>,
33    pub beta: Array1<f64>,
34}
35
36/// Configuration for ML-based feature detection
37#[derive(Clone, Debug)]
38pub struct MLDetectorConfig {
39    /// Number of pyramid levels for multi-scale detection
40    pub pyramid_levels: usize,
41    /// Scale factor between pyramid levels
42    pub scale_factor: f64,
43    /// Non-maximum suppression threshold
44    pub nms_threshold: f64,
45    /// Minimum confidence score for detections
46    pub confidence_threshold: f64,
47    /// Use GPU acceleration if available
48    pub use_gpu: bool,
49}
50
51impl Default for MLDetectorConfig {
52    fn default() -> Self {
53        Self {
54            pyramid_levels: 3,
55            scale_factor: 1.5,
56            nms_threshold: 0.3,
57            confidence_threshold: 0.5,
58            use_gpu: false,
59        }
60    }
61}
62
63/// Learned edge detector using convolutional filters
64pub struct LearnedEdgeDetector {
65    weights: FeatureDetectorWeights,
66    config: MLDetectorConfig,
67}
68
69impl LearnedEdgeDetector {
70    /// Create a new learned edge detector with pre-trained weights
71    pub fn new(weights: Option<FeatureDetectorWeights>, config: Option<MLDetectorConfig>) -> Self {
72        let weights = weights.unwrap_or_else(|| Self::default_weights());
73        let config = config.unwrap_or_default();
74
75        Self { weights, config }
76    }
77
78    /// Get default pre-trained weights (simplified example)
79    fn default_weights() -> FeatureDetectorWeights {
80        // Create learned filters that combine multiple edge detection approaches
81        let mut kernels = Vec::new();
82
83        // Layer 1: Basic edge filters (3x3x1x8)
84        let mut layer1 = Array4::zeros((3, 3, 1, 8));
85
86        // Sobel-like filters
87        layer1
88            .slice_mut(scirs2_core::ndarray::s![.., .., 0, 0])
89            .assign(&scirs2_core::ndarray::arr2(&[
90                [-1.0, 0.0, 1.0],
91                [-2.0, 0.0, 2.0],
92                [-1.0, 0.0, 1.0],
93            ]));
94
95        layer1
96            .slice_mut(scirs2_core::ndarray::s![.., .., 0, 1])
97            .assign(&scirs2_core::ndarray::arr2(&[
98                [-1.0, -2.0, -1.0],
99                [0.0, 0.0, 0.0],
100                [1.0, 2.0, 1.0],
101            ]));
102
103        // Diagonal edge filters
104        layer1
105            .slice_mut(scirs2_core::ndarray::s![.., .., 0, 2])
106            .assign(&scirs2_core::ndarray::arr2(&[
107                [-2.0, -1.0, 0.0],
108                [-1.0, 0.0, 1.0],
109                [0.0, 1.0, 2.0],
110            ]));
111
112        layer1
113            .slice_mut(scirs2_core::ndarray::s![.., .., 0, 3])
114            .assign(&scirs2_core::ndarray::arr2(&[
115                [0.0, -1.0, -2.0],
116                [1.0, 0.0, -1.0],
117                [2.0, 1.0, 0.0],
118            ]));
119
120        // Laplacian-like filters
121        layer1
122            .slice_mut(scirs2_core::ndarray::s![.., .., 0, 4])
123            .assign(&scirs2_core::ndarray::arr2(&[
124                [0.0, -1.0, 0.0],
125                [-1.0, 4.0, -1.0],
126                [0.0, -1.0, 0.0],
127            ]));
128
129        // Corner-like filters
130        layer1
131            .slice_mut(scirs2_core::ndarray::s![.., .., 0, 5])
132            .assign(&scirs2_core::ndarray::arr2(&[
133                [1.0, -2.0, 1.0],
134                [-2.0, 4.0, -2.0],
135                [1.0, -2.0, 1.0],
136            ]));
137
138        // Texture filters
139        layer1
140            .slice_mut(scirs2_core::ndarray::s![.., .., 0, 6])
141            .assign(&scirs2_core::ndarray::arr2(&[
142                [1.0, 0.0, -1.0],
143                [0.0, 0.0, 0.0],
144                [-1.0, 0.0, 1.0],
145            ]));
146
147        layer1
148            .slice_mut(scirs2_core::ndarray::s![.., .., 0, 7])
149            .assign(&scirs2_core::ndarray::arr2(&[
150                [0.0, 1.0, 0.0],
151                [1.0, -4.0, 1.0],
152                [0.0, 1.0, 0.0],
153            ]));
154
155        kernels.push(layer1);
156
157        // Layer 2: Combination filters (3x3x8x4)
158        let mut layer2 = Array4::zeros((3, 3, 8, 4));
159
160        // Learned combinations of previous features
161        for i in 0..4 {
162            for j in 0..8 {
163                let weight = if i == j / 2 { 1.0 } else { 0.1 };
164                layer2
165                    .slice_mut(scirs2_core::ndarray::s![1, 1, j, i])
166                    .fill(weight);
167            }
168        }
169
170        kernels.push(layer2);
171
172        // Biases
173        let biases = vec![Array1::zeros(8), Array1::from_vec(vec![0.1, 0.1, 0.1, 0.1])];
174
175        FeatureDetectorWeights {
176            conv_kernels: kernels,
177            biases,
178            bn_params: None,
179        }
180    }
181
182    /// Detect edges using learned filters
183    pub fn detect_edges<T>(&self, image: &ArrayView2<T>) -> NdimageResult<Array2<f64>>
184    where
185        T: Float + FromPrimitive + Debug + Send + Sync + 'static,
186    {
187        let (height, width) = image.dim();
188
189        // Convert to f64 and normalize
190        let mut features = image.mapv(|x| x.to_f64().unwrap_or(0.0));
191        let max_val = features.iter().cloned().fold(0.0, f64::max);
192        if max_val > 0.0 {
193            features /= max_val;
194        }
195
196        // Add channel dimension
197        let mut features_3d = Array3::zeros((height, width, 1));
198        features_3d
199            .slice_mut(scirs2_core::ndarray::s![.., .., 0])
200            .assign(&features);
201
202        // Apply convolutional layers
203        for (layer_idx, (kernels, bias)) in self
204            .weights
205            .conv_kernels
206            .iter()
207            .zip(self.weights.biases.iter())
208            .enumerate()
209        {
210            let in_channels = features_3d.dim().2;
211            let out_channels = kernels.dim().3;
212            let mut output = Array3::zeros((height, width, out_channels));
213
214            // Apply convolutions
215            for out_ch in 0..out_channels {
216                let mut channel_sum = Array2::zeros((height, width));
217
218                for in_ch in 0..in_channels {
219                    let kernel = kernels.slice(scirs2_core::ndarray::s![.., .., in_ch, out_ch]);
220                    let input = features_3d.slice(scirs2_core::ndarray::s![.., .., in_ch]);
221
222                    // Simple 2D convolution
223                    let conv_result = self.convolve_2d(&input, &kernel)?;
224                    channel_sum += &conv_result;
225                }
226
227                // Add bias and apply ReLU activation
228                channel_sum += bias[out_ch];
229                channel_sum.mapv_inplace(|x| x.max(0.0));
230
231                output
232                    .slice_mut(scirs2_core::ndarray::s![.., .., out_ch])
233                    .assign(&channel_sum);
234            }
235
236            features_3d = output;
237        }
238
239        // Combine all output channels into edge strength
240        let mut edge_map = Array2::zeros((height, width));
241        for ch in 0..features_3d.dim().2 {
242            let channel = features_3d.slice(scirs2_core::ndarray::s![.., .., ch]);
243            edge_map += &(&channel * &channel);
244        }
245        edge_map.mapv_inplace(|x| x.sqrt());
246
247        // Apply non-maximum suppression
248        let suppressed = self.non_max_suppression(&edge_map.view())?;
249
250        Ok(suppressed)
251    }
252
253    /// Simple 2D convolution
254    fn convolve_2d(
255        &self,
256        input: &ArrayView2<f64>,
257        kernel: &ArrayView2<f64>,
258    ) -> NdimageResult<Array2<f64>> {
259        let (h, w) = input.dim();
260        let (kh, kw) = kernel.dim();
261        let pad_h = kh / 2;
262        let pad_w = kw / 2;
263
264        let mut output = Array2::zeros((h, w));
265
266        for i in pad_h..h - pad_h {
267            for j in pad_w..w - pad_w {
268                let mut sum = 0.0;
269
270                for ki in 0..kh {
271                    for kj in 0..kw {
272                        let ii = i + ki - pad_h;
273                        let jj = j + kj - pad_w;
274                        sum += input[[ii, jj]] * kernel[[ki, kj]];
275                    }
276                }
277
278                output[[i, j]] = sum;
279            }
280        }
281
282        Ok(output)
283    }
284
285    /// Non-maximum suppression for edge thinning
286    fn non_max_suppression(&self, edgemap: &ArrayView2<f64>) -> NdimageResult<Array2<f64>> {
287        let (height, width) = edgemap.dim();
288        let mut suppressed = Array2::zeros((height, width));
289
290        // Compute gradients
291        let gx = crate::filters::sobel(&edgemap.to_owned(), 1, None)?; // axis 1 for x-direction
292        let gy = crate::filters::sobel(&edgemap.to_owned(), 0, None)?; // axis 0 for y-direction
293
294        for i in 1..height - 1 {
295            for j in 1..width - 1 {
296                let mag = edgemap[[i, j]];
297
298                if mag < self.config.confidence_threshold {
299                    continue;
300                }
301
302                // Compute gradient direction
303                let angle = gy[[i, j]].atan2(gx[[i, j]]);
304
305                // Discretize to 8 directions
306                let direction = ((angle + std::f64::consts::PI) * 4.0 / std::f64::consts::PI)
307                    .round() as i32
308                    % 8;
309
310                // Check neighbors based on gradient direction
311                let (di1, dj1, di2, dj2) = match direction {
312                    0 | 4 => (0, -1, 0, 1),  // Horizontal
313                    1 | 5 => (-1, -1, 1, 1), // Diagonal
314                    2 | 6 => (-1, 0, 1, 0),  // Vertical
315                    3 | 7 => (-1, 1, 1, -1), // Anti-diagonal
316                    _ => (0, -1, 0, 1),
317                };
318
319                let neighbor1 = edgemap[[(i as i32 + di1) as usize, (j as i32 + dj1) as usize]];
320                let neighbor2 = edgemap[[(i as i32 + di2) as usize, (j as i32 + dj2) as usize]];
321
322                // Keep only if local maximum
323                if mag >= neighbor1 && mag >= neighbor2 {
324                    suppressed[[i, j]] = mag;
325                }
326            }
327        }
328
329        Ok(suppressed)
330    }
331}
332
333/// Keypoint descriptor using learned features
334pub struct LearnedKeypointDescriptor {
335    patch_size: usize,
336    descriptor_size: usize,
337    weights: Array2<f64>,
338}
339
340impl LearnedKeypointDescriptor {
341    /// Create a new learned keypoint descriptor
342    pub fn new(patch_size: usize, descriptor_size: usize) -> Self {
343        // Initialize with random projection matrix (simplified)
344        let weights =
345            Array2::from_shape_fn((descriptor_size, patch_size * patch_size), |(i, j)| {
346                // Simple deterministic "random" weights
347                ((i * 7 + j * 13) % 11) as f64 / 11.0 - 0.5
348            });
349
350        Self {
351            patch_size,
352            descriptor_size,
353            weights,
354        }
355    }
356
357    /// Extract descriptors for given keypoints
358    pub fn extract_descriptors<T>(
359        &self,
360        image: &ArrayView2<T>,
361        keypoints: &[(f64, f64)],
362    ) -> NdimageResult<Vec<Array1<f64>>>
363    where
364        T: Float + FromPrimitive + Debug,
365    {
366        let mut descriptors = Vec::new();
367        let half_patch = self.patch_size / 2;
368
369        for &(x, y) in keypoints {
370            let xi = x.round() as i32;
371            let yi = y.round() as i32;
372
373            // Extract patch
374            let mut patch = Array1::zeros(self.patch_size * self.patch_size);
375            let mut idx = 0;
376
377            for dy in -(half_patch as i32)..=(half_patch as i32) {
378                for dx in -(half_patch as i32)..=(half_patch as i32) {
379                    let px = xi + dx;
380                    let py = yi + dy;
381
382                    if px >= 0 && px < image.dim().1 as i32 && py >= 0 && py < image.dim().0 as i32
383                    {
384                        patch[idx] = image[[py as usize, px as usize]].to_f64().unwrap_or(0.0);
385                    }
386                    idx += 1;
387                }
388            }
389
390            // Normalize patch
391            let mean = patch.clone().mean();
392            let std = patch.std(0.0);
393            if std > 0.0 {
394                patch = (patch - mean) / std;
395            }
396
397            // Apply learned projection
398            let descriptor = self.weights.dot(&patch);
399
400            // L2 normalize descriptor
401            let norm = descriptor.dot(&descriptor).sqrt();
402            let descriptor = if norm > 0.0 {
403                descriptor / norm
404            } else {
405                descriptor
406            };
407
408            descriptors.push(descriptor);
409        }
410
411        Ok(descriptors)
412    }
413}
414
415/// Semantic feature extractor using pre-trained deep features
416pub struct SemanticFeatureExtractor {
417    feature_maps: HashMap<String, Array4<f64>>,
418    config: MLDetectorConfig,
419}
420
421impl SemanticFeatureExtractor {
422    /// Create a new semantic feature extractor
423    pub fn new(config: Option<MLDetectorConfig>) -> Self {
424        Self {
425            feature_maps: HashMap::new(),
426            config: config.unwrap_or_default(),
427        }
428    }
429
430    /// Extract semantic features at multiple scales
431    pub fn extractfeatures<T>(
432        &mut self,
433        image: &ArrayView2<T>,
434        feature_types: &[&str],
435    ) -> NdimageResult<HashMap<String, Array3<f64>>>
436    where
437        T: Float + FromPrimitive + Debug + Send + Sync + 'static,
438    {
439        let mut results = HashMap::new();
440
441        for &feature_type in feature_types {
442            match feature_type {
443                "texture" => {
444                    let texturefeatures = self.extracttexturefeatures(image)?;
445                    results.insert("texture".to_string(), texturefeatures);
446                }
447                "shape" => {
448                    let shapefeatures = self.extractshapefeatures(image)?;
449                    results.insert("shape".to_string(), shapefeatures);
450                }
451                "color" => {
452                    let colorfeatures = self.extract_colorfeatures(image)?;
453                    results.insert("color".to_string(), colorfeatures);
454                }
455                _ => {
456                    return Err(NdimageError::InvalidInput(format!(
457                        "Unknown feature type: {}",
458                        feature_type
459                    )));
460                }
461            }
462        }
463
464        Ok(results)
465    }
466
467    /// Extract texture features using Gabor-like filters
468    fn extracttexturefeatures<T>(&self, image: &ArrayView2<T>) -> NdimageResult<Array3<f64>>
469    where
470        T: Float + FromPrimitive + Debug + Send + Sync + 'static,
471    {
472        let (height, width) = image.dim();
473        let num_orientations = 4;
474        let num_scales = 3;
475        let num_features = num_orientations * num_scales;
476
477        let mut features = Array3::zeros((height, width, num_features));
478        let img_f64 = image.mapv(|x| x.to_f64().unwrap_or(0.0));
479
480        let mut feature_idx = 0;
481        for scale in 0..num_scales {
482            let sigma = 2.0 * (scale + 1) as f64;
483
484            for orientation in 0..num_orientations {
485                let angle = orientation as f64 * std::f64::consts::PI / num_orientations as f64;
486
487                // Create oriented filter (simplified Gabor-like)
488                let filter_size = (sigma * 3.0) as usize | 1; // Ensure odd
489                let mut filter = Array2::zeros((filter_size, filter_size));
490                let center = filter_size / 2;
491
492                for i in 0..filter_size {
493                    for j in 0..filter_size {
494                        let x = (j as f64 - center as f64) * angle.cos()
495                            + (i as f64 - center as f64) * angle.sin();
496                        let y = -(j as f64 - center as f64) * angle.sin()
497                            + (i as f64 - center as f64) * angle.cos();
498
499                        let gaussian = (-0.5 * (x * x + y * y) / (sigma * sigma)).exp();
500                        let sinusoid = (2.0 * std::f64::consts::PI * x / (sigma * 2.0)).cos();
501
502                        filter[[i, j]] = gaussian * sinusoid;
503                    }
504                }
505
506                // Normalize filter
507                let sum: f64 = filter.iter().map(|x| x.abs()).sum();
508                if sum > 0.0 {
509                    filter /= sum;
510                }
511
512                // Apply filter
513                let response = convolve(&img_f64, &filter, None)?;
514                features
515                    .slice_mut(scirs2_core::ndarray::s![.., .., feature_idx])
516                    .assign(&response);
517
518                feature_idx += 1;
519            }
520        }
521
522        Ok(features)
523    }
524
525    /// Extract shape features using morphological operations
526    fn extractshapefeatures<T>(&self, image: &ArrayView2<T>) -> NdimageResult<Array3<f64>>
527    where
528        T: Float + FromPrimitive + Debug + Send + Sync + 'static,
529    {
530        let (height, width) = image.dim();
531        let mut features = Array3::zeros((height, width, 4));
532
533        // Convert to binary for shape analysis
534        let img_f64 = image.mapv(|x| x.to_f64().unwrap_or(0.0));
535        let threshold = img_f64.clone().mean();
536        let binary = img_f64.mapv(|x| if x > threshold { 1.0 } else { 0.0 });
537
538        // Feature 1: Distance to nearest edge
539        let edges_x = crate::filters::sobel(&binary, 1, None)?; // x-direction gradient
540        let edges_y = crate::filters::sobel(&binary, 0, None)?; // y-direction gradient
541        let edge_magnitude = (edges_x.mapv(|x| x * x) + edges_y.mapv(|x| x * x)).mapv(|x| x.sqrt());
542        features
543            .slice_mut(scirs2_core::ndarray::s![.., .., 0])
544            .assign(&edge_magnitude);
545
546        // Feature 2: Local curvature (using Laplacian)
547        let curvature = crate::filters::laplace(&binary, None, None)?;
548        features
549            .slice_mut(scirs2_core::ndarray::s![.., .., 1])
550            .assign(&curvature.mapv(|x| x.abs()));
551
552        // Feature 3: Local thickness (simplified)
553        let smoothed = gaussian_filter(&binary, 3.0, None, None)?;
554        features
555            .slice_mut(scirs2_core::ndarray::s![.., .., 2])
556            .assign(&smoothed);
557
558        // Feature 4: Orientation strength
559        let (gx, gy) = (&edges_x, &edges_y);
560        let orientation_strength = gx.mapv(|x| x.abs()) + gy.mapv(|x| x.abs());
561        features
562            .slice_mut(scirs2_core::ndarray::s![.., .., 3])
563            .assign(&orientation_strength);
564
565        Ok(features)
566    }
567
568    /// Extract color-based features (for grayscale, extract intensity features)
569    fn extract_colorfeatures<T>(&self, image: &ArrayView2<T>) -> NdimageResult<Array3<f64>>
570    where
571        T: Float + FromPrimitive + Debug + Send + Sync + 'static,
572    {
573        let (height, width) = image.dim();
574        let mut features = Array3::zeros((height, width, 3));
575
576        let img_f64 = image.mapv(|x| x.to_f64().unwrap_or(0.0));
577
578        // Feature 1: Normalized intensity
579        let max_val = img_f64.iter().cloned().fold(0.0, f64::max);
580        let normalized = if max_val > 0.0 {
581            &img_f64 / max_val
582        } else {
583            img_f64.clone()
584        };
585        features
586            .slice_mut(scirs2_core::ndarray::s![.., .., 0])
587            .assign(&normalized);
588
589        // Feature 2: Local contrast
590        let window_size = 5;
591        let mut contrast = Array2::zeros((height, width));
592
593        for i in window_size / 2..height - window_size / 2 {
594            for j in window_size / 2..width - window_size / 2 {
595                let window = img_f64.slice(scirs2_core::ndarray::s![
596                    i - window_size / 2..=i + window_size / 2,
597                    j - window_size / 2..=j + window_size / 2
598                ]);
599
600                let local_mean = window.mean();
601                let local_std = window.std(0.0);
602                let epsilon = T::from_f64(1e-6).unwrap_or_else(|| T::zero());
603                contrast[[i, j]] = local_std / (local_mean + epsilon.to_f64().unwrap_or(1e-6));
604            }
605        }
606        features
607            .slice_mut(scirs2_core::ndarray::s![.., .., 1])
608            .assign(&contrast);
609
610        // Feature 3: Local entropy (simplified)
611        let mut entropy = Array2::zeros((height, width));
612
613        for i in window_size / 2..height - window_size / 2 {
614            for j in window_size / 2..width - window_size / 2 {
615                let window = img_f64.slice(scirs2_core::ndarray::s![
616                    i - window_size / 2..=i + window_size / 2,
617                    j - window_size / 2..=j + window_size / 2
618                ]);
619
620                // Simple entropy approximation
621                let variance = window.variance();
622                entropy[[i, j]] = (1.0 + variance).ln();
623            }
624        }
625        features
626            .slice_mut(scirs2_core::ndarray::s![.., .., 2])
627            .assign(&entropy);
628
629        Ok(features)
630    }
631}
632
633/// Object proposal generator using learned objectness scores
634pub struct ObjectProposalGenerator {
635    min_size: usize,
636    max_size: usize,
637    stride: usize,
638    aspect_ratios: Vec<f64>,
639    config: MLDetectorConfig,
640}
641
642impl ObjectProposalGenerator {
643    /// Create a new object proposal generator
644    pub fn new(config: Option<MLDetectorConfig>) -> Self {
645        Self {
646            min_size: 16,
647            max_size: 256,
648            stride: 8,
649            aspect_ratios: vec![0.5, 1.0, 2.0],
650            config: config.unwrap_or_default(),
651        }
652    }
653
654    /// Generate object proposals with objectness scores
655    pub fn generate_proposals<T>(
656        &self,
657        image: &ArrayView2<T>,
658        edge_map: Option<&ArrayView2<f64>>,
659    ) -> NdimageResult<Vec<ObjectProposal>>
660    where
661        T: Float + FromPrimitive + Debug + Send + Sync + 'static,
662    {
663        let (height, width) = image.dim();
664        let mut proposals = Vec::new();
665
666        // Compute edge _map if not provided
667        let edge_detector = LearnedEdgeDetector::new(None, None);
668        let edges = if let Some(e) = edge_map {
669            e.to_owned()
670        } else {
671            edge_detector.detect_edges(image)?
672        };
673
674        // Generate proposals at multiple scales
675        for scale in 0..self.config.pyramid_levels {
676            let scale_factor = self.config.scale_factor.powi(scale as i32);
677
678            // Resize edge _map for current scale
679            let scaled_height = ((height as f64) / scale_factor) as usize;
680            let scaled_width = ((width as f64) / scale_factor) as usize;
681
682            if scaled_height < self.min_size || scaled_width < self.min_size {
683                continue;
684            }
685
686            let scaled_edges = zoom(
687                &edges,
688                1.0 / scale_factor,
689                Some(InterpolationOrder::Linear),
690                None,
691                None,
692                None,
693            )?;
694
695            // Sliding window with multiple sizes and aspect ratios
696            for box_size in (self.min_size..=self.max_size).step_by(self.stride * 2) {
697                for &aspect_ratio in &self.aspect_ratios {
698                    let box_width = (box_size as f64 * aspect_ratio.sqrt()) as usize;
699                    let box_height = (box_size as f64 / aspect_ratio.sqrt()) as usize;
700
701                    if box_width > scaled_width || box_height > scaled_height {
702                        continue;
703                    }
704
705                    for y in (0..=scaled_height - box_height).step_by(self.stride) {
706                        for x in (0..=scaled_width - box_width).step_by(self.stride) {
707                            // Compute objectness score
708                            let roi = scaled_edges.slice(scirs2_core::ndarray::s![
709                                y..y + box_height,
710                                x..x + box_width
711                            ]);
712
713                            let objectness = self.compute_objectness(&roi);
714
715                            if objectness > self.config.confidence_threshold {
716                                // Convert back to original coordinates
717                                let proposal = ObjectProposal {
718                                    x: (x as f64 * scale_factor) as usize,
719                                    y: (y as f64 * scale_factor) as usize,
720                                    width: (box_width as f64 * scale_factor) as usize,
721                                    height: (box_height as f64 * scale_factor) as usize,
722                                    score: objectness,
723                                    scale,
724                                };
725
726                                proposals.push(proposal);
727                            }
728                        }
729                    }
730                }
731            }
732        }
733
734        // Apply non-maximum suppression
735        let filtered_proposals = self.non_max_suppression_boxes(&mut proposals);
736
737        Ok(filtered_proposals)
738    }
739
740    /// Compute objectness score for a region
741    fn compute_objectness(&self, roi: &ArrayView2<f64>) -> f64 {
742        // Simple objectness based on edge density and distribution
743        let edge_sum: f64 = roi.sum();
744        let num_pixels = (roi.dim().0 * roi.dim().1) as f64;
745        let edge_density = edge_sum / num_pixels;
746
747        // Check edge distribution (prefer edges near boundaries)
748        let (h, w) = roi.dim();
749        let border_width = 3;
750
751        let mut border_sum = 0.0;
752        let mut border_pixels = 0;
753
754        for i in 0..h {
755            for j in 0..w {
756                if i < border_width
757                    || i >= h - border_width
758                    || j < border_width
759                    || j >= w - border_width
760                {
761                    border_sum += roi[[i, j]];
762                    border_pixels += 1;
763                }
764            }
765        }
766
767        let border_density = if border_pixels > 0 {
768            border_sum / border_pixels as f64
769        } else {
770            0.0
771        };
772
773        // Combine scores
774        let objectness = edge_density * 0.3 + border_density * 0.7;
775
776        objectness.min(1.0)
777    }
778
779    /// Non-maximum suppression for object proposals
780    fn non_max_suppression_boxes(
781        &self,
782        proposals: &mut Vec<ObjectProposal>,
783    ) -> Vec<ObjectProposal> {
784        // Sort by score in descending order
785        proposals.sort_by(|a, b| b.score.partial_cmp(&a.score).expect("Operation failed"));
786
787        let mut keep = Vec::new();
788        let mut suppressed = vec![false; proposals.len()];
789
790        for i in 0..proposals.len() {
791            if suppressed[i] {
792                continue;
793            }
794
795            keep.push(proposals[i].clone());
796
797            // Suppress overlapping proposals
798            for j in i + 1..proposals.len() {
799                if suppressed[j] {
800                    continue;
801                }
802
803                let iou = self.compute_iou(&proposals[i], &proposals[j]);
804                if iou > self.config.nms_threshold {
805                    suppressed[j] = true;
806                }
807            }
808        }
809
810        keep
811    }
812
813    /// Compute intersection over union for two boxes
814    fn compute_iou(&self, box1: &ObjectProposal, box2: &ObjectProposal) -> f64 {
815        let x1 = box1.x.max(box2.x);
816        let y1 = box1.y.max(box2.y);
817        let x2 = (box1.x + box1.width).min(box2.x + box2.width);
818        let y2 = (box1.y + box1.height).min(box2.y + box2.height);
819
820        if x2 <= x1 || y2 <= y1 {
821            return 0.0;
822        }
823
824        let intersection = (x2 - x1) * (y2 - y1);
825        let area1 = box1.width * box1.height;
826        let area2 = box2.width * box2.height;
827        let union = area1 + area2 - intersection;
828
829        intersection as f64 / union as f64
830    }
831}
832
833/// Object proposal with location and score
834#[derive(Clone, Debug)]
835pub struct ObjectProposal {
836    pub x: usize,
837    pub y: usize,
838    pub width: usize,
839    pub height: usize,
840    pub score: f64,
841    pub scale: usize,
842}
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847    use scirs2_core::ndarray::arr2;
848
849    #[test]
850    fn test_learned_edge_detector() {
851        // Create a simple test image
852        let image = arr2(&[
853            [0.0, 0.0, 1.0, 1.0],
854            [0.0, 0.0, 1.0, 1.0],
855            [1.0, 1.0, 0.0, 0.0],
856            [1.0, 1.0, 0.0, 0.0],
857        ]);
858
859        let detector = LearnedEdgeDetector::new(None, None);
860        let edges = detector
861            .detect_edges(&image.view())
862            .expect("Operation failed");
863
864        assert_eq!(edges.dim(), image.dim());
865        // Should detect edges at boundaries
866        assert!(edges[[1, 2]] > 0.0 || edges[[2, 1]] > 0.0);
867    }
868
869    #[test]
870    fn test_keypoint_descriptor() {
871        let image = arr2(&[
872            [0.1, 0.2, 0.3, 0.4, 0.5],
873            [0.2, 0.3, 0.4, 0.5, 0.6],
874            [0.3, 0.4, 0.5, 0.6, 0.7],
875            [0.4, 0.5, 0.6, 0.7, 0.8],
876            [0.5, 0.6, 0.7, 0.8, 0.9],
877        ]);
878
879        let descriptor = LearnedKeypointDescriptor::new(3, 16);
880        let keypoints = vec![(2.0, 2.0)];
881
882        let descriptors = descriptor
883            .extract_descriptors(&image.view(), &keypoints)
884            .expect("Operation failed");
885
886        assert_eq!(descriptors.len(), 1);
887        assert_eq!(descriptors[0].len(), 16);
888
889        // Check normalization
890        let norm = descriptors[0].dot(&descriptors[0]).sqrt();
891        assert!((norm - 1.0).abs() < 1e-6);
892    }
893
894    #[test]
895    fn test_semantic_feature_extractor() {
896        let image = arr2(&[
897            [0.0, 0.0, 1.0, 1.0],
898            [0.0, 0.0, 1.0, 1.0],
899            [1.0, 1.0, 0.0, 0.0],
900            [1.0, 1.0, 0.0, 0.0],
901        ]);
902
903        let mut extractor = SemanticFeatureExtractor::new(None);
904        let features = extractor
905            .extractfeatures(&image.view(), &["texture", "shape", "color"])
906            .expect("Operation failed");
907
908        assert_eq!(features.len(), 3);
909        assert!(features.contains_key("texture"));
910        assert!(features.contains_key("shape"));
911        assert!(features.contains_key("color"));
912
913        let texturefeatures = &features["texture"];
914        assert_eq!(texturefeatures.dim().0, 4);
915        assert_eq!(texturefeatures.dim().1, 4);
916        assert!(texturefeatures.dim().2 > 0);
917    }
918
919    #[test]
920    fn test_object_proposal_generator() {
921        let mut image = Array2::zeros((50, 50));
922
923        // Create a simple rectangle
924        for i in 10..30 {
925            for j in 15..35 {
926                image[[i, j]] = 1.0;
927            }
928        }
929
930        let generator = ObjectProposalGenerator::new(None);
931        let proposals = generator
932            .generate_proposals(&image.view(), None)
933            .expect("Operation failed");
934
935        assert!(!proposals.is_empty());
936
937        // Check that proposals have valid dimensions
938        for proposal in &proposals {
939            assert!(proposal.x + proposal.width <= 50);
940            assert!(proposal.y + proposal.height <= 50);
941            assert!(proposal.score >= 0.0 && proposal.score <= 1.0);
942        }
943    }
944}