smartcore/algorithm/neighbour/
linear_search.rs

1//! # Brute Force Linear Search
2//!
3//! see [KNN algorithms](../index.html)
4//! ```
5//! use smartcore::algorithm::neighbour::linear_search::*;
6//! use smartcore::metrics::distance::Distance;
7//!
8//! #[derive(Clone)]
9//! struct SimpleDistance {} // Our distance function
10//!
11//! impl Distance<i32> for SimpleDistance {
12//!   fn distance(&self, a: &i32, b: &i32) -> f64 { // simple simmetrical scalar distance
13//!     (a - b).abs() as f64
14//!   }
15//! }
16//!
17//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
18//!
19//! let knn = LinearKNNSearch::new(data, SimpleDistance {}).unwrap();
20//!
21//! knn.find(&5, 3); // find 3 knn points from 5
22//!
23//! ```
24
25#[cfg(feature = "serde")]
26use serde::{Deserialize, Serialize};
27use std::cmp::{Ordering, PartialOrd};
28
29use crate::algorithm::sort::heap_select::HeapSelection;
30use crate::error::{Failed, FailedError};
31use crate::metrics::distance::Distance;
32
33/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35#[derive(Debug)]
36pub struct LinearKNNSearch<T, D: Distance<T>> {
37    distance: D,
38    data: Vec<T>,
39}
40
41impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
42    /// Initializes algorithm.
43    /// * `data` - vector of data points to search for.
44    /// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
45    pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, D>, Failed> {
46        Ok(LinearKNNSearch { data, distance })
47    }
48
49    /// Find k nearest neighbors
50    /// * `from` - look for k nearest points to `from`
51    /// * `k` - the number of nearest neighbors to return
52    pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, f64, &T)>, Failed> {
53        if k < 1 || k > self.data.len() {
54            return Err(Failed::because(
55                FailedError::FindFailed,
56                "k should be >= 1 and <= length(data)",
57            ));
58        }
59
60        let mut heap = HeapSelection::<KNNPoint>::with_capacity(k);
61
62        for _ in 0..k {
63            heap.add(KNNPoint {
64                distance: f64::INFINITY,
65                index: None,
66            });
67        }
68
69        for i in 0..self.data.len() {
70            let d = self.distance.distance(from, &self.data[i]);
71            let datum = heap.peek_mut();
72            if d < datum.distance {
73                datum.distance = d;
74                datum.index = Some(i);
75                heap.heapify();
76            }
77        }
78
79        Ok(heap
80            .get()
81            .into_iter()
82            .flat_map(|x| x.index.map(|i| (i, x.distance, &self.data[i])))
83            .collect())
84    }
85
86    /// Find all nearest neighbors within radius `radius` from `p`
87    /// * `p` - look for k nearest points to `p`
88    /// * `radius` - radius of the search
89    pub fn find_radius(&self, from: &T, radius: f64) -> Result<Vec<(usize, f64, &T)>, Failed> {
90        if radius <= 0f64 {
91            return Err(Failed::because(
92                FailedError::FindFailed,
93                "radius should be > 0",
94            ));
95        }
96
97        let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
98
99        for i in 0..self.data.len() {
100            let d = self.distance.distance(from, &self.data[i]);
101
102            if d <= radius {
103                neighbors.push((i, d, &self.data[i]));
104            }
105        }
106
107        Ok(neighbors)
108    }
109}
110
111#[derive(Debug)]
112struct KNNPoint {
113    distance: f64,
114    index: Option<usize>,
115}
116
117impl PartialOrd for KNNPoint {
118    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
119        self.distance.partial_cmp(&other.distance)
120    }
121}
122
123impl PartialEq for KNNPoint {
124    fn eq(&self, other: &Self) -> bool {
125        self.distance == other.distance
126    }
127}
128
129impl Eq for KNNPoint {}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::metrics::distance::Distances;
135
136    #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
137    #[derive(Debug, Clone)]
138    struct SimpleDistance {}
139
140    impl Distance<i32> for SimpleDistance {
141        fn distance(&self, a: &i32, b: &i32) -> f64 {
142            (a - b).abs() as f64
143        }
144    }
145
146    #[cfg_attr(
147        all(target_arch = "wasm32", not(target_os = "wasi")),
148        wasm_bindgen_test::wasm_bindgen_test
149    )]
150    #[test]
151    fn knn_find() {
152        let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
153
154        let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {}).unwrap();
155
156        let mut found_idxs1: Vec<usize> = algorithm1
157            .find(&2, 3)
158            .unwrap()
159            .iter()
160            .map(|v| v.0)
161            .collect();
162        found_idxs1.sort_unstable();
163
164        assert_eq!(vec!(0, 1, 2), found_idxs1);
165
166        let mut found_idxs1: Vec<i32> = algorithm1
167            .find_radius(&5, 3.0)
168            .unwrap()
169            .iter()
170            .map(|v| *v.2)
171            .collect();
172        found_idxs1.sort_unstable();
173
174        assert_eq!(vec!(2, 3, 4, 5, 6, 7, 8), found_idxs1);
175
176        let data2 = vec![
177            vec![1., 1.],
178            vec![2., 2.],
179            vec![3., 3.],
180            vec![4., 4.],
181            vec![5., 5.],
182        ];
183
184        let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()).unwrap();
185
186        let mut found_idxs2: Vec<usize> = algorithm2
187            .find(&vec![3., 3.], 3)
188            .unwrap()
189            .iter()
190            .map(|v| v.0)
191            .collect();
192        found_idxs2.sort_unstable();
193
194        assert_eq!(vec!(1, 2, 3), found_idxs2);
195    }
196    #[cfg_attr(
197        all(target_arch = "wasm32", not(target_os = "wasi")),
198        wasm_bindgen_test::wasm_bindgen_test
199    )]
200    #[test]
201    fn knn_point_eq() {
202        let point1 = KNNPoint {
203            distance: 10.,
204            index: Some(0),
205        };
206
207        let point2 = KNNPoint {
208            distance: 100.,
209            index: Some(1),
210        };
211
212        let point3 = KNNPoint {
213            distance: 10.,
214            index: Some(2),
215        };
216
217        let point_inf = KNNPoint {
218            distance: f64::INFINITY,
219            index: Some(3),
220        };
221
222        assert!(point2 > point1);
223        assert_eq!(point3, point1);
224        assert_ne!(point3, point2);
225        assert!(point_inf > point3 && point_inf > point2 && point_inf > point1);
226    }
227}