vp_avl/lib.rs
1#![feature(test)]
2#![feature(let_chains)]
3
4mod iter;
5mod metric;
6
7use iter::*;
8
9use metric::*;
10
11extern crate test;
12
13use replace_with::replace_with_or_abort;
14use std::{collections::BinaryHeap, marker::PhantomData};
15
16pub trait Metric {
17 type PointType;
18 fn distance(&self, p1: &Self::PointType, p2: &Self::PointType) -> f64;
19}
20
21pub trait VpTreeObject: Sized {
22 type PointType: PartialEq;
23 fn location(&self) -> &Self::PointType;
24 // used to roughly estimate size of n-d balls
25 // essentially an optimization, will still work if wrong
26 fn dimension(&self) -> usize {
27 10
28 }
29
30 fn approx_halving_radius(&self) -> f64 {
31 // using stirling approximation of gamma fn
32 let dim = self.dimension() as f64;
33
34 // volume of unit sphere
35 let unit_volume = 1.0 / (dim * std::f64::consts::PI).sqrt()
36 * (2.0 * std::f64::consts::PI * std::f64::consts::E / dim).powf(dim / 2.0);
37
38 // radius of ball of volume half the unit sphere in dim-dimensional euclidean space
39 let half_radius = (std::f64::consts::PI * dim).powf(1.0 / (2.0 * dim))
40 * (dim / (2.0 * std::f64::consts::PI * std::f64::consts::E)).sqrt()
41 * (0.5_f64 * unit_volume).powf(1.0 / dim);
42
43 half_radius
44 }
45}
46
47impl VpTreeObject for Vec<f64> {
48 type PointType = Self;
49 fn location(&self) -> &Self {
50 self
51 }
52}
53
54pub trait Storage: FromIterator<Self::DType> + IntoIterator<Item = Self::DType> {
55 type DType;
56 fn read(&self, index: usize) -> &Self::DType;
57 fn write(&mut self, index: usize, value: Self::DType);
58 fn replace(&mut self, index: usize, value: Self::DType) -> Self::DType;
59 fn map_i<F: FnOnce(Self::DType) -> Self::DType>(&mut self, index: usize, op: F);
60 fn size(&self) -> usize;
61 fn push(&mut self, value: Self::DType);
62 fn pop(&mut self) -> Option<Self::DType>;
63 fn iter(&self) -> impl Iterator<Item = &Self::DType>;
64}
65
66impl<T> Storage for Vec<T> {
67 type DType = T;
68
69 fn read(&self, index: usize) -> &Self::DType {
70 &self[index]
71 }
72
73 fn write(&mut self, index: usize, value: Self::DType) {
74 self[index] = value;
75 }
76
77 fn replace(&mut self, index: usize, mut value: Self::DType) -> Self::DType {
78 std::mem::swap(&mut value, &mut self[index]);
79 value
80 }
81
82 fn map_i<F: FnOnce(Self::DType) -> Self::DType>(&mut self, index: usize, op: F) {
83 replace_with_or_abort(&mut self[index], op)
84 }
85
86 fn size(&self) -> usize {
87 self.len()
88 }
89
90 fn push(&mut self, value: Self::DType) {
91 self.push(value)
92 }
93
94 fn pop(&mut self) -> Option<Self::DType> {
95 self.pop()
96 }
97
98 fn iter(&self) -> impl Iterator<Item = &Self::DType> {
99 self.into_iter()
100 }
101}
102
103pub trait VpAvl: Sized {
104 type Point: VpTreeObject;
105 type PointMetric: Metric<PointType = <Self::Point as VpTreeObject>::PointType>;
106 type NodeStorage: Storage<DType = Node>;
107 type DataStorage: Storage<DType = Self::Point>;
108
109 fn nodes(&self) -> &Self::NodeStorage;
110 fn nodes_mut(&mut self) -> &mut Self::NodeStorage;
111 fn data(&self) -> &Self::DataStorage;
112 fn data_mut(&mut self) -> &mut Self::DataStorage;
113 fn metric(&self) -> &Self::PointMetric;
114 fn root(&self) -> usize;
115 fn set_root(&mut self, new: usize);
116
117 fn node_index_data(&self, node_index: usize) -> &Self::Point {
118 &self.data().read(self.nodes().read(node_index).center)
119 }
120
121 fn bulk_build_indices(&mut self, root: usize, mut indices: Vec<usize>) {
122 // #[cfg(test)]
123 // let mut prior_indices = indices.clone();
124
125 if indices.len() < 2 {
126 // simpler case
127 match indices.len() {
128 0 => {
129 // leaf node
130 self.nodes_mut().map_i(root, |mut node| {
131 node.height = 0;
132 node.interior = None;
133 node.exterior = None;
134 node
135 });
136 }
137 1 => {
138 // still has one child
139
140 let exterior = indices.pop().unwrap();
141
142 let radius = self.metric().distance(
143 self.node_index_data(root).location(),
144 self.node_index_data(exterior).location(),
145 );
146 self.nodes_mut().map_i(root, |mut node| {
147 node.exterior = Some(exterior);
148 node.radius = radius;
149 node.height = 1;
150 node.interior = None;
151 node.parent = Some(root);
152 node
153 });
154
155 self.bulk_build_indices(self.nodes().read(root).exterior.unwrap(), indices)
156 }
157 _ => unreachable!(),
158 }
159 return;
160 }
161
162 let mut distances = Vec::with_capacity(indices.len());
163 for index in indices.iter() {
164 distances.push((
165 *index,
166 self.metric().distance(
167 self.node_index_data(root).location(),
168 self.node_index_data(*index).location(),
169 ),
170 ));
171 }
172 // sort indices by distance from root
173 distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
174
175 let partitions: Vec<&[(usize, f64)]> = distances
176 .chunks(distances.len() / 2 + distances.len() % 2)
177 .collect();
178 let mut interior_indices: Vec<usize> = partitions[0].iter().map(|x| x.0).collect();
179 let mut exterior_indices: Vec<usize> = partitions[1].iter().map(|x| x.0).collect();
180
181 let min_exterior_distance = partitions[1].first().unwrap().1;
182
183 let interior_center = interior_indices.pop();
184 let exterior_center = exterior_indices.pop();
185
186 self.nodes_mut().map_i(root, |mut node| {
187 node.radius = min_exterior_distance;
188 node.interior = interior_center;
189 node.exterior = exterior_center;
190 node
191 });
192
193 if let Some(interior) = interior_center {
194 self.nodes_mut().map_i(interior, |mut node| {
195 node.parent = Some(root);
196 node
197 });
198 }
199
200 if let Some(exterior) = exterior_center {
201 self.nodes_mut().map_i(exterior, |mut node| {
202 node.parent = Some(root);
203 node
204 });
205 }
206
207 let mut height = 0;
208 // recurse
209 if let Some(x) = interior_center {
210 self.bulk_build_indices(x, interior_indices);
211 height = height.max(self.nodes().read(x).height);
212 }
213
214 if let Some(x) = exterior_center {
215 self.bulk_build_indices(x, exterior_indices);
216 height = height.max(self.nodes().read(x).height);
217 }
218 self.nodes_mut().map_i(root, |mut node| {
219 node.height = height + 1;
220 node
221 });
222
223 // #[cfg(test)]
224 // {
225 // let mut post_indices = self.child_indices(root);
226 // prior_indices.sort();
227 // post_indices.sort();
228 // assert_eq!(prior_indices, post_indices);
229 // }
230 }
231
232 fn set_height(&mut self, root: usize) {
233 let interior_height = self
234 .nodes()
235 .read(root)
236 .interior
237 .map(|i| self.nodes().read(i).height + 1)
238 .unwrap_or(0);
239 let exterior_height = self
240 .nodes()
241 .read(root)
242 .exterior
243 .map(|i| self.nodes().read(i).height + 1)
244 .unwrap_or(0);
245
246 self.nodes_mut().map_i(root, |mut node| {
247 node.height = interior_height.max(exterior_height);
248 node
249 });
250 }
251
252 fn insert_root(&mut self, root: usize, value: Self::Point) {
253 // #[cfg(test)]
254 // let mut prior_children = self.child_indices(root);
255
256 let distance = self.metric().distance(
257 self.node_index_data(self.nodes().read(root).center)
258 .location(),
259 value.location(),
260 );
261 let root_radius = self.nodes().read(root).radius;
262
263 if distance < root_radius {
264 // in the interior
265 if let Some(ind) = self.nodes().read(root).interior {
266 // recurse
267 self.insert_root(ind, value);
268 } else {
269 let new_radius = root_radius * value.approx_halving_radius();
270 // new leaf node
271 self.data_mut().push(value);
272 let new_index = self.n_nodes();
273
274 self.nodes_mut().push(Node {
275 height: 0,
276 center: new_index,
277 radius: new_radius,
278 parent: Some(root),
279 interior: None,
280 exterior: None,
281 });
282
283 self.nodes_mut().map_i(root, |mut node| {
284 node.interior = Some(new_index);
285 node
286 });
287 }
288 } else {
289 if let Some(ind) = self.nodes().read(root).exterior {
290 // recurse
291 self.insert_root(ind, value);
292 } else {
293 let new_radius = root_radius * value.approx_halving_radius();
294
295 // new leaf node
296 self.data_mut().push(value);
297 let new_index = self.n_nodes();
298
299 self.nodes_mut().push(Node {
300 height: 0,
301 center: new_index,
302 radius: new_radius,
303 parent: Some(root),
304 interior: None,
305 exterior: None,
306 });
307
308 self.nodes_mut().map_i(root, |mut node| {
309 node.exterior = Some(new_index);
310 node
311 });
312 }
313 }
314 // update the height
315 self.set_height(root);
316 // inserted!
317 // rebalance?
318 // will be called again at each successively higher level
319
320 self.rebalance(root);
321
322 // #[cfg(test)]
323 // {
324 // let mut final_children = self.child_indices(root);
325 // prior_children.push(self.n_nodes() - 1);
326 // prior_children.sort();
327 // final_children.sort();
328 //
329 // assert_eq!(prior_children, final_children);
330 // }
331 }
332
333 fn insert(&mut self, value: Self::Point) {
334 if self.n_nodes() > 1 {
335 self.insert_root(self.root(), value)
336 } else if self.n_nodes() == 0 {
337 self.data_mut().push(value);
338 self.nodes_mut().push(Node::new_leaf(0, None))
339 } else {
340 let root = self.root();
341 let root_dist = self.metric().distance(
342 self.node_index_data(self.nodes().read(root).center)
343 .location(),
344 value.location(),
345 );
346
347 self.insert_root(root, value)
348 }
349 }
350
351 // insert an orphaned node
352 // fn insert_existing(&mut self, root: usize, graft: usize) {
353 // let distance = self.metric().distance(
354 // self.node_index_data(self.nodes().read(root).center)
355 // .location(),
356 // self.node_index_data(graft).location(),
357 // );
358 // let root_radius = self.nodes().read(root).radius;
359 //
360 // if distance < root_radius {
361 // // in the interior
362 // if let Some(ind) = self.nodes().read(root).interior {
363 // // recurse
364 // self.insert_existing(ind, graft)
365 // } else {
366 // // leaf node
367 // self.nodes_mut().map_i(root, |mut node| {
368 // node.interior = Some(graft);
369 // node
370 // });
371 // self.nodes_mut().map_i(graft, |mut node| {
372 // node.radius = distance.clamp(root_radius / 2.0, root_radius);
373 // node.parent = Some(root);
374 // node
375 // });
376 // }
377 // } else {
378 // if let Some(ind) = self.nodes().read(root).exterior {
379 // // recurse
380 // self.insert_existing(ind, graft)
381 // } else {
382 // // leaf node
383 // self.nodes_mut().map_i(root, |mut node| {
384 // node.exterior = Some(graft);
385 // node
386 // });
387 // self.nodes_mut().map_i(graft, |mut node| {
388 // node.radius = distance.clamp(root_radius / 2.0, root_radius);
389 // node.parent = Some(root);
390 // node
391 // });
392 // }
393 // }
394 // // update the height
395 // self.set_height(root);
396 //
397 // // inserted!
398 // // rebalance?
399 // // will be called again at each successively higher level
400 // self.rebalance(self.root())
401 // }
402
403 fn rebalance(&mut self, root: usize) {
404 // #[cfg(test)]
405 // let mut prior_children = self.child_indices(root);
406
407 let interior_height = self
408 .nodes()
409 .read(root)
410 .interior
411 .map(|ind| self.nodes().read(ind).height)
412 .unwrap_or(0);
413 let exterior_height = self
414 .nodes()
415 .read(root)
416 .exterior
417 .map(|ind| self.nodes().read(ind).height)
418 .unwrap_or(0);
419
420 if interior_height > (exterior_height + 1) {
421 // interior is too big, it must be rebalanced
422 self.rebalance_interior(root)
423 } else if exterior_height > (interior_height + 1) {
424 // exterior is too big, must be rebalanced
425 self.rebalance_exterior(root)
426 }
427
428 // #[cfg(test)]
429 // {
430 // let mut final_children = self.child_indices(root);
431 // prior_children.sort();
432 // final_children.sort();
433 //
434 // assert_eq!(prior_children, final_children);
435 // }
436 }
437
438 fn child_indices_impl(&self, root: usize, progress: &mut Vec<usize>) {
439 if let Some(int) = self.nodes().read(root).interior {
440 self.child_indices_impl(int, progress)
441 }
442
443 if let Some(ext) = self.nodes().read(root).exterior {
444 self.child_indices_impl(ext, progress)
445 }
446
447 progress.push(root);
448 }
449
450 fn child_indices(&self, root: usize) -> Vec<usize> {
451 let mut chillum = vec![];
452 self.child_indices_impl(root, &mut chillum);
453 chillum.pop();
454
455 chillum
456 }
457
458 // make the interior shorter
459 fn rebalance_interior(&mut self, root: usize) {
460 let mut children = self.child_indices(root);
461
462 // let root = children.pop().unwrap();
463 self.bulk_build_indices(root, children);
464 }
465
466 // make the exterior shorter
467 fn rebalance_exterior(&mut self, root: usize) {
468 // honestly I don't see a way to be clever about this case yet.
469 // rebuilding the whole dang thing
470 // TODO: be good
471 let mut children = self.child_indices(root);
472
473 // let root = children.pop().unwrap();
474 self.bulk_build_indices(root, children)
475 }
476
477 fn nn_iter<'a>(
478 &'a self,
479 query_point: &'a <Self::Point as VpTreeObject>::PointType,
480 ) -> impl Iterator<Item = &'a Self::Point> {
481 KnnIterator::new(query_point, self).map(|(p, _d)| p)
482 }
483
484 fn nn_dist_iter<'a>(
485 &'a self,
486 query_point: &'a <Self::Point as VpTreeObject>::PointType,
487 ) -> KnnIterator<'a, Self> {
488 KnnIterator::new(query_point, self)
489 }
490
491 fn nn_index_iter<'a>(
492 &'a self,
493 query_point: &'a <Self::Point as VpTreeObject>::PointType,
494 ) -> KnnIndexIterator<'a, Self> {
495 KnnIndexIterator::new(query_point, self)
496 }
497
498 fn check_validity_node(&self, root: usize) {
499 if let Some(interior) = self.nodes().read(root).interior {
500 let distance = self.metric().distance(
501 self.node_index_data(root).location(),
502 self.node_index_data(interior).location(),
503 );
504
505 assert!(
506 distance < self.nodes().read(root).radius,
507 "interior {} of {} not within radius: {} >= {}",
508 interior,
509 root,
510 distance,
511 self.nodes().read(root).radius
512 );
513 }
514
515 if let Some(exterior) = self.nodes().read(root).exterior {
516 let distance = self.metric().distance(
517 self.node_index_data(root).location(),
518 self.node_index_data(exterior).location(),
519 );
520
521 assert!(
522 distance >= self.nodes().read(root).radius,
523 "exterior {} of {} not outside radius: {} < {}",
524 exterior,
525 root,
526 distance,
527 self.nodes().read(root).radius
528 );
529 }
530 }
531
532 fn check_validity_root(&self, root: usize) {
533 self.check_validity_node(root);
534
535 if let Some(interior) = self.nodes().read(root).interior {
536 self.check_validity_root(interior)
537 }
538
539 if let Some(exterior) = self.nodes().read(root).exterior {
540 self.check_validity_root(exterior)
541 }
542 }
543
544 fn remove(&mut self, value: &<Self::Point as VpTreeObject>::PointType) -> Option<Self::Point> {
545 let mut to_remove = None;
546 for (nn, _) in self
547 .nn_index_iter(value)
548 .take_while(|nn| nn.1 <= 0.0)
549 .filter(|nn| self.data().read(nn.0).location() == value)
550 {
551 to_remove = Some(nn);
552 break;
553 }
554
555 Some(self.remove_index(to_remove?))
556 }
557
558 // TODO: DONT BE DUM
559 fn remove_index(&mut self, index: usize) -> Self::Point {
560 let end = self.data_mut().pop().unwrap();
561 self.nodes_mut().pop();
562 let old = if index == self.n_nodes() {
563 end
564 } else {
565 self.data_mut().replace(index, end)
566 };
567
568 let indices: Vec<usize> = (1..self.n_nodes()).collect();
569
570 if indices.len() > 0 {
571 self.bulk_build_indices(0, indices);
572 } else if self.n_nodes() == 1 {
573 self.nodes_mut().write(0, Node::new_leaf(0, None));
574 }
575
576 old
577 }
578
579 fn iter<'a>(&'a self) -> impl Iterator<Item = &'a Self::Point> {
580 self.data().iter()
581 }
582
583 // fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Self::Point> {
584 // self.data().iter_mut()
585 // }
586
587 fn check_validity(&self) {
588 if self.n_nodes() > 0 {
589 self.check_validity_root(self.root())
590 }
591 }
592
593 fn n_nodes(&self) -> usize {
594 self.nodes().size()
595 }
596}
597
598#[derive(Clone, Debug)]
599pub struct Node {
600 height: usize,
601 center: usize,
602 radius: f64,
603 parent: Option<usize>,
604 interior: Option<usize>,
605 exterior: Option<usize>,
606}
607
608#[derive(Default, Debug, Clone)]
609pub struct VpAvlData<
610 Point,
611 PointMetric,
612 NodeStorage: Storage<DType = Node>,
613 DataStorage: Storage<DType = Point>,
614> {
615 root: usize,
616 nodes: NodeStorage,
617 data: DataStorage,
618 metric: PointMetric,
619}
620
621impl<Point, PointMetric, NodeStorage, DataStorage> VpAvl
622 for VpAvlData<Point, PointMetric, NodeStorage, DataStorage>
623where
624 PointMetric: Metric<PointType = Point::PointType>,
625 Point: VpTreeObject,
626 NodeStorage: Storage<DType = Node> + Default,
627 DataStorage: Storage<DType = Point> + Default,
628{
629 type Point = Point;
630 type PointMetric = PointMetric;
631 type NodeStorage = NodeStorage;
632 type DataStorage = DataStorage;
633
634 fn nodes(&self) -> &Self::NodeStorage {
635 &self.nodes
636 }
637
638 fn nodes_mut(&mut self) -> &mut Self::NodeStorage {
639 &mut self.nodes
640 }
641
642 fn data(&self) -> &Self::DataStorage {
643 &self.data
644 }
645
646 fn data_mut(&mut self) -> &mut Self::DataStorage {
647 &mut self.data
648 }
649
650 fn metric(&self) -> &Self::PointMetric {
651 &self.metric
652 }
653
654 fn root(&self) -> usize {
655 self.root
656 }
657
658 fn set_root(&mut self, new: usize) {
659 self.root = new;
660 }
661}
662
663pub type VpAvlVec<Point, PointMetric> = VpAvlData<Point, PointMetric, Vec<Node>, Vec<Point>>;
664impl<Point, PointMetric, NodeStorage, DataStorage>
665 VpAvlData<Point, PointMetric, NodeStorage, DataStorage>
666where
667 PointMetric: Metric<PointType = Point::PointType>,
668 Point: VpTreeObject,
669 NodeStorage: Storage<DType = Node> + Default,
670 DataStorage: Storage<DType = Point> + Default,
671{
672 // fn node_index_data(&self, node_index: usize) -> &Point {
673 // &self.data().read(self.nodes().read(node_index).center)
674 // }
675 //
676 pub fn new(metric: PointMetric) -> Self {
677 VpAvlData {
678 root: 0,
679 nodes: Default::default(),
680 data: Default::default(),
681 metric,
682 }
683 }
684
685 fn bulk_insert(metric: PointMetric, data: impl IntoIterator<Item = Point>) -> Self {
686 let data: DataStorage = data.into_iter().collect();
687 let indices: Vec<usize> = (1..data.size()).collect();
688 let nodes = (0..data.size())
689 .map(|ind| Node::new_leaf(ind, None))
690 .collect();
691 let mut rv = VpAvlData {
692 root: 0,
693 nodes,
694 data,
695 metric,
696 };
697
698 rv.bulk_build_indices(0, indices);
699
700 rv
701 }
702
703 pub fn update_metric<NewMetric: Metric<PointType = <Point as VpTreeObject>::PointType>>(
704 self,
705 metric: NewMetric,
706 ) -> VpAvlData<Point, NewMetric, NodeStorage, DataStorage> {
707 VpAvlData::bulk_insert(metric, self.data)
708 }
709
710 //
711 // pub fn update_metric<NewMetric: Metric<PointType = Point::PointType>>(
712 // self,
713 // metric: NewMetric,
714 // ) -> VpAvl<Point, NewMetric, NodeStorage, DataStorage> {
715 // VpAvl::bulk_insert(metric, self.data)
716 // }
717 //
718 // pub fn bulk_insert(metric: PointMetric, data: Vec<Point>) -> Self {
719 // let indices: Vec<usize> = (1..data.len()).collect();
720 // let nodes = (0..data.len())
721 // .map(|ind| Node::new_leaf(ind, None))
722 // .collect();
723 // let mut rv = VpAvl {
724 // root: 0,
725 // nodes,
726 // data,
727 // metric,
728 // };
729 //
730 // rv.bulk_build_indices(0, indices);
731 //
732 // rv
733 // }
734 //
735 // fn bulk_build_indices(&mut self, root: usize, mut indices: Vec<usize>) {
736 // if indices.len() < 2 {
737 // // simpler case
738 // match indices.len() {
739 // 0 => {
740 // // leaf node
741 // self.nodes().read(root).height = 0;
742 // self.nodes().read(root).interior = None;
743 // self.nodes().read(root).exterior = None;
744 // }
745 // 1 => {
746 // // still has one child
747 //
748 // let exterior = indices.pop().unwrap();
749 //
750 // self.nodes().read(root).exterior = Some(exterior);
751 // self.nodes().read(root).radius = self.metric.distance(
752 // self.node_index_data(root).location(),
753 // self.node_index_data(exterior).location(),
754 // );
755 // self.nodes().read(root).height = 1;
756 // self.nodes().read(root).interior = None;
757 // self.nodes().read(exterior).parent = Some(root);
758 //
759 // self.bulk_build_indices(self.nodes().read(root).exterior.unwrap(), indices)
760 // }
761 // _ => unreachable!(),
762 // }
763 // return;
764 // }
765 //
766 // let mut distances = Vec::with_capacity(indices.len());
767 // for index in indices.iter() {
768 // distances.push((
769 // *index,
770 // self.metric.distance(
771 // self.node_index_data(root).location(),
772 // self.node_index_data(*index).location(),
773 // ),
774 // ));
775 // }
776 // // sort indices by distance from root
777 // distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
778 //
779 // let partitions: Vec<&[(usize, f64)]> = distances
780 // .chunks(distances.len() / 2 + distances.len() % 2)
781 // .collect();
782 // let mut interior_indices: Vec<usize> = partitions[0].iter().map(|x| x.0).collect();
783 // let mut exterior_indices: Vec<usize> = partitions[1].iter().map(|x| x.0).collect();
784 //
785 // let min_exterior_distance = partitions[1].first().unwrap().1;
786 //
787 // self.nodes().read(root).radius = min_exterior_distance;
788 //
789 // let interior_center = interior_indices.pop();
790 // let exterior_center = exterior_indices.pop();
791 //
792 // self.nodes().read(root).interior = interior_center;
793 // self.nodes().read(root).exterior = exterior_center;
794 //
795 // if let Some(interior) = interior_center {
796 // self.nodes().read(interior).parent = Some(root);
797 // }
798 //
799 // if let Some(exterior) = exterior_center {
800 // self.nodes().read(exterior).parent = Some(root);
801 // }
802 //
803 // let mut height = 0;
804 // // recurse
805 // if let Some(x) = interior_center {
806 // self.bulk_build_indices(x, interior_indices);
807 // height = height.max(self.nodes().read(x).height);
808 // }
809 //
810 // if let Some(x) = exterior_center {
811 // self.bulk_build_indices(x, exterior_indices);
812 // height = height.max(self.nodes().read(x).height);
813 // }
814 //
815 // self.nodes().read(root).height = height + 1;
816 // }
817 //
818 // fn set_height(&mut self, root: usize) {
819 // let interior_height = self.nodes().read(root)
820 // .interior
821 // .map(|i| self.nodes().read(i).height + 1)
822 // .unwrap_or(0);
823 // let exterior_height = self.nodes().read(root)
824 // .exterior
825 // .map(|i| self.nodes().read(i).height + 1)
826 // .unwrap_or(0);
827 //
828 // self.nodes().read(root).height = interior_height.max(exterior_height);
829 // }
830 //
831 // fn insert_root(&mut self, root: usize, value: Point) {
832 // let distance = self.metric.distance(
833 // self.node_index_data(self.nodes().read(root).center).location(),
834 // value.location(),
835 // );
836 // let root_radius = self.nodes().read(root).radius;
837 //
838 // if distance < root_radius {
839 // // in the interior
840 // if let Some(ind) = self.nodes().read(root).interior {
841 // // recurse
842 // self.insert_root(ind, value);
843 // } else {
844 // // new leaf node
845 // self.data.push(value);
846 // let new_index = self.data.len() - 1;
847 //
848 // self.nodes.push(Node::new_leaf(new_index, Some(root)));
849 //
850 // self.nodes().read(new_index).radius = distance.clamp(root_radius / 2.0, root_radius);
851 //
852 // self.nodes().read(root).interior = Some(new_index);
853 // }
854 // } else {
855 // if let Some(ind) = self.nodes().read(root).exterior {
856 // // recurse
857 // self.insert_root(ind, value);
858 // } else {
859 // // new leaf node
860 // self.data.push(value);
861 // let new_index = self.data.len() - 1;
862 //
863 // self.nodes.push(Node::new_leaf(new_index, Some(root)));
864 //
865 // self.nodes().read(new_index).radius = distance.clamp(root_radius / 2.0, root_radius);
866 //
867 // self.nodes().read(root).exterior = Some(new_index);
868 // }
869 // }
870 // // update the height
871 // self.set_height(root);
872 // // inserted!
873 // // rebalance?
874 // // will be called again at each successively higher level
875 // self.rebalance(root);
876 // }
877 //
878 // pub fn insert(&mut self, value: Point) {
879 // if self.data.len() > 1 {
880 // self.insert_root(self.root, value)
881 // } else if self.data.len() == 0 {
882 // self.data.push(value);
883 // self.nodes.push(Node::new_leaf(0, None))
884 // } else {
885 // let root_dist = self.metric.distance(
886 // self.node_index_data(self.nodes().read(self.root).center)
887 // .location(),
888 // value.location(),
889 // );
890 // self.nodes().read(self.root).radius = root_dist / 2.0;
891 // self.insert_root(self.root, value)
892 // }
893 // }
894 //
895 // // insert an orphaned node
896 // fn insert_existing(&mut self, root: usize, graft: usize) {
897 // let distance = self.metric.distance(
898 // self.node_index_data(self.nodes().read(root).center).location(),
899 // self.node_index_data(graft).location(),
900 // );
901 // let root_radius = self.nodes().read(root).radius;
902 //
903 // if distance < root_radius {
904 // // in the interior
905 // if let Some(ind) = self.nodes().read(root).interior {
906 // // recurse
907 // self.insert_existing(ind, graft)
908 // } else {
909 // // leaf node
910 // self.nodes().read(root).interior = Some(graft);
911 // self.nodes().read(graft).radius = distance.clamp(root_radius / 2.0, root_radius);
912 // self.nodes().read(graft).parent = Some(root);
913 // }
914 // } else {
915 // if let Some(ind) = self.nodes().read(root).exterior {
916 // // recurse
917 // self.insert_existing(ind, graft)
918 // } else {
919 // // leaf node
920 // self.nodes().read(root).exterior = Some(graft);
921 // self.nodes().read(graft).radius = distance.clamp(root_radius / 2.0, root_radius);
922 // self.nodes().read(graft).parent = Some(root);
923 // }
924 // }
925 // // update the height
926 // self.set_height(root);
927 //
928 // // inserted!
929 // // rebalance?
930 // // will be called again at each successively higher level
931 // self.rebalance(root)
932 // }
933 //
934 // fn rebalance(&mut self, root: usize) {
935 // let interior_height = self.nodes().read(root)
936 // .interior
937 // .map(|ind| self.nodes().read(ind).height)
938 // .unwrap_or(0);
939 // let exterior_height = self.nodes().read(root)
940 // .exterior
941 // .map(|ind| self.nodes().read(ind).height)
942 // .unwrap_or(0);
943 //
944 // if interior_height > (exterior_height + 1) {
945 // // interior is too big, it must be rebalanced
946 // self.rebalance_interior(root)
947 // } else if exterior_height > (interior_height + 1) {
948 // // exterior is too big, must be rebalanced
949 // self.rebalance_exterior(root)
950 // }
951 // }
952 //
953 // fn child_indices_impl(&self, root: usize, progress: &mut Vec<usize>) {
954 // if let Some(int) = self.nodes().read(root).interior {
955 // self.child_indices_impl(int, progress)
956 // }
957 //
958 // if let Some(ext) = self.nodes().read(root).exterior {
959 // self.child_indices_impl(ext, progress)
960 // }
961 //
962 // progress.push(root);
963 // }
964 //
965 // fn child_indices(&self, root: usize) -> Vec<usize> {
966 // let mut chillum = vec![];
967 // self.child_indices_impl(root, &mut chillum);
968 //
969 // chillum
970 // }
971 //
972 // // make the interior shorter
973 // fn rebalance_interior(&mut self, root: usize) {
974 // let mut children = self.child_indices(root);
975 //
976 // let root = children.pop().unwrap();
977 // self.bulk_build_indices(root, children);
978 //
979 // //
980 // // The following doesn't work, because it's possible, following a bulk reindex, for the radius of a child to be larger than that of its parent.
981 // // Consequently I haven't figured out materially more efficient way of grafting the subtrees here.
982 // // I keep this in place as an inspiration to figure out how to do this properly in the future
983 // //
984 //
985 // // // moves nodes as:
986 // // // interior -> root
987 // // // exterior -> new root exterior
988 // // // old root -> reinsert
989 //
990 // // // there must be an interior in this case, but maybe no exterior
991 // // let old_interior = self.nodes().read(root).interior.unwrap();
992 // // let old_exterior = self.nodes().read(root).exterior;
993 //
994 // // let old_root_data = self.nodes().read(root).center;
995 //
996 // // println!(
997 // // "swapping {}: {:?} <> {}: {:?}",
998 // // root, self.nodes().read(root], old_interior, self.nodes[old_interior)
999 // // );
1000 //
1001 // // // if there is no graft node, no children....
1002 // // let mut old_exterior_children = old_exterior
1003 // // .map(|ind| self.child_indices(ind))
1004 // // .unwrap_or(vec![]);
1005 //
1006 // // // transplant the old interior to the root
1007 // // self.nodes().read(root] = self.nodes[old_interior).clone();
1008 // // self.nodes().read(old_interior).center = old_root_data;
1009 //
1010 // // let old_root_distance = self.metric.distance(
1011 // // self.node_index_data(root),
1012 // // self.node_index_data(old_interior),
1013 // // );
1014 //
1015 // // let root_radius = self.nodes().read(root).radius;
1016 //
1017 // // // make the old root data located in the old interior node
1018 // // self.nodes().read(old_interior).center = old_root_data;
1019 // // self.nodes().read(old_interior).interior = None;
1020 // // self.nodes().read(old_interior).exterior = None;
1021 // // self.nodes().read(old_interior).height = 0;
1022 // // self.nodes().read(old_interior).radius = old_root_distance.clamp(root_radius / 2.0, root_radius);
1023 //
1024 // // // collect the new exterior nodes
1025 // // let new_exterior_node = self.nodes().read(root).exterior;
1026 // // // this could be empty
1027 // // let mut new_exterior_children = new_exterior_node
1028 // // .map(|ind| self.child_indices(ind))
1029 // // .unwrap_or(vec![]);
1030 //
1031 // // println!(
1032 // // "new exterior children {:?}: {:?}",
1033 // // new_exterior_node, new_exterior_children
1034 // // );
1035 //
1036 // // println!(
1037 // // "old exterior children {:?}: {:?}",
1038 // // old_exterior, old_exterior_children
1039 // // );
1040 //
1041 // // // aggregate all children...
1042 // // // either or both could be empty
1043 // // new_exterior_children.append(&mut old_exterior_children);
1044 //
1045 // // // check where the old root should go...
1046 // // println!(
1047 // // "swapped {}: {:?} <> {}: {:?} distance: {}/{}",
1048 // // root,
1049 // // self.nodes().read(root),
1050 // // old_interior,
1051 // // self.nodes().read(old_interior),
1052 // // old_root_distance,
1053 // // self.nodes().read(root).radius
1054 // // );
1055 //
1056 // // if old_root_distance < root_radius {
1057 // // println!(
1058 // // "old root in interior {} < {}",
1059 // // old_root_distance, root_radius
1060 // // );
1061 // // // old root is within the new root interior
1062 // // match self.nodes().read(root).interior {
1063 // // Some(interior) => self.insert_existing(interior, old_interior),
1064 // // None => {
1065 // // self.nodes().read(root).interior = Some(old_interior);
1066 // // self.nodes().read(root).radius =
1067 // // old_root_distance.clamp(root_radius / 2.0, root_radius)
1068 // // }
1069 // // }
1070 // // } else {
1071 // // println!("old root in exterior");
1072 // // // old root can be handled along with all of the other new exterior points
1073 // // new_exterior_children.push(old_interior)
1074 // // }
1075 //
1076 // // let new_exterior_root = new_exterior_children.pop();
1077 // // self.nodes().read(root).exterior = new_exterior_root;
1078 //
1079 // // // now reindex the new exterior nodes
1080 // // if let Some(exterior) = new_exterior_root {
1081 // // println!(
1082 // // "new exterior nodes {}: {:?}",
1083 // // exterior, new_exterior_children
1084 // // );
1085 // // self.bulk_build_indices(exterior, new_exterior_children);
1086 // // }
1087 //
1088 // // self.set_height(root);
1089 //
1090 // // println!(
1091 // // "finally {}: {:?} int: {:?} ext {:?}",
1092 // // root,
1093 // // self.nodes().read(root),
1094 // // self.nodes().read(root].interior.map(|i| &self.nodes[i)),
1095 // // self.nodes().read(root].exterior.map(|i| &self.nodes[i)),
1096 // // );
1097 //
1098 // // self.check_validity_root(root);
1099 // }
1100 //
1101 // // make the exterior shorter
1102 // fn rebalance_exterior(&mut self, root: usize) {
1103 // // honestly I don't see a way to be clever about this case yet.
1104 // // rebuilding the whole dang thing
1105 // // TODO: be good
1106 // let mut children = self.child_indices(root);
1107 //
1108 // let root = children.pop().unwrap();
1109 // self.bulk_build_indices(root, children)
1110 // }
1111 //
1112 // pub fn nn_iter<'a>(
1113 // &'a self,
1114 // query_point: &'a Point::PointType,
1115 // ) -> impl Iterator<Item = &'a Point> {
1116 // KnnIterator::new(query_point, self).map(|(p, _d)| p)
1117 // }
1118 //
1119 // pub fn nn_dist_iter<'a>(
1120 // &'a self,
1121 // query_point: &'a Point::PointType,
1122 // ) -> KnnIterator<'a, Point, PointMetric> {
1123 // KnnIterator::new(query_point, self)
1124 // }
1125 //
1126 // pub fn nn_iter_mut<'a>(
1127 // &'a mut self,
1128 // query_point: &'a Point::PointType,
1129 // ) -> impl Iterator<Item = &'a mut Point> {
1130 // self.nn_dist_iter_mut(query_point).map(|(p, _d)| p)
1131 // }
1132 //
1133 // pub fn nn_dist_iter_mut<'a>(
1134 // &'a mut self,
1135 // query_point: &'a Point::PointType,
1136 // ) -> KnnMutIterator<'a, Point, PointMetric> {
1137 // KnnMutIterator::new(query_point, self)
1138 // }
1139 //
1140 // fn nn_index_iter<'a>(
1141 // &'a self,
1142 // query_point: &'a Point::PointType,
1143 // ) -> KnnIndexIterator<'a, Point, PointMetric> {
1144 // KnnIndexIterator::new(query_point, self)
1145 // }
1146 //
1147 // fn check_validity_node(&self, root: usize) {
1148 // if let Some(interior) = self.nodes().read(root).interior {
1149 // let distance = self.metric.distance(
1150 // self.node_index_data(root).location(),
1151 // self.node_index_data(interior).location(),
1152 // );
1153 //
1154 // assert!(
1155 // distance < self.nodes().read(root).radius,
1156 // "interior {} of {} not within radius: {} >= {}",
1157 // interior,
1158 // root,
1159 // distance,
1160 // self.nodes().read(root).radius
1161 // );
1162 // }
1163 //
1164 // if let Some(exterior) = self.nodes().read(root).exterior {
1165 // let distance = self.metric.distance(
1166 // self.node_index_data(root).location(),
1167 // self.node_index_data(exterior).location(),
1168 // );
1169 //
1170 // assert!(
1171 // distance >= self.nodes().read(root).radius,
1172 // "exterior {} of {} not outside radius: {} < {}",
1173 // exterior,
1174 // root,
1175 // distance,
1176 // self.nodes().read(root).radius
1177 // );
1178 // }
1179 // }
1180 //
1181 // fn check_validity_root(&self, root: usize) {
1182 // self.check_validity_node(root);
1183 //
1184 // if let Some(interior) = self.nodes().read(root).interior {
1185 // self.check_validity_root(interior)
1186 // }
1187 //
1188 // if let Some(exterior) = self.nodes().read(root).exterior {
1189 // self.check_validity_root(exterior)
1190 // }
1191 // }
1192 //
1193 // pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a Point> {
1194 // self.data.iter()
1195 // }
1196 //
1197 // pub fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut Point> {
1198 // self.data.iter_mut()
1199 // }
1200 //
1201 // fn check_validity(&self) {
1202 // if self.len()() > 0 {
1203 // self.check_validity_root(self.root)
1204 // }
1205 // }
1206 //
1207 // pub fn size(&self) -> usize {
1208 // self.data.len()
1209 // }
1210}
1211
1212impl Node {
1213 fn new_leaf(center: usize, parent: Option<usize>) -> Self {
1214 Node {
1215 height: 0,
1216 center,
1217 radius: 0.0,
1218 interior: None,
1219 exterior: None,
1220 parent,
1221 }
1222 }
1223}
1224
1225#[cfg(test)]
1226mod tests {
1227 use super::*;
1228 use rand::distributions::Uniform;
1229 use rand::Rng;
1230 use test::Bencher;
1231
1232 fn check_ordering<T: VpAvl>(tree: &T, test_points: &[<T::Point as VpTreeObject>::PointType]) {
1233 for p in test_points.iter() {
1234 let mut d = 0.0;
1235 for (point, dist) in tree.nn_dist_iter(p) {
1236 assert!(dist >= d);
1237 d = dist;
1238 }
1239 }
1240 }
1241
1242 #[test]
1243 fn test_vp() {
1244 let random_points = k_random(10000);
1245 let query_set = k_random(1000);
1246
1247 let avl = VpAvlVec::bulk_insert(EuclideanMetric::default(), random_points.clone());
1248
1249 assert_eq!(avl.data.len(), 10000);
1250 assert_eq!(avl.nodes.len(), 10000);
1251
1252 // verify all nodes are children of the root
1253 assert_eq!(
1254 avl.child_indices(avl.root).len(),
1255 9999,
1256 "children: {} != 10000 - 1",
1257 avl.child_indices(avl.root).len()
1258 );
1259
1260 avl.check_validity();
1261
1262 let metric = EuclideanMetric::default();
1263 for q in query_set {
1264 let nn = avl.nn_iter(&q).next().unwrap();
1265 let avl_min_dist = metric.distance(&q, &nn);
1266
1267 // linear search
1268 let linear_min_dist = random_points.iter().fold(f64::INFINITY, |acc, x| {
1269 let dist = metric.distance(&q, x);
1270 acc.min(dist)
1271 });
1272
1273 assert!(
1274 linear_min_dist == avl_min_dist,
1275 "linear = {}, avl = {}",
1276 linear_min_dist,
1277 avl_min_dist
1278 );
1279 }
1280 }
1281
1282 #[test]
1283 fn test_vp_incremental() {
1284 let random_points = k_random(10000);
1285 let query_set = k_random(1000);
1286
1287 let metric = EuclideanMetric::default();
1288
1289 let mut avl = VpAvlVec::new(metric);
1290 for point in random_points.iter() {
1291 avl.insert(point.clone());
1292 }
1293
1294 assert!(avl.data.len() == 10000);
1295 assert!(avl.nodes.len() == 10000);
1296
1297 // verify all nodes are children of the root
1298 assert_eq!(
1299 avl.child_indices(avl.root).len(),
1300 10000 - 1,
1301 "children: {} != 10000 - 1",
1302 avl.child_indices(avl.root).len()
1303 );
1304
1305 avl.check_validity();
1306
1307 let metric = EuclideanMetric::default();
1308 for q in query_set {
1309 let nn = avl.nn_iter(&q).next().unwrap();
1310 let avl_min_dist = metric.distance(&q, &nn);
1311
1312 // linear search
1313 let linear_min_dist = random_points.iter().fold(f64::INFINITY, |acc, x| {
1314 let dist = metric.distance(&q, x);
1315 acc.min(dist)
1316 });
1317
1318 assert!(
1319 linear_min_dist == avl_min_dist,
1320 "linear = {}, avl = {}",
1321 linear_min_dist,
1322 avl_min_dist
1323 );
1324 }
1325 }
1326
1327 #[test]
1328 fn test_vp_remove() {
1329 // smaller because this is slow
1330 let mut random_points = k_random(1000);
1331
1332 let metric = EuclideanMetric::default();
1333
1334 let mut avl = VpAvlVec::new(metric);
1335 for point in random_points.iter() {
1336 avl.insert(point.clone());
1337 }
1338
1339 assert_eq!(avl.data.len(), 1000);
1340 assert_eq!(avl.nodes.len(), 1000);
1341
1342 // verify all nodes are children of the root
1343 assert_eq!(
1344 avl.child_indices(avl.root()).len(),
1345 1000 - 1,
1346 "children: {} != 1000 - 1",
1347 avl.child_indices(avl.root()).len()
1348 );
1349
1350 avl.check_validity();
1351
1352 for (ind, removal) in random_points.iter().enumerate() {
1353 assert!(avl.remove(&removal).is_some());
1354 assert!(avl.nodes.len() == (1000 - ind - 1));
1355 avl.check_validity();
1356 check_ordering(&avl, random_points.as_slice());
1357 }
1358 }
1359
1360 #[test]
1361 fn test_reweight() {
1362 let random_points = k_random(10000);
1363 let query_set = k_random(1000);
1364
1365 let avl = VpAvlVec::bulk_insert(EuclideanMetric::default(), random_points.clone());
1366
1367 let weighted_metric = WeightedEuclideanMetric::new(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
1368 let new_avl = avl.update_metric(weighted_metric.clone());
1369
1370 for q in query_set {
1371 let nn = new_avl.nn_iter(&q).next().unwrap();
1372 let avl_min_dist = weighted_metric.distance(&q, &nn);
1373
1374 // linear search
1375 let linear_min_dist = random_points.iter().fold(f64::INFINITY, |acc, x| {
1376 let dist = weighted_metric.distance(&q, x);
1377 acc.min(dist)
1378 });
1379
1380 assert_eq!(
1381 linear_min_dist, avl_min_dist,
1382 "linear = {}, avl = {}",
1383 linear_min_dist, avl_min_dist
1384 );
1385 }
1386 }
1387
1388 #[test]
1389 fn test_iter() {
1390 let random_points = k_random(10000);
1391 let query_set = k_random(100);
1392 let metric = EuclideanMetric::default();
1393 let avl = VpAvlVec::bulk_insert(metric.clone(), random_points.clone());
1394
1395 for q in query_set {
1396 avl.nn_iter(&q).fold(0.0, |prev, pt| {
1397 let dist = metric.distance(&q, pt);
1398 assert!(dist >= prev, "distance went down! {} < {}", dist, prev);
1399 dist
1400 });
1401 }
1402 }
1403
1404 #[test]
1405 fn test_empty() {
1406 let query_set = k_random(1);
1407 let metric = EuclideanMetric::default();
1408 let avl = VpAvlVec::<Vec<f64>, _>::new(metric.clone());
1409
1410 for q in query_set {
1411 avl.nn_iter(&q).fold(0.0, |prev, pt| {
1412 let dist = metric.distance(&q, pt);
1413 assert!(dist >= prev, "distance went down! {} < {}", dist, prev);
1414 dist
1415 });
1416 }
1417 }
1418
1419 fn k_random(k: usize) -> Vec<Vec<f64>> {
1420 let range = Uniform::new(-1.0, 1.0);
1421 (0..k)
1422 .map(|_| rand::thread_rng().sample_iter(range).take(5).collect())
1423 .collect()
1424 }
1425
1426 fn random_k(k: usize) {
1427 // so this is a little messy because it also generates the points, but I want to make sure the bench uses new points each time
1428 let points = k_random(k);
1429 }
1430
1431 fn bench_bulk_k(k: usize) {
1432 // so this is a little messy because it also generates the points, but I want to make sure the bench uses new points each time
1433 let points = k_random(k);
1434 let metric = EuclideanMetric::default();
1435 let avl = VpAvlVec::bulk_insert(metric, points);
1436 }
1437
1438 fn bench_incremental_k(k: usize) {
1439 // so this is a little messy because it also generates the points, but I want to make sure the bench uses new points each time
1440 let points = k_random(k);
1441 let metric = EuclideanMetric::default();
1442 let mut avl = VpAvlVec::new(metric);
1443 for point in points {
1444 avl.insert(point);
1445 }
1446 }
1447
1448 #[bench]
1449 fn bench_random_1000(b: &mut Bencher) {
1450 b.iter(|| random_k(1000));
1451 }
1452
1453 #[bench]
1454 fn bench_random_10000(b: &mut Bencher) {
1455 b.iter(|| random_k(10000));
1456 }
1457
1458 #[bench]
1459 fn bench_random_100000(b: &mut Bencher) {
1460 b.iter(|| random_k(100000));
1461 }
1462
1463 #[bench]
1464 fn bench_random_1000000(b: &mut Bencher) {
1465 b.iter(|| random_k(1000000));
1466 }
1467
1468 #[bench]
1469 fn bench_build_vp_bulk_1000(b: &mut Bencher) {
1470 b.iter(|| bench_bulk_k(1000));
1471 }
1472
1473 #[bench]
1474 fn bench_build_vp_incremental_1000(b: &mut Bencher) {
1475 b.iter(|| bench_incremental_k(1000));
1476 }
1477
1478 #[bench]
1479 fn bench_build_vp_bulk_10000(b: &mut Bencher) {
1480 b.iter(|| bench_bulk_k(10000));
1481 }
1482
1483 #[bench]
1484 fn bench_build_vp_incremental_10000(b: &mut Bencher) {
1485 b.iter(|| bench_incremental_k(10000));
1486 }
1487
1488 #[bench]
1489 fn bench_build_vp_bulk_100000(b: &mut Bencher) {
1490 b.iter(|| bench_bulk_k(100000));
1491 }
1492
1493 #[bench]
1494 fn bench_build_vp_incremental_100000(b: &mut Bencher) {
1495 b.iter(|| bench_incremental_k(100000));
1496 }
1497
1498 // #[bench]
1499 // fn bench_build_vp_bulk_1000000(b: &mut Bencher) {
1500 // b.iter(|| bench_bulk_k(1000000));
1501 // }
1502
1503 // #[bench]
1504 // fn bench_build_vp_incremental_1000000(b: &mut Bencher) {
1505 // b.iter(|| bench_incremental_k(1000000));
1506 // }
1507}