smartcore/algorithm/neighbour/
bbd_tree.rs

1use std::fmt::Debug;
2
3use crate::linalg::basic::arrays::Array2;
4use crate::metrics::distance::euclidian::*;
5use crate::numbers::basenum::Number;
6
7#[derive(Debug)]
8pub struct BBDTree {
9    nodes: Vec<BBDTreeNode>,
10    index: Vec<usize>,
11    root: usize,
12}
13
14#[derive(Debug)]
15struct BBDTreeNode {
16    count: usize,
17    index: usize,
18    center: Vec<f64>,
19    radius: Vec<f64>,
20    sum: Vec<f64>,
21    cost: f64,
22    lower: Option<usize>,
23    upper: Option<usize>,
24}
25
26impl BBDTreeNode {
27    fn new(d: usize) -> BBDTreeNode {
28        BBDTreeNode {
29            count: 0,
30            index: 0,
31            center: vec![0f64; d],
32            radius: vec![0f64; d],
33            sum: vec![0f64; d],
34            cost: 0f64,
35            lower: Option::None,
36            upper: Option::None,
37        }
38    }
39}
40
41impl BBDTree {
42    pub fn new<T: Number, M: Array2<T>>(data: &M) -> BBDTree {
43        let nodes: Vec<BBDTreeNode> = Vec::new();
44
45        let (n, _) = data.shape();
46
47        let index = (0..n).collect::<Vec<usize>>();
48
49        let mut tree = BBDTree {
50            nodes,
51            index,
52            root: 0,
53        };
54
55        let root = tree.build_node(data, 0, n);
56
57        tree.root = root;
58
59        tree
60    }
61
62    pub(crate) fn clustering(
63        &self,
64        centroids: &[Vec<f64>],
65        sums: &mut Vec<Vec<f64>>,
66        counts: &mut Vec<usize>,
67        membership: &mut Vec<usize>,
68    ) -> f64 {
69        let k = centroids.len();
70
71        counts.iter_mut().for_each(|v| *v = 0);
72        let mut candidates = vec![0; k];
73        for i in 0..k {
74            candidates[i] = i;
75            sums[i].iter_mut().for_each(|v| *v = 0f64);
76        }
77
78        self.filter(
79            self.root,
80            centroids,
81            &candidates,
82            k,
83            sums,
84            counts,
85            membership,
86        )
87    }
88
89    fn filter(
90        &self,
91        node: usize,
92        centroids: &[Vec<f64>],
93        candidates: &[usize],
94        k: usize,
95        sums: &mut Vec<Vec<f64>>,
96        counts: &mut Vec<usize>,
97        membership: &mut Vec<usize>,
98    ) -> f64 {
99        let d = centroids[0].len();
100
101        let mut min_dist =
102            Euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[0]]);
103        let mut closest = candidates[0];
104        for i in 1..k {
105            let dist =
106                Euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[i]]);
107            if dist < min_dist {
108                min_dist = dist;
109                closest = candidates[i];
110            }
111        }
112
113        if self.nodes[node].lower.is_some() {
114            let mut new_candidates = vec![0; k];
115            let mut newk = 0;
116
117            for candidate in candidates.iter().take(k) {
118                if !BBDTree::prune(
119                    &self.nodes[node].center,
120                    &self.nodes[node].radius,
121                    centroids,
122                    closest,
123                    *candidate,
124                ) {
125                    new_candidates[newk] = *candidate;
126                    newk += 1;
127                }
128            }
129
130            if newk > 1 {
131                return self.filter(
132                    self.nodes[node].lower.unwrap(),
133                    centroids,
134                    &new_candidates,
135                    newk,
136                    sums,
137                    counts,
138                    membership,
139                ) + self.filter(
140                    self.nodes[node].upper.unwrap(),
141                    centroids,
142                    &new_candidates,
143                    newk,
144                    sums,
145                    counts,
146                    membership,
147                );
148            }
149        }
150
151        for i in 0..d {
152            sums[closest][i] += self.nodes[node].sum[i];
153        }
154
155        counts[closest] += self.nodes[node].count;
156
157        let last = self.nodes[node].index + self.nodes[node].count;
158        for i in self.nodes[node].index..last {
159            membership[self.index[i]] = closest;
160        }
161
162        BBDTree::node_cost(&self.nodes[node], &centroids[closest])
163    }
164
165    fn prune(
166        center: &[f64],
167        radius: &[f64],
168        centroids: &[Vec<f64>],
169        best_index: usize,
170        test_index: usize,
171    ) -> bool {
172        if best_index == test_index {
173            return false;
174        }
175
176        let d = centroids[0].len();
177
178        let best = &centroids[best_index];
179        let test = &centroids[test_index];
180        let mut lhs = 0f64;
181        let mut rhs = 0f64;
182        for i in 0..d {
183            let diff = test[i] - best[i];
184            lhs += diff * diff;
185            if diff > 0f64 {
186                rhs += (center[i] + radius[i] - best[i]) * diff;
187            } else {
188                rhs += (center[i] - radius[i] - best[i]) * diff;
189            }
190        }
191
192        lhs >= 2f64 * rhs
193    }
194
195    fn build_node<T: Number, M: Array2<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
196        let (_, d) = data.shape();
197
198        let mut node = BBDTreeNode::new(d);
199
200        node.count = end - begin;
201        node.index = begin;
202
203        let mut lower_bound = vec![0f64; d];
204        let mut upper_bound = vec![0f64; d];
205
206        for i in 0..d {
207            lower_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
208            upper_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
209        }
210
211        for i in begin..end {
212            for j in 0..d {
213                let c = data.get((self.index[i], j)).to_f64().unwrap();
214                if lower_bound[j] > c {
215                    lower_bound[j] = c;
216                }
217                if upper_bound[j] < c {
218                    upper_bound[j] = c;
219                }
220            }
221        }
222
223        let mut max_radius = -1f64;
224        let mut split_index = 0;
225        for i in 0..d {
226            node.center[i] = (lower_bound[i] + upper_bound[i]) / 2f64;
227            node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2f64;
228            if node.radius[i] > max_radius {
229                max_radius = node.radius[i];
230                split_index = i;
231            }
232        }
233
234        if max_radius < 1E-10 {
235            node.lower = Option::None;
236            node.upper = Option::None;
237            for i in 0..d {
238                node.sum[i] = data.get((self.index[begin], i)).to_f64().unwrap();
239            }
240
241            if end > begin + 1 {
242                let len = end - begin;
243                for i in 0..d {
244                    node.sum[i] *= len as f64;
245                }
246            }
247
248            node.cost = 0f64;
249            return self.add_node(node);
250        }
251
252        let split_cutoff = node.center[split_index];
253        let mut i1 = begin;
254        let mut i2 = end - 1;
255        let mut size = 0;
256        while i1 <= i2 {
257            let mut i1_good =
258                data.get((self.index[i1], split_index)).to_f64().unwrap() < split_cutoff;
259            let mut i2_good =
260                data.get((self.index[i2], split_index)).to_f64().unwrap() >= split_cutoff;
261
262            if !i1_good && !i2_good {
263                self.index.swap(i1, i2);
264                i1_good = true;
265                i2_good = true;
266            }
267
268            if i1_good {
269                i1 += 1;
270                size += 1;
271            }
272
273            if i2_good {
274                i2 -= 1;
275            }
276        }
277
278        node.lower = Option::Some(self.build_node(data, begin, begin + size));
279        node.upper = Option::Some(self.build_node(data, begin + size, end));
280
281        for i in 0..d {
282            node.sum[i] =
283                self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
284        }
285
286        let mut mean = vec![0f64; d];
287        for (i, mean_i) in mean.iter_mut().enumerate().take(d) {
288            *mean_i = node.sum[i] / node.count as f64;
289        }
290
291        node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
292            + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
293
294        self.add_node(node)
295    }
296
297    fn node_cost(node: &BBDTreeNode, center: &[f64]) -> f64 {
298        let d = center.len();
299        let mut scatter = 0f64;
300        for (i, center_i) in center.iter().enumerate().take(d) {
301            let x = (node.sum[i] / node.count as f64) - *center_i;
302            scatter += x * x;
303        }
304        node.cost + node.count as f64 * scatter
305    }
306
307    fn add_node(&mut self, new_node: BBDTreeNode) -> usize {
308        let idx = self.nodes.len();
309        self.nodes.push(new_node);
310        idx
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::linalg::basic::matrix::DenseMatrix;
318
319    #[cfg_attr(
320        all(target_arch = "wasm32", not(target_os = "wasi")),
321        wasm_bindgen_test::wasm_bindgen_test
322    )]
323    #[test]
324    fn bbdtree_iris() {
325        let data = DenseMatrix::from_2d_array(&[
326            &[5.1, 3.5, 1.4, 0.2],
327            &[4.9, 3.0, 1.4, 0.2],
328            &[4.7, 3.2, 1.3, 0.2],
329            &[4.6, 3.1, 1.5, 0.2],
330            &[5.0, 3.6, 1.4, 0.2],
331            &[5.4, 3.9, 1.7, 0.4],
332            &[4.6, 3.4, 1.4, 0.3],
333            &[5.0, 3.4, 1.5, 0.2],
334            &[4.4, 2.9, 1.4, 0.2],
335            &[4.9, 3.1, 1.5, 0.1],
336            &[7.0, 3.2, 4.7, 1.4],
337            &[6.4, 3.2, 4.5, 1.5],
338            &[6.9, 3.1, 4.9, 1.5],
339            &[5.5, 2.3, 4.0, 1.3],
340            &[6.5, 2.8, 4.6, 1.5],
341            &[5.7, 2.8, 4.5, 1.3],
342            &[6.3, 3.3, 4.7, 1.6],
343            &[4.9, 2.4, 3.3, 1.0],
344            &[6.6, 2.9, 4.6, 1.3],
345            &[5.2, 2.7, 3.9, 1.4],
346        ])
347        .unwrap();
348
349        let tree = BBDTree::new(&data);
350
351        let centroids = vec![vec![4.86, 3.22, 1.61, 0.29], vec![6.23, 2.92, 4.48, 1.42]];
352
353        let mut sums = vec![vec![0f64; 4], vec![0f64; 4]];
354
355        let mut counts = vec![11, 9];
356
357        let mut membership = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1];
358
359        let dist = tree.clustering(&centroids, &mut sums, &mut counts, &mut membership);
360        assert!((dist - 10.68).abs() < 1e-2);
361        assert!((sums[0][0] - 48.6).abs() < 1e-2);
362        assert!((sums[1][3] - 13.8).abs() < 1e-2);
363        assert_eq!(membership[17], 1);
364    }
365}