smartcore/algorithm/neighbour/
linear_search.rs1#[cfg(feature = "serde")]
26use serde::{Deserialize, Serialize};
27use std::cmp::{Ordering, PartialOrd};
28
29use crate::algorithm::sort::heap_select::HeapSelection;
30use crate::error::{Failed, FailedError};
31use crate::metrics::distance::Distance;
32
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35#[derive(Debug)]
36pub struct LinearKNNSearch<T, D: Distance<T>> {
37 distance: D,
38 data: Vec<T>,
39}
40
41impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
42 pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, D>, Failed> {
46 Ok(LinearKNNSearch { data, distance })
47 }
48
49 pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, f64, &T)>, Failed> {
53 if k < 1 || k > self.data.len() {
54 return Err(Failed::because(
55 FailedError::FindFailed,
56 "k should be >= 1 and <= length(data)",
57 ));
58 }
59
60 let mut heap = HeapSelection::<KNNPoint>::with_capacity(k);
61
62 for _ in 0..k {
63 heap.add(KNNPoint {
64 distance: f64::INFINITY,
65 index: None,
66 });
67 }
68
69 for i in 0..self.data.len() {
70 let d = self.distance.distance(from, &self.data[i]);
71 let datum = heap.peek_mut();
72 if d < datum.distance {
73 datum.distance = d;
74 datum.index = Some(i);
75 heap.heapify();
76 }
77 }
78
79 Ok(heap
80 .get()
81 .into_iter()
82 .flat_map(|x| x.index.map(|i| (i, x.distance, &self.data[i])))
83 .collect())
84 }
85
86 pub fn find_radius(&self, from: &T, radius: f64) -> Result<Vec<(usize, f64, &T)>, Failed> {
90 if radius <= 0f64 {
91 return Err(Failed::because(
92 FailedError::FindFailed,
93 "radius should be > 0",
94 ));
95 }
96
97 let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
98
99 for i in 0..self.data.len() {
100 let d = self.distance.distance(from, &self.data[i]);
101
102 if d <= radius {
103 neighbors.push((i, d, &self.data[i]));
104 }
105 }
106
107 Ok(neighbors)
108 }
109}
110
111#[derive(Debug)]
112struct KNNPoint {
113 distance: f64,
114 index: Option<usize>,
115}
116
117impl PartialOrd for KNNPoint {
118 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
119 self.distance.partial_cmp(&other.distance)
120 }
121}
122
123impl PartialEq for KNNPoint {
124 fn eq(&self, other: &Self) -> bool {
125 self.distance == other.distance
126 }
127}
128
129impl Eq for KNNPoint {}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::metrics::distance::Distances;
135
136 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
137 #[derive(Debug, Clone)]
138 struct SimpleDistance {}
139
140 impl Distance<i32> for SimpleDistance {
141 fn distance(&self, a: &i32, b: &i32) -> f64 {
142 (a - b).abs() as f64
143 }
144 }
145
146 #[cfg_attr(
147 all(target_arch = "wasm32", not(target_os = "wasi")),
148 wasm_bindgen_test::wasm_bindgen_test
149 )]
150 #[test]
151 fn knn_find() {
152 let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
153
154 let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {}).unwrap();
155
156 let mut found_idxs1: Vec<usize> = algorithm1
157 .find(&2, 3)
158 .unwrap()
159 .iter()
160 .map(|v| v.0)
161 .collect();
162 found_idxs1.sort_unstable();
163
164 assert_eq!(vec!(0, 1, 2), found_idxs1);
165
166 let mut found_idxs1: Vec<i32> = algorithm1
167 .find_radius(&5, 3.0)
168 .unwrap()
169 .iter()
170 .map(|v| *v.2)
171 .collect();
172 found_idxs1.sort_unstable();
173
174 assert_eq!(vec!(2, 3, 4, 5, 6, 7, 8), found_idxs1);
175
176 let data2 = vec![
177 vec![1., 1.],
178 vec![2., 2.],
179 vec![3., 3.],
180 vec![4., 4.],
181 vec![5., 5.],
182 ];
183
184 let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()).unwrap();
185
186 let mut found_idxs2: Vec<usize> = algorithm2
187 .find(&vec![3., 3.], 3)
188 .unwrap()
189 .iter()
190 .map(|v| v.0)
191 .collect();
192 found_idxs2.sort_unstable();
193
194 assert_eq!(vec!(1, 2, 3), found_idxs2);
195 }
196 #[cfg_attr(
197 all(target_arch = "wasm32", not(target_os = "wasi")),
198 wasm_bindgen_test::wasm_bindgen_test
199 )]
200 #[test]
201 fn knn_point_eq() {
202 let point1 = KNNPoint {
203 distance: 10.,
204 index: Some(0),
205 };
206
207 let point2 = KNNPoint {
208 distance: 100.,
209 index: Some(1),
210 };
211
212 let point3 = KNNPoint {
213 distance: 10.,
214 index: Some(2),
215 };
216
217 let point_inf = KNNPoint {
218 distance: f64::INFINITY,
219 index: Some(3),
220 };
221
222 assert!(point2 > point1);
223 assert_eq!(point3, point1);
224 assert_ne!(point3, point2);
225 assert!(point_inf > point3 && point_inf > point2 && point_inf > point1);
226 }
227}