1use std::fmt::Debug;
2
3use crate::linalg::basic::arrays::Array2;
4use crate::metrics::distance::euclidian::*;
5use crate::numbers::basenum::Number;
6
7#[derive(Debug)]
8pub struct BBDTree {
9 nodes: Vec<BBDTreeNode>,
10 index: Vec<usize>,
11 root: usize,
12}
13
14#[derive(Debug)]
15struct BBDTreeNode {
16 count: usize,
17 index: usize,
18 center: Vec<f64>,
19 radius: Vec<f64>,
20 sum: Vec<f64>,
21 cost: f64,
22 lower: Option<usize>,
23 upper: Option<usize>,
24}
25
26impl BBDTreeNode {
27 fn new(d: usize) -> BBDTreeNode {
28 BBDTreeNode {
29 count: 0,
30 index: 0,
31 center: vec![0f64; d],
32 radius: vec![0f64; d],
33 sum: vec![0f64; d],
34 cost: 0f64,
35 lower: Option::None,
36 upper: Option::None,
37 }
38 }
39}
40
41impl BBDTree {
42 pub fn new<T: Number, M: Array2<T>>(data: &M) -> BBDTree {
43 let nodes: Vec<BBDTreeNode> = Vec::new();
44
45 let (n, _) = data.shape();
46
47 let index = (0..n).collect::<Vec<usize>>();
48
49 let mut tree = BBDTree {
50 nodes,
51 index,
52 root: 0,
53 };
54
55 let root = tree.build_node(data, 0, n);
56
57 tree.root = root;
58
59 tree
60 }
61
62 pub(crate) fn clustering(
63 &self,
64 centroids: &[Vec<f64>],
65 sums: &mut Vec<Vec<f64>>,
66 counts: &mut Vec<usize>,
67 membership: &mut Vec<usize>,
68 ) -> f64 {
69 let k = centroids.len();
70
71 counts.iter_mut().for_each(|v| *v = 0);
72 let mut candidates = vec![0; k];
73 for i in 0..k {
74 candidates[i] = i;
75 sums[i].iter_mut().for_each(|v| *v = 0f64);
76 }
77
78 self.filter(
79 self.root,
80 centroids,
81 &candidates,
82 k,
83 sums,
84 counts,
85 membership,
86 )
87 }
88
89 fn filter(
90 &self,
91 node: usize,
92 centroids: &[Vec<f64>],
93 candidates: &[usize],
94 k: usize,
95 sums: &mut Vec<Vec<f64>>,
96 counts: &mut Vec<usize>,
97 membership: &mut Vec<usize>,
98 ) -> f64 {
99 let d = centroids[0].len();
100
101 let mut min_dist =
102 Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
103 let mut closest = candidates[0];
104 for i in 1..k {
105 let dist =
106 Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
107 if dist < min_dist {
108 min_dist = dist;
109 closest = candidates[i];
110 }
111 }
112
113 if self.nodes[node].lower.is_some() {
114 let mut new_candidates = vec![0; k];
115 let mut newk = 0;
116
117 for candidate in candidates.iter().take(k) {
118 if !BBDTree::prune(
119 &self.nodes[node].center,
120 &self.nodes[node].radius,
121 centroids,
122 closest,
123 *candidate,
124 ) {
125 new_candidates[newk] = *candidate;
126 newk += 1;
127 }
128 }
129
130 if newk > 1 {
131 return self.filter(
132 self.nodes[node].lower.unwrap(),
133 centroids,
134 &new_candidates,
135 newk,
136 sums,
137 counts,
138 membership,
139 ) + self.filter(
140 self.nodes[node].upper.unwrap(),
141 centroids,
142 &new_candidates,
143 newk,
144 sums,
145 counts,
146 membership,
147 );
148 }
149 }
150
151 for i in 0..d {
152 sums[closest][i] += self.nodes[node].sum[i];
153 }
154
155 counts[closest] += self.nodes[node].count;
156
157 let last = self.nodes[node].index + self.nodes[node].count;
158 for i in self.nodes[node].index..last {
159 membership[self.index[i]] = closest;
160 }
161
162 BBDTree::node_cost(&self.nodes[node], ¢roids[closest])
163 }
164
165 fn prune(
166 center: &[f64],
167 radius: &[f64],
168 centroids: &[Vec<f64>],
169 best_index: usize,
170 test_index: usize,
171 ) -> bool {
172 if best_index == test_index {
173 return false;
174 }
175
176 let d = centroids[0].len();
177
178 let best = ¢roids[best_index];
179 let test = ¢roids[test_index];
180 let mut lhs = 0f64;
181 let mut rhs = 0f64;
182 for i in 0..d {
183 let diff = test[i] - best[i];
184 lhs += diff * diff;
185 if diff > 0f64 {
186 rhs += (center[i] + radius[i] - best[i]) * diff;
187 } else {
188 rhs += (center[i] - radius[i] - best[i]) * diff;
189 }
190 }
191
192 lhs >= 2f64 * rhs
193 }
194
195 fn build_node<T: Number, M: Array2<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
196 let (_, d) = data.shape();
197
198 let mut node = BBDTreeNode::new(d);
199
200 node.count = end - begin;
201 node.index = begin;
202
203 let mut lower_bound = vec![0f64; d];
204 let mut upper_bound = vec![0f64; d];
205
206 for i in 0..d {
207 lower_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
208 upper_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
209 }
210
211 for i in begin..end {
212 for j in 0..d {
213 let c = data.get((self.index[i], j)).to_f64().unwrap();
214 if lower_bound[j] > c {
215 lower_bound[j] = c;
216 }
217 if upper_bound[j] < c {
218 upper_bound[j] = c;
219 }
220 }
221 }
222
223 let mut max_radius = -1f64;
224 let mut split_index = 0;
225 for i in 0..d {
226 node.center[i] = (lower_bound[i] + upper_bound[i]) / 2f64;
227 node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2f64;
228 if node.radius[i] > max_radius {
229 max_radius = node.radius[i];
230 split_index = i;
231 }
232 }
233
234 if max_radius < 1E-10 {
235 node.lower = Option::None;
236 node.upper = Option::None;
237 for i in 0..d {
238 node.sum[i] = data.get((self.index[begin], i)).to_f64().unwrap();
239 }
240
241 if end > begin + 1 {
242 let len = end - begin;
243 for i in 0..d {
244 node.sum[i] *= len as f64;
245 }
246 }
247
248 node.cost = 0f64;
249 return self.add_node(node);
250 }
251
252 let split_cutoff = node.center[split_index];
253 let mut i1 = begin;
254 let mut i2 = end - 1;
255 let mut size = 0;
256 while i1 <= i2 {
257 let mut i1_good =
258 data.get((self.index[i1], split_index)).to_f64().unwrap() < split_cutoff;
259 let mut i2_good =
260 data.get((self.index[i2], split_index)).to_f64().unwrap() >= split_cutoff;
261
262 if !i1_good && !i2_good {
263 self.index.swap(i1, i2);
264 i1_good = true;
265 i2_good = true;
266 }
267
268 if i1_good {
269 i1 += 1;
270 size += 1;
271 }
272
273 if i2_good {
274 i2 -= 1;
275 }
276 }
277
278 node.lower = Option::Some(self.build_node(data, begin, begin + size));
279 node.upper = Option::Some(self.build_node(data, begin + size, end));
280
281 for i in 0..d {
282 node.sum[i] =
283 self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
284 }
285
286 let mut mean = vec![0f64; d];
287 for (i, mean_i) in mean.iter_mut().enumerate().take(d) {
288 *mean_i = node.sum[i] / node.count as f64;
289 }
290
291 node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
292 + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
293
294 self.add_node(node)
295 }
296
297 fn node_cost(node: &BBDTreeNode, center: &[f64]) -> f64 {
298 let d = center.len();
299 let mut scatter = 0f64;
300 for (i, center_i) in center.iter().enumerate().take(d) {
301 let x = (node.sum[i] / node.count as f64) - *center_i;
302 scatter += x * x;
303 }
304 node.cost + node.count as f64 * scatter
305 }
306
307 fn add_node(&mut self, new_node: BBDTreeNode) -> usize {
308 let idx = self.nodes.len();
309 self.nodes.push(new_node);
310 idx
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::linalg::basic::matrix::DenseMatrix;
318
319 #[cfg_attr(
320 all(target_arch = "wasm32", not(target_os = "wasi")),
321 wasm_bindgen_test::wasm_bindgen_test
322 )]
323 #[test]
324 fn bbdtree_iris() {
325 let data = DenseMatrix::from_2d_array(&[
326 &[5.1, 3.5, 1.4, 0.2],
327 &[4.9, 3.0, 1.4, 0.2],
328 &[4.7, 3.2, 1.3, 0.2],
329 &[4.6, 3.1, 1.5, 0.2],
330 &[5.0, 3.6, 1.4, 0.2],
331 &[5.4, 3.9, 1.7, 0.4],
332 &[4.6, 3.4, 1.4, 0.3],
333 &[5.0, 3.4, 1.5, 0.2],
334 &[4.4, 2.9, 1.4, 0.2],
335 &[4.9, 3.1, 1.5, 0.1],
336 &[7.0, 3.2, 4.7, 1.4],
337 &[6.4, 3.2, 4.5, 1.5],
338 &[6.9, 3.1, 4.9, 1.5],
339 &[5.5, 2.3, 4.0, 1.3],
340 &[6.5, 2.8, 4.6, 1.5],
341 &[5.7, 2.8, 4.5, 1.3],
342 &[6.3, 3.3, 4.7, 1.6],
343 &[4.9, 2.4, 3.3, 1.0],
344 &[6.6, 2.9, 4.6, 1.3],
345 &[5.2, 2.7, 3.9, 1.4],
346 ])
347 .unwrap();
348
349 let tree = BBDTree::new(&data);
350
351 let centroids = vec![vec![4.86, 3.22, 1.61, 0.29], vec![6.23, 2.92, 4.48, 1.42]];
352
353 let mut sums = vec![vec![0f64; 4], vec![0f64; 4]];
354
355 let mut counts = vec![11, 9];
356
357 let mut membership = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1];
358
359 let dist = tree.clustering(¢roids, &mut sums, &mut counts, &mut membership);
360 assert!((dist - 10.68).abs() < 1e-2);
361 assert!((sums[0][0] - 48.6).abs() < 1e-2);
362 assert!((sums[1][3] - 13.8).abs() < 1e-2);
363 assert_eq!(membership[17], 1);
364 }
365}