scirs2_spatial/
kdtree_optimized.rs1use crate::error::SpatialResult;
8use crate::kdtree::KDTree;
9use scirs2_core::ndarray::{Array1, ArrayView2};
10use scirs2_core::numeric::Float;
11use std::marker::{Send, Sync};
12
13pub trait KDTreeOptimized<T: Float + Send + Sync + 'static, D> {
15 fn directed_hausdorff_distance(
28 &self,
29 points: &ArrayView2<T>,
30 seed: Option<u64>,
31 ) -> SpatialResult<(T, usize, usize)>;
32
33 fn hausdorff_distance(&self, points: &ArrayView2<T>, seed: Option<u64>) -> SpatialResult<T>;
44
45 fn batch_nearest_neighbor(
55 &self,
56 points: &ArrayView2<T>,
57 ) -> SpatialResult<(Array1<usize>, Array1<T>)>;
58}
59
60impl<T: Float + Send + Sync + 'static, D: crate::distance::Distance<T> + 'static>
61 KDTreeOptimized<T, D> for KDTree<T, D>
62{
63 fn directed_hausdorff_distance(
64 &self,
65 points: &ArrayView2<T>,
66 _seed: Option<u64>,
67 ) -> SpatialResult<(T, usize, usize)> {
68 let tree_dims = self.ndim();
74 let points_dims = points.shape()[1];
75
76 if tree_dims != points_dims {
77 return Err(crate::error::SpatialError::DimensionError(format!(
78 "Point dimensions ({points_dims}) do not match tree dimensions ({tree_dims})"
79 )));
80 }
81
82 let n_points = points.shape()[0];
83
84 if n_points == 0 {
85 return Err(crate::error::SpatialError::ValueError(
86 "Empty point set".to_string(),
87 ));
88 }
89
90 let mut max_dist = T::zero();
93 let mut max_i = 0; let mut max_j = 0; for j in 0..n_points {
97 let query_point = points.row(j).to_vec();
98
99 let (indices, distances) = self.query(&query_point, 1)?;
101 if indices.is_empty() {
102 continue;
103 }
104
105 let min_dist = distances[0];
106 let min_idx = indices[0];
107
108 if min_dist > max_dist {
110 max_dist = min_dist;
111 max_i = min_idx;
112 max_j = j;
113 }
114 }
115
116 Ok((max_dist, max_i, max_j))
117 }
118
119 fn hausdorff_distance(&self, points: &ArrayView2<T>, seed: Option<u64>) -> SpatialResult<T> {
120 let (dist_forward__, _, _) = self.directed_hausdorff_distance(points, seed)?;
122
123 let points_owned = points.to_owned();
125 let points_tree = KDTree::new(&points_owned)?;
126
127 let (dist_backward__, _, _) = points_tree.directed_hausdorff_distance(points, seed)?;
132
133 Ok(if dist_forward__ > dist_backward__ {
135 dist_forward__
136 } else {
137 dist_backward__
138 })
139 }
140
141 fn batch_nearest_neighbor(
142 &self,
143 points: &ArrayView2<T>,
144 ) -> SpatialResult<(Array1<usize>, Array1<T>)> {
145 let tree_dims = self.ndim();
147 let points_dims = points.shape()[1];
148
149 if tree_dims != points_dims {
150 return Err(crate::error::SpatialError::DimensionError(format!(
151 "Point dimensions ({points_dims}) do not match tree dimensions ({tree_dims})"
152 )));
153 }
154
155 let n_points = points.shape()[0];
156 let mut indices = Array1::<usize>::zeros(n_points);
157 let mut distances = Array1::<T>::zeros(n_points);
158
159 const BATCH_SIZE: usize = 32;
161
162 for batch_start in (0..n_points).step_by(BATCH_SIZE) {
163 let batch_end = std::cmp::min(batch_start + BATCH_SIZE, n_points);
164
165 #[cfg(feature = "parallel")]
167 {
168 use scirs2_core::parallel_ops::*;
169
170 let batch_results: Vec<_> = (batch_start..batch_end)
171 .into_par_iter()
172 .map(|i| {
173 let point = points.row(i).to_vec();
174 let (idx, dist) = self.query(&point, 1).expect("Operation failed");
175 (i, idx[0], dist[0])
176 })
177 .collect();
178
179 for (i, idx, dist) in batch_results {
180 indices[i] = idx;
181 distances[i] = dist;
182 }
183 }
184
185 #[cfg(not(feature = "parallel"))]
187 {
188 for i in batch_start..batch_end {
189 let point = points.row(i).to_vec();
190 let (idx, dist) = self.query(&point, 1)?;
191 indices[i] = idx[0];
192 distances[i] = dist[0];
193 }
194 }
195 }
196
197 Ok((indices, distances))
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::kdtree::KDTree;
205 use scirs2_core::ndarray::array;
206
207 #[test]
208 fn test_batch_nearest_neighbor() {
209 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
211
212 let kdtree = KDTree::new(&points).expect("Operation failed");
213
214 let query_points = array![[0.1, 0.1], [0.9, 0.1], [0.1, 0.9], [0.9, 0.9],];
216
217 let (indices, distances) = kdtree
219 .batch_nearest_neighbor(&query_points.view())
220 .expect("Operation failed");
221
222 assert_eq!(indices.len(), 4);
224 assert_eq!(distances.len(), 4);
225
226 for i in 0..4 {
229 assert!(distances[i] <= 1.5);
230 }
231 }
232
233 #[test]
234 fn test_hausdorff_distance() {
235 let points1 = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0],];
237
238 let points2 = array![[0.0, 0.5], [1.0, 0.5], [0.5, 1.0],];
239
240 let kdtree = KDTree::new(&points1).expect("Operation failed");
242
243 let dist = kdtree
245 .hausdorff_distance(&points2.view(), Some(42))
246 .expect("Operation failed");
247
248 assert!(dist > 0.4 && dist < 1.2);
252 }
253}