1use crate::alignment::kabsch;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ClusterResult {
11 pub n_clusters: usize,
13 pub assignments: Vec<usize>,
15 pub centroid_indices: Vec<usize>,
17 pub cluster_sizes: Vec<usize>,
19 pub rmsd_cutoff: f64,
21}
22
23pub fn compute_rmsd_matrix(conformers: &[Vec<f64>]) -> Vec<Vec<f64>> {
28 let n = conformers.len();
29 let mut matrix = vec![vec![0.0f64; n]; n];
30
31 for i in 0..n {
32 for j in (i + 1)..n {
33 let rmsd = kabsch::compute_rmsd(&conformers[i], &conformers[j]);
34 matrix[i][j] = rmsd;
35 matrix[j][i] = rmsd;
36 }
37 }
38
39 matrix
40}
41
42pub fn butina_cluster(conformers: &[Vec<f64>], rmsd_cutoff: f64) -> ClusterResult {
57 let n = conformers.len();
58
59 if n == 0 {
60 return ClusterResult {
61 n_clusters: 0,
62 assignments: vec![],
63 centroid_indices: vec![],
64 cluster_sizes: vec![],
65 rmsd_cutoff,
66 };
67 }
68
69 if n == 1 {
70 return ClusterResult {
71 n_clusters: 1,
72 assignments: vec![0],
73 centroid_indices: vec![0],
74 cluster_sizes: vec![1],
75 rmsd_cutoff,
76 };
77 }
78
79 let rmsd_matrix = compute_rmsd_matrix(conformers);
81
82 let mut neighbor_counts: Vec<(usize, usize)> = (0..n)
84 .map(|i| {
85 let count = (0..n)
86 .filter(|&j| j != i && rmsd_matrix[i][j] <= rmsd_cutoff)
87 .count();
88 (i, count)
89 })
90 .collect();
91
92 let mut assignments = vec![usize::MAX; n];
94 let mut centroid_indices = Vec::new();
95 let mut cluster_sizes = Vec::new();
96 let mut assigned = vec![false; n];
97
98 loop {
99 neighbor_counts.sort_by(|a, b| b.1.cmp(&a.1));
101
102 let centroid = neighbor_counts
104 .iter()
105 .find(|(idx, _)| !assigned[*idx])
106 .map(|(idx, _)| *idx);
107
108 let centroid = match centroid {
109 Some(c) => c,
110 None => break,
111 };
112
113 let cluster_id = centroid_indices.len();
114 centroid_indices.push(centroid);
115 assigned[centroid] = true;
116 assignments[centroid] = cluster_id;
117 let mut size = 1;
118
119 for j in 0..n {
121 if !assigned[j] && rmsd_matrix[centroid][j] <= rmsd_cutoff {
122 assigned[j] = true;
123 assignments[j] = cluster_id;
124 size += 1;
125 }
126 }
127
128 cluster_sizes.push(size);
129
130 for entry in &mut neighbor_counts {
132 if assigned[entry.0] {
133 entry.1 = 0;
134 } else {
135 entry.1 = (0..n)
136 .filter(|&j| {
137 !assigned[j] && j != entry.0 && rmsd_matrix[entry.0][j] <= rmsd_cutoff
138 })
139 .count();
140 }
141 }
142
143 if assigned.iter().all(|&a| a) {
144 break;
145 }
146 }
147
148 ClusterResult {
149 n_clusters: centroid_indices.len(),
150 assignments,
151 centroid_indices,
152 cluster_sizes,
153 rmsd_cutoff,
154 }
155}
156
157pub fn filter_diverse_conformers(conformers: &[Vec<f64>], rmsd_cutoff: f64) -> Vec<usize> {
161 let result = butina_cluster(conformers, rmsd_cutoff);
162 result.centroid_indices
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_single_conformer() {
171 let conformers = vec![vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0]];
172 let result = butina_cluster(&conformers, 1.0);
173 assert_eq!(result.n_clusters, 1);
174 assert_eq!(result.assignments, vec![0]);
175 }
176
177 #[test]
178 fn test_identical_conformers() {
179 let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
180 let conformers = vec![coords.clone(), coords.clone(), coords.clone()];
181 let result = butina_cluster(&conformers, 0.5);
182 assert_eq!(result.n_clusters, 1);
184 assert_eq!(result.cluster_sizes, vec![3]);
185 }
186
187 #[test]
188 fn test_distinct_conformers() {
189 let c1 = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
191 let c2 = vec![0.0, 0.0, 0.0, 100.0, 0.0, 0.0];
192 let conformers = vec![c1, c2];
193 let result = butina_cluster(&conformers, 0.5);
194 assert_eq!(result.n_clusters, 2);
195 }
196
197 #[test]
198 fn test_empty_input() {
199 let conformers: Vec<Vec<f64>> = vec![];
200 let result = butina_cluster(&conformers, 1.0);
201 assert_eq!(result.n_clusters, 0);
202 }
203
204 #[test]
205 fn test_rmsd_matrix_symmetry() {
206 let c1 = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
207 let c2 = vec![0.0, 0.0, 0.0, 1.5, 0.0, 0.0];
208 let c3 = vec![0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
209 let conformers = vec![c1, c2, c3];
210 let matrix = compute_rmsd_matrix(&conformers);
211
212 for i in 0..3 {
213 assert!((matrix[i][i]).abs() < 1e-10, "diagonal must be 0");
214 for j in 0..3 {
215 assert!(
216 (matrix[i][j] - matrix[j][i]).abs() < 1e-10,
217 "matrix must be symmetric"
218 );
219 }
220 }
221 }
222}