1use crate::error::{Result, VisionError};
21use crate::feature::KeyPoint;
22use crate::gpu_ops::GpuVisionContext;
23use scirs2_core::ndarray::ArrayStatCompat;
24use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
25use statrs::statistics::Statistics;
26
27pub struct NeuralFeatureNetwork {
29 #[allow(dead_code)]
31 detection_weights: ModelWeights,
32 #[allow(dead_code)]
34 descriptor_weights: ModelWeights,
35 gpu_context: Option<GpuVisionContext>,
37 config: NeuralFeatureConfig,
39}
40
41#[derive(Clone)]
43pub struct ModelWeights {
44 #[allow(dead_code)]
46 conv_weights: Vec<Array3<f32>>,
47 #[allow(dead_code)]
49 conv_biases: Vec<Array1<f32>>,
50 #[allow(dead_code)]
52 bn_weights: Vec<Array1<f32>>,
53 #[allow(dead_code)]
55 bn_biases: Vec<Array1<f32>>,
56 #[allow(dead_code)]
58 fc_weights: Vec<Array2<f32>>,
59 #[allow(dead_code)]
61 fc_biases: Vec<Array1<f32>>,
62}
63
64#[derive(Clone)]
66pub struct NeuralFeatureConfig {
67 pub input_size: (usize, usize),
69 pub max_keypoints: usize,
71 pub detection_threshold: f32,
73 pub nms_radius: usize,
75 pub descriptor_dim: usize,
77 pub border_remove: usize,
79 pub use_gpu: bool,
81}
82
83impl Default for NeuralFeatureConfig {
84 fn default() -> Self {
85 Self {
86 input_size: (480, 640),
87 max_keypoints: 1024,
88 detection_threshold: 0.015,
89 nms_radius: 4,
90 descriptor_dim: 256,
91 border_remove: 4,
92 use_gpu: true,
93 }
94 }
95}
96
97pub struct SuperPointNet {
99 network: NeuralFeatureNetwork,
100}
101
102impl SuperPointNet {
103 pub fn new(config: Option<NeuralFeatureConfig>) -> Result<Self> {
105 let config = config.unwrap_or_default();
106
107 let detection_weights = Self::create_detection_weights(&config)?;
110 let descriptor_weights = Self::create_descriptor_weights(&config)?;
111
112 let gpu_context = if config.use_gpu {
113 GpuVisionContext::new().ok()
114 } else {
115 None
116 };
117
118 let network = NeuralFeatureNetwork {
119 detection_weights,
120 descriptor_weights,
121 gpu_context,
122 config,
123 };
124
125 Ok(Self { network })
126 }
127
128 #[allow(dead_code)]
130 pub fn from_file(_modelpath: &str, config: Option<NeuralFeatureConfig>) -> Result<Self> {
131 let config = config.unwrap_or_default();
132
133 Self::new(Some(config))
136 }
137
138 pub fn detect_and_describe(
140 &self,
141 image: &ArrayView2<f32>,
142 ) -> Result<(Vec<KeyPoint>, Array2<f32>)> {
143 let (height, width) = image.dim();
145 if height % 8 != 0 || width % 8 != 0 {
146 return Err(VisionError::InvalidInput(
147 "Input image dimensions must be multiples of 8 for neural feature detection"
148 .to_string(),
149 ));
150 }
151
152 let processed_image = if (height, width) != self.network.config.input_size {
154 self.resize_image(image, self.network.config.input_size)?
155 } else {
156 image.to_owned()
157 };
158
159 if let Some(ref gpu_ctx) = self.network.gpu_context {
161 self.gpu_inference(gpu_ctx, &processed_image.view())
162 } else {
163 self.cpu_inference(&processed_image.view())
164 }
165 }
166
167 fn gpu_inference(
169 &self,
170 gpu_ctx: &GpuVisionContext,
171 image: &ArrayView2<f32>,
172 ) -> Result<(Vec<KeyPoint>, Array2<f32>)> {
173 let featuremap = self.gpu_forward_detection(gpu_ctx, image)?;
175 let descriptor_map = self.gpu_forward_descriptors(gpu_ctx, image)?;
176
177 self.post_process_features(&featuremap, &descriptor_map)
179 }
180
181 fn cpu_inference(&self, image: &ArrayView2<f32>) -> Result<(Vec<KeyPoint>, Array2<f32>)> {
183 let featuremap = self.cpu_forward_detection(image)?;
185 let descriptor_map = self.cpu_forward_descriptors(image)?;
186
187 self.post_process_features(&featuremap, &descriptor_map)
189 }
190
191 fn gpu_forward_detection(
193 &self,
194 gpu_ctx: &GpuVisionContext,
195 image: &ArrayView2<f32>,
196 ) -> Result<Array2<f32>> {
197 let conv1_kernel =
202 Array2::from_shape_vec((3, 3), vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0])?;
203
204 let conv1_result = crate::gpu_ops::gpu_convolve_2d(gpu_ctx, image, &conv1_kernel.view())?;
205
206 let activated = conv1_result.mapv(|x| x.max(0.0));
208
209 let pooled = crate::gpu_ops::gpu_gaussian_blur(gpu_ctx, &activated.view(), 2.0)?;
211
212 let (height, width) = pooled.dim();
214 let out_height = height / 8;
215 let out_width = width / 8;
216
217 let mut detection_map = Array2::zeros((out_height, out_width));
218 for y in 0..out_height {
219 for x in 0..out_width {
220 let src_y = (y * 8).min(height - 1);
221 let src_x = (x * 8).min(width - 1);
222 detection_map[[y, x]] = pooled[[src_y, src_x]].abs();
223 }
224 }
225
226 Ok(detection_map)
227 }
228
229 fn cpu_forward_detection(&self, image: &ArrayView2<f32>) -> Result<Array2<f32>> {
231 let (_, _, magnitude) = self.compute_simple_gradients(image)?;
233
234 let (height, width) = magnitude.dim();
236 let out_height = height / 8;
237 let out_width = width / 8;
238
239 let mut detection_map = Array2::zeros((out_height, out_width));
240 for y in 0..out_height {
241 for x in 0..out_width {
242 let mut max_val = 0.0f32;
243 for dy in 0..8 {
244 for dx in 0..8 {
245 let src_y = (y * 8 + dy).min(height - 1);
246 let src_x = (x * 8 + dx).min(width - 1);
247 max_val = max_val.max(magnitude[[src_y, src_x]]);
248 }
249 }
250 detection_map[[y, x]] = max_val;
251 }
252 }
253
254 Ok(detection_map)
255 }
256
257 fn gpu_forward_descriptors(
259 &self,
260 gpu_ctx: &GpuVisionContext,
261 image: &ArrayView2<f32>,
262 ) -> Result<Array3<f32>> {
263 let blurred = crate::gpu_ops::gpu_gaussian_blur(gpu_ctx, image, 1.0)?;
265 let (height, width) = blurred.dim();
266
267 let desc_height = height / 8;
269 let desc_width = width / 8;
270 let desc_dim = self.network.config.descriptor_dim;
271
272 let mut descriptor_map = Array3::zeros((desc_height, desc_width, desc_dim));
273
274 for y in 0..desc_height {
276 for x in 0..desc_width {
277 let patch_y = y * 8;
278 let patch_x = x * 8;
279
280 let mut descriptor = Array1::zeros(desc_dim);
281
282 for i in 0..desc_dim {
284 let dy = i % 16;
285 let dx = i / 16;
286 let sample_y = (patch_y + dy).min(height - 1);
287 let sample_x = (patch_x + dx).min(width - 1);
288 descriptor[i] = blurred[[sample_y, sample_x]];
289 }
290
291 let norm = descriptor.dot(&descriptor).sqrt();
293 if norm > 1e-6 {
294 descriptor.mapv_inplace(|x| x / norm);
295 }
296
297 descriptor_map.slice_mut(s![y, x, ..]).assign(&descriptor);
298 }
299 }
300
301 Ok(descriptor_map)
302 }
303
304 fn compute_simple_gradients(
306 &self,
307 image: &ArrayView2<f32>,
308 ) -> Result<(Array2<f32>, Array2<f32>, Array2<f32>)> {
309 let (height, width) = image.dim();
310 let mut gx = Array2::zeros((height, width));
311 let mut gy = Array2::zeros((height, width));
312 let mut magnitude = Array2::zeros((height, width));
313
314 for y in 1..height - 1 {
315 for x in 1..width - 1 {
316 let dx = image[[y, x + 1]] - image[[y, x - 1]];
317 let dy = image[[y + 1, x]] - image[[y - 1, x]];
318 gx[[y, x]] = dx;
319 gy[[y, x]] = dy;
320 magnitude[[y, x]] = (dx * dx + dy * dy).sqrt();
321 }
322 }
323
324 Ok((gx, gy, magnitude))
325 }
326
327 fn simple_gaussian_blur(&self, image: &ArrayView2<f32>, sigma: f32) -> Result<Array2<f32>> {
329 let (height, width) = image.dim();
331 let mut blurred = Array2::zeros((height, width));
332
333 for y in 1..height - 1 {
334 for x in 1..width - 1 {
335 let avg = (image[[y - 1, x - 1]]
336 + image[[y - 1, x]]
337 + image[[y - 1, x + 1]]
338 + image[[y, x - 1]]
339 + image[[y, x]]
340 + image[[y, x + 1]]
341 + image[[y + 1, x - 1]]
342 + image[[y + 1, x]]
343 + image[[y + 1, x + 1]])
344 / 9.0;
345 blurred[[y, x]] = avg;
346 }
347 }
348
349 for y in 0..height {
351 blurred[[y, 0]] = image[[y, 0]];
352 if width > 1 {
353 blurred[[y, width - 1]] = image[[y, width - 1]];
354 }
355 }
356 for x in 0..width {
357 blurred[[0, x]] = image[[0, x]];
358 if height > 1 {
359 blurred[[height - 1, x]] = image[[height - 1, x]];
360 }
361 }
362
363 Ok(blurred)
364 }
365
366 fn cpu_forward_descriptors(&self, image: &ArrayView2<f32>) -> Result<Array3<f32>> {
368 let blurred = self.simple_gaussian_blur(image, 1.0)?;
369 let (height, width) = blurred.dim();
370
371 let desc_height = height / 8;
372 let desc_width = width / 8;
373 let desc_dim = self.network.config.descriptor_dim;
374
375 let mut descriptor_map = Array3::zeros((desc_height, desc_width, desc_dim));
376
377 for y in 0..desc_height {
379 for x in 0..desc_width {
380 let patch_y = y * 8;
381 let patch_x = x * 8;
382
383 let mut descriptor = Array1::zeros(desc_dim);
384
385 for i in 0..desc_dim.min(64) {
387 let angle = i as f32 * std::f32::consts::PI / 32.0;
388 let cos_a = angle.cos();
389 let sin_a = angle.sin();
390
391 let mut sum = 0.0f32;
392 for dy in 0..8 {
393 for dx in 0..8 {
394 let sample_y = (patch_y + dy).min(height - 1);
395 let sample_x = (patch_x + dx).min(width - 1);
396 let value = blurred[[sample_y, sample_x]];
397 let weight = (cos_a * dx as f32 + sin_a * dy as f32).cos();
398 sum += value * weight;
399 }
400 }
401 descriptor[i] = sum;
402 }
403
404 let norm = descriptor.dot(&descriptor).sqrt();
406 if norm > 1e-6 {
407 descriptor.mapv_inplace(|x| x / norm);
408 }
409
410 descriptor_map.slice_mut(s![y, x, ..]).assign(&descriptor);
411 }
412 }
413
414 Ok(descriptor_map)
415 }
416
417 fn post_process_features(
419 &self,
420 featuremap: &Array2<f32>,
421 descriptor_map: &Array3<f32>,
422 ) -> Result<(Vec<KeyPoint>, Array2<f32>)> {
423 let nms_result = self.non_maximum_suppression(featuremap)?;
425
426 let mut candidates: Vec<(f32, usize, usize)> = Vec::new();
428 let (height, width) = nms_result.dim();
429
430 for y in self.network.config.border_remove..height - self.network.config.border_remove {
431 for x in self.network.config.border_remove..width - self.network.config.border_remove {
432 let score = nms_result[[y, x]];
433 if score > self.network.config.detection_threshold {
434 candidates.push((score, y, x));
435 }
436 }
437 }
438
439 candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Operation failed"));
441 candidates.truncate(self.network.config.max_keypoints);
442
443 let mut keypoints = Vec::new();
445 let mut descriptors = Array2::zeros((candidates.len(), self.network.config.descriptor_dim));
446
447 for (i, &(score, y, x)) in candidates.iter().enumerate() {
448 let orig_x = (x * 8) as f32;
450 let orig_y = (y * 8) as f32;
451
452 keypoints.push(KeyPoint {
453 x: orig_x,
454 y: orig_y,
455 response: score,
456 scale: 1.0,
457 orientation: 0.0, });
459
460 if y < descriptor_map.shape()[0] && x < descriptor_map.shape()[1] {
462 let desc = descriptor_map.slice(s![y, x, ..]);
463 descriptors.slice_mut(s![i, ..]).assign(&desc);
464 }
465 }
466
467 Ok((keypoints, descriptors))
468 }
469
470 fn non_maximum_suppression(&self, featuremap: &Array2<f32>) -> Result<Array2<f32>> {
472 let (height, width) = featuremap.dim();
473 let mut nms_result = Array2::zeros((height, width));
474 let radius = self.network.config.nms_radius;
475
476 for y in radius..height - radius {
477 for x in radius..width - radius {
478 let center_val = featuremap[[y, x]];
479 let mut is_maximum = true;
480
481 for dy in -(radius as isize)..=(radius as isize) {
483 for dx in -(radius as isize)..=(radius as isize) {
484 if dy == 0 && dx == 0 {
485 continue;
486 }
487
488 let ny = (y as isize + dy) as usize;
489 let nx = (x as isize + dx) as usize;
490
491 if featuremap[[ny, nx]] >= center_val {
492 is_maximum = false;
493 break;
494 }
495 }
496 if !is_maximum {
497 break;
498 }
499 }
500
501 if is_maximum {
502 nms_result[[y, x]] = center_val;
503 }
504 }
505 }
506
507 Ok(nms_result)
508 }
509
510 fn resize_image(
512 &self,
513 image: &ArrayView2<f32>,
514 target_size: (usize, usize),
515 ) -> Result<Array2<f32>> {
516 let (src_height, src_width) = image.dim();
517 let (dst_height, dst_width) = target_size;
518
519 let mut resized = Array2::zeros((dst_height, dst_width));
520
521 let scale_y = src_height as f32 / dst_height as f32;
522 let scale_x = src_width as f32 / dst_width as f32;
523
524 for y in 0..dst_height {
525 for x in 0..dst_width {
526 let src_y = (y as f32 * scale_y) as usize;
527 let src_x = (x as f32 * scale_x) as usize;
528
529 let src_y = src_y.min(src_height - 1);
530 let src_x = src_x.min(src_width - 1);
531
532 resized[[y, x]] = image[[src_y, src_x]];
533 }
534 }
535
536 Ok(resized)
537 }
538
539 fn create_detection_weights(config: &NeuralFeatureConfig) -> Result<ModelWeights> {
541 let conv_weights = vec![
545 Array3::from_shape_fn((64, 1, 3), |___| scirs2_core::random::random::<f32>() * 0.1),
546 Array3::from_shape_fn((64, 64, 3), |___| {
547 scirs2_core::random::random::<f32>() * 0.1
548 }),
549 Array3::from_shape_fn((128, 64, 3), |___| {
550 scirs2_core::random::random::<f32>() * 0.1
551 }),
552 Array3::from_shape_fn((128, 128, 3), |___| {
553 scirs2_core::random::random::<f32>() * 0.1
554 }),
555 ];
556
557 let conv_biases = vec![
558 Array1::zeros(64),
559 Array1::zeros(64),
560 Array1::zeros(128),
561 Array1::zeros(128),
562 ];
563
564 let bn_weights = vec![
565 Array1::ones(64),
566 Array1::ones(64),
567 Array1::ones(128),
568 Array1::ones(128),
569 ];
570
571 let bn_biases = vec![
572 Array1::zeros(64),
573 Array1::zeros(64),
574 Array1::zeros(128),
575 Array1::zeros(128),
576 ];
577
578 let fc_weights = vec![Array2::from_shape_fn((65, 128), |_| {
580 scirs2_core::random::random::<f32>() * 0.1
581 })];
582
583 let fc_biases = vec![
584 Array1::zeros(65), ];
586
587 Ok(ModelWeights {
588 conv_weights,
589 conv_biases,
590 bn_weights,
591 bn_biases,
592 fc_weights,
593 fc_biases,
594 })
595 }
596
597 fn create_descriptor_weights(config: &NeuralFeatureConfig) -> Result<ModelWeights> {
599 let fc_weights = vec![Array2::from_shape_fn((config.descriptor_dim, 128), |_| {
601 scirs2_core::random::random::<f32>() * 0.1
602 })];
603
604 let fc_biases = vec![Array1::zeros(config.descriptor_dim)];
605
606 Ok(ModelWeights {
607 conv_weights: Vec::new(),
608 conv_biases: Vec::new(),
609 bn_weights: Vec::new(),
610 bn_biases: Vec::new(),
611 fc_weights,
612 fc_biases,
613 })
614 }
615}
616
617pub struct NeuralFeatureMatcher {
619 distance_threshold: f32,
621 ratio_threshold: f32,
623 #[allow(dead_code)]
625 use_gpu: bool,
626}
627
628impl Default for NeuralFeatureMatcher {
629 fn default() -> Self {
630 Self::new()
631 }
632}
633
634impl NeuralFeatureMatcher {
635 pub fn new() -> Self {
637 Self {
638 distance_threshold: 0.7,
639 ratio_threshold: 0.8,
640 use_gpu: true,
641 }
642 }
643
644 pub fn with_params(mut self, distance_threshold: f32, ratiothreshold: f32) -> Self {
646 self.distance_threshold = distance_threshold;
647 self.ratio_threshold = ratiothreshold;
648 self
649 }
650
651 pub fn match_descriptors(
653 &self,
654 desc1: &ArrayView2<f32>,
655 desc2: &ArrayView2<f32>,
656 ) -> Result<Vec<(usize, usize)>> {
657 let n1 = desc1.shape()[0];
658 let n2 = desc2.shape()[0];
659
660 if n1 == 0 || n2 == 0 {
661 return Ok(Vec::new());
662 }
663
664 let distances = self.compute_pairwise_distances(desc1, desc2)?;
666
667 let mut matches = Vec::new();
669
670 for i in 0..n1 {
671 let mut best_dist = f32::INFINITY;
672 let mut second_best_dist = f32::INFINITY;
673 let mut best_idx = 0;
674
675 for j in 0..n2 {
676 let dist = distances[[i, j]];
677 if dist < best_dist {
678 second_best_dist = best_dist;
679 best_dist = dist;
680 best_idx = j;
681 } else if dist < second_best_dist {
682 second_best_dist = dist;
683 }
684 }
685
686 if best_dist < self.distance_threshold
688 && best_dist / second_best_dist < self.ratio_threshold
689 {
690 matches.push((i, best_idx));
691 }
692 }
693
694 Ok(matches)
695 }
696
697 fn compute_pairwise_distances(
699 &self,
700 desc1: &ArrayView2<f32>,
701 desc2: &ArrayView2<f32>,
702 ) -> Result<Array2<f32>> {
703 let n1 = desc1.shape()[0];
704 let n2 = desc2.shape()[0];
705 let mut distances = Array2::zeros((n1, n2));
706
707 for i in 0..n1 {
709 for j in 0..n2 {
710 let desc1_row = desc1.slice(s![i, ..]);
711 let desc2_row = desc2.slice(s![j, ..]);
712
713 let dot_product = desc1_row.dot(&desc2_row);
715 let norm1 = desc1_row.dot(&desc1_row).sqrt();
716 let norm2 = desc2_row.dot(&desc2_row).sqrt();
717
718 let cosine_sim = if norm1 > 1e-6 && norm2 > 1e-6 {
719 dot_product / (norm1 * norm2)
720 } else {
721 0.0
722 };
723
724 distances[[i, j]] = 1.0 - cosine_sim;
725 }
726 }
727
728 Ok(distances)
729 }
730}
731
732pub struct AttentionFeatureMatcher {
734 #[allow(dead_code)]
736 attention_dim: usize,
737 #[allow(dead_code)]
739 numheads: usize,
740 #[allow(dead_code)]
742 use_gpu: bool,
743}
744
745impl AttentionFeatureMatcher {
746 pub fn new(_attention_dim: usize, numheads: usize) -> Self {
748 Self {
749 attention_dim: _attention_dim,
750 numheads,
751 use_gpu: true,
752 }
753 }
754
755 pub fn match_with_attention(
757 &self,
758 keypoints1: &[KeyPoint],
759 descriptors1: &ArrayView2<f32>,
760 keypoints2: &[KeyPoint],
761 descriptors2: &ArrayView2<f32>,
762 ) -> Result<Vec<(usize, usize)>> {
763 let n1 = descriptors1.shape()[0];
767 let n2 = descriptors2.shape()[0];
768
769 if n1 == 0 || n2 == 0 {
770 return Ok(Vec::new());
771 }
772
773 let pos_enc1 = self.compute_positional_encoding(keypoints1)?;
775 let pos_enc2 = self.compute_positional_encoding(keypoints2)?;
776
777 let enhanced_desc1 = self.enhance_descriptors(descriptors1, &pos_enc1)?;
779 let enhanced_desc2 = self.enhance_descriptors(descriptors2, &pos_enc2)?;
780
781 let attention_scores = self.compute_attention_scores(&enhanced_desc1, &enhanced_desc2)?;
783
784 self.extract_matches_from_attention(&attention_scores)
786 }
787
788 fn compute_positional_encoding(&self, keypoints: &[KeyPoint]) -> Result<Array2<f32>> {
790 let n = keypoints.len();
791 let mut pos_encoding = Array2::zeros((n, 4)); for (i, kp) in keypoints.iter().enumerate() {
794 pos_encoding[[i, 0]] = kp.x / 1000.0; pos_encoding[[i, 1]] = kp.y / 1000.0;
796 pos_encoding[[i, 2]] = (kp.x * 0.01).cos();
797 pos_encoding[[i, 3]] = (kp.y * 0.01).sin();
798 }
799
800 Ok(pos_encoding)
801 }
802
803 fn enhance_descriptors(
805 &self,
806 descriptors: &ArrayView2<f32>,
807 pos_encoding: &Array2<f32>,
808 ) -> Result<Array2<f32>> {
809 let n = descriptors.shape()[0];
810 let desc_dim = descriptors.shape()[1];
811 let pos_dim = pos_encoding.shape()[1];
812
813 let mut enhanced = Array2::zeros((n, desc_dim + pos_dim));
814
815 for i in 0..n {
817 enhanced
818 .slice_mut(s![i, ..desc_dim])
819 .assign(&descriptors.slice(s![i, ..]));
820 enhanced
821 .slice_mut(s![i, desc_dim..])
822 .assign(&pos_encoding.slice(s![i, ..]));
823 }
824
825 Ok(enhanced)
826 }
827
828 fn compute_attention_scores(
830 &self,
831 desc1: &Array2<f32>,
832 desc2: &Array2<f32>,
833 ) -> Result<Array2<f32>> {
834 let n1 = desc1.shape()[0];
835 let n2 = desc2.shape()[0];
836 let dim = desc1.shape()[1];
837
838 let mut attention_scores = Array2::zeros((n1, n2));
840 let scale = 1.0 / (dim as f32).sqrt();
841
842 for i in 0..n1 {
843 for j in 0..n2 {
844 let query = desc1.slice(s![i, ..]);
845 let key = desc2.slice(s![j, ..]);
846
847 let score = query.dot(&key) * scale;
849 attention_scores[[i, j]] = score;
850 }
851 }
852
853 for i in 0..n1 {
855 let mut row = attention_scores.slice_mut(s![i, ..]);
856 let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
857
858 row.mapv_inplace(|x| (x - max_val).exp());
859 let sum = row.sum();
860 if sum > 1e-8 {
861 row.mapv_inplace(|x| x / sum);
862 }
863 }
864
865 Ok(attention_scores)
866 }
867
868 fn extract_matches_from_attention(
870 &self,
871 attention_scores: &Array2<f32>,
872 ) -> Result<Vec<(usize, usize)>> {
873 let n1 = attention_scores.shape()[0];
874 let n2 = attention_scores.shape()[1];
875 let mut matches = Vec::new();
876
877 let mut used_j = vec![false; n2];
880
881 for i in 0..n1 {
882 let mut best_score = 0.0;
883 let mut best_j = None;
884
885 for j in 0..n2 {
886 if !used_j[j] && attention_scores[[i, j]] > best_score {
887 best_score = attention_scores[[i, j]];
888 best_j = Some(j);
889 }
890 }
891
892 if let Some(j) = best_j {
894 if best_score > 0.1 {
895 matches.push((i, j));
897 used_j[j] = true;
898 }
899 }
900 }
901
902 Ok(matches)
903 }
904}
905
906pub struct LearnedSIFT {
908 siftconfig: SIFTConfig,
910 enhancement_network: Option<NeuralFeatureNetwork>,
912}
913
914#[derive(Clone)]
916pub struct SIFTConfig {
917 pub num_octaves: usize,
919 pub num_scales: usize,
921 pub sigma: f32,
923 pub edge_threshold: f32,
925 pub peak_threshold: f32,
927}
928
929impl Default for SIFTConfig {
930 fn default() -> Self {
931 Self {
932 num_octaves: 4,
933 num_scales: 3,
934 sigma: 1.6,
935 edge_threshold: 10.0,
936 peak_threshold: 0.03,
937 }
938 }
939}
940
941impl LearnedSIFT {
942 pub fn new(config: Option<SIFTConfig>) -> Self {
944 Self {
945 siftconfig: config.unwrap_or_default(),
946 enhancement_network: None,
947 }
948 }
949
950 fn simple_gaussian_blur(&self, image: &ArrayView2<f32>, sigma: f32) -> Result<Array2<f32>> {
952 let (height, width) = image.dim();
954 let mut blurred = Array2::zeros((height, width));
955
956 for y in 1..height - 1 {
957 for x in 1..width - 1 {
958 let avg = (image[[y - 1, x - 1]]
959 + image[[y - 1, x]]
960 + image[[y - 1, x + 1]]
961 + image[[y, x - 1]]
962 + image[[y, x]]
963 + image[[y, x + 1]]
964 + image[[y + 1, x - 1]]
965 + image[[y + 1, x]]
966 + image[[y + 1, x + 1]])
967 / 9.0;
968 blurred[[y, x]] = avg;
969 }
970 }
971
972 for y in 0..height {
974 blurred[[y, 0]] = image[[y, 0]];
975 if width > 1 {
976 blurred[[y, width - 1]] = image[[y, width - 1]];
977 }
978 }
979 for x in 0..width {
980 blurred[[0, x]] = image[[0, x]];
981 if height > 1 {
982 blurred[[height - 1, x]] = image[[height - 1, x]];
983 }
984 }
985
986 Ok(blurred)
987 }
988
989 pub fn detect_keypoints(&self, image: &ArrayView2<f32>) -> Result<Vec<KeyPoint>> {
991 let scalespace = self.build_scale_space(image)?;
993
994 let dogspace = self.compute_dog_space(&scalespace)?;
996 let extrema = self.detect_extrema(&dogspace)?;
997
998 let refined_keypoints = self.refine_keypoints(&extrema, &dogspace)?;
1000
1001 let filtered_keypoints = self.filter_keypoints(&refined_keypoints, &dogspace)?;
1003
1004 Ok(filtered_keypoints)
1005 }
1006
1007 pub fn compute_descriptors(
1009 &self,
1010 image: &ArrayView2<f32>,
1011 keypoints: &[KeyPoint],
1012 ) -> Result<Array2<f32>> {
1013 let mut descriptors = Array2::zeros((keypoints.len(), 128));
1014
1015 for (i, kp) in keypoints.iter().enumerate() {
1016 let descriptor = self.compute_sift_descriptor(image, kp)?;
1017 descriptors.slice_mut(s![i, ..]).assign(&descriptor);
1018 }
1019
1020 if let Some(ref network) = self.enhancement_network {
1022 self.enhance_descriptors_neural(&mut descriptors, network)?;
1023 }
1024
1025 Ok(descriptors)
1026 }
1027
1028 fn build_scale_space(&self, image: &ArrayView2<f32>) -> Result<Vec<Vec<Array2<f32>>>> {
1030 let mut scalespace = Vec::new();
1031 let mut current_image = image.to_owned();
1032
1033 for octave in 0..self.siftconfig.num_octaves {
1034 let mut octave_images = Vec::new();
1035
1036 for scale in 0..self.siftconfig.num_scales + 3 {
1037 let sigma = self.siftconfig.sigma
1038 * 2.0_f32.powf(scale as f32 / self.siftconfig.num_scales as f32);
1039 let blurred = self.simple_gaussian_blur(¤t_image.view(), sigma)?;
1040 octave_images.push(blurred);
1041 }
1042
1043 scalespace.push(octave_images);
1044
1045 if octave < self.siftconfig.num_octaves - 1 {
1047 current_image = self.downsample(¤t_image)?;
1048 }
1049 }
1050
1051 Ok(scalespace)
1052 }
1053
1054 fn compute_dog_space(&self, scalespace: &[Vec<Array2<f32>>]) -> Result<Vec<Vec<Array2<f32>>>> {
1056 let mut dogspace = Vec::new();
1057
1058 for octave_images in scalespace {
1059 let mut dog_octave = Vec::new();
1060
1061 for i in 0..octave_images.len() - 1 {
1062 let dog = &octave_images[i + 1] - &octave_images[i];
1063 dog_octave.push(dog);
1064 }
1065
1066 dogspace.push(dog_octave);
1067 }
1068
1069 Ok(dogspace)
1070 }
1071
1072 fn detect_extrema(&self, dogspace: &[Vec<Array2<f32>>]) -> Result<Vec<KeyPoint>> {
1074 let mut extrema = Vec::new();
1075
1076 for (octave, dog_octave) in dogspace.iter().enumerate() {
1077 for (scale, dog_image) in dog_octave
1078 .iter()
1079 .enumerate()
1080 .skip(1)
1081 .take(dog_octave.len() - 2)
1082 {
1083 let (height, width) = dog_image.dim();
1084
1085 for y in 1..height - 1 {
1086 for x in 1..width - 1 {
1087 let center_val = dog_image[[y, x]];
1088
1089 if center_val.abs() < self.siftconfig.peak_threshold {
1090 continue;
1091 }
1092
1093 if self.is_extremum(dog_octave, scale, y, x, center_val) {
1095 extrema.push(KeyPoint {
1096 x: x as f32 * 2.0_f32.powi(octave as i32),
1097 y: y as f32 * 2.0_f32.powi(octave as i32),
1098 response: center_val.abs(),
1099 scale: 2.0_f32.powi(octave as i32),
1100 orientation: 0.0,
1101 });
1102 }
1103 }
1104 }
1105 }
1106 }
1107
1108 Ok(extrema)
1109 }
1110
1111 fn is_extremum(
1113 &self,
1114 dog_octave: &[Array2<f32>],
1115 scale: usize,
1116 y: usize,
1117 x: usize,
1118 center_val: f32,
1119 ) -> bool {
1120 let is_max = center_val > 0.0;
1121
1122 for s_offset in -1_isize..=1_isize {
1124 let s = (scale as isize + s_offset) as usize;
1125 for dy in -1_isize..=1_isize {
1126 for dx in -1_isize..=1_isize {
1127 if s_offset == 0 && dy == 0 && dx == 0 {
1128 continue;
1129 }
1130
1131 let ny = (y as isize + dy) as usize;
1132 let nx = (x as isize + dx) as usize;
1133
1134 let neighbor_val = dog_octave[s][[ny, nx]];
1135
1136 if is_max && neighbor_val >= center_val {
1137 return false;
1138 }
1139 if !is_max && neighbor_val <= center_val {
1140 return false;
1141 }
1142 }
1143 }
1144 }
1145
1146 true
1147 }
1148
1149 fn refine_keypoints(
1151 &self,
1152 keypoints: &[KeyPoint],
1153 _dog_space: &[Vec<Array2<f32>>],
1154 ) -> Result<Vec<KeyPoint>> {
1155 Ok(keypoints.to_vec())
1158 }
1159
1160 fn filter_keypoints(
1162 &self,
1163 keypoints: &[KeyPoint],
1164 _dog_space: &[Vec<Array2<f32>>],
1165 ) -> Result<Vec<KeyPoint>> {
1166 let mut filtered = Vec::new();
1167
1168 for kp in keypoints {
1169 if kp.response > self.siftconfig.peak_threshold {
1171 filtered.push(kp.clone());
1172 }
1173 }
1174
1175 Ok(filtered)
1176 }
1177
1178 fn compute_sift_descriptor(
1180 &self,
1181 image: &ArrayView2<f32>,
1182 keypoint: &KeyPoint,
1183 ) -> Result<Array1<f32>> {
1184 let mut descriptor = Array1::zeros(128);
1188 let (height, width) = image.dim();
1189
1190 let x = keypoint.x as usize;
1191 let y = keypoint.y as usize;
1192
1193 for i in 0..128 {
1195 let angle = i as f32 * 2.0 * std::f32::consts::PI / 128.0;
1196 let radius = 8.0 + (i % 16) as f32;
1197
1198 let sample_x = x as f32 + radius * angle.cos();
1199 let sample_y = y as f32 + radius * angle.sin();
1200
1201 if sample_x >= 0.0
1202 && sample_x < width as f32
1203 && sample_y >= 0.0
1204 && sample_y < height as f32
1205 {
1206 let sx = sample_x as usize;
1207 let sy = sample_y as usize;
1208 descriptor[i] = image[[sy.min(height - 1), sx.min(width - 1)]];
1209 }
1210 }
1211
1212 let norm = descriptor.dot(&descriptor).sqrt();
1214 if norm > 1e-6 {
1215 descriptor.mapv_inplace(|x| x / norm);
1216 }
1217
1218 Ok(descriptor)
1219 }
1220
1221 fn enhance_descriptors_neural(
1223 &self,
1224 descriptors: &mut Array2<f32>,
1225 _network: &NeuralFeatureNetwork,
1226 ) -> Result<()> {
1227 for mut row in descriptors.rows_mut() {
1232 let mean = row.mean_or(0.0);
1233 let std = ((row.mapv(|x| (x - mean).powi(2)).mean_or(0.0)).sqrt()).max(1e-6);
1234 row.mapv_inplace(|x| (x - mean) / std);
1235 }
1236
1237 Ok(())
1238 }
1239
1240 fn downsample(&self, image: &Array2<f32>) -> Result<Array2<f32>> {
1242 let (height, width) = image.dim();
1243 let new_height = height / 2;
1244 let new_width = width / 2;
1245
1246 let mut downsampled = Array2::zeros((new_height, new_width));
1247
1248 for y in 0..new_height {
1249 for x in 0..new_width {
1250 downsampled[[y, x]] = image[[y * 2, x * 2]];
1251 }
1252 }
1253
1254 Ok(downsampled)
1255 }
1256}
1257
1258#[cfg(test)]
1259mod tests {
1260 use super::*;
1261 use scirs2_core::ndarray::arr2;
1262
1263 #[test]
1264 fn test_superpoint_creation() {
1265 let config = NeuralFeatureConfig {
1266 input_size: (480, 640),
1267 max_keypoints: 512,
1268 use_gpu: false, ..Default::default()
1270 };
1271
1272 let result = SuperPointNet::new(Some(config));
1273 assert!(result.is_ok());
1274 }
1275
1276 #[test]
1277 fn test_superpoint_detection() {
1278 let config = NeuralFeatureConfig {
1279 input_size: (480, 640),
1280 max_keypoints: 100,
1281 use_gpu: false,
1282 ..Default::default()
1283 };
1284
1285 if let Ok(superpoint) = SuperPointNet::new(Some(config)) {
1286 let image = Array2::from_shape_fn((480, 640), |(y, x)| {
1287 ((x as f32 / 10.0).sin() + (y as f32 / 10.0).cos()) * 0.5 + 0.5
1288 });
1289
1290 let result = superpoint.detect_and_describe(&image.view());
1291 assert!(result.is_ok());
1292
1293 let (keypoints, descriptors) = result.expect("Operation failed");
1294 assert!(!keypoints.is_empty());
1295 assert_eq!(descriptors.shape()[0], keypoints.len());
1296 }
1297 }
1298
1299 #[test]
1300 fn test_neural_feature_matcher() {
1301 let matcher = NeuralFeatureMatcher::new();
1302
1303 let desc1 = arr2(&[
1304 [1.0, 0.0, 0.0, 0.0],
1305 [0.0, 1.0, 0.0, 0.0],
1306 [0.0, 0.0, 1.0, 0.0],
1307 ]);
1308
1309 let desc2 = arr2(&[
1310 [0.9, 0.1, 0.0, 0.0],
1311 [0.0, 0.0, 0.9, 0.1],
1312 [0.1, 0.9, 0.0, 0.0],
1313 ]);
1314
1315 let matches = matcher
1316 .match_descriptors(&desc1.view(), &desc2.view())
1317 .expect("Operation failed");
1318 assert!(!matches.is_empty());
1319 }
1320
1321 #[test]
1322 fn test_learned_sift() {
1323 let sift = LearnedSIFT::new(None);
1324 let image = Array2::from_shape_fn((100, 100), |(y, x)| {
1325 if (x as i32 - 50).abs() < 5 && (y as i32 - 50).abs() < 5 {
1326 1.0
1327 } else {
1328 0.0
1329 }
1330 });
1331
1332 let keypoints = sift
1333 .detect_keypoints(&image.view())
1334 .expect("Operation failed");
1335 if !keypoints.is_empty() {
1336 let descriptors = sift
1337 .compute_descriptors(&image.view(), &keypoints)
1338 .expect("Operation failed");
1339 assert_eq!(descriptors.shape()[0], keypoints.len());
1340 assert_eq!(descriptors.shape()[1], 128);
1341 }
1342 }
1343
1344 #[test]
1345 fn test_attention_matcher() {
1346 let matcher = AttentionFeatureMatcher::new(64, 4);
1347
1348 let keypoints1 = vec![
1349 KeyPoint {
1350 x: 10.0,
1351 y: 10.0,
1352 response: 1.0,
1353 scale: 1.0,
1354 orientation: 0.0,
1355 },
1356 KeyPoint {
1357 x: 20.0,
1358 y: 20.0,
1359 response: 1.0,
1360 scale: 1.0,
1361 orientation: 0.0,
1362 },
1363 ];
1364
1365 let keypoints2 = vec![
1366 KeyPoint {
1367 x: 12.0,
1368 y: 11.0,
1369 response: 1.0,
1370 scale: 1.0,
1371 orientation: 0.0,
1372 },
1373 KeyPoint {
1374 x: 50.0,
1375 y: 50.0,
1376 response: 1.0,
1377 scale: 1.0,
1378 orientation: 0.0,
1379 },
1380 ];
1381
1382 let desc1 = Array2::from_shape_fn((2, 64), |__| scirs2_core::random::random::<f32>());
1383 let desc2 = Array2::from_shape_fn((2, 64), |__| scirs2_core::random::random::<f32>());
1384
1385 let result =
1386 matcher.match_with_attention(&keypoints1, &desc1.view(), &keypoints2, &desc2.view());
1387 assert!(result.is_ok());
1388 }
1389}