1use crate::{knn::KNeighborsClassifier, NeighborsError, NeighborsResult};
42use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
43use scirs2_core::random::thread_rng;
44use sklears_core::traits::{Fit, Predict, Trained};
45use sklears_core::types::{Features, Float, Int};
46use std::collections::HashMap;
47
48#[cfg(feature = "serde")]
49use serde::{Deserialize, Serialize};
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
54pub enum FeatureType {
55 ColorHistogram,
57 LocalBinaryPattern,
59 HistogramOfGradients,
61 GaborFilters,
63 EdgeFeatures,
65 TextureFeatures,
67 MultiModal,
69}
70
71#[derive(Debug, Clone)]
73#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74pub struct ImageSearchConfig {
75 pub feature_type: FeatureType,
77 pub k_neighbors: usize,
79 pub distance_metric: String,
81 pub normalize_features: bool,
83 pub target_dimensions: Option<usize>,
85 pub use_approximate: bool,
87}
88
89impl Default for ImageSearchConfig {
90 fn default() -> Self {
91 Self {
92 feature_type: FeatureType::ColorHistogram,
93 k_neighbors: 5,
94 distance_metric: "euclidean".to_string(),
95 normalize_features: true,
96 target_dimensions: None,
97 use_approximate: false,
98 }
99 }
100}
101
102pub struct ImageSimilaritySearch {
104 config: ImageSearchConfig,
105 feature_extractor: Box<dyn FeatureExtractor>,
106 neighbor_index: Option<KNeighborsClassifier<Trained>>,
107 feature_database: Option<Array2<f64>>,
108 image_metadata: Vec<ImageMetadata>,
109}
110
111#[derive(Debug, Clone)]
113#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
114pub struct ImageMetadata {
115 pub id: String,
117 pub dimensions: (usize, usize),
119 pub channels: usize,
121 pub path: Option<String>,
123 pub tags: HashMap<String, String>,
125}
126
127#[derive(Debug, Clone)]
129pub struct ImageSearchResult {
130 pub metadata: ImageMetadata,
132 pub distance: f64,
134 pub features: Array1<f64>,
136 pub confidence: f64,
138}
139
140pub trait FeatureExtractor: Send + Sync {
142 fn extract_features(
148 &self,
149 image_data: &[f64],
150 dimensions: (usize, usize, usize),
151 ) -> NeighborsResult<Array1<f64>>;
152
153 fn feature_dimension(&self) -> usize;
155
156 fn name(&self) -> &str;
158}
159
160pub struct ColorHistogramExtractor {
162 bins_per_channel: usize,
163 normalize: bool,
164}
165
166impl ColorHistogramExtractor {
167 pub fn new(bins_per_channel: usize, normalize: bool) -> Self {
168 Self {
169 bins_per_channel,
170 normalize,
171 }
172 }
173}
174
175impl FeatureExtractor for ColorHistogramExtractor {
176 fn extract_features(
177 &self,
178 image_data: &[f64],
179 dimensions: (usize, usize, usize),
180 ) -> NeighborsResult<Array1<f64>> {
181 let (width, height, channels) = dimensions;
182
183 if image_data.len() != width * height * channels {
184 return Err(NeighborsError::InvalidInput(format!(
185 "Image data length {} doesn't match dimensions {}x{}x{}",
186 image_data.len(),
187 width,
188 height,
189 channels
190 )));
191 }
192
193 let total_bins = self.bins_per_channel * channels;
194 let mut histogram = Array1::zeros(total_bins);
195
196 for c in 0..channels {
198 for i in 0..(width * height) {
199 let pixel_idx = i * channels + c;
200 let pixel_value = image_data[pixel_idx];
201
202 let normalized = pixel_value.clamp(0.0, 1.0);
204 let bin = ((normalized * self.bins_per_channel as f64) as usize)
205 .min(self.bins_per_channel - 1);
206
207 histogram[c * self.bins_per_channel + bin] += 1.0;
208 }
209 }
210
211 if self.normalize {
213 let sum: f64 = histogram.sum();
214 if sum > 0.0 {
215 histogram /= sum;
216 }
217 }
218
219 Ok(histogram)
220 }
221
222 fn feature_dimension(&self) -> usize {
223 self.bins_per_channel * 3 }
225
226 fn name(&self) -> &str {
227 "ColorHistogram"
228 }
229}
230
231pub struct LocalBinaryPatternExtractor {
233 radius: usize,
234 neighbors: usize,
235 uniform_patterns: bool,
236}
237
238impl LocalBinaryPatternExtractor {
239 pub fn new(radius: usize, neighbors: usize, uniform_patterns: bool) -> Self {
240 Self {
241 radius,
242 neighbors,
243 uniform_patterns,
244 }
245 }
246
247 fn compute_lbp_value(&self, center: f64, neighbors: &[f64]) -> u32 {
249 let mut lbp_code = 0u32;
250
251 for (i, &neighbor) in neighbors.iter().enumerate() {
252 if neighbor >= center {
253 lbp_code |= 1 << i;
254 }
255 }
256
257 if self.uniform_patterns {
259 self.to_uniform_pattern(lbp_code)
260 } else {
261 lbp_code
262 }
263 }
264
265 fn to_uniform_pattern(&self, code: u32) -> u32 {
267 let mut transitions = 0;
269 for i in 0..self.neighbors {
270 let bit1 = (code >> i) & 1;
271 let bit2 = (code >> ((i + 1) % self.neighbors)) & 1;
272 if bit1 != bit2 {
273 transitions += 1;
274 }
275 }
276
277 if transitions <= 2 {
279 code.count_ones()
280 } else {
281 self.neighbors as u32 + 1 }
283 }
284}
285
286impl FeatureExtractor for LocalBinaryPatternExtractor {
287 fn extract_features(
288 &self,
289 image_data: &[f64],
290 dimensions: (usize, usize, usize),
291 ) -> NeighborsResult<Array1<f64>> {
292 let (width, height, channels) = dimensions;
293
294 if image_data.len() != width * height * channels {
295 return Err(NeighborsError::InvalidInput(
296 "Image data length doesn't match dimensions".to_string(),
297 ));
298 }
299
300 let grayscale: Vec<f64> = if channels == 1 {
302 image_data.to_vec()
303 } else {
304 (0..(width * height))
306 .map(|i| {
307 let r = image_data[i * channels];
308 let g = image_data[i * channels + 1];
309 let b = image_data[i * channels + 2];
310 0.299 * r + 0.587 * g + 0.114 * b
311 })
312 .collect()
313 };
314
315 let max_pattern = if self.uniform_patterns {
316 self.neighbors + 2
317 } else {
318 1 << self.neighbors
319 };
320
321 let mut histogram = Array1::zeros(max_pattern);
322
323 for y in self.radius..(height - self.radius) {
325 for x in self.radius..(width - self.radius) {
326 let center_idx = y * width + x;
327 let center_value = grayscale[center_idx];
328
329 let mut neighbor_values = Vec::with_capacity(self.neighbors);
331 for i in 0..self.neighbors {
332 let angle = 2.0 * std::f64::consts::PI * i as f64 / self.neighbors as f64;
333 let nx = x as f64 + self.radius as f64 * angle.cos();
334 let ny = y as f64 + self.radius as f64 * angle.sin();
335
336 let x1 = nx.floor() as usize;
338 let y1 = ny.floor() as usize;
339 let x2 = (x1 + 1).min(width - 1);
340 let y2 = (y1 + 1).min(height - 1);
341
342 let fx = nx - x1 as f64;
343 let fy = ny - y1 as f64;
344
345 let v1 = grayscale[y1 * width + x1];
346 let v2 = grayscale[y1 * width + x2];
347 let v3 = grayscale[y2 * width + x1];
348 let v4 = grayscale[y2 * width + x2];
349
350 let interpolated = v1 * (1.0 - fx) * (1.0 - fy)
351 + v2 * fx * (1.0 - fy)
352 + v3 * (1.0 - fx) * fy
353 + v4 * fx * fy;
354
355 neighbor_values.push(interpolated);
356 }
357
358 let lbp_code = self.compute_lbp_value(center_value, &neighbor_values);
360 histogram[lbp_code as usize] += 1.0;
361 }
362 }
363
364 let sum: f64 = histogram.sum();
366 if sum > 0.0 {
367 histogram /= sum;
368 }
369
370 Ok(histogram)
371 }
372
373 fn feature_dimension(&self) -> usize {
374 if self.uniform_patterns {
375 self.neighbors + 2
376 } else {
377 1 << self.neighbors
378 }
379 }
380
381 fn name(&self) -> &str {
382 "LocalBinaryPattern"
383 }
384}
385
386pub struct HistogramOfGradientsExtractor {
388 cell_size: usize,
389 block_size: usize,
390 num_bins: usize,
391 normalize_blocks: bool,
392}
393
394impl HistogramOfGradientsExtractor {
395 pub fn new(
396 cell_size: usize,
397 block_size: usize,
398 num_bins: usize,
399 normalize_blocks: bool,
400 ) -> Self {
401 Self {
402 cell_size,
403 block_size,
404 num_bins,
405 normalize_blocks,
406 }
407 }
408
409 fn compute_gradients(
411 &self,
412 image: &[f64],
413 width: usize,
414 height: usize,
415 ) -> (Vec<f64>, Vec<f64>) {
416 let mut magnitudes = vec![0.0; width * height];
417 let mut orientations = vec![0.0; width * height];
418
419 for y in 1..(height - 1) {
420 for x in 1..(width - 1) {
421 let idx = y * width + x;
422
423 let gx = image[y * width + (x + 1)] - image[y * width + (x - 1)];
425 let gy = image[(y + 1) * width + x] - image[(y - 1) * width + x];
426
427 magnitudes[idx] = (gx * gx + gy * gy).sqrt();
428 orientations[idx] = gy.atan2(gx);
429
430 if orientations[idx] < 0.0 {
432 orientations[idx] += std::f64::consts::PI;
433 }
434 }
435 }
436
437 (magnitudes, orientations)
438 }
439}
440
441impl FeatureExtractor for HistogramOfGradientsExtractor {
442 fn extract_features(
443 &self,
444 image_data: &[f64],
445 dimensions: (usize, usize, usize),
446 ) -> NeighborsResult<Array1<f64>> {
447 let (width, height, channels) = dimensions;
448
449 let grayscale: Vec<f64> = if channels == 1 {
451 image_data.to_vec()
452 } else {
453 (0..(width * height))
454 .map(|i| {
455 let r = image_data[i * channels];
456 let g = image_data[i * channels + 1];
457 let b = image_data[i * channels + 2];
458 0.299 * r + 0.587 * g + 0.114 * b
459 })
460 .collect()
461 };
462
463 let (magnitudes, orientations) = self.compute_gradients(&grayscale, width, height);
465
466 let cells_x = width / self.cell_size;
468 let cells_y = height / self.cell_size;
469
470 let mut cell_histograms = vec![Array1::zeros(self.num_bins); cells_x * cells_y];
472
473 for cell_y in 0..cells_y {
474 for cell_x in 0..cells_x {
475 let cell_idx = cell_y * cells_x + cell_x;
476
477 for y in (cell_y * self.cell_size)..((cell_y + 1) * self.cell_size) {
479 for x in (cell_x * self.cell_size)..((cell_x + 1) * self.cell_size) {
480 if y < height && x < width {
481 let pixel_idx = y * width + x;
482 let magnitude = magnitudes[pixel_idx];
483 let orientation = orientations[pixel_idx];
484
485 let bin_width = std::f64::consts::PI / self.num_bins as f64;
487 let bin = ((orientation / bin_width) as usize).min(self.num_bins - 1);
488
489 cell_histograms[cell_idx][bin] += magnitude;
491 }
492 }
493 }
494 }
495 }
496
497 let blocks_x = cells_x.saturating_sub(self.block_size - 1);
499 let blocks_y = cells_y.saturating_sub(self.block_size - 1);
500 let descriptor_size =
501 blocks_x * blocks_y * self.block_size * self.block_size * self.num_bins;
502
503 let mut descriptor = Array1::zeros(descriptor_size);
504 let mut desc_idx = 0;
505
506 for block_y in 0..blocks_y {
507 for block_x in 0..blocks_x {
508 let mut block_hist =
510 Array1::zeros(self.block_size * self.block_size * self.num_bins);
511 let mut hist_idx = 0;
512
513 for by in 0..self.block_size {
514 for bx in 0..self.block_size {
515 let cell_x = block_x + bx;
516 let cell_y = block_y + by;
517 let cell_idx = cell_y * cells_x + cell_x;
518
519 for bin in 0..self.num_bins {
520 block_hist[hist_idx * self.num_bins + bin] =
521 cell_histograms[cell_idx][bin];
522 }
523 hist_idx += 1;
524 }
525 }
526
527 if self.normalize_blocks {
529 let norm = (block_hist.mapv(|x: f64| x * x).sum()).sqrt();
530 if norm > 1e-8 {
531 block_hist /= norm;
532 }
533 }
534
535 for i in 0..block_hist.len() {
537 descriptor[desc_idx + i] = block_hist[i];
538 }
539 desc_idx += block_hist.len();
540 }
541 }
542
543 Ok(descriptor)
544 }
545
546 fn feature_dimension(&self) -> usize {
547 self.num_bins * self.block_size * self.block_size * 16
549 }
550
551 fn name(&self) -> &str {
552 "HistogramOfGradients"
553 }
554}
555
556impl ImageSimilaritySearch {
557 pub fn new(config: ImageSearchConfig) -> Self {
559 let feature_extractor: Box<dyn FeatureExtractor> = match config.feature_type {
560 FeatureType::ColorHistogram => Box::new(ColorHistogramExtractor::new(32, true)),
561 FeatureType::LocalBinaryPattern => {
562 Box::new(LocalBinaryPatternExtractor::new(1, 8, true))
563 }
564 FeatureType::HistogramOfGradients => {
565 Box::new(HistogramOfGradientsExtractor::new(8, 2, 9, true))
566 }
567 FeatureType::GaborFilters => Box::new(ColorHistogramExtractor::new(32, true)), FeatureType::EdgeFeatures => Box::new(ColorHistogramExtractor::new(32, true)), FeatureType::TextureFeatures => Box::new(ColorHistogramExtractor::new(32, true)), FeatureType::MultiModal => Box::new(ColorHistogramExtractor::new(32, true)), };
572
573 Self {
574 config,
575 feature_extractor,
576 neighbor_index: None,
577 feature_database: None,
578 image_metadata: Vec::new(),
579 }
580 }
581
582 pub fn build_index(&mut self, features: &Array2<f64>) -> NeighborsResult<()> {
584 if features.is_empty() {
585 return Err(NeighborsError::EmptyInput);
586 }
587
588 self.feature_database = Some(features.clone());
590 self.neighbor_index = None; Ok(())
593 }
594
595 pub fn add_image(
597 &mut self,
598 image_data: &[f64],
599 dimensions: (usize, usize, usize),
600 metadata: ImageMetadata,
601 ) -> NeighborsResult<()> {
602 let features = self
604 .feature_extractor
605 .extract_features(image_data, dimensions)?;
606
607 self.image_metadata.push(metadata);
609
610 if let Some(ref mut db) = self.feature_database {
612 let mut new_db = Array2::zeros((db.nrows() + 1, db.ncols()));
614 new_db
615 .slice_mut(scirs2_core::ndarray::s![..db.nrows(), ..])
616 .assign(db);
617 new_db.row_mut(db.nrows()).assign(&features);
618 *db = new_db;
619
620 let db_clone = db.clone();
622 self.build_index(&db_clone)?;
623 } else {
624 let new_db = features.insert_axis(Axis(0));
626 let db_clone = new_db.clone();
627 self.feature_database = Some(new_db);
628 self.build_index(&db_clone)?;
629 }
630
631 Ok(())
632 }
633
634 pub fn search(
636 &self,
637 query_features: &Array1<f64>,
638 k: usize,
639 ) -> NeighborsResult<Vec<ImageSearchResult>> {
640 let database = self
641 .feature_database
642 .as_ref()
643 .ok_or(NeighborsError::InvalidInput(
644 "Feature database not available".to_string(),
645 ))?;
646
647 if database.is_empty() {
648 return Ok(Vec::new());
649 }
650
651 let mut distances_with_indices: Vec<(f64, usize)> = Vec::new();
653
654 for (idx, db_features) in database.rows().into_iter().enumerate() {
655 let distance: f64 = query_features
657 .iter()
658 .zip(db_features.iter())
659 .map(|(a, b)| (a - b).powi(2))
660 .sum::<f64>()
661 .sqrt();
662
663 distances_with_indices.push((distance, idx));
664 }
665
666 distances_with_indices.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
668 let k_results = k.min(distances_with_indices.len());
669
670 let mut results = Vec::new();
671 let max_distance = if !distances_with_indices.is_empty() {
672 distances_with_indices[distances_with_indices.len().min(k) - 1].0
673 } else {
674 1.0
675 };
676
677 for i in 0..k_results {
678 let (distance, idx) = distances_with_indices[i];
679
680 let features = database.row(idx).to_owned();
681
682 let confidence = if max_distance > 0.0 {
684 1.0 - (distance / max_distance)
685 } else {
686 1.0
687 };
688
689 let metadata = if idx < self.image_metadata.len() {
691 self.image_metadata[idx].clone()
692 } else {
693 ImageMetadata {
694 id: format!("image_{}", idx),
695 dimensions: (0, 0),
696 channels: 0,
697 path: None,
698 tags: HashMap::new(),
699 }
700 };
701
702 results.push(ImageSearchResult {
703 metadata,
704 distance,
705 features,
706 confidence,
707 });
708 }
709
710 Ok(results)
711 }
712
713 pub fn search_by_image(
715 &self,
716 image_data: &[f64],
717 dimensions: (usize, usize, usize),
718 k: usize,
719 ) -> NeighborsResult<Vec<ImageSearchResult>> {
720 let features = self
721 .feature_extractor
722 .extract_features(image_data, dimensions)?;
723 self.search(&features, k)
724 }
725
726 pub fn get_stats(&self) -> (usize, usize, String) {
728 let num_images = self.image_metadata.len();
729 let feature_dim = self
730 .feature_database
731 .as_ref()
732 .map(|db| db.ncols())
733 .unwrap_or(0);
734 let extractor_name = self.feature_extractor.name().to_string();
735
736 (num_images, feature_dim, extractor_name)
737 }
738}
739
740pub struct PatchBasedMatching {
742 patch_size: usize,
743 stride: usize,
744 feature_extractor: Box<dyn FeatureExtractor>,
745 neighbor_search: Option<ImageSimilaritySearch>,
746}
747
748impl PatchBasedMatching {
749 pub fn new(patch_size: usize, stride: usize, feature_type: FeatureType) -> Self {
751 let config = ImageSearchConfig {
752 feature_type,
753 k_neighbors: 10,
754 ..Default::default()
755 };
756
757 let feature_extractor: Box<dyn FeatureExtractor> = match feature_type {
758 FeatureType::ColorHistogram => Box::new(ColorHistogramExtractor::new(16, true)),
759 FeatureType::LocalBinaryPattern => {
760 Box::new(LocalBinaryPatternExtractor::new(1, 8, true))
761 }
762 FeatureType::HistogramOfGradients => {
763 Box::new(HistogramOfGradientsExtractor::new(4, 1, 9, true))
764 }
765 _ => Box::new(LocalBinaryPatternExtractor::new(1, 8, true)),
766 };
767
768 Self {
769 patch_size,
770 stride,
771 feature_extractor,
772 neighbor_search: Some(ImageSimilaritySearch::new(config)),
773 }
774 }
775
776 pub fn extract_patches(
778 &self,
779 image_data: &[f64],
780 dimensions: (usize, usize, usize),
781 ) -> NeighborsResult<Vec<(Array1<f64>, (usize, usize))>> {
782 let (width, height, channels) = dimensions;
783 let mut patches = Vec::new();
784
785 for y in (0..(height.saturating_sub(self.patch_size))).step_by(self.stride) {
787 for x in (0..(width.saturating_sub(self.patch_size))).step_by(self.stride) {
788 let mut patch_data = Vec::new();
790
791 for py in y..(y + self.patch_size).min(height) {
792 for px in x..(x + self.patch_size).min(width) {
793 for c in 0..channels {
794 let idx = py * width * channels + px * channels + c;
795 patch_data.push(image_data[idx]);
796 }
797 }
798 }
799
800 let patch_dims = (self.patch_size, self.patch_size, channels);
802 let features = self
803 .feature_extractor
804 .extract_features(&patch_data, patch_dims)?;
805 patches.push((features, (x, y)));
806 }
807 }
808
809 Ok(patches)
810 }
811
812 pub fn build_patch_database(
814 &mut self,
815 images: &[(Vec<f64>, (usize, usize, usize))],
816 ) -> NeighborsResult<()> {
817 let mut all_patch_features = Vec::new();
818
819 for (image_data, dimensions) in images {
820 let patches = self.extract_patches(image_data, *dimensions)?;
821 for (features, _position) in patches {
822 all_patch_features.push(features);
823 }
824 }
825
826 if all_patch_features.is_empty() {
827 return Err(NeighborsError::EmptyInput);
828 }
829
830 let num_patches = all_patch_features.len();
832 let feature_dim = all_patch_features[0].len();
833 let mut feature_matrix = Array2::zeros((num_patches, feature_dim));
834
835 for (i, features) in all_patch_features.into_iter().enumerate() {
836 feature_matrix.row_mut(i).assign(&features);
837 }
838
839 if let Some(ref mut search) = self.neighbor_search {
841 search.build_index(&feature_matrix)?;
842 }
843
844 Ok(())
845 }
846
847 pub fn find_similar_patches(
849 &self,
850 query_patch: &[f64],
851 patch_dimensions: (usize, usize, usize),
852 k: usize,
853 ) -> NeighborsResult<Vec<ImageSearchResult>> {
854 let features = self
855 .feature_extractor
856 .extract_features(query_patch, patch_dimensions)?;
857
858 if let Some(ref search) = self.neighbor_search {
859 search.search(&features, k)
860 } else {
861 Err(NeighborsError::InvalidInput(
862 "Patch database not built".to_string(),
863 ))
864 }
865 }
866}
867
868pub struct FeatureDescriptorMatcher {
870 descriptor_dimension: usize,
871 match_threshold: f64,
872 ratio_test_threshold: f64,
873 use_cross_check: bool,
874}
875
876#[derive(Debug, Clone)]
878pub struct Keypoint {
879 pub position: (f64, f64),
881 pub scale: f64,
883 pub orientation: f64,
885 pub descriptor: Array1<f64>,
887 pub response: f64,
889}
890
891#[derive(Debug, Clone)]
893pub struct DescriptorMatch {
894 pub query_idx: usize,
896 pub train_idx: usize,
898 pub distance: f64,
900 pub confidence: f64,
902}
903
904impl FeatureDescriptorMatcher {
905 pub fn new(descriptor_dimension: usize) -> Self {
907 Self {
908 descriptor_dimension,
909 match_threshold: 0.8,
910 ratio_test_threshold: 0.7,
911 use_cross_check: true,
912 }
913 }
914
915 pub fn match_descriptors(
917 &self,
918 query_keypoints: &[Keypoint],
919 train_keypoints: &[Keypoint],
920 ) -> NeighborsResult<Vec<DescriptorMatch>> {
921 if query_keypoints.is_empty() || train_keypoints.is_empty() {
922 return Ok(Vec::new());
923 }
924
925 let mut matches = Vec::new();
926
927 let query_descriptors = Array2::from_shape_fn(
929 (query_keypoints.len(), self.descriptor_dimension),
930 |(i, j)| query_keypoints[i].descriptor[j],
931 );
932
933 let train_descriptors = Array2::from_shape_fn(
934 (train_keypoints.len(), self.descriptor_dimension),
935 |(i, j)| train_keypoints[i].descriptor[j],
936 );
937
938 for (query_idx, query_desc) in query_descriptors.rows().into_iter().enumerate() {
940 let mut distances: Vec<(usize, f64)> = Vec::new();
941
942 for (train_idx, train_desc) in train_descriptors.rows().into_iter().enumerate() {
944 let distance = self.compute_descriptor_distance(&query_desc, &train_desc);
945 distances.push((train_idx, distance));
946 }
947
948 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
950
951 if distances.len() >= 2 {
953 let best_distance = distances[0].1;
954 let second_best_distance = distances[1].1;
955
956 if second_best_distance > 0.0
957 && (best_distance / second_best_distance) < self.ratio_test_threshold
958 {
959 let train_idx = distances[0].0;
960 let confidence = 1.0 - (best_distance / second_best_distance);
961
962 matches.push(DescriptorMatch {
963 query_idx,
964 train_idx,
965 distance: best_distance,
966 confidence,
967 });
968 }
969 } else if !distances.is_empty() && distances[0].1 < self.match_threshold {
970 let train_idx = distances[0].0;
972 let confidence = 1.0 - distances[0].1;
973
974 matches.push(DescriptorMatch {
975 query_idx,
976 train_idx,
977 distance: distances[0].1,
978 confidence,
979 });
980 }
981 }
982
983 if self.use_cross_check {
985 matches = self.apply_cross_check(matches, &query_descriptors, &train_descriptors)?;
986 }
987
988 Ok(matches)
989 }
990
991 fn compute_descriptor_distance(&self, desc1: &ArrayView1<f64>, desc2: &ArrayView1<f64>) -> f64 {
993 desc1
995 .iter()
996 .zip(desc2.iter())
997 .map(|(a, b)| (a - b).powi(2))
998 .sum::<f64>()
999 .sqrt()
1000 }
1001
1002 fn apply_cross_check(
1004 &self,
1005 matches: Vec<DescriptorMatch>,
1006 query_descriptors: &Array2<f64>,
1007 train_descriptors: &Array2<f64>,
1008 ) -> NeighborsResult<Vec<DescriptorMatch>> {
1009 let mut filtered_matches = Vec::new();
1010
1011 for m in matches {
1012 let train_desc = train_descriptors.row(m.train_idx);
1014 let mut best_query_idx = 0;
1015 let mut best_distance = f64::INFINITY;
1016
1017 for (query_idx, query_desc) in query_descriptors.rows().into_iter().enumerate() {
1018 let distance = self.compute_descriptor_distance(&train_desc, &query_desc);
1019 if distance < best_distance {
1020 best_distance = distance;
1021 best_query_idx = query_idx;
1022 }
1023 }
1024
1025 if best_query_idx == m.query_idx {
1027 filtered_matches.push(m);
1028 }
1029 }
1030
1031 Ok(filtered_matches)
1032 }
1033}
1034
1035pub struct VisualWordRecognizer {
1037 vocabulary: Option<Array2<f64>>,
1038 vocabulary_size: usize,
1039 feature_extractor: Box<dyn FeatureExtractor>,
1040 classifier: Option<KNeighborsClassifier<Trained>>,
1041}
1042
1043impl VisualWordRecognizer {
1044 pub fn new(vocabulary_size: usize) -> Self {
1046 Self {
1047 vocabulary: None,
1048 vocabulary_size,
1049 feature_extractor: Box::new(LocalBinaryPatternExtractor::new(1, 8, true)),
1050 classifier: None,
1051 }
1052 }
1053
1054 pub fn build_vocabulary(
1056 &mut self,
1057 training_patches: &[Vec<f64>],
1058 patch_dimensions: (usize, usize, usize),
1059 ) -> NeighborsResult<()> {
1060 if training_patches.is_empty() {
1061 return Err(NeighborsError::EmptyInput);
1062 }
1063
1064 let mut patch_features = Vec::new();
1066 for patch_data in training_patches {
1067 let features = self
1068 .feature_extractor
1069 .extract_features(patch_data, patch_dimensions)?;
1070 patch_features.push(features);
1071 }
1072
1073 let feature_dim = patch_features[0].len();
1075 let mut vocabulary = Array2::zeros((self.vocabulary_size, feature_dim));
1076
1077 let mut rng = thread_rng();
1079 for i in 0..self.vocabulary_size {
1080 let random_idx = rng.gen_range(0..patch_features.len());
1081 vocabulary.row_mut(i).assign(&patch_features[random_idx]);
1082 }
1083
1084 for _iteration in 0..10 {
1086 let mut cluster_sums = vec![Array1::zeros(feature_dim); self.vocabulary_size];
1087 let mut cluster_counts = vec![0; self.vocabulary_size];
1088
1089 for patch_feature in &patch_features {
1091 let mut best_cluster = 0;
1092 let mut best_distance = f64::INFINITY;
1093
1094 for (cluster_idx, centroid) in vocabulary.rows().into_iter().enumerate() {
1095 let distance: f64 = patch_feature
1096 .iter()
1097 .zip(centroid.iter())
1098 .map(|(a, b)| (a - b).powi(2))
1099 .sum::<f64>()
1100 .sqrt();
1101
1102 if distance < best_distance {
1103 best_distance = distance;
1104 best_cluster = cluster_idx;
1105 }
1106 }
1107
1108 cluster_sums[best_cluster] = &cluster_sums[best_cluster] + patch_feature;
1109 cluster_counts[best_cluster] += 1;
1110 }
1111
1112 for i in 0..self.vocabulary_size {
1114 if cluster_counts[i] > 0 {
1115 vocabulary
1116 .row_mut(i)
1117 .assign(&(&cluster_sums[i] / cluster_counts[i] as f64));
1118 }
1119 }
1120 }
1121
1122 self.vocabulary = Some(vocabulary);
1123 Ok(())
1124 }
1125
1126 pub fn compute_bow_histogram(
1128 &self,
1129 image_patches: &[Vec<f64>],
1130 patch_dimensions: (usize, usize, usize),
1131 ) -> NeighborsResult<Array1<f64>> {
1132 let vocabulary = self
1133 .vocabulary
1134 .as_ref()
1135 .ok_or(NeighborsError::InvalidInput(
1136 "Vocabulary not built".to_string(),
1137 ))?;
1138
1139 let mut histogram = Array1::zeros(self.vocabulary_size);
1140
1141 for patch_data in image_patches {
1142 let features = self
1143 .feature_extractor
1144 .extract_features(patch_data, patch_dimensions)?;
1145
1146 let mut best_word = 0;
1148 let mut best_distance = f64::INFINITY;
1149
1150 for (word_idx, word_features) in vocabulary.rows().into_iter().enumerate() {
1151 let distance: f64 = features
1152 .iter()
1153 .zip(word_features.iter())
1154 .map(|(a, b)| (a - b).powi(2))
1155 .sum::<f64>()
1156 .sqrt();
1157
1158 if distance < best_distance {
1159 best_distance = distance;
1160 best_word = word_idx;
1161 }
1162 }
1163
1164 histogram[best_word] += 1.0;
1165 }
1166
1167 let sum: f64 = histogram.sum();
1169 if sum > 0.0 {
1170 histogram /= sum;
1171 }
1172
1173 Ok(histogram)
1174 }
1175
1176 pub fn train_classifier(
1178 &mut self,
1179 bow_histograms: &Array2<f64>,
1180 labels: &Array1<i32>,
1181 ) -> NeighborsResult<()> {
1182 let classifier = KNeighborsClassifier::new(5);
1183
1184 let features: Features = bow_histograms.mapv(|x| x as Float);
1186 let target_labels: Array1<Int> = labels.mapv(|x| x as Int);
1187
1188 let trained = classifier
1189 .fit(&features, &target_labels)
1190 .map_err(|e| NeighborsError::InvalidInput(e.to_string()))?;
1191
1192 self.classifier = Some(trained);
1193 Ok(())
1194 }
1195
1196 pub fn classify_image(
1198 &self,
1199 image_patches: &[Vec<f64>],
1200 patch_dimensions: (usize, usize, usize),
1201 ) -> NeighborsResult<i32> {
1202 let classifier = self
1203 .classifier
1204 .as_ref()
1205 .ok_or(NeighborsError::InvalidInput(
1206 "Classifier not trained".to_string(),
1207 ))?;
1208
1209 let bow_histogram = self.compute_bow_histogram(image_patches, patch_dimensions)?;
1210 let bow_2d = bow_histogram.insert_axis(Axis(0));
1211
1212 let features: Features = bow_2d.mapv(|x| x as Float);
1214
1215 let predictions = classifier
1216 .predict(&features)
1217 .map_err(|e| NeighborsError::InvalidInput(e.to_string()))?;
1218
1219 Ok(predictions[0] as i32)
1220 }
1221}
1222
1223#[allow(non_snake_case)]
1224#[cfg(test)]
1225mod tests {
1226 use super::*;
1227
1228 #[test]
1229 fn test_color_histogram_extractor() {
1230 let extractor = ColorHistogramExtractor::new(8, true);
1231 let image_data = vec![0.5; 32 * 32 * 3]; let dimensions = (32, 32, 3);
1233
1234 let features = extractor.extract_features(&image_data, dimensions).unwrap();
1235 assert_eq!(features.len(), 8 * 3);
1236
1237 let sum: f64 = features.sum();
1239 assert!((sum - 1.0).abs() < 1e-6);
1240 }
1241
1242 #[test]
1243 fn test_local_binary_pattern_extractor() {
1244 let extractor = LocalBinaryPatternExtractor::new(1, 8, true);
1245 let image_data = vec![0.5; 16 * 16]; let dimensions = (16, 16, 1);
1247
1248 let features = extractor.extract_features(&image_data, dimensions).unwrap();
1249 assert_eq!(features.len(), 8 + 2); }
1251
1252 #[test]
1253 fn test_image_similarity_search() {
1254 let config = ImageSearchConfig::default();
1255 let mut search = ImageSimilaritySearch::new(config);
1256
1257 let features = Array2::from_shape_fn((10, 96), |(i, j)| (i + j) as f64);
1259 search.build_index(&features).unwrap();
1260
1261 let (num_images, _feature_dim, _) = search.get_stats();
1263 assert_eq!(num_images, 0); assert_eq!(search.feature_database.as_ref().unwrap().nrows(), 10);
1265
1266 let query = Array1::from_vec((0..96).map(|i| i as f64).collect());
1268 let results = search.search(&query, 3).unwrap();
1269
1270 assert_eq!(results.len(), 3);
1271 assert!(results[0].distance <= results[1].distance);
1272 }
1273
1274 #[test]
1275 fn test_patch_based_matching() {
1276 let matcher = PatchBasedMatching::new(8, 4, FeatureType::LocalBinaryPattern);
1277
1278 let image_data = vec![0.5; 32 * 32 * 3];
1280 let dimensions = (32, 32, 3);
1281
1282 let patches = matcher.extract_patches(&image_data, dimensions).unwrap();
1283 assert!(!patches.is_empty());
1284
1285 let expected_patches = ((32_usize - 8) / 4 + 1).pow(2);
1287 assert!(patches.len() <= expected_patches);
1288 }
1289
1290 #[test]
1291 fn test_feature_descriptor_matcher() {
1292 let matcher = FeatureDescriptorMatcher::new(128);
1293
1294 let query_keypoints = vec![Keypoint {
1296 position: (10.0, 10.0),
1297 scale: 1.0,
1298 orientation: 0.0,
1299 descriptor: Array1::zeros(128),
1300 response: 0.5,
1301 }];
1302
1303 let train_keypoints = vec![Keypoint {
1304 position: (11.0, 11.0),
1305 scale: 1.0,
1306 orientation: 0.0,
1307 descriptor: Array1::from_elem(128, 0.1),
1308 response: 0.6,
1309 }];
1310
1311 let matches = matcher
1312 .match_descriptors(&query_keypoints, &train_keypoints)
1313 .unwrap();
1314 assert!(matches.len() <= 1);
1316 }
1317
1318 #[test]
1319 fn test_visual_word_recognizer() {
1320 let mut recognizer = VisualWordRecognizer::new(16);
1321
1322 let training_patches: Vec<Vec<f64>> =
1324 (0..10).map(|i| vec![i as f64 * 0.1; 8 * 8 * 3]).collect();
1325 let patch_dimensions = (8, 8, 3);
1326
1327 recognizer
1328 .build_vocabulary(&training_patches, patch_dimensions)
1329 .unwrap();
1330
1331 let test_patches = vec![vec![0.5; 8 * 8 * 3]];
1333 let histogram = recognizer
1334 .compute_bow_histogram(&test_patches, patch_dimensions)
1335 .unwrap();
1336
1337 assert_eq!(histogram.len(), 16);
1338 assert!((histogram.sum() - 1.0).abs() < 1e-6);
1339 }
1340}