1use crate::alignment::kabsch;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ClusterResult {
11 pub n_clusters: usize,
13 pub assignments: Vec<usize>,
15 pub centroid_indices: Vec<usize>,
17 pub cluster_sizes: Vec<usize>,
19 pub rmsd_cutoff: f64,
21}
22
23pub fn compute_rmsd_matrix(conformers: &[Vec<f64>]) -> Vec<Vec<f64>> {
28 let n = conformers.len();
29 let mut matrix = vec![vec![0.0f64; n]; n];
30
31 for i in 0..n {
32 for j in (i + 1)..n {
33 let rmsd = kabsch::compute_rmsd(&conformers[i], &conformers[j]);
34 matrix[i][j] = rmsd;
35 matrix[j][i] = rmsd;
36 }
37 }
38
39 matrix
40}
41
42#[cfg(feature = "parallel")]
48pub fn compute_rmsd_matrix_parallel(conformers: &[Vec<f64>]) -> Vec<Vec<f64>> {
49 use rayon::prelude::*;
50
51 let n = conformers.len();
52 let pairs: Vec<(usize, usize)> = (0..n)
53 .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
54 .collect();
55
56 let results: Vec<(usize, usize, f64)> = pairs
57 .into_par_iter()
58 .map(|(i, j)| {
59 let rmsd = kabsch::compute_rmsd(&conformers[i], &conformers[j]);
60 (i, j, rmsd)
61 })
62 .collect();
63
64 let mut matrix = vec![vec![0.0f64; n]; n];
65 for (i, j, rmsd) in results {
66 matrix[i][j] = rmsd;
67 matrix[j][i] = rmsd;
68 }
69
70 matrix
71}
72
73pub fn butina_cluster(conformers: &[Vec<f64>], rmsd_cutoff: f64) -> ClusterResult {
88 let n = conformers.len();
89
90 if n == 0 {
91 return ClusterResult {
92 n_clusters: 0,
93 assignments: vec![],
94 centroid_indices: vec![],
95 cluster_sizes: vec![],
96 rmsd_cutoff,
97 };
98 }
99
100 if n == 1 {
101 return ClusterResult {
102 n_clusters: 1,
103 assignments: vec![0],
104 centroid_indices: vec![0],
105 cluster_sizes: vec![1],
106 rmsd_cutoff,
107 };
108 }
109
110 #[cfg(feature = "parallel")]
112 let rmsd_matrix = compute_rmsd_matrix_parallel(conformers);
113 #[cfg(not(feature = "parallel"))]
114 let rmsd_matrix = compute_rmsd_matrix(conformers);
115
116 let mut neighbor_counts: Vec<(usize, usize)> = (0..n)
118 .map(|i| {
119 let count = (0..n)
120 .filter(|&j| j != i && rmsd_matrix[i][j] <= rmsd_cutoff)
121 .count();
122 (i, count)
123 })
124 .collect();
125
126 let mut assignments = vec![usize::MAX; n];
128 let mut centroid_indices = Vec::new();
129 let mut cluster_sizes = Vec::new();
130 let mut assigned = vec![false; n];
131
132 loop {
133 neighbor_counts.sort_by(|a, b| b.1.cmp(&a.1));
135
136 let centroid = neighbor_counts
138 .iter()
139 .find(|(idx, _)| !assigned[*idx])
140 .map(|(idx, _)| *idx);
141
142 let centroid = match centroid {
143 Some(c) => c,
144 None => break,
145 };
146
147 let cluster_id = centroid_indices.len();
148 centroid_indices.push(centroid);
149 assigned[centroid] = true;
150 assignments[centroid] = cluster_id;
151 let mut size = 1;
152
153 for j in 0..n {
155 if !assigned[j] && rmsd_matrix[centroid][j] <= rmsd_cutoff {
156 assigned[j] = true;
157 assignments[j] = cluster_id;
158 size += 1;
159 }
160 }
161
162 cluster_sizes.push(size);
163
164 for entry in &mut neighbor_counts {
166 if assigned[entry.0] {
167 entry.1 = 0;
168 } else {
169 entry.1 = (0..n)
170 .filter(|&j| {
171 !assigned[j] && j != entry.0 && rmsd_matrix[entry.0][j] <= rmsd_cutoff
172 })
173 .count();
174 }
175 }
176
177 if assigned.iter().all(|&a| a) {
178 break;
179 }
180 }
181
182 ClusterResult {
183 n_clusters: centroid_indices.len(),
184 assignments,
185 centroid_indices,
186 cluster_sizes,
187 rmsd_cutoff,
188 }
189}
190
191pub fn filter_diverse_conformers(conformers: &[Vec<f64>], rmsd_cutoff: f64) -> Vec<usize> {
195 let result = butina_cluster(conformers, rmsd_cutoff);
196 result.centroid_indices
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_single_conformer() {
205 let conformers = vec![vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0]];
206 let result = butina_cluster(&conformers, 1.0);
207 assert_eq!(result.n_clusters, 1);
208 assert_eq!(result.assignments, vec![0]);
209 }
210
211 #[test]
212 fn test_identical_conformers() {
213 let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
214 let conformers = vec![coords.clone(), coords.clone(), coords.clone()];
215 let result = butina_cluster(&conformers, 0.5);
216 assert_eq!(result.n_clusters, 1);
218 assert_eq!(result.cluster_sizes, vec![3]);
219 }
220
221 #[test]
222 fn test_distinct_conformers() {
223 let c1 = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
225 let c2 = vec![0.0, 0.0, 0.0, 100.0, 0.0, 0.0];
226 let conformers = vec![c1, c2];
227 let result = butina_cluster(&conformers, 0.5);
228 assert_eq!(result.n_clusters, 2);
229 }
230
231 #[test]
232 fn test_empty_input() {
233 let conformers: Vec<Vec<f64>> = vec![];
234 let result = butina_cluster(&conformers, 1.0);
235 assert_eq!(result.n_clusters, 0);
236 }
237
238 #[test]
239 fn test_rmsd_matrix_symmetry() {
240 let c1 = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
241 let c2 = vec![0.0, 0.0, 0.0, 1.5, 0.0, 0.0];
242 let c3 = vec![0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
243 let conformers = vec![c1, c2, c3];
244 let matrix = compute_rmsd_matrix(&conformers);
245
246 for i in 0..3 {
247 assert!((matrix[i][i]).abs() < 1e-10, "diagonal must be 0");
248 for j in 0..3 {
249 assert!(
250 (matrix[i][j] - matrix[j][i]).abs() < 1e-10,
251 "matrix must be symmetric"
252 );
253 }
254 }
255 }
256}