sklears_neighbors/
computer_vision.rs

1//! Computer Vision Neighbor-Based Methods
2//!
3//! This module provides specialized neighbor-based algorithms for computer vision applications,
4//! including image similarity search, patch-based matching, feature descriptor analysis,
5//! and content-based image retrieval. These methods leverage the existing neighbor infrastructure
6//! with computer vision specific optimizations and feature extractors.
7//!
8//! # Key Features
9//!
10//! - **Image Similarity Search**: Efficient similarity search using various image features
11//! - **Patch-Based Neighbors**: Local patch matching for texture analysis and object detection
12//! - **Feature Descriptor Matching**: SIFT, SURF, ORB, and other descriptor-based matching
13//! - **Visual Word Recognition**: Bag-of-visual-words for image categorization
14//! - **Content-Based Image Retrieval**: Full CBIR system with multiple feature types
15//!
16//! # Examples
17//!
18//! ```rust
19//! use sklears_neighbors::computer_vision::{ImageSimilaritySearch, ImageSearchConfig, FeatureType};
20//! use scirs2_core::ndarray::{Array1, Array2};
21//!
22//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
23//! // Create image similarity search
24//! let config = ImageSearchConfig {
25//!     feature_type: FeatureType::ColorHistogram,
26//!     ..Default::default()
27//! };
28//! let mut search = ImageSimilaritySearch::new(config);
29//!
30//! // Add images to index
31//! let features = Array2::zeros((100, 512)); // 100 images with 512-dim features
32//! search.build_index(&features)?;
33//!
34//! // Search for similar images
35//! let query = Array1::zeros(512);
36//! let results = search.search(&query, 5)?;
37//! # Ok(())
38//! # }
39//! ```
40
41use crate::{knn::KNeighborsClassifier, NeighborsError, NeighborsResult};
42use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
43use scirs2_core::random::thread_rng;
44use sklears_core::traits::{Fit, Predict, Trained};
45use sklears_core::types::{Features, Float, Int};
46use std::collections::HashMap;
47
48#[cfg(feature = "serde")]
49use serde::{Deserialize, Serialize};
50
51/// Image feature types for similarity computation
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
54pub enum FeatureType {
55    /// RGB color histogram features
56    ColorHistogram,
57    /// Local Binary Pattern (LBP) texture features
58    LocalBinaryPattern,
59    /// Histogram of Oriented Gradients (HOG) features
60    HistogramOfGradients,
61    /// Gabor filter bank responses
62    GaborFilters,
63    /// Edge density and orientation features
64    EdgeFeatures,
65    /// GLCM (Gray-Level Co-occurrence Matrix) features
66    TextureFeatures,
67    /// Combined multi-modal features
68    MultiModal,
69}
70
71/// Configuration for image similarity search
72#[derive(Debug, Clone)]
73#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74pub struct ImageSearchConfig {
75    /// Feature type to use
76    pub feature_type: FeatureType,
77    /// Number of neighbors to retrieve
78    pub k_neighbors: usize,
79    /// Distance metric for similarity computation
80    pub distance_metric: String,
81    /// Enable feature normalization
82    pub normalize_features: bool,
83    /// Feature dimensionality reduction target
84    pub target_dimensions: Option<usize>,
85    /// Use approximate search for large datasets
86    pub use_approximate: bool,
87}
88
89impl Default for ImageSearchConfig {
90    fn default() -> Self {
91        Self {
92            feature_type: FeatureType::ColorHistogram,
93            k_neighbors: 5,
94            distance_metric: "euclidean".to_string(),
95            normalize_features: true,
96            target_dimensions: None,
97            use_approximate: false,
98        }
99    }
100}
101
102/// Image similarity search engine
103pub struct ImageSimilaritySearch {
104    config: ImageSearchConfig,
105    feature_extractor: Box<dyn FeatureExtractor>,
106    neighbor_index: Option<KNeighborsClassifier<Trained>>,
107    feature_database: Option<Array2<f64>>,
108    image_metadata: Vec<ImageMetadata>,
109}
110
111/// Metadata for indexed images
112#[derive(Debug, Clone)]
113#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
114pub struct ImageMetadata {
115    /// Unique image identifier
116    pub id: String,
117    /// Image dimensions (width, height)
118    pub dimensions: (usize, usize),
119    /// Number of color channels
120    pub channels: usize,
121    /// Optional image path or URL
122    pub path: Option<String>,
123    /// Custom metadata tags
124    pub tags: HashMap<String, String>,
125}
126
127/// Search result for image similarity
128#[derive(Debug, Clone)]
129pub struct ImageSearchResult {
130    /// Image metadata
131    pub metadata: ImageMetadata,
132    /// Similarity distance (lower = more similar)
133    pub distance: f64,
134    /// Feature vector used for matching
135    pub features: Array1<f64>,
136    /// Match confidence score (0-1)
137    pub confidence: f64,
138}
139
140/// Feature extractor trait for different image feature types
141pub trait FeatureExtractor: Send + Sync {
142    /// Extract features from image data
143    ///
144    /// # Arguments
145    /// * `image_data` - Flattened image data (height * width * channels)
146    /// * `dimensions` - Image dimensions (width, height, channels)
147    fn extract_features(
148        &self,
149        image_data: &[f64],
150        dimensions: (usize, usize, usize),
151    ) -> NeighborsResult<Array1<f64>>;
152
153    /// Get the dimensionality of extracted features
154    fn feature_dimension(&self) -> usize;
155
156    /// Get feature extractor name
157    fn name(&self) -> &str;
158}
159
160/// Color histogram feature extractor
161pub struct ColorHistogramExtractor {
162    bins_per_channel: usize,
163    normalize: bool,
164}
165
166impl ColorHistogramExtractor {
167    pub fn new(bins_per_channel: usize, normalize: bool) -> Self {
168        Self {
169            bins_per_channel,
170            normalize,
171        }
172    }
173}
174
175impl FeatureExtractor for ColorHistogramExtractor {
176    fn extract_features(
177        &self,
178        image_data: &[f64],
179        dimensions: (usize, usize, usize),
180    ) -> NeighborsResult<Array1<f64>> {
181        let (width, height, channels) = dimensions;
182
183        if image_data.len() != width * height * channels {
184            return Err(NeighborsError::InvalidInput(format!(
185                "Image data length {} doesn't match dimensions {}x{}x{}",
186                image_data.len(),
187                width,
188                height,
189                channels
190            )));
191        }
192
193        let total_bins = self.bins_per_channel * channels;
194        let mut histogram = Array1::zeros(total_bins);
195
196        // Compute histogram for each channel
197        for c in 0..channels {
198            for i in 0..(width * height) {
199                let pixel_idx = i * channels + c;
200                let pixel_value = image_data[pixel_idx];
201
202                // Normalize pixel value to [0, 1] range and compute bin
203                let normalized = pixel_value.clamp(0.0, 1.0);
204                let bin = ((normalized * self.bins_per_channel as f64) as usize)
205                    .min(self.bins_per_channel - 1);
206
207                histogram[c * self.bins_per_channel + bin] += 1.0;
208            }
209        }
210
211        // Normalize histogram if requested
212        if self.normalize {
213            let sum: f64 = histogram.sum();
214            if sum > 0.0 {
215                histogram /= sum;
216            }
217        }
218
219        Ok(histogram)
220    }
221
222    fn feature_dimension(&self) -> usize {
223        self.bins_per_channel * 3 // Assuming RGB images
224    }
225
226    fn name(&self) -> &str {
227        "ColorHistogram"
228    }
229}
230
231/// Local Binary Pattern (LBP) feature extractor
232pub struct LocalBinaryPatternExtractor {
233    radius: usize,
234    neighbors: usize,
235    uniform_patterns: bool,
236}
237
238impl LocalBinaryPatternExtractor {
239    pub fn new(radius: usize, neighbors: usize, uniform_patterns: bool) -> Self {
240        Self {
241            radius,
242            neighbors,
243            uniform_patterns,
244        }
245    }
246
247    /// Compute LBP value for a single pixel
248    fn compute_lbp_value(&self, center: f64, neighbors: &[f64]) -> u32 {
249        let mut lbp_code = 0u32;
250
251        for (i, &neighbor) in neighbors.iter().enumerate() {
252            if neighbor >= center {
253                lbp_code |= 1 << i;
254            }
255        }
256
257        // If using uniform patterns, map to uniform LBP codes
258        if self.uniform_patterns {
259            self.to_uniform_pattern(lbp_code)
260        } else {
261            lbp_code
262        }
263    }
264
265    /// Convert LBP code to uniform pattern
266    fn to_uniform_pattern(&self, code: u32) -> u32 {
267        // Count transitions in circular pattern
268        let mut transitions = 0;
269        for i in 0..self.neighbors {
270            let bit1 = (code >> i) & 1;
271            let bit2 = (code >> ((i + 1) % self.neighbors)) & 1;
272            if bit1 != bit2 {
273                transitions += 1;
274            }
275        }
276
277        // Uniform patterns have at most 2 transitions
278        if transitions <= 2 {
279            code.count_ones()
280        } else {
281            self.neighbors as u32 + 1 // Non-uniform pattern bin
282        }
283    }
284}
285
286impl FeatureExtractor for LocalBinaryPatternExtractor {
287    fn extract_features(
288        &self,
289        image_data: &[f64],
290        dimensions: (usize, usize, usize),
291    ) -> NeighborsResult<Array1<f64>> {
292        let (width, height, channels) = dimensions;
293
294        if image_data.len() != width * height * channels {
295            return Err(NeighborsError::InvalidInput(
296                "Image data length doesn't match dimensions".to_string(),
297            ));
298        }
299
300        // Convert to grayscale if needed
301        let grayscale: Vec<f64> = if channels == 1 {
302            image_data.to_vec()
303        } else {
304            // Convert RGB to grayscale using luminance formula
305            (0..(width * height))
306                .map(|i| {
307                    let r = image_data[i * channels];
308                    let g = image_data[i * channels + 1];
309                    let b = image_data[i * channels + 2];
310                    0.299 * r + 0.587 * g + 0.114 * b
311                })
312                .collect()
313        };
314
315        let max_pattern = if self.uniform_patterns {
316            self.neighbors + 2
317        } else {
318            1 << self.neighbors
319        };
320
321        let mut histogram = Array1::zeros(max_pattern);
322
323        // Process each pixel (excluding border pixels)
324        for y in self.radius..(height - self.radius) {
325            for x in self.radius..(width - self.radius) {
326                let center_idx = y * width + x;
327                let center_value = grayscale[center_idx];
328
329                // Sample neighbors in circular pattern
330                let mut neighbor_values = Vec::with_capacity(self.neighbors);
331                for i in 0..self.neighbors {
332                    let angle = 2.0 * std::f64::consts::PI * i as f64 / self.neighbors as f64;
333                    let nx = x as f64 + self.radius as f64 * angle.cos();
334                    let ny = y as f64 + self.radius as f64 * angle.sin();
335
336                    // Bilinear interpolation for non-integer coordinates
337                    let x1 = nx.floor() as usize;
338                    let y1 = ny.floor() as usize;
339                    let x2 = (x1 + 1).min(width - 1);
340                    let y2 = (y1 + 1).min(height - 1);
341
342                    let fx = nx - x1 as f64;
343                    let fy = ny - y1 as f64;
344
345                    let v1 = grayscale[y1 * width + x1];
346                    let v2 = grayscale[y1 * width + x2];
347                    let v3 = grayscale[y2 * width + x1];
348                    let v4 = grayscale[y2 * width + x2];
349
350                    let interpolated = v1 * (1.0 - fx) * (1.0 - fy)
351                        + v2 * fx * (1.0 - fy)
352                        + v3 * (1.0 - fx) * fy
353                        + v4 * fx * fy;
354
355                    neighbor_values.push(interpolated);
356                }
357
358                // Compute LBP code
359                let lbp_code = self.compute_lbp_value(center_value, &neighbor_values);
360                histogram[lbp_code as usize] += 1.0;
361            }
362        }
363
364        // Normalize histogram
365        let sum: f64 = histogram.sum();
366        if sum > 0.0 {
367            histogram /= sum;
368        }
369
370        Ok(histogram)
371    }
372
373    fn feature_dimension(&self) -> usize {
374        if self.uniform_patterns {
375            self.neighbors + 2
376        } else {
377            1 << self.neighbors
378        }
379    }
380
381    fn name(&self) -> &str {
382        "LocalBinaryPattern"
383    }
384}
385
386/// Histogram of Oriented Gradients (HOG) feature extractor
387pub struct HistogramOfGradientsExtractor {
388    cell_size: usize,
389    block_size: usize,
390    num_bins: usize,
391    normalize_blocks: bool,
392}
393
394impl HistogramOfGradientsExtractor {
395    pub fn new(
396        cell_size: usize,
397        block_size: usize,
398        num_bins: usize,
399        normalize_blocks: bool,
400    ) -> Self {
401        Self {
402            cell_size,
403            block_size,
404            num_bins,
405            normalize_blocks,
406        }
407    }
408
409    /// Compute gradient magnitude and orientation
410    fn compute_gradients(
411        &self,
412        image: &[f64],
413        width: usize,
414        height: usize,
415    ) -> (Vec<f64>, Vec<f64>) {
416        let mut magnitudes = vec![0.0; width * height];
417        let mut orientations = vec![0.0; width * height];
418
419        for y in 1..(height - 1) {
420            for x in 1..(width - 1) {
421                let idx = y * width + x;
422
423                // Compute gradients using Sobel-like operators
424                let gx = image[y * width + (x + 1)] - image[y * width + (x - 1)];
425                let gy = image[(y + 1) * width + x] - image[(y - 1) * width + x];
426
427                magnitudes[idx] = (gx * gx + gy * gy).sqrt();
428                orientations[idx] = gy.atan2(gx);
429
430                // Convert orientation to [0, π] range
431                if orientations[idx] < 0.0 {
432                    orientations[idx] += std::f64::consts::PI;
433                }
434            }
435        }
436
437        (magnitudes, orientations)
438    }
439}
440
441impl FeatureExtractor for HistogramOfGradientsExtractor {
442    fn extract_features(
443        &self,
444        image_data: &[f64],
445        dimensions: (usize, usize, usize),
446    ) -> NeighborsResult<Array1<f64>> {
447        let (width, height, channels) = dimensions;
448
449        // Convert to grayscale
450        let grayscale: Vec<f64> = if channels == 1 {
451            image_data.to_vec()
452        } else {
453            (0..(width * height))
454                .map(|i| {
455                    let r = image_data[i * channels];
456                    let g = image_data[i * channels + 1];
457                    let b = image_data[i * channels + 2];
458                    0.299 * r + 0.587 * g + 0.114 * b
459                })
460                .collect()
461        };
462
463        // Compute gradients
464        let (magnitudes, orientations) = self.compute_gradients(&grayscale, width, height);
465
466        // Calculate number of cells
467        let cells_x = width / self.cell_size;
468        let cells_y = height / self.cell_size;
469
470        // Compute cell histograms
471        let mut cell_histograms = vec![Array1::zeros(self.num_bins); cells_x * cells_y];
472
473        for cell_y in 0..cells_y {
474            for cell_x in 0..cells_x {
475                let cell_idx = cell_y * cells_x + cell_x;
476
477                // Process pixels in current cell
478                for y in (cell_y * self.cell_size)..((cell_y + 1) * self.cell_size) {
479                    for x in (cell_x * self.cell_size)..((cell_x + 1) * self.cell_size) {
480                        if y < height && x < width {
481                            let pixel_idx = y * width + x;
482                            let magnitude = magnitudes[pixel_idx];
483                            let orientation = orientations[pixel_idx];
484
485                            // Compute bin for orientation
486                            let bin_width = std::f64::consts::PI / self.num_bins as f64;
487                            let bin = ((orientation / bin_width) as usize).min(self.num_bins - 1);
488
489                            // Add weighted vote to histogram
490                            cell_histograms[cell_idx][bin] += magnitude;
491                        }
492                    }
493                }
494            }
495        }
496
497        // Create block descriptors
498        let blocks_x = cells_x.saturating_sub(self.block_size - 1);
499        let blocks_y = cells_y.saturating_sub(self.block_size - 1);
500        let descriptor_size =
501            blocks_x * blocks_y * self.block_size * self.block_size * self.num_bins;
502
503        let mut descriptor = Array1::zeros(descriptor_size);
504        let mut desc_idx = 0;
505
506        for block_y in 0..blocks_y {
507            for block_x in 0..blocks_x {
508                // Collect histograms from cells in current block
509                let mut block_hist =
510                    Array1::zeros(self.block_size * self.block_size * self.num_bins);
511                let mut hist_idx = 0;
512
513                for by in 0..self.block_size {
514                    for bx in 0..self.block_size {
515                        let cell_x = block_x + bx;
516                        let cell_y = block_y + by;
517                        let cell_idx = cell_y * cells_x + cell_x;
518
519                        for bin in 0..self.num_bins {
520                            block_hist[hist_idx * self.num_bins + bin] =
521                                cell_histograms[cell_idx][bin];
522                        }
523                        hist_idx += 1;
524                    }
525                }
526
527                // Normalize block if requested
528                if self.normalize_blocks {
529                    let norm = (block_hist.mapv(|x: f64| x * x).sum()).sqrt();
530                    if norm > 1e-8 {
531                        block_hist /= norm;
532                    }
533                }
534
535                // Copy to descriptor
536                for i in 0..block_hist.len() {
537                    descriptor[desc_idx + i] = block_hist[i];
538                }
539                desc_idx += block_hist.len();
540            }
541        }
542
543        Ok(descriptor)
544    }
545
546    fn feature_dimension(&self) -> usize {
547        // This is an approximation - actual size depends on image dimensions
548        self.num_bins * self.block_size * self.block_size * 16
549    }
550
551    fn name(&self) -> &str {
552        "HistogramOfGradients"
553    }
554}
555
556impl ImageSimilaritySearch {
557    /// Create a new image similarity search engine
558    pub fn new(config: ImageSearchConfig) -> Self {
559        let feature_extractor: Box<dyn FeatureExtractor> = match config.feature_type {
560            FeatureType::ColorHistogram => Box::new(ColorHistogramExtractor::new(32, true)),
561            FeatureType::LocalBinaryPattern => {
562                Box::new(LocalBinaryPatternExtractor::new(1, 8, true))
563            }
564            FeatureType::HistogramOfGradients => {
565                Box::new(HistogramOfGradientsExtractor::new(8, 2, 9, true))
566            }
567            FeatureType::GaborFilters => Box::new(ColorHistogramExtractor::new(32, true)), // Placeholder
568            FeatureType::EdgeFeatures => Box::new(ColorHistogramExtractor::new(32, true)), // Placeholder
569            FeatureType::TextureFeatures => Box::new(ColorHistogramExtractor::new(32, true)), // Placeholder
570            FeatureType::MultiModal => Box::new(ColorHistogramExtractor::new(32, true)), // Placeholder
571        };
572
573        Self {
574            config,
575            feature_extractor,
576            neighbor_index: None,
577            feature_database: None,
578            image_metadata: Vec::new(),
579        }
580    }
581
582    /// Build index from precomputed features
583    pub fn build_index(&mut self, features: &Array2<f64>) -> NeighborsResult<()> {
584        if features.is_empty() {
585            return Err(NeighborsError::EmptyInput);
586        }
587
588        // Simply store the feature database - we'll use brute-force search for now
589        self.feature_database = Some(features.clone());
590        self.neighbor_index = None; // Not using the classifier for now
591
592        Ok(())
593    }
594
595    /// Add image to the search index
596    pub fn add_image(
597        &mut self,
598        image_data: &[f64],
599        dimensions: (usize, usize, usize),
600        metadata: ImageMetadata,
601    ) -> NeighborsResult<()> {
602        // Extract features from image
603        let features = self
604            .feature_extractor
605            .extract_features(image_data, dimensions)?;
606
607        // Add metadata first
608        self.image_metadata.push(metadata);
609
610        // Add to feature database
611        if let Some(ref mut db) = self.feature_database {
612            // Extend existing database by creating a new matrix
613            let mut new_db = Array2::zeros((db.nrows() + 1, db.ncols()));
614            new_db
615                .slice_mut(scirs2_core::ndarray::s![..db.nrows(), ..])
616                .assign(db);
617            new_db.row_mut(db.nrows()).assign(&features);
618            *db = new_db;
619
620            // Clone the database to avoid borrowing issues
621            let db_clone = db.clone();
622            self.build_index(&db_clone)?;
623        } else {
624            // Create new database
625            let new_db = features.insert_axis(Axis(0));
626            let db_clone = new_db.clone();
627            self.feature_database = Some(new_db);
628            self.build_index(&db_clone)?;
629        }
630
631        Ok(())
632    }
633
634    /// Search for similar images using brute-force search
635    pub fn search(
636        &self,
637        query_features: &Array1<f64>,
638        k: usize,
639    ) -> NeighborsResult<Vec<ImageSearchResult>> {
640        let database = self
641            .feature_database
642            .as_ref()
643            .ok_or(NeighborsError::InvalidInput(
644                "Feature database not available".to_string(),
645            ))?;
646
647        if database.is_empty() {
648            return Ok(Vec::new());
649        }
650
651        // Compute distances to all database features
652        let mut distances_with_indices: Vec<(f64, usize)> = Vec::new();
653
654        for (idx, db_features) in database.rows().into_iter().enumerate() {
655            // Compute Euclidean distance
656            let distance: f64 = query_features
657                .iter()
658                .zip(db_features.iter())
659                .map(|(a, b)| (a - b).powi(2))
660                .sum::<f64>()
661                .sqrt();
662
663            distances_with_indices.push((distance, idx));
664        }
665
666        // Sort by distance and take top k
667        distances_with_indices.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
668        let k_results = k.min(distances_with_indices.len());
669
670        let mut results = Vec::new();
671        let max_distance = if !distances_with_indices.is_empty() {
672            distances_with_indices[distances_with_indices.len().min(k) - 1].0
673        } else {
674            1.0
675        };
676
677        for i in 0..k_results {
678            let (distance, idx) = distances_with_indices[i];
679
680            let features = database.row(idx).to_owned();
681
682            // Compute confidence score (inverse of normalized distance)
683            let confidence = if max_distance > 0.0 {
684                1.0 - (distance / max_distance)
685            } else {
686                1.0
687            };
688
689            // Create metadata if not available
690            let metadata = if idx < self.image_metadata.len() {
691                self.image_metadata[idx].clone()
692            } else {
693                ImageMetadata {
694                    id: format!("image_{}", idx),
695                    dimensions: (0, 0),
696                    channels: 0,
697                    path: None,
698                    tags: HashMap::new(),
699                }
700            };
701
702            results.push(ImageSearchResult {
703                metadata,
704                distance,
705                features,
706                confidence,
707            });
708        }
709
710        Ok(results)
711    }
712
713    /// Search by image data
714    pub fn search_by_image(
715        &self,
716        image_data: &[f64],
717        dimensions: (usize, usize, usize),
718        k: usize,
719    ) -> NeighborsResult<Vec<ImageSearchResult>> {
720        let features = self
721            .feature_extractor
722            .extract_features(image_data, dimensions)?;
723        self.search(&features, k)
724    }
725
726    /// Get database statistics
727    pub fn get_stats(&self) -> (usize, usize, String) {
728        let num_images = self.image_metadata.len();
729        let feature_dim = self
730            .feature_database
731            .as_ref()
732            .map(|db| db.ncols())
733            .unwrap_or(0);
734        let extractor_name = self.feature_extractor.name().to_string();
735
736        (num_images, feature_dim, extractor_name)
737    }
738}
739
740/// Patch-based neighbor matching for texture analysis
741pub struct PatchBasedMatching {
742    patch_size: usize,
743    stride: usize,
744    feature_extractor: Box<dyn FeatureExtractor>,
745    neighbor_search: Option<ImageSimilaritySearch>,
746}
747
748impl PatchBasedMatching {
749    /// Create new patch-based matching system
750    pub fn new(patch_size: usize, stride: usize, feature_type: FeatureType) -> Self {
751        let config = ImageSearchConfig {
752            feature_type,
753            k_neighbors: 10,
754            ..Default::default()
755        };
756
757        let feature_extractor: Box<dyn FeatureExtractor> = match feature_type {
758            FeatureType::ColorHistogram => Box::new(ColorHistogramExtractor::new(16, true)),
759            FeatureType::LocalBinaryPattern => {
760                Box::new(LocalBinaryPatternExtractor::new(1, 8, true))
761            }
762            FeatureType::HistogramOfGradients => {
763                Box::new(HistogramOfGradientsExtractor::new(4, 1, 9, true))
764            }
765            _ => Box::new(LocalBinaryPatternExtractor::new(1, 8, true)),
766        };
767
768        Self {
769            patch_size,
770            stride,
771            feature_extractor,
772            neighbor_search: Some(ImageSimilaritySearch::new(config)),
773        }
774    }
775
776    /// Extract patches from image
777    pub fn extract_patches(
778        &self,
779        image_data: &[f64],
780        dimensions: (usize, usize, usize),
781    ) -> NeighborsResult<Vec<(Array1<f64>, (usize, usize))>> {
782        let (width, height, channels) = dimensions;
783        let mut patches = Vec::new();
784
785        // Extract patches with sliding window
786        for y in (0..(height.saturating_sub(self.patch_size))).step_by(self.stride) {
787            for x in (0..(width.saturating_sub(self.patch_size))).step_by(self.stride) {
788                // Extract patch data
789                let mut patch_data = Vec::new();
790
791                for py in y..(y + self.patch_size).min(height) {
792                    for px in x..(x + self.patch_size).min(width) {
793                        for c in 0..channels {
794                            let idx = py * width * channels + px * channels + c;
795                            patch_data.push(image_data[idx]);
796                        }
797                    }
798                }
799
800                // Extract features from patch
801                let patch_dims = (self.patch_size, self.patch_size, channels);
802                let features = self
803                    .feature_extractor
804                    .extract_features(&patch_data, patch_dims)?;
805                patches.push((features, (x, y)));
806            }
807        }
808
809        Ok(patches)
810    }
811
812    /// Build patch database from multiple images
813    pub fn build_patch_database(
814        &mut self,
815        images: &[(Vec<f64>, (usize, usize, usize))],
816    ) -> NeighborsResult<()> {
817        let mut all_patch_features = Vec::new();
818
819        for (image_data, dimensions) in images {
820            let patches = self.extract_patches(image_data, *dimensions)?;
821            for (features, _position) in patches {
822                all_patch_features.push(features);
823            }
824        }
825
826        if all_patch_features.is_empty() {
827            return Err(NeighborsError::EmptyInput);
828        }
829
830        // Convert to Array2
831        let num_patches = all_patch_features.len();
832        let feature_dim = all_patch_features[0].len();
833        let mut feature_matrix = Array2::zeros((num_patches, feature_dim));
834
835        for (i, features) in all_patch_features.into_iter().enumerate() {
836            feature_matrix.row_mut(i).assign(&features);
837        }
838
839        // Build search index
840        if let Some(ref mut search) = self.neighbor_search {
841            search.build_index(&feature_matrix)?;
842        }
843
844        Ok(())
845    }
846
847    /// Find similar patches
848    pub fn find_similar_patches(
849        &self,
850        query_patch: &[f64],
851        patch_dimensions: (usize, usize, usize),
852        k: usize,
853    ) -> NeighborsResult<Vec<ImageSearchResult>> {
854        let features = self
855            .feature_extractor
856            .extract_features(query_patch, patch_dimensions)?;
857
858        if let Some(ref search) = self.neighbor_search {
859            search.search(&features, k)
860        } else {
861            Err(NeighborsError::InvalidInput(
862                "Patch database not built".to_string(),
863            ))
864        }
865    }
866}
867
868/// Feature descriptor matching (SIFT-like)
869pub struct FeatureDescriptorMatcher {
870    descriptor_dimension: usize,
871    match_threshold: f64,
872    ratio_test_threshold: f64,
873    use_cross_check: bool,
874}
875
876/// Keypoint with descriptor
877#[derive(Debug, Clone)]
878pub struct Keypoint {
879    /// 2D coordinates (x, y)
880    pub position: (f64, f64),
881    /// Scale/size of the keypoint
882    pub scale: f64,
883    /// Orientation in radians
884    pub orientation: f64,
885    /// Feature descriptor vector
886    pub descriptor: Array1<f64>,
887    /// Response strength
888    pub response: f64,
889}
890
891/// Descriptor match between two keypoints
892#[derive(Debug, Clone)]
893pub struct DescriptorMatch {
894    /// Query keypoint index
895    pub query_idx: usize,
896    /// Train keypoint index
897    pub train_idx: usize,
898    /// Match distance
899    pub distance: f64,
900    /// Match confidence
901    pub confidence: f64,
902}
903
904impl FeatureDescriptorMatcher {
905    /// Create new descriptor matcher
906    pub fn new(descriptor_dimension: usize) -> Self {
907        Self {
908            descriptor_dimension,
909            match_threshold: 0.8,
910            ratio_test_threshold: 0.7,
911            use_cross_check: true,
912        }
913    }
914
915    /// Match descriptors between two sets of keypoints
916    pub fn match_descriptors(
917        &self,
918        query_keypoints: &[Keypoint],
919        train_keypoints: &[Keypoint],
920    ) -> NeighborsResult<Vec<DescriptorMatch>> {
921        if query_keypoints.is_empty() || train_keypoints.is_empty() {
922            return Ok(Vec::new());
923        }
924
925        let mut matches = Vec::new();
926
927        // Build descriptor matrices
928        let query_descriptors = Array2::from_shape_fn(
929            (query_keypoints.len(), self.descriptor_dimension),
930            |(i, j)| query_keypoints[i].descriptor[j],
931        );
932
933        let train_descriptors = Array2::from_shape_fn(
934            (train_keypoints.len(), self.descriptor_dimension),
935            |(i, j)| train_keypoints[i].descriptor[j],
936        );
937
938        // For each query descriptor, find best matches
939        for (query_idx, query_desc) in query_descriptors.rows().into_iter().enumerate() {
940            let mut distances: Vec<(usize, f64)> = Vec::new();
941
942            // Compute distances to all train descriptors
943            for (train_idx, train_desc) in train_descriptors.rows().into_iter().enumerate() {
944                let distance = self.compute_descriptor_distance(&query_desc, &train_desc);
945                distances.push((train_idx, distance));
946            }
947
948            // Sort by distance
949            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
950
951            // Apply ratio test if we have at least 2 matches
952            if distances.len() >= 2 {
953                let best_distance = distances[0].1;
954                let second_best_distance = distances[1].1;
955
956                if second_best_distance > 0.0
957                    && (best_distance / second_best_distance) < self.ratio_test_threshold
958                {
959                    let train_idx = distances[0].0;
960                    let confidence = 1.0 - (best_distance / second_best_distance);
961
962                    matches.push(DescriptorMatch {
963                        query_idx,
964                        train_idx,
965                        distance: best_distance,
966                        confidence,
967                    });
968                }
969            } else if !distances.is_empty() && distances[0].1 < self.match_threshold {
970                // Single match case
971                let train_idx = distances[0].0;
972                let confidence = 1.0 - distances[0].1;
973
974                matches.push(DescriptorMatch {
975                    query_idx,
976                    train_idx,
977                    distance: distances[0].1,
978                    confidence,
979                });
980            }
981        }
982
983        // Apply cross-check if enabled
984        if self.use_cross_check {
985            matches = self.apply_cross_check(matches, &query_descriptors, &train_descriptors)?;
986        }
987
988        Ok(matches)
989    }
990
991    /// Compute distance between two descriptors
992    fn compute_descriptor_distance(&self, desc1: &ArrayView1<f64>, desc2: &ArrayView1<f64>) -> f64 {
993        // Use Euclidean distance
994        desc1
995            .iter()
996            .zip(desc2.iter())
997            .map(|(a, b)| (a - b).powi(2))
998            .sum::<f64>()
999            .sqrt()
1000    }
1001
1002    /// Apply cross-check filtering
1003    fn apply_cross_check(
1004        &self,
1005        matches: Vec<DescriptorMatch>,
1006        query_descriptors: &Array2<f64>,
1007        train_descriptors: &Array2<f64>,
1008    ) -> NeighborsResult<Vec<DescriptorMatch>> {
1009        let mut filtered_matches = Vec::new();
1010
1011        for m in matches {
1012            // Check if train descriptor also matches back to query descriptor
1013            let train_desc = train_descriptors.row(m.train_idx);
1014            let mut best_query_idx = 0;
1015            let mut best_distance = f64::INFINITY;
1016
1017            for (query_idx, query_desc) in query_descriptors.rows().into_iter().enumerate() {
1018                let distance = self.compute_descriptor_distance(&train_desc, &query_desc);
1019                if distance < best_distance {
1020                    best_distance = distance;
1021                    best_query_idx = query_idx;
1022                }
1023            }
1024
1025            // If mutual best match, keep it
1026            if best_query_idx == m.query_idx {
1027                filtered_matches.push(m);
1028            }
1029        }
1030
1031        Ok(filtered_matches)
1032    }
1033}
1034
1035/// Visual word recognition using bag-of-visual-words
1036pub struct VisualWordRecognizer {
1037    vocabulary: Option<Array2<f64>>,
1038    vocabulary_size: usize,
1039    feature_extractor: Box<dyn FeatureExtractor>,
1040    classifier: Option<KNeighborsClassifier<Trained>>,
1041}
1042
1043impl VisualWordRecognizer {
1044    /// Create new visual word recognizer
1045    pub fn new(vocabulary_size: usize) -> Self {
1046        Self {
1047            vocabulary: None,
1048            vocabulary_size,
1049            feature_extractor: Box::new(LocalBinaryPatternExtractor::new(1, 8, true)),
1050            classifier: None,
1051        }
1052    }
1053
1054    /// Build vocabulary from training patches
1055    pub fn build_vocabulary(
1056        &mut self,
1057        training_patches: &[Vec<f64>],
1058        patch_dimensions: (usize, usize, usize),
1059    ) -> NeighborsResult<()> {
1060        if training_patches.is_empty() {
1061            return Err(NeighborsError::EmptyInput);
1062        }
1063
1064        // Extract features from all patches
1065        let mut patch_features = Vec::new();
1066        for patch_data in training_patches {
1067            let features = self
1068                .feature_extractor
1069                .extract_features(patch_data, patch_dimensions)?;
1070            patch_features.push(features);
1071        }
1072
1073        // Use k-means clustering to build vocabulary (simplified version)
1074        let feature_dim = patch_features[0].len();
1075        let mut vocabulary = Array2::zeros((self.vocabulary_size, feature_dim));
1076
1077        // Initialize vocabulary with random patches
1078        let mut rng = thread_rng();
1079        for i in 0..self.vocabulary_size {
1080            let random_idx = rng.gen_range(0..patch_features.len());
1081            vocabulary.row_mut(i).assign(&patch_features[random_idx]);
1082        }
1083
1084        // Simple k-means iterations (in practice, would use more sophisticated clustering)
1085        for _iteration in 0..10 {
1086            let mut cluster_sums = vec![Array1::zeros(feature_dim); self.vocabulary_size];
1087            let mut cluster_counts = vec![0; self.vocabulary_size];
1088
1089            // Assign patches to clusters
1090            for patch_feature in &patch_features {
1091                let mut best_cluster = 0;
1092                let mut best_distance = f64::INFINITY;
1093
1094                for (cluster_idx, centroid) in vocabulary.rows().into_iter().enumerate() {
1095                    let distance: f64 = patch_feature
1096                        .iter()
1097                        .zip(centroid.iter())
1098                        .map(|(a, b)| (a - b).powi(2))
1099                        .sum::<f64>()
1100                        .sqrt();
1101
1102                    if distance < best_distance {
1103                        best_distance = distance;
1104                        best_cluster = cluster_idx;
1105                    }
1106                }
1107
1108                cluster_sums[best_cluster] = &cluster_sums[best_cluster] + patch_feature;
1109                cluster_counts[best_cluster] += 1;
1110            }
1111
1112            // Update centroids
1113            for i in 0..self.vocabulary_size {
1114                if cluster_counts[i] > 0 {
1115                    vocabulary
1116                        .row_mut(i)
1117                        .assign(&(&cluster_sums[i] / cluster_counts[i] as f64));
1118                }
1119            }
1120        }
1121
1122        self.vocabulary = Some(vocabulary);
1123        Ok(())
1124    }
1125
1126    /// Convert image to bag-of-visual-words histogram
1127    pub fn compute_bow_histogram(
1128        &self,
1129        image_patches: &[Vec<f64>],
1130        patch_dimensions: (usize, usize, usize),
1131    ) -> NeighborsResult<Array1<f64>> {
1132        let vocabulary = self
1133            .vocabulary
1134            .as_ref()
1135            .ok_or(NeighborsError::InvalidInput(
1136                "Vocabulary not built".to_string(),
1137            ))?;
1138
1139        let mut histogram = Array1::zeros(self.vocabulary_size);
1140
1141        for patch_data in image_patches {
1142            let features = self
1143                .feature_extractor
1144                .extract_features(patch_data, patch_dimensions)?;
1145
1146            // Find closest visual word
1147            let mut best_word = 0;
1148            let mut best_distance = f64::INFINITY;
1149
1150            for (word_idx, word_features) in vocabulary.rows().into_iter().enumerate() {
1151                let distance: f64 = features
1152                    .iter()
1153                    .zip(word_features.iter())
1154                    .map(|(a, b)| (a - b).powi(2))
1155                    .sum::<f64>()
1156                    .sqrt();
1157
1158                if distance < best_distance {
1159                    best_distance = distance;
1160                    best_word = word_idx;
1161                }
1162            }
1163
1164            histogram[best_word] += 1.0;
1165        }
1166
1167        // Normalize histogram
1168        let sum: f64 = histogram.sum();
1169        if sum > 0.0 {
1170            histogram /= sum;
1171        }
1172
1173        Ok(histogram)
1174    }
1175
1176    /// Train classifier on bag-of-visual-words histograms
1177    pub fn train_classifier(
1178        &mut self,
1179        bow_histograms: &Array2<f64>,
1180        labels: &Array1<i32>,
1181    ) -> NeighborsResult<()> {
1182        let classifier = KNeighborsClassifier::new(5);
1183
1184        // Convert to proper types
1185        let features: Features = bow_histograms.mapv(|x| x as Float);
1186        let target_labels: Array1<Int> = labels.mapv(|x| x as Int);
1187
1188        let trained = classifier
1189            .fit(&features, &target_labels)
1190            .map_err(|e| NeighborsError::InvalidInput(e.to_string()))?;
1191
1192        self.classifier = Some(trained);
1193        Ok(())
1194    }
1195
1196    /// Classify image using bag-of-visual-words
1197    pub fn classify_image(
1198        &self,
1199        image_patches: &[Vec<f64>],
1200        patch_dimensions: (usize, usize, usize),
1201    ) -> NeighborsResult<i32> {
1202        let classifier = self
1203            .classifier
1204            .as_ref()
1205            .ok_or(NeighborsError::InvalidInput(
1206                "Classifier not trained".to_string(),
1207            ))?;
1208
1209        let bow_histogram = self.compute_bow_histogram(image_patches, patch_dimensions)?;
1210        let bow_2d = bow_histogram.insert_axis(Axis(0));
1211
1212        // Convert to Features type
1213        let features: Features = bow_2d.mapv(|x| x as Float);
1214
1215        let predictions = classifier
1216            .predict(&features)
1217            .map_err(|e| NeighborsError::InvalidInput(e.to_string()))?;
1218
1219        Ok(predictions[0] as i32)
1220    }
1221}
1222
1223#[allow(non_snake_case)]
1224#[cfg(test)]
1225mod tests {
1226    use super::*;
1227
1228    #[test]
1229    fn test_color_histogram_extractor() {
1230        let extractor = ColorHistogramExtractor::new(8, true);
1231        let image_data = vec![0.5; 32 * 32 * 3]; // 32x32 RGB image
1232        let dimensions = (32, 32, 3);
1233
1234        let features = extractor.extract_features(&image_data, dimensions).unwrap();
1235        assert_eq!(features.len(), 8 * 3);
1236
1237        // Check normalization
1238        let sum: f64 = features.sum();
1239        assert!((sum - 1.0).abs() < 1e-6);
1240    }
1241
1242    #[test]
1243    fn test_local_binary_pattern_extractor() {
1244        let extractor = LocalBinaryPatternExtractor::new(1, 8, true);
1245        let image_data = vec![0.5; 16 * 16]; // 16x16 grayscale image
1246        let dimensions = (16, 16, 1);
1247
1248        let features = extractor.extract_features(&image_data, dimensions).unwrap();
1249        assert_eq!(features.len(), 8 + 2); // uniform patterns
1250    }
1251
1252    #[test]
1253    fn test_image_similarity_search() {
1254        let config = ImageSearchConfig::default();
1255        let mut search = ImageSimilaritySearch::new(config);
1256
1257        // Create dummy feature data
1258        let features = Array2::from_shape_fn((10, 96), |(i, j)| (i + j) as f64);
1259        search.build_index(&features).unwrap();
1260
1261        // Verify database was built properly
1262        let (num_images, _feature_dim, _) = search.get_stats();
1263        assert_eq!(num_images, 0); // No metadata added yet
1264        assert_eq!(search.feature_database.as_ref().unwrap().nrows(), 10);
1265
1266        // Test search
1267        let query = Array1::from_vec((0..96).map(|i| i as f64).collect());
1268        let results = search.search(&query, 3).unwrap();
1269
1270        assert_eq!(results.len(), 3);
1271        assert!(results[0].distance <= results[1].distance);
1272    }
1273
1274    #[test]
1275    fn test_patch_based_matching() {
1276        let matcher = PatchBasedMatching::new(8, 4, FeatureType::LocalBinaryPattern);
1277
1278        // Create dummy image
1279        let image_data = vec![0.5; 32 * 32 * 3];
1280        let dimensions = (32, 32, 3);
1281
1282        let patches = matcher.extract_patches(&image_data, dimensions).unwrap();
1283        assert!(!patches.is_empty());
1284
1285        // Test patch extraction positions
1286        let expected_patches = ((32_usize - 8) / 4 + 1).pow(2);
1287        assert!(patches.len() <= expected_patches);
1288    }
1289
1290    #[test]
1291    fn test_feature_descriptor_matcher() {
1292        let matcher = FeatureDescriptorMatcher::new(128);
1293
1294        // Create dummy keypoints
1295        let query_keypoints = vec![Keypoint {
1296            position: (10.0, 10.0),
1297            scale: 1.0,
1298            orientation: 0.0,
1299            descriptor: Array1::zeros(128),
1300            response: 0.5,
1301        }];
1302
1303        let train_keypoints = vec![Keypoint {
1304            position: (11.0, 11.0),
1305            scale: 1.0,
1306            orientation: 0.0,
1307            descriptor: Array1::from_elem(128, 0.1),
1308            response: 0.6,
1309        }];
1310
1311        let matches = matcher
1312            .match_descriptors(&query_keypoints, &train_keypoints)
1313            .unwrap();
1314        // With ratio test, may not find matches for identical descriptors
1315        assert!(matches.len() <= 1);
1316    }
1317
1318    #[test]
1319    fn test_visual_word_recognizer() {
1320        let mut recognizer = VisualWordRecognizer::new(16);
1321
1322        // Create training patches
1323        let training_patches: Vec<Vec<f64>> =
1324            (0..10).map(|i| vec![i as f64 * 0.1; 8 * 8 * 3]).collect();
1325        let patch_dimensions = (8, 8, 3);
1326
1327        recognizer
1328            .build_vocabulary(&training_patches, patch_dimensions)
1329            .unwrap();
1330
1331        // Test BOW histogram computation
1332        let test_patches = vec![vec![0.5; 8 * 8 * 3]];
1333        let histogram = recognizer
1334            .compute_bow_histogram(&test_patches, patch_dimensions)
1335            .unwrap();
1336
1337        assert_eq!(histogram.len(), 16);
1338        assert!((histogram.sum() - 1.0).abs() < 1e-6);
1339    }
1340}