1#[derive(Debug, Clone)]
7pub struct AlignmentResult {
8 pub rmsd: f64,
10 pub rotation: [[f64; 3]; 3],
12 pub translation: [f64; 3],
14 pub aligned_coords: Vec<f64>,
16}
17
18pub fn compute_rmsd(coords: &[f64], reference: &[f64]) -> f64 {
24 align_coordinates(coords, reference).rmsd
25}
26
27pub fn align_coordinates(coords: &[f64], reference: &[f64]) -> AlignmentResult {
31 if coords.len() != reference.len() || !coords.len().is_multiple_of(3) {
32 return AlignmentResult {
33 rmsd: f64::NAN,
34 rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
35 translation: [0.0; 3],
36 aligned_coords: coords.to_vec(),
37 };
38 }
39 let n = coords.len() / 3;
40
41 if n == 0 {
42 return AlignmentResult {
43 rmsd: 0.0,
44 rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
45 translation: [0.0; 3],
46 aligned_coords: Vec::new(),
47 };
48 }
49
50 let mut c1 = [0.0f64; 3];
52 let mut c2 = [0.0f64; 3];
53 for i in 0..n {
54 for k in 0..3 {
55 c1[k] += coords[i * 3 + k];
56 c2[k] += reference[i * 3 + k];
57 }
58 }
59 for k in 0..3 {
60 c1[k] /= n as f64;
61 c2[k] /= n as f64;
62 }
63
64 let mut h = [[0.0f64; 3]; 3];
66 for i in 0..n {
67 let p = [
68 coords[i * 3] - c1[0],
69 coords[i * 3 + 1] - c1[1],
70 coords[i * 3 + 2] - c1[2],
71 ];
72 let q = [
73 reference[i * 3] - c2[0],
74 reference[i * 3 + 1] - c2[1],
75 reference[i * 3 + 2] - c2[2],
76 ];
77 for r in 0..3 {
78 for c in 0..3 {
79 h[r][c] += p[r] * q[c];
80 }
81 }
82 }
83
84 let h_mat = nalgebra::Matrix3::new(
86 h[0][0], h[0][1], h[0][2], h[1][0], h[1][1], h[1][2], h[2][0], h[2][1], h[2][2],
87 );
88 let svd = h_mat.svd(true, true);
89 let (u, v_t) = match (svd.u, svd.v_t) {
90 (Some(u), Some(v_t)) => (u, v_t),
91 _ => {
92 let mut sum_sq = 0.0;
94 for i in 0..coords.len() {
95 let diff = coords[i] - reference[i];
96 sum_sq += diff * diff;
97 }
98 return AlignmentResult {
99 rmsd: (sum_sq / n as f64).sqrt(),
100 rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
101 translation: [0.0; 3],
102 aligned_coords: coords.to_vec(),
103 };
104 }
105 };
106 let v = v_t.transpose();
107
108 let mut d = nalgebra::Matrix3::<f64>::identity();
113 if (v * u.transpose()).determinant() < 0.0 {
114 d[(2, 2)] = -1.0;
115 }
116 let r_mat = v * d * u.transpose();
117
118 let rotation = [
120 [r_mat[(0, 0)], r_mat[(0, 1)], r_mat[(0, 2)]],
121 [r_mat[(1, 0)], r_mat[(1, 1)], r_mat[(1, 2)]],
122 [r_mat[(2, 0)], r_mat[(2, 1)], r_mat[(2, 2)]],
123 ];
124
125 let translation = [c2[0] - c1[0], c2[1] - c1[1], c2[2] - c1[2]];
126
127 let mut aligned = vec![0.0f64; coords.len()];
129 let mut sum_sq = 0.0;
130 for i in 0..n {
131 let p = [
132 coords[i * 3] - c1[0],
133 coords[i * 3 + 1] - c1[1],
134 coords[i * 3 + 2] - c1[2],
135 ];
136 for k in 0..3 {
137 let rotated = r_mat[(k, 0)] * p[0] + r_mat[(k, 1)] * p[1] + r_mat[(k, 2)] * p[2];
138 aligned[i * 3 + k] = rotated + c2[k];
139 }
140 for k in 0..3 {
141 let diff = aligned[i * 3 + k] - reference[i * 3 + k];
142 sum_sq += diff * diff;
143 }
144 }
145 let rmsd = (sum_sq / n as f64).sqrt();
146
147 AlignmentResult {
148 rmsd,
149 rotation,
150 translation,
151 aligned_coords: aligned,
152 }
153}
154
155pub fn align_quaternion(coords: &[f64], reference: &[f64]) -> AlignmentResult {
161 if coords.len() != reference.len() || !coords.len().is_multiple_of(3) {
162 return AlignmentResult {
163 rmsd: f64::NAN,
164 rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
165 translation: [0.0; 3],
166 aligned_coords: coords.to_vec(),
167 };
168 }
169 let n = coords.len() / 3;
170
171 if n == 0 {
172 return AlignmentResult {
173 rmsd: 0.0,
174 rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
175 translation: [0.0; 3],
176 aligned_coords: Vec::new(),
177 };
178 }
179
180 let mut c1 = [0.0f64; 3];
182 let mut c2 = [0.0f64; 3];
183 for i in 0..n {
184 for k in 0..3 {
185 c1[k] += coords[i * 3 + k];
186 c2[k] += reference[i * 3 + k];
187 }
188 }
189 for k in 0..3 {
190 c1[k] /= n as f64;
191 c2[k] /= n as f64;
192 }
193
194 let mut r = [[0.0f64; 3]; 3];
196 for i in 0..n {
197 let p = [
198 coords[i * 3] - c1[0],
199 coords[i * 3 + 1] - c1[1],
200 coords[i * 3 + 2] - c1[2],
201 ];
202 let q = [
203 reference[i * 3] - c2[0],
204 reference[i * 3 + 1] - c2[1],
205 reference[i * 3 + 2] - c2[2],
206 ];
207 for a in 0..3 {
208 for b in 0..3 {
209 r[a][b] += p[a] * q[b];
210 }
211 }
212 }
213
214 let sxx = r[0][0];
220 let sxy = r[0][1];
221 let sxz = r[0][2];
222 let syx = r[1][0];
223 let syy = r[1][1];
224 let syz = r[1][2];
225 let szx = r[2][0];
226 let szy = r[2][1];
227 let szz = r[2][2];
228
229 let f = nalgebra::Matrix4::new(
230 sxx + syy + szz,
231 syz - szy,
232 szx - sxz,
233 sxy - syx,
234 syz - szy,
235 sxx - syy - szz,
236 sxy + syx,
237 szx + sxz,
238 szx - sxz,
239 sxy + syx,
240 -sxx + syy - szz,
241 syz + szy,
242 sxy - syx,
243 szx + sxz,
244 syz + szy,
245 -sxx - syy + szz,
246 );
247
248 let eig = f.symmetric_eigen();
250 let mut best_idx = 0;
251 let mut best_val = eig.eigenvalues[0];
252 for i in 1..4 {
253 if eig.eigenvalues[i] > best_val {
254 best_val = eig.eigenvalues[i];
255 best_idx = i;
256 }
257 }
258
259 let q0 = eig.eigenvectors[(0, best_idx)];
260 let q1 = eig.eigenvectors[(1, best_idx)];
261 let q2 = eig.eigenvectors[(2, best_idx)];
262 let q3 = eig.eigenvectors[(3, best_idx)];
263
264 let rotation = [
266 [
267 q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3,
268 2.0 * (q1 * q2 - q0 * q3),
269 2.0 * (q1 * q3 + q0 * q2),
270 ],
271 [
272 2.0 * (q1 * q2 + q0 * q3),
273 q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3,
274 2.0 * (q2 * q3 - q0 * q1),
275 ],
276 [
277 2.0 * (q1 * q3 - q0 * q2),
278 2.0 * (q2 * q3 + q0 * q1),
279 q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3,
280 ],
281 ];
282
283 let translation = [c2[0] - c1[0], c2[1] - c1[1], c2[2] - c1[2]];
284
285 let mut aligned = vec![0.0f64; coords.len()];
287 let mut sum_sq = 0.0;
288 for i in 0..n {
289 let p = [
290 coords[i * 3] - c1[0],
291 coords[i * 3 + 1] - c1[1],
292 coords[i * 3 + 2] - c1[2],
293 ];
294 for k in 0..3 {
295 let rotated = rotation[k][0] * p[0] + rotation[k][1] * p[1] + rotation[k][2] * p[2];
296 aligned[i * 3 + k] = rotated + c2[k];
297 }
298 for k in 0..3 {
299 let diff = aligned[i * 3 + k] - reference[i * 3 + k];
300 sum_sq += diff * diff;
301 }
302 }
303 let rmsd = (sum_sq / n as f64).sqrt();
304
305 AlignmentResult {
306 rmsd,
307 rotation,
308 translation,
309 aligned_coords: aligned,
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_identical_zero_rmsd() {
319 let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
320 let rmsd = compute_rmsd(&coords, &coords);
321 assert!(rmsd < 1e-10);
322 }
323
324 #[test]
325 fn test_translated_zero_rmsd() {
326 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
328 let coords: Vec<f64> = reference.iter().map(|x| x + 5.0).collect();
329 let rmsd = compute_rmsd(&coords, &reference);
330 assert!(rmsd < 1e-10, "got rmsd = {rmsd}");
331 }
332
333 #[test]
334 fn test_rotation_90deg_z() {
335 let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0];
337 let rotated = vec![0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0];
339 let rmsd = compute_rmsd(&rotated, &reference);
340 assert!(rmsd < 1e-10, "got rmsd = {rmsd}");
341 }
342
343 #[test]
344 fn test_known_rmsd() {
345 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
347 let perturbed = vec![0.1, 0.0, 0.0, 1.0, 0.1, 0.0, 0.0, 1.0, 0.1];
348 let rmsd = compute_rmsd(&perturbed, &reference);
349 assert!(rmsd > 0.01);
350 assert!(rmsd < 1.0);
351 }
352
353 #[test]
354 fn test_aligned_coords_returned() {
355 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
356 let coords: Vec<f64> = reference.iter().map(|x| x + 10.0).collect();
357 let result = align_coordinates(&coords, &reference);
358 assert_eq!(result.aligned_coords.len(), 9);
359 for i in 0..9 {
361 assert!(
362 (result.aligned_coords[i] - reference[i]).abs() < 1e-8,
363 "mismatch at index {i}"
364 );
365 }
366 }
367
368 #[test]
369 fn test_reflection_handling() {
370 let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
372 let reflected = vec![-1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
373 let result = align_coordinates(&reflected, &reference);
374 assert!(result.rmsd.is_finite());
376 }
377
378 #[test]
379 fn test_quaternion_identical() {
380 let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
381 let result = align_quaternion(&coords, &coords);
382 assert!(result.rmsd < 1e-10);
383 }
384
385 #[test]
386 fn test_quaternion_translated() {
387 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
388 let coords: Vec<f64> = reference.iter().map(|x| x + 5.0).collect();
389 let result = align_quaternion(&coords, &reference);
390 assert!(result.rmsd < 1e-10, "got rmsd = {}", result.rmsd);
391 }
392
393 #[test]
394 fn test_quaternion_rotated_90() {
395 let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0];
396 let rotated = vec![0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0];
397 let result = align_quaternion(&rotated, &reference);
398 assert!(result.rmsd < 1e-10, "got rmsd = {}", result.rmsd);
399 }
400
401 #[test]
402 fn test_quaternion_matches_kabsch() {
403 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.5, 0.5, 1.0];
405 let perturbed = vec![
406 0.1, -0.05, 0.02, 1.1, 0.1, -0.05, -0.1, 0.9, 0.1, 0.6, 0.4, 1.1,
407 ];
408
409 let kabsch = align_coordinates(&perturbed, &reference);
410 let quat = align_quaternion(&perturbed, &reference);
411
412 assert!(
413 (kabsch.rmsd - quat.rmsd).abs() < 1e-8,
414 "Kabsch RMSD = {}, Quaternion RMSD = {}",
415 kabsch.rmsd,
416 quat.rmsd,
417 );
418
419 for i in 0..reference.len() {
421 assert!(
422 (kabsch.aligned_coords[i] - quat.aligned_coords[i]).abs() < 1e-6,
423 "aligned mismatch at {}: {:.8} vs {:.8}",
424 i,
425 kabsch.aligned_coords[i],
426 quat.aligned_coords[i],
427 );
428 }
429 }
430}