1use crate::dataset::Dataset;
26use crate::distance::euclidean_sq;
27use crate::error::{Result, ScryLearnError};
28use std::cmp::Ordering;
29use std::collections::BinaryHeap;
30
31#[derive(Clone, Copy, Debug, Default)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[non_exhaustive]
46pub enum Linkage {
47 Single,
49 Complete,
51 Average,
53 #[default]
55 Ward,
56}
57
58#[derive(Clone, Debug)]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63#[non_exhaustive]
64pub struct MergeStep {
65 pub cluster_a: usize,
67 pub cluster_b: usize,
69 pub distance: f64,
71 pub size: usize,
73}
74
75#[derive(Clone, Debug)]
98#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
99#[non_exhaustive]
100pub struct AgglomerativeClustering {
101 n_clusters: usize,
102 linkage: Linkage,
103 labels: Vec<usize>,
104 children: Vec<MergeStep>,
105 fitted: bool,
106 #[cfg_attr(feature = "serde", serde(default))]
107 _schema_version: u32,
108}
109
110impl AgglomerativeClustering {
111 pub fn new(n_clusters: usize) -> Self {
117 Self {
118 n_clusters,
119 linkage: Linkage::Ward,
120 labels: Vec::new(),
121 children: Vec::new(),
122 fitted: false,
123 _schema_version: crate::version::SCHEMA_VERSION,
124 }
125 }
126
127 pub fn linkage(mut self, l: Linkage) -> Self {
129 self.linkage = l;
130 self
131 }
132
133 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
139 data.validate_finite()?;
140 let n = data.n_samples();
141 if n == 0 {
142 return Err(ScryLearnError::EmptyDataset);
143 }
144 if self.n_clusters == 0 || self.n_clusters > n {
145 return Err(ScryLearnError::InvalidParameter(format!(
146 "n_clusters must be between 1 and n_samples ({}), got {}",
147 n, self.n_clusters
148 )));
149 }
150
151 let rows = data.feature_matrix();
152 let n_features = data.n_features();
153
154 let mut dist = vec![vec![0.0_f64; n]; n];
156 for i in 0..n {
157 for j in (i + 1)..n {
158 let d = euclidean_sq(&rows[i], &rows[j]);
159 dist[i][j] = d;
160 dist[j][i] = d;
161 }
162 }
163
164 let mut cluster_of = (0..n).collect::<Vec<usize>>();
167
168 let mut members: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
170
171 let mut centroids: Vec<Vec<f64>> = rows.clone();
173
174 let mut heap: BinaryHeap<MergeCandidate> = BinaryHeap::new();
177
178 for i in 0..n {
180 for j in (i + 1)..n {
181 let d = self.linkage_distance(i, j, &dist, &members, ¢roids, n_features);
182 heap.push(MergeCandidate {
183 neg_dist: -d,
184 a: i,
185 b: j,
186 });
187 }
188 }
189
190 let mut active: Vec<bool> = vec![true; n];
191 let mut n_active = n;
192 let mut next_cluster_id = n; let mut children = Vec::new();
194
195 while n_active > self.n_clusters {
196 let merge = loop {
198 let Some(candidate) = heap.pop() else {
199 break None;
200 };
201 if active[candidate.a] && active[candidate.b] {
202 break Some(candidate);
203 }
204 };
205
206 let Some(merge) = merge else { break };
207
208 let ca = merge.a;
209 let cb = merge.b;
210 let merge_dist = -merge.neg_dist;
211
212 let new_id = next_cluster_id;
214 next_cluster_id += 1;
215
216 let mut new_members = std::mem::take(&mut members[ca]);
218 new_members.extend(std::mem::take(&mut members[cb]));
219 let new_size = new_members.len();
220
221 children.push(MergeStep {
222 cluster_a: ca,
223 cluster_b: cb,
224 distance: merge_dist.sqrt(),
225 size: new_size,
226 });
227
228 let new_centroid = if matches!(self.linkage, Linkage::Ward) {
230 let mut c = vec![0.0; n_features];
231 for &idx in &new_members {
232 for (j, &v) in rows[idx].iter().enumerate() {
233 c[j] += v;
234 }
235 }
236 for v in &mut c {
237 *v /= new_size as f64;
238 }
239 c
240 } else {
241 Vec::new()
242 };
243
244 active[ca] = false;
246 active[cb] = false;
247
248 while active.len() <= new_id {
250 active.push(false);
251 members.push(Vec::new());
252 centroids.push(Vec::new());
253 for row in &mut dist {
255 row.push(f64::INFINITY);
256 }
257 dist.push(vec![f64::INFINITY; dist[0].len()]);
258 }
259
260 active[new_id] = true;
261 members[new_id] = new_members;
262 centroids[new_id] = new_centroid;
263
264 for other in 0..active.len() {
266 if !active[other] || other == new_id {
267 continue;
268 }
269 let d = self.compute_merged_distance(
270 ca, cb, other, &dist, &members, ¢roids, n_features, &rows,
271 );
272 dist[new_id][other] = d;
273 dist[other][new_id] = d;
274 heap.push(MergeCandidate {
275 neg_dist: -d,
276 a: new_id.min(other),
277 b: new_id.max(other),
278 });
279 }
280
281 for &idx in &members[new_id] {
283 cluster_of[idx] = new_id;
284 }
285
286 n_active -= 1;
287 }
288
289 let active_ids: Vec<usize> = active
291 .iter()
292 .enumerate()
293 .filter(|(_, &a)| a)
294 .map(|(i, _)| i)
295 .collect();
296
297 let mut labels = vec![0usize; n];
298 for (label, &cid) in active_ids.iter().enumerate() {
299 for &sample in &members[cid] {
300 labels[sample] = label;
301 }
302 }
303
304 self.labels = labels;
305 self.children = children;
306 self.fitted = true;
307 Ok(())
308 }
309
310 fn linkage_distance(
312 &self,
313 a: usize,
314 b: usize,
315 dist: &[Vec<f64>],
316 members: &[Vec<usize>],
317 centroids: &[Vec<f64>],
318 _n_features: usize,
319 ) -> f64 {
320 match self.linkage {
321 Linkage::Single => {
322 let mut min_d = f64::INFINITY;
323 for &i in &members[a] {
324 for &j in &members[b] {
325 let d = dist[i][j];
326 if d < min_d {
327 min_d = d;
328 }
329 }
330 }
331 min_d
332 }
333 Linkage::Complete => {
334 let mut max_d = 0.0_f64;
335 for &i in &members[a] {
336 for &j in &members[b] {
337 let d = dist[i][j];
338 if d > max_d {
339 max_d = d;
340 }
341 }
342 }
343 max_d
344 }
345 Linkage::Average => {
346 let mut sum = 0.0;
347 let count = members[a].len() * members[b].len();
348 for &i in &members[a] {
349 for &j in &members[b] {
350 sum += dist[i][j];
351 }
352 }
353 if count > 0 {
354 sum / count as f64
355 } else {
356 0.0
357 }
358 }
359 Linkage::Ward => {
360 let sa = members[a].len() as f64;
362 let sb = members[b].len() as f64;
363 let d: f64 = centroids[a]
364 .iter()
365 .zip(centroids[b].iter())
366 .map(|(ca, cb)| (ca - cb).powi(2))
367 .sum();
368 sa * sb / (sa + sb) * d
369 }
370 }
371 }
372
373 #[allow(clippy::too_many_arguments)]
375 fn compute_merged_distance(
376 &self,
377 ca: usize,
378 cb: usize,
379 other: usize,
380 dist: &[Vec<f64>],
381 members: &[Vec<usize>],
382 _centroids: &[Vec<f64>],
383 _n_features: usize,
384 _rows: &[Vec<f64>],
385 ) -> f64 {
386 match self.linkage {
387 Linkage::Single => dist[ca][other].min(dist[cb][other]),
388 Linkage::Complete => dist[ca][other].max(dist[cb][other]),
389 Linkage::Average => {
390 let na = members[ca].len() as f64;
391 let nb = members[cb].len() as f64;
392 (na * dist[ca][other] + nb * dist[cb][other]) / (na + nb)
393 }
394 Linkage::Ward => {
395 let na = members[ca].len() as f64;
397 let nb = members[cb].len() as f64;
398 let nc = members[other].len() as f64;
399 let total = na + nb + nc;
400 ((na + nc) * dist[ca][other] + (nb + nc) * dist[cb][other] - nc * dist[ca][cb])
401 / total
402 }
403 }
404 }
405
406 pub fn labels(&self) -> &[usize] {
408 &self.labels
409 }
410
411 pub fn n_clusters(&self) -> usize {
413 self.n_clusters
414 }
415
416 pub fn children(&self) -> &[MergeStep] {
420 &self.children
421 }
422}
423
424#[derive(Clone, Copy)]
426struct MergeCandidate {
427 neg_dist: f64, a: usize,
429 b: usize,
430}
431
432impl PartialEq for MergeCandidate {
433 fn eq(&self, other: &Self) -> bool {
434 self.neg_dist == other.neg_dist
435 }
436}
437
438impl Eq for MergeCandidate {}
439
440impl PartialOrd for MergeCandidate {
441 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
442 Some(self.cmp(other))
443 }
444}
445
446impl Ord for MergeCandidate {
447 fn cmp(&self, other: &Self) -> Ordering {
448 self.neg_dist
449 .partial_cmp(&other.neg_dist)
450 .unwrap_or(Ordering::Equal)
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_agglomerative_three_clusters() {
460 let mut rng = crate::rng::FastRng::new(0);
462 let mut f1 = Vec::new();
463 let mut f2 = Vec::new();
464 for _ in 0..10 {
465 f1.push(rng.f64() * 2.0);
466 f2.push(rng.f64() * 2.0);
467 }
468 for _ in 0..10 {
469 f1.push(50.0 + rng.f64() * 2.0);
470 f2.push(50.0 + rng.f64() * 2.0);
471 }
472 for _ in 0..10 {
473 f1.push(100.0 + rng.f64() * 2.0);
474 f2.push(100.0 + rng.f64() * 2.0);
475 }
476
477 let data = Dataset::new(
478 vec![f1, f2],
479 vec![0.0; 30],
480 vec!["x".into(), "y".into()],
481 "label",
482 );
483
484 let mut model = AgglomerativeClustering::new(3);
485 model.fit(&data).unwrap();
486
487 let labels = model.labels();
488 assert_eq!(labels.len(), 30);
489
490 let label_a = labels[0];
492 assert!(
493 labels[..10].iter().all(|&l| l == label_a),
494 "Cluster A inconsistent"
495 );
496
497 let label_b = labels[10];
498 assert!(
499 labels[10..20].iter().all(|&l| l == label_b),
500 "Cluster B inconsistent"
501 );
502
503 let label_c = labels[20];
504 assert!(
505 labels[20..].iter().all(|&l| l == label_c),
506 "Cluster C inconsistent"
507 );
508
509 assert_ne!(label_a, label_b);
511 assert_ne!(label_a, label_c);
512 assert_ne!(label_b, label_c);
513 }
514
515 #[test]
516 fn test_agglomerative_linkage_variants() {
517 let data = Dataset::new(
518 vec![vec![0.0, 1.0, 5.0, 6.0], vec![0.0, 0.0, 0.0, 0.0]],
519 vec![0.0; 4],
520 vec!["x".into(), "y".into()],
521 "label",
522 );
523
524 for linkage in [
525 Linkage::Single,
526 Linkage::Complete,
527 Linkage::Average,
528 Linkage::Ward,
529 ] {
530 let mut model = AgglomerativeClustering::new(2).linkage(linkage);
531 model.fit(&data).unwrap();
532 assert_eq!(model.labels().len(), 4, "Failed for {linkage:?}");
533 }
534 }
535
536 #[test]
537 fn test_agglomerative_ward_vs_single() {
538 let data = Dataset::new(
541 vec![
542 vec![0.0, 1.0, 3.0, 10.0, 11.0, 13.0],
543 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
544 ],
545 vec![0.0; 6],
546 vec!["x".into(), "y".into()],
547 "label",
548 );
549
550 let mut ward = AgglomerativeClustering::new(2).linkage(Linkage::Ward);
551 ward.fit(&data).unwrap();
552
553 let mut single = AgglomerativeClustering::new(2).linkage(Linkage::Single);
554 single.fit(&data).unwrap();
555
556 assert_eq!(ward.labels().len(), 6);
558 assert_eq!(single.labels().len(), 6);
559
560 assert_eq!(ward.children().len(), 4); assert_eq!(single.children().len(), 4);
563 }
564
565 #[test]
566 fn test_agglomerative_single_cluster() {
567 let data = Dataset::new(
568 vec![vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 2.0]],
569 vec![0.0; 3],
570 vec!["x".into(), "y".into()],
571 "label",
572 );
573
574 let mut model = AgglomerativeClustering::new(1);
575 model.fit(&data).unwrap();
576 assert!(
577 model.labels().iter().all(|&l| l == 0),
578 "All should be cluster 0"
579 );
580 }
581}