1use vpsearch::{BestCandidate, MetricSpace};
2
3use std::collections::HashSet;
4use num_traits::Bounded;
5
6#[derive(Clone, Debug)]
7struct PointN {
8 data: Vec<f32>,
9}
10
11impl PointN {
13 pub fn new(data: impl Into<Vec<f32>>) -> Self {
14 Self { data: data.into() }
15 }
16}
17
18impl MetricSpace for PointN {
20 type UserData = ();
21 type Distance = f32;
22
23 fn distance(&self, other: &Self, _: &Self::UserData) -> Self::Distance {
24 self.data
25 .iter()
26 .zip(other.data.iter())
27 .map(|(s, o)| (s - o).powi(2))
28 .sum::<f32>()
29 .sqrt()
30 }
31}
32
33struct CountBasedNeighborhood<Item: MetricSpace<Impl>, Impl> {
35 max_item_count: usize,
37 max_observed_distance: Item::Distance,
39 distance_x_index: Vec<(Item::Distance, usize)>,
41}
42
43impl<Item: MetricSpace<Impl>, Impl> CountBasedNeighborhood<Item, Impl> {
44 fn new(item_count: usize) -> Self {
47 Self {
48 max_item_count: item_count,
49 max_observed_distance: <Item::Distance as Bounded>::min_value(),
50 distance_x_index: Vec::<(Item::Distance, usize)>::new(),
51 }
52 }
53
54 fn insert_index(&mut self, index: usize, distance: Item::Distance) {
57 self.distance_x_index.push((distance, index));
59 if self.distance_x_index.len() > 1 {
61 let mut n = self.distance_x_index.len() - 1;
63 while n > 0 && self.distance_x_index[n].0 < self.distance_x_index[n - 1].0 {
66 self.distance_x_index.swap(n, n - 1);
67 n -= 1;
68 }
69 self.distance_x_index.truncate(self.max_item_count);
70 }
71 self.max_observed_distance = self.distance_x_index.last().unwrap().0;
74 }
75}
76
77impl<Item: MetricSpace<Impl> + Clone, Impl> BestCandidate<Item, Impl>
80 for CountBasedNeighborhood<Item, Impl>
81{
82 type Output = HashSet<usize>;
83
84 #[inline]
85 fn consider(
86 &mut self,
87 _: &Item,
88 distance: Item::Distance,
89 candidate_index: usize,
90 _: &Item::UserData,
91 ) {
92 if self.max_item_count == 0 {
94 return;
95 }
96
97 if distance < self.max_observed_distance
105 || self.distance_x_index.len() < self.max_item_count
106 {
107 self.insert_index(candidate_index, distance);
108 }
109 }
110
111 #[inline]
112 fn distance(&self) -> Item::Distance {
113 self.max_observed_distance
116 }
117
118 fn result(self, _: &Item::UserData) -> Self::Output {
119 self.distance_x_index
121 .into_iter()
122 .map(|(_, index)| index)
123 .collect::<HashSet<usize>>()
124 }
125}
126
127fn main() {
128 let points = vec![
129 PointN::new([2.0, 3.0]),
130 PointN::new([0.0, 1.0]),
131 PointN::new([4.0, 5.0]),
132 ];
133 let tree = vpsearch::Tree::new(&points);
134
135 let actual = tree.find_nearest_custom(
137 &PointN::new([1.0, 2.0]),
138 &(),
139 CountBasedNeighborhood::new(1),
140 );
141 assert_eq!(actual.len(), 1);
142
143 let expected = [0, 1].iter().copied().collect::<HashSet<usize>>();
145 let actual = tree.find_nearest_custom(
146 &PointN::new([1.0, 2.0]),
147 &(),
148 CountBasedNeighborhood::new(2),
149 );
150 assert_eq!(actual, expected);
151
152 let expected = [0, 1, 2].iter().copied().collect::<HashSet<usize>>();
154 let actual = tree.find_nearest_custom(
155 &PointN::new([1.0, 2.0]),
156 &(),
157 CountBasedNeighborhood::new(10),
158 );
159 assert_eq!(actual, expected);
160}