Skip to main content

sci_form/
clustering.rs

1//! Butina (Taylor-Butina) clustering for conformer ensembles.
2//!
3//! Groups conformers by RMSD similarity using greedy distance-based clustering.
4
5use crate::alignment::kabsch;
6use serde::{Deserialize, Serialize};
7
8/// Result of Butina clustering on a conformer ensemble.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ClusterResult {
11    /// Number of clusters found.
12    pub n_clusters: usize,
13    /// Cluster assignment for each conformer (0-indexed cluster ID).
14    pub assignments: Vec<usize>,
15    /// Indices of cluster centroids (representative conformers).
16    pub centroid_indices: Vec<usize>,
17    /// Number of members in each cluster.
18    pub cluster_sizes: Vec<usize>,
19    /// RMSD cutoff used.
20    pub rmsd_cutoff: f64,
21}
22
23/// Compute the all-pairs RMSD matrix for a set of conformers.
24///
25/// `conformers`: slice of flat coordinate arrays `[x0,y0,z0, x1,y1,z1, ...]`.
26/// All conformers must have the same number of atoms.
27pub fn compute_rmsd_matrix(conformers: &[Vec<f64>]) -> Vec<Vec<f64>> {
28    let n = conformers.len();
29    let mut matrix = vec![vec![0.0f64; n]; n];
30
31    for i in 0..n {
32        for j in (i + 1)..n {
33            let rmsd = kabsch::compute_rmsd(&conformers[i], &conformers[j]);
34            matrix[i][j] = rmsd;
35            matrix[j][i] = rmsd;
36        }
37    }
38
39    matrix
40}
41
42/// Parallel all-pairs RMSD matrix using rayon.
43///
44/// The O(N²) pairs are computed in parallel since each Kabsch RMSD
45/// is independent. Results are collected and assembled into the
46/// symmetric matrix.
47#[cfg(feature = "parallel")]
48pub fn compute_rmsd_matrix_parallel(conformers: &[Vec<f64>]) -> Vec<Vec<f64>> {
49    use rayon::prelude::*;
50
51    let n = conformers.len();
52    let pairs: Vec<(usize, usize)> = (0..n)
53        .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
54        .collect();
55
56    let results: Vec<(usize, usize, f64)> = pairs
57        .into_par_iter()
58        .map(|(i, j)| {
59            let rmsd = kabsch::compute_rmsd(&conformers[i], &conformers[j]);
60            (i, j, rmsd)
61        })
62        .collect();
63
64    let mut matrix = vec![vec![0.0f64; n]; n];
65    for (i, j, rmsd) in results {
66        matrix[i][j] = rmsd;
67        matrix[j][i] = rmsd;
68    }
69
70    matrix
71}
72
73/// Perform Butina (Taylor-Butina) clustering on a set of conformers.
74///
75/// Algorithm:
76/// 1. Compute the all-pairs RMSD matrix after Kabsch alignment
77/// 2. For each conformer, count neighbors within `rmsd_cutoff`
78/// 3. Select the conformer with the most neighbors as a cluster centroid
79/// 4. Remove all its neighbors from the pool; repeat until empty
80///
81/// # Arguments
82/// - `conformers`: slice of flat coordinate arrays
83/// - `rmsd_cutoff`: RMSD threshold for clustering (Å), typically 1.0
84///
85/// # Returns
86/// `ClusterResult` with cluster assignments, centroids, and sizes.
87pub fn butina_cluster(conformers: &[Vec<f64>], rmsd_cutoff: f64) -> ClusterResult {
88    let n = conformers.len();
89
90    if n == 0 {
91        return ClusterResult {
92            n_clusters: 0,
93            assignments: vec![],
94            centroid_indices: vec![],
95            cluster_sizes: vec![],
96            rmsd_cutoff,
97        };
98    }
99
100    if n == 1 {
101        return ClusterResult {
102            n_clusters: 1,
103            assignments: vec![0],
104            centroid_indices: vec![0],
105            cluster_sizes: vec![1],
106            rmsd_cutoff,
107        };
108    }
109
110    // 1. Compute RMSD matrix (auto-dispatch to parallel when available)
111    #[cfg(feature = "parallel")]
112    let rmsd_matrix = compute_rmsd_matrix_parallel(conformers);
113    #[cfg(not(feature = "parallel"))]
114    let rmsd_matrix = compute_rmsd_matrix(conformers);
115
116    // 2. Build neighbor lists
117    let mut neighbor_counts: Vec<(usize, usize)> = (0..n)
118        .map(|i| {
119            let count = (0..n)
120                .filter(|&j| j != i && rmsd_matrix[i][j] <= rmsd_cutoff)
121                .count();
122            (i, count)
123        })
124        .collect();
125
126    // 3. Greedy clustering
127    let mut assignments = vec![usize::MAX; n];
128    let mut centroid_indices = Vec::new();
129    let mut cluster_sizes = Vec::new();
130    let mut assigned = vec![false; n];
131
132    loop {
133        // Sort by neighbor count (descending)
134        neighbor_counts.sort_by(|a, b| b.1.cmp(&a.1));
135
136        // Find the unassigned conformer with the most neighbors
137        let centroid = neighbor_counts
138            .iter()
139            .find(|(idx, _)| !assigned[*idx])
140            .map(|(idx, _)| *idx);
141
142        let centroid = match centroid {
143            Some(c) => c,
144            None => break,
145        };
146
147        let cluster_id = centroid_indices.len();
148        centroid_indices.push(centroid);
149        assigned[centroid] = true;
150        assignments[centroid] = cluster_id;
151        let mut size = 1;
152
153        // Assign all unassigned neighbors to this cluster
154        for j in 0..n {
155            if !assigned[j] && rmsd_matrix[centroid][j] <= rmsd_cutoff {
156                assigned[j] = true;
157                assignments[j] = cluster_id;
158                size += 1;
159            }
160        }
161
162        cluster_sizes.push(size);
163
164        // Update neighbor counts for remaining unassigned
165        for entry in &mut neighbor_counts {
166            if assigned[entry.0] {
167                entry.1 = 0;
168            } else {
169                entry.1 = (0..n)
170                    .filter(|&j| {
171                        !assigned[j] && j != entry.0 && rmsd_matrix[entry.0][j] <= rmsd_cutoff
172                    })
173                    .count();
174            }
175        }
176
177        if assigned.iter().all(|&a| a) {
178            break;
179        }
180    }
181
182    ClusterResult {
183        n_clusters: centroid_indices.len(),
184        assignments,
185        centroid_indices,
186        cluster_sizes,
187        rmsd_cutoff,
188    }
189}
190
191/// Filter a conformer ensemble to keep only cluster centroids.
192///
193/// Returns the indices of representative conformers.
194pub fn filter_diverse_conformers(conformers: &[Vec<f64>], rmsd_cutoff: f64) -> Vec<usize> {
195    let result = butina_cluster(conformers, rmsd_cutoff);
196    result.centroid_indices
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_single_conformer() {
205        let conformers = vec![vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0]];
206        let result = butina_cluster(&conformers, 1.0);
207        assert_eq!(result.n_clusters, 1);
208        assert_eq!(result.assignments, vec![0]);
209    }
210
211    #[test]
212    fn test_identical_conformers() {
213        let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
214        let conformers = vec![coords.clone(), coords.clone(), coords.clone()];
215        let result = butina_cluster(&conformers, 0.5);
216        // All identical → 1 cluster
217        assert_eq!(result.n_clusters, 1);
218        assert_eq!(result.cluster_sizes, vec![3]);
219    }
220
221    #[test]
222    fn test_distinct_conformers() {
223        // Two very different conformers (large RMSD)
224        let c1 = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
225        let c2 = vec![0.0, 0.0, 0.0, 100.0, 0.0, 0.0];
226        let conformers = vec![c1, c2];
227        let result = butina_cluster(&conformers, 0.5);
228        assert_eq!(result.n_clusters, 2);
229    }
230
231    #[test]
232    fn test_empty_input() {
233        let conformers: Vec<Vec<f64>> = vec![];
234        let result = butina_cluster(&conformers, 1.0);
235        assert_eq!(result.n_clusters, 0);
236    }
237
238    #[test]
239    fn test_rmsd_matrix_symmetry() {
240        let c1 = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
241        let c2 = vec![0.0, 0.0, 0.0, 1.5, 0.0, 0.0];
242        let c3 = vec![0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
243        let conformers = vec![c1, c2, c3];
244        let matrix = compute_rmsd_matrix(&conformers);
245
246        for i in 0..3 {
247            assert!((matrix[i][i]).abs() < 1e-10, "diagonal must be 0");
248            for j in 0..3 {
249                assert!(
250                    (matrix[i][j] - matrix[j][i]).abs() < 1e-10,
251                    "matrix must be symmetric"
252                );
253            }
254        }
255    }
256}