Skip to main content

sci_form/
clustering.rs

1//! Butina (Taylor-Butina) clustering for conformer ensembles.
2//!
3//! Groups conformers by RMSD similarity using greedy distance-based clustering.
4
5use crate::alignment::kabsch;
6use serde::{Deserialize, Serialize};
7
8/// Result of Butina clustering on a conformer ensemble.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ClusterResult {
11    /// Number of clusters found.
12    pub n_clusters: usize,
13    /// Cluster assignment for each conformer (0-indexed cluster ID).
14    pub assignments: Vec<usize>,
15    /// Indices of cluster centroids (representative conformers).
16    pub centroid_indices: Vec<usize>,
17    /// Number of members in each cluster.
18    pub cluster_sizes: Vec<usize>,
19    /// RMSD cutoff used.
20    pub rmsd_cutoff: f64,
21}
22
23/// Compute the all-pairs RMSD matrix for a set of conformers.
24///
25/// `conformers`: slice of flat coordinate arrays `[x0,y0,z0, x1,y1,z1, ...]`.
26/// All conformers must have the same number of atoms.
27pub fn compute_rmsd_matrix(conformers: &[Vec<f64>]) -> Vec<Vec<f64>> {
28    let n = conformers.len();
29    let mut matrix = vec![vec![0.0f64; n]; n];
30
31    for i in 0..n {
32        for j in (i + 1)..n {
33            let rmsd = kabsch::compute_rmsd(&conformers[i], &conformers[j]);
34            matrix[i][j] = rmsd;
35            matrix[j][i] = rmsd;
36        }
37    }
38
39    matrix
40}
41
42/// Perform Butina (Taylor-Butina) clustering on a set of conformers.
43///
44/// Algorithm:
45/// 1. Compute the all-pairs RMSD matrix after Kabsch alignment
46/// 2. For each conformer, count neighbors within `rmsd_cutoff`
47/// 3. Select the conformer with the most neighbors as a cluster centroid
48/// 4. Remove all its neighbors from the pool; repeat until empty
49///
50/// # Arguments
51/// - `conformers`: slice of flat coordinate arrays
52/// - `rmsd_cutoff`: RMSD threshold for clustering (Å), typically 1.0
53///
54/// # Returns
55/// `ClusterResult` with cluster assignments, centroids, and sizes.
56pub fn butina_cluster(conformers: &[Vec<f64>], rmsd_cutoff: f64) -> ClusterResult {
57    let n = conformers.len();
58
59    if n == 0 {
60        return ClusterResult {
61            n_clusters: 0,
62            assignments: vec![],
63            centroid_indices: vec![],
64            cluster_sizes: vec![],
65            rmsd_cutoff,
66        };
67    }
68
69    if n == 1 {
70        return ClusterResult {
71            n_clusters: 1,
72            assignments: vec![0],
73            centroid_indices: vec![0],
74            cluster_sizes: vec![1],
75            rmsd_cutoff,
76        };
77    }
78
79    // 1. Compute RMSD matrix
80    let rmsd_matrix = compute_rmsd_matrix(conformers);
81
82    // 2. Build neighbor lists
83    let mut neighbor_counts: Vec<(usize, usize)> = (0..n)
84        .map(|i| {
85            let count = (0..n)
86                .filter(|&j| j != i && rmsd_matrix[i][j] <= rmsd_cutoff)
87                .count();
88            (i, count)
89        })
90        .collect();
91
92    // 3. Greedy clustering
93    let mut assignments = vec![usize::MAX; n];
94    let mut centroid_indices = Vec::new();
95    let mut cluster_sizes = Vec::new();
96    let mut assigned = vec![false; n];
97
98    loop {
99        // Sort by neighbor count (descending)
100        neighbor_counts.sort_by(|a, b| b.1.cmp(&a.1));
101
102        // Find the unassigned conformer with the most neighbors
103        let centroid = neighbor_counts
104            .iter()
105            .find(|(idx, _)| !assigned[*idx])
106            .map(|(idx, _)| *idx);
107
108        let centroid = match centroid {
109            Some(c) => c,
110            None => break,
111        };
112
113        let cluster_id = centroid_indices.len();
114        centroid_indices.push(centroid);
115        assigned[centroid] = true;
116        assignments[centroid] = cluster_id;
117        let mut size = 1;
118
119        // Assign all unassigned neighbors to this cluster
120        for j in 0..n {
121            if !assigned[j] && rmsd_matrix[centroid][j] <= rmsd_cutoff {
122                assigned[j] = true;
123                assignments[j] = cluster_id;
124                size += 1;
125            }
126        }
127
128        cluster_sizes.push(size);
129
130        // Update neighbor counts for remaining unassigned
131        for entry in &mut neighbor_counts {
132            if assigned[entry.0] {
133                entry.1 = 0;
134            } else {
135                entry.1 = (0..n)
136                    .filter(|&j| {
137                        !assigned[j] && j != entry.0 && rmsd_matrix[entry.0][j] <= rmsd_cutoff
138                    })
139                    .count();
140            }
141        }
142
143        if assigned.iter().all(|&a| a) {
144            break;
145        }
146    }
147
148    ClusterResult {
149        n_clusters: centroid_indices.len(),
150        assignments,
151        centroid_indices,
152        cluster_sizes,
153        rmsd_cutoff,
154    }
155}
156
157/// Filter a conformer ensemble to keep only cluster centroids.
158///
159/// Returns the indices of representative conformers.
160pub fn filter_diverse_conformers(conformers: &[Vec<f64>], rmsd_cutoff: f64) -> Vec<usize> {
161    let result = butina_cluster(conformers, rmsd_cutoff);
162    result.centroid_indices
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_single_conformer() {
171        let conformers = vec![vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0]];
172        let result = butina_cluster(&conformers, 1.0);
173        assert_eq!(result.n_clusters, 1);
174        assert_eq!(result.assignments, vec![0]);
175    }
176
177    #[test]
178    fn test_identical_conformers() {
179        let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
180        let conformers = vec![coords.clone(), coords.clone(), coords.clone()];
181        let result = butina_cluster(&conformers, 0.5);
182        // All identical → 1 cluster
183        assert_eq!(result.n_clusters, 1);
184        assert_eq!(result.cluster_sizes, vec![3]);
185    }
186
187    #[test]
188    fn test_distinct_conformers() {
189        // Two very different conformers (large RMSD)
190        let c1 = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
191        let c2 = vec![0.0, 0.0, 0.0, 100.0, 0.0, 0.0];
192        let conformers = vec![c1, c2];
193        let result = butina_cluster(&conformers, 0.5);
194        assert_eq!(result.n_clusters, 2);
195    }
196
197    #[test]
198    fn test_empty_input() {
199        let conformers: Vec<Vec<f64>> = vec![];
200        let result = butina_cluster(&conformers, 1.0);
201        assert_eq!(result.n_clusters, 0);
202    }
203
204    #[test]
205    fn test_rmsd_matrix_symmetry() {
206        let c1 = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
207        let c2 = vec![0.0, 0.0, 0.0, 1.5, 0.0, 0.0];
208        let c3 = vec![0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
209        let conformers = vec![c1, c2, c3];
210        let matrix = compute_rmsd_matrix(&conformers);
211
212        for i in 0..3 {
213            assert!((matrix[i][i]).abs() < 1e-10, "diagonal must be 0");
214            for j in 0..3 {
215                assert!(
216                    (matrix[i][j] - matrix[j][i]).abs() < 1e-10,
217                    "matrix must be symmetric"
218                );
219            }
220        }
221    }
222}