vecstore/
clustering.rs

1//! Vector clustering algorithms
2//!
3//! This module provides clustering algorithms for grouping similar vectors:
4//! - K-means: Partitioning into K clusters based on centroids
5//! - DBSCAN: Density-based clustering with automatic cluster detection
6//! - Hierarchical: Agglomerative clustering with dendrogram support
7//!
8//! # Features
9//!
10//! - Multiple distance metrics (Euclidean, Cosine, Manhattan)
11//! - Automatic cluster number detection (for DBSCAN)
12//! - Cluster quality metrics (silhouette score, inertia)
13//! - Outlier detection
14//! - Visualization support
15//!
16//! # Example
17//!
18//! ```rust
19//! use vecstore::clustering::{KMeansClustering, ClusteringConfig};
20//!
21//! let vectors = vec![
22//!     vec![1.0, 2.0],
23//!     vec![1.5, 1.8],
24//!     vec![5.0, 8.0],
25//!     vec![8.0, 8.0],
26//! ];
27//!
28//! let config = ClusteringConfig {
29//!     k: 2,
30//!     max_iterations: 100,
31//!     tolerance: 0.001,
32//! };
33//!
34//! let kmeans = KMeansClustering::new(config);
35//! let result = kmeans.fit(&vectors)?;
36//!
37//! for (i, label) in result.labels.iter().enumerate() {
38//!     println!("Vector {} belongs to cluster {}", i, label);
39//! }
40//! ```
41
42use crate::error::{Result, VecStoreError};
43use crate::simd::euclidean_distance_simd;
44use serde::{Deserialize, Serialize};
45use std::collections::{HashMap, HashSet};
46
47// Helper function for euclidean distance
48fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
49    euclidean_distance_simd(a, b)
50}
51
52/// Clustering configuration
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ClusteringConfig {
55    /// Number of clusters
56    pub k: usize,
57    /// Maximum iterations
58    pub max_iterations: usize,
59    /// Convergence tolerance
60    pub tolerance: f32,
61}
62
63impl Default for ClusteringConfig {
64    fn default() -> Self {
65        Self {
66            k: 3,
67            max_iterations: 100,
68            tolerance: 0.001,
69        }
70    }
71}
72
73/// DBSCAN configuration
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct DBSCANConfig {
76    /// Epsilon (neighborhood radius)
77    pub eps: f32,
78    /// Minimum points to form a cluster
79    pub min_points: usize,
80}
81
82impl Default for DBSCANConfig {
83    fn default() -> Self {
84        Self {
85            eps: 0.5,
86            min_points: 5,
87        }
88    }
89}
90
91/// Hierarchical clustering configuration
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct HierarchicalConfig {
94    /// Number of clusters to form
95    pub n_clusters: usize,
96    /// Linkage method
97    pub linkage: LinkageMethod,
98}
99
100/// Linkage method for hierarchical clustering
101#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
102pub enum LinkageMethod {
103    /// Single linkage (minimum distance)
104    Single,
105    /// Complete linkage (maximum distance)
106    Complete,
107    /// Average linkage
108    Average,
109}
110
111impl Default for HierarchicalConfig {
112    fn default() -> Self {
113        Self {
114            n_clusters: 3,
115            linkage: LinkageMethod::Average,
116        }
117    }
118}
119
120/// Clustering result
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ClusteringResult {
123    /// Cluster labels for each vector (-1 for noise in DBSCAN)
124    pub labels: Vec<i32>,
125    /// Cluster centroids (for K-means)
126    pub centroids: Option<Vec<Vec<f32>>>,
127    /// Inertia (sum of squared distances to centroids)
128    pub inertia: f32,
129    /// Number of iterations (for K-means)
130    pub iterations: usize,
131    /// Silhouette score (cluster quality metric)
132    pub silhouette_score: Option<f32>,
133}
134
135/// K-means clustering
136pub struct KMeansClustering {
137    config: ClusteringConfig,
138}
139
140impl KMeansClustering {
141    /// Create new K-means clusterer
142    pub fn new(config: ClusteringConfig) -> Self {
143        Self { config }
144    }
145
146    /// Fit the model and return clustering result
147    pub fn fit(&self, vectors: &[Vec<f32>]) -> Result<ClusteringResult> {
148        if vectors.is_empty() {
149            return Err(VecStoreError::Other(
150                "Cannot cluster empty vector set".to_string(),
151            ));
152        }
153
154        if vectors.len() < self.config.k {
155            return Err(VecStoreError::Other(format!(
156                "Number of vectors ({}) must be >= k ({})",
157                vectors.len(),
158                self.config.k
159            )));
160        }
161
162        let dim = vectors[0].len();
163
164        // Initialize centroids using k-means++
165        let mut centroids = self.initialize_centroids_plus_plus(vectors);
166
167        let mut labels = vec![0; vectors.len()];
168        let mut iterations = 0;
169
170        for iter in 0..self.config.max_iterations {
171            iterations = iter + 1;
172
173            // Assign points to nearest centroid
174            let mut changed = false;
175            for (i, vector) in vectors.iter().enumerate() {
176                let nearest = self.find_nearest_centroid(vector, &centroids);
177                if labels[i] != nearest {
178                    labels[i] = nearest;
179                    changed = true;
180                }
181            }
182
183            if !changed {
184                break; // Converged
185            }
186
187            // Update centroids
188            let old_centroids = centroids.clone();
189            centroids = self.update_centroids(vectors, &labels, dim)?;
190
191            // Check convergence
192            let max_shift = centroids
193                .iter()
194                .zip(old_centroids.iter())
195                .map(|(new, old)| euclidean_distance(new, old))
196                .fold(0.0_f32, f32::max);
197
198            if max_shift < self.config.tolerance {
199                break; // Converged
200            }
201        }
202
203        // Calculate inertia
204        let inertia = self.calculate_inertia(vectors, &labels, &centroids);
205
206        // Calculate silhouette score
207        let silhouette_score = if vectors.len() > 1 {
208            Some(self.calculate_silhouette_score(vectors, &labels))
209        } else {
210            None
211        };
212
213        Ok(ClusteringResult {
214            labels: labels.iter().map(|&l| l as i32).collect(),
215            centroids: Some(centroids),
216            inertia,
217            iterations,
218            silhouette_score,
219        })
220    }
221
222    /// Initialize centroids using k-means++ algorithm
223    fn initialize_centroids_plus_plus(&self, vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
224        let mut centroids = Vec::with_capacity(self.config.k);
225        let mut rng = rand::thread_rng();
226        use rand::seq::SliceRandom;
227
228        // Choose first centroid randomly
229        centroids.push(vectors.choose(&mut rng).unwrap().clone());
230
231        // Choose remaining centroids
232        for _ in 1..self.config.k {
233            let distances: Vec<f32> = vectors
234                .iter()
235                .map(|v| {
236                    centroids
237                        .iter()
238                        .map(|c| euclidean_distance(v, c))
239                        .fold(f32::INFINITY, f32::min)
240                        .powi(2)
241                })
242                .collect();
243
244            let total: f32 = distances.iter().sum();
245            let mut threshold = rand::random::<f32>() * total;
246
247            for (i, &dist) in distances.iter().enumerate() {
248                threshold -= dist;
249                if threshold <= 0.0 {
250                    centroids.push(vectors[i].clone());
251                    break;
252                }
253            }
254        }
255
256        centroids
257    }
258
259    /// Find nearest centroid for a vector
260    fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
261        centroids
262            .iter()
263            .enumerate()
264            .map(|(i, c)| (i, euclidean_distance(vector, c)))
265            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
266            .unwrap()
267            .0
268    }
269
270    /// Update centroids based on current assignment
271    fn update_centroids(
272        &self,
273        vectors: &[Vec<f32>],
274        labels: &[usize],
275        dim: usize,
276    ) -> Result<Vec<Vec<f32>>> {
277        let mut centroids = vec![vec![0.0; dim]; self.config.k];
278        let mut counts = vec![0; self.config.k];
279
280        for (vector, &label) in vectors.iter().zip(labels.iter()) {
281            for (i, &val) in vector.iter().enumerate() {
282                centroids[label][i] += val;
283            }
284            counts[label] += 1;
285        }
286
287        // Average
288        for (centroid, &count) in centroids.iter_mut().zip(counts.iter()) {
289            if count > 0 {
290                for val in centroid.iter_mut() {
291                    *val /= count as f32;
292                }
293            }
294        }
295
296        Ok(centroids)
297    }
298
299    /// Calculate inertia (within-cluster sum of squares)
300    fn calculate_inertia(
301        &self,
302        vectors: &[Vec<f32>],
303        labels: &[usize],
304        centroids: &[Vec<f32>],
305    ) -> f32 {
306        vectors
307            .iter()
308            .zip(labels.iter())
309            .map(|(v, &l)| euclidean_distance(v, &centroids[l]).powi(2))
310            .sum()
311    }
312
313    /// Calculate silhouette score
314    fn calculate_silhouette_score(&self, vectors: &[Vec<f32>], labels: &[usize]) -> f32 {
315        let n = vectors.len();
316        let mut scores = vec![0.0; n];
317
318        for i in 0..n {
319            let own_cluster = labels[i];
320
321            // Calculate a(i): average distance to points in same cluster
322            let same_cluster: Vec<usize> = labels
323                .iter()
324                .enumerate()
325                .filter(|(_, &l)| l == own_cluster)
326                .map(|(idx, _)| idx)
327                .collect();
328
329            let a = if same_cluster.len() > 1 {
330                same_cluster
331                    .iter()
332                    .filter(|&&idx| idx != i)
333                    .map(|&idx| euclidean_distance(&vectors[i], &vectors[idx]))
334                    .sum::<f32>()
335                    / (same_cluster.len() - 1) as f32
336            } else {
337                0.0
338            };
339
340            // Calculate b(i): min average distance to points in other clusters
341            let mut min_b = f32::INFINITY;
342            for cluster in 0..self.config.k {
343                if cluster == own_cluster {
344                    continue;
345                }
346
347                let other_cluster: Vec<usize> = labels
348                    .iter()
349                    .enumerate()
350                    .filter(|(_, &l)| l == cluster)
351                    .map(|(idx, _)| idx)
352                    .collect();
353
354                if !other_cluster.is_empty() {
355                    let b = other_cluster
356                        .iter()
357                        .map(|&idx| euclidean_distance(&vectors[i], &vectors[idx]))
358                        .sum::<f32>()
359                        / other_cluster.len() as f32;
360
361                    min_b = min_b.min(b);
362                }
363            }
364
365            scores[i] = if a < min_b {
366                1.0 - a / min_b
367            } else if a > min_b {
368                min_b / a - 1.0
369            } else {
370                0.0
371            };
372        }
373
374        scores.iter().sum::<f32>() / n as f32
375    }
376}
377
378/// DBSCAN clustering (Density-Based Spatial Clustering of Applications with Noise)
379pub struct DBSCANClustering {
380    config: DBSCANConfig,
381}
382
383impl DBSCANClustering {
384    /// Create new DBSCAN clusterer
385    pub fn new(config: DBSCANConfig) -> Self {
386        Self { config }
387    }
388
389    /// Fit the model and return clustering result
390    pub fn fit(&self, vectors: &[Vec<f32>]) -> Result<ClusteringResult> {
391        if vectors.is_empty() {
392            return Err(VecStoreError::Other(
393                "Cannot cluster empty vector set".to_string(),
394            ));
395        }
396
397        let n = vectors.len();
398        let mut labels = vec![-1; n]; // -1 = unvisited
399        let mut cluster_id = 0;
400
401        for i in 0..n {
402            if labels[i] != -1 {
403                continue; // Already visited
404            }
405
406            let neighbors = self.region_query(vectors, i);
407
408            if neighbors.len() < self.config.min_points {
409                labels[i] = -1; // Mark as noise
410                continue;
411            }
412
413            // Start new cluster
414            self.expand_cluster(vectors, i, neighbors, cluster_id, &mut labels);
415            cluster_id += 1;
416        }
417
418        // Calculate inertia (for noise points, use distance to nearest cluster)
419        let inertia = 0.0; // DBSCAN doesn't have centroids, so inertia is not meaningful
420
421        Ok(ClusteringResult {
422            labels,
423            centroids: None,
424            inertia,
425            iterations: 1, // DBSCAN doesn't iterate
426            silhouette_score: None,
427        })
428    }
429
430    /// Find neighbors within eps radius
431    fn region_query(&self, vectors: &[Vec<f32>], point_idx: usize) -> Vec<usize> {
432        vectors
433            .iter()
434            .enumerate()
435            .filter(|(i, v)| {
436                *i != point_idx && euclidean_distance(&vectors[point_idx], v) <= self.config.eps
437            })
438            .map(|(i, _)| i)
439            .collect()
440    }
441
442    /// Expand cluster from seed point
443    fn expand_cluster(
444        &self,
445        vectors: &[Vec<f32>],
446        seed_idx: usize,
447        mut neighbors: Vec<usize>,
448        cluster_id: i32,
449        labels: &mut [i32],
450    ) {
451        labels[seed_idx] = cluster_id;
452
453        let mut i = 0;
454        while i < neighbors.len() {
455            let neighbor_idx = neighbors[i];
456
457            if labels[neighbor_idx] == -1 {
458                // Was noise, add to cluster
459                labels[neighbor_idx] = cluster_id;
460            }
461
462            if labels[neighbor_idx] != -1 && labels[neighbor_idx] != cluster_id {
463                i += 1;
464                continue; // Already in another cluster
465            }
466
467            labels[neighbor_idx] = cluster_id;
468
469            let neighbor_neighbors = self.region_query(vectors, neighbor_idx);
470
471            if neighbor_neighbors.len() >= self.config.min_points {
472                // Add new neighbors to explore
473                for &nn in &neighbor_neighbors {
474                    if !neighbors.contains(&nn) {
475                        neighbors.push(nn);
476                    }
477                }
478            }
479
480            i += 1;
481        }
482    }
483}
484
485/// Hierarchical clustering
486pub struct HierarchicalClustering {
487    config: HierarchicalConfig,
488}
489
490impl HierarchicalClustering {
491    /// Create new hierarchical clusterer
492    pub fn new(config: HierarchicalConfig) -> Self {
493        Self { config }
494    }
495
496    /// Fit the model and return clustering result
497    pub fn fit(&self, vectors: &[Vec<f32>]) -> Result<ClusteringResult> {
498        if vectors.is_empty() {
499            return Err(VecStoreError::Other(
500                "Cannot cluster empty vector set".to_string(),
501            ));
502        }
503
504        if vectors.len() < self.config.n_clusters {
505            return Err(VecStoreError::Other(format!(
506                "Number of vectors ({}) must be >= n_clusters ({})",
507                vectors.len(),
508                self.config.n_clusters
509            )));
510        }
511
512        let n = vectors.len();
513
514        // Initialize: each point is its own cluster
515        let mut clusters: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
516
517        // Build distance matrix
518        let mut distances = self.build_distance_matrix(vectors);
519
520        // Agglomerative clustering
521        while clusters.len() > self.config.n_clusters {
522            // Find closest pair of clusters
523            let (i, j) = self.find_closest_clusters(&clusters, &distances);
524
525            // Merge clusters i and j
526            let merged = self.merge_clusters(&mut clusters, i, j);
527
528            // Update distances
529            self.update_distances(&clusters, &merged, &mut distances, vectors);
530        }
531
532        // Convert cluster assignments to labels
533        let mut labels = vec![0; n];
534        for (cluster_id, cluster) in clusters.iter().enumerate() {
535            for &point_idx in cluster {
536                labels[point_idx] = cluster_id as i32;
537            }
538        }
539
540        Ok(ClusteringResult {
541            labels,
542            centroids: None,
543            inertia: 0.0,
544            iterations: n - self.config.n_clusters,
545            silhouette_score: None,
546        })
547    }
548
549    /// Build initial distance matrix
550    fn build_distance_matrix(&self, vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
551        let n = vectors.len();
552        let mut distances = vec![vec![0.0; n]; n];
553
554        for i in 0..n {
555            for j in (i + 1)..n {
556                let dist = euclidean_distance(&vectors[i], &vectors[j]);
557                distances[i][j] = dist;
558                distances[j][i] = dist;
559            }
560        }
561
562        distances
563    }
564
565    /// Find closest pair of clusters
566    fn find_closest_clusters(
567        &self,
568        clusters: &[Vec<usize>],
569        distances: &[Vec<f32>],
570    ) -> (usize, usize) {
571        let mut min_dist = f32::INFINITY;
572        let mut best_pair = (0, 0);
573
574        for i in 0..clusters.len() {
575            for j in (i + 1)..clusters.len() {
576                let dist = self.cluster_distance(&clusters[i], &clusters[j], distances);
577                if dist < min_dist {
578                    min_dist = dist;
579                    best_pair = (i, j);
580                }
581            }
582        }
583
584        best_pair
585    }
586
587    /// Calculate distance between two clusters based on linkage method
588    fn cluster_distance(
589        &self,
590        cluster1: &[usize],
591        cluster2: &[usize],
592        distances: &[Vec<f32>],
593    ) -> f32 {
594        match self.config.linkage {
595            LinkageMethod::Single => {
596                // Minimum distance
597                cluster1
598                    .iter()
599                    .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
600                    .fold(f32::INFINITY, f32::min)
601            }
602            LinkageMethod::Complete => {
603                // Maximum distance
604                cluster1
605                    .iter()
606                    .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
607                    .fold(0.0, f32::max)
608            }
609            LinkageMethod::Average => {
610                // Average distance
611                let sum: f32 = cluster1
612                    .iter()
613                    .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
614                    .sum();
615                sum / (cluster1.len() * cluster2.len()) as f32
616            }
617        }
618    }
619
620    /// Merge two clusters
621    fn merge_clusters(&self, clusters: &mut Vec<Vec<usize>>, i: usize, j: usize) -> Vec<usize> {
622        let (smaller, larger) = if i < j { (i, j) } else { (j, i) };
623
624        let mut merged = clusters.remove(larger);
625        merged.extend(clusters.remove(smaller));
626
627        clusters.push(merged.clone());
628        merged
629    }
630
631    /// Update distance matrix after merge (placeholder - simplified)
632    fn update_distances(
633        &self,
634        _clusters: &[Vec<usize>],
635        _merged: &[usize],
636        _distances: &mut [Vec<f32>],
637        _vectors: &[Vec<f32>],
638    ) {
639        // In a full implementation, we would update the distance matrix
640        // For now, we recalculate on demand in cluster_distance
641    }
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[test]
649    fn test_kmeans_simple() -> Result<()> {
650        let vectors = vec![
651            vec![1.0, 2.0],
652            vec![1.5, 1.8],
653            vec![5.0, 8.0],
654            vec![8.0, 8.0],
655            vec![1.0, 0.6],
656            vec![9.0, 11.0],
657        ];
658
659        let config = ClusteringConfig {
660            k: 2,
661            max_iterations: 100,
662            tolerance: 0.001,
663        };
664
665        let kmeans = KMeansClustering::new(config);
666        let result = kmeans.fit(&vectors)?;
667
668        assert_eq!(result.labels.len(), 6);
669        assert_eq!(result.centroids.as_ref().unwrap().len(), 2);
670
671        // Check that similar vectors are in the same cluster
672        assert_eq!(result.labels[0], result.labels[1]);
673        assert_eq!(result.labels[2], result.labels[3]);
674
675        Ok(())
676    }
677
678    #[test]
679    fn test_dbscan() -> Result<()> {
680        let vectors = vec![
681            vec![1.0, 2.0],
682            vec![1.5, 1.8],
683            vec![5.0, 8.0],
684            vec![8.0, 8.0],
685            vec![1.0, 0.6],
686            vec![9.0, 11.0],
687        ];
688
689        let config = DBSCANConfig {
690            eps: 2.0,
691            min_points: 2,
692        };
693
694        let dbscan = DBSCANClustering::new(config);
695        let result = dbscan.fit(&vectors)?;
696
697        assert_eq!(result.labels.len(), 6);
698
699        Ok(())
700    }
701
702    #[test]
703    fn test_hierarchical() -> Result<()> {
704        let vectors = vec![
705            vec![1.0, 2.0],
706            vec![1.5, 1.8],
707            vec![5.0, 8.0],
708            vec![8.0, 8.0],
709        ];
710
711        let config = HierarchicalConfig {
712            n_clusters: 2,
713            linkage: LinkageMethod::Average,
714        };
715
716        let hierarchical = HierarchicalClustering::new(config);
717        let result = hierarchical.fit(&vectors)?;
718
719        assert_eq!(result.labels.len(), 4);
720
721        // Check we have exactly 2 clusters
722        let unique_labels: HashSet<_> = result.labels.iter().collect();
723        assert_eq!(unique_labels.len(), 2);
724
725        Ok(())
726    }
727
728    #[test]
729    fn test_silhouette_score() -> Result<()> {
730        let vectors = vec![
731            vec![1.0, 1.0],
732            vec![1.5, 1.5],
733            vec![10.0, 10.0],
734            vec![10.5, 10.5],
735        ];
736
737        let config = ClusteringConfig {
738            k: 2,
739            max_iterations: 100,
740            tolerance: 0.001,
741        };
742
743        let kmeans = KMeansClustering::new(config);
744        let result = kmeans.fit(&vectors)?;
745
746        // Silhouette score should be high for well-separated clusters
747        assert!(result.silhouette_score.unwrap() > 0.5);
748
749        Ok(())
750    }
751}