1#[derive(Debug, Clone)]
7pub struct AlignmentResult {
8 pub rmsd: f64,
10 pub rotation: [[f64; 3]; 3],
12 pub translation: [f64; 3],
14 pub aligned_coords: Vec<f64>,
16}
17
18pub fn compute_rmsd(coords: &[f64], reference: &[f64]) -> f64 {
24 align_coordinates(coords, reference).rmsd
25}
26
27pub fn align_coordinates(coords: &[f64], reference: &[f64]) -> AlignmentResult {
31 assert_eq!(coords.len(), reference.len(), "coordinate length mismatch");
32 assert_eq!(coords.len() % 3, 0, "coordinates must be xyz triples");
33 let n = coords.len() / 3;
34
35 if n == 0 {
36 return AlignmentResult {
37 rmsd: 0.0,
38 rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
39 translation: [0.0; 3],
40 aligned_coords: Vec::new(),
41 };
42 }
43
44 let mut c1 = [0.0f64; 3];
46 let mut c2 = [0.0f64; 3];
47 for i in 0..n {
48 for k in 0..3 {
49 c1[k] += coords[i * 3 + k];
50 c2[k] += reference[i * 3 + k];
51 }
52 }
53 for k in 0..3 {
54 c1[k] /= n as f64;
55 c2[k] /= n as f64;
56 }
57
58 let mut h = [[0.0f64; 3]; 3];
60 for i in 0..n {
61 let p = [
62 coords[i * 3] - c1[0],
63 coords[i * 3 + 1] - c1[1],
64 coords[i * 3 + 2] - c1[2],
65 ];
66 let q = [
67 reference[i * 3] - c2[0],
68 reference[i * 3 + 1] - c2[1],
69 reference[i * 3 + 2] - c2[2],
70 ];
71 for r in 0..3 {
72 for c in 0..3 {
73 h[r][c] += p[r] * q[c];
74 }
75 }
76 }
77
78 let h_mat = nalgebra::Matrix3::new(
80 h[0][0], h[0][1], h[0][2], h[1][0], h[1][1], h[1][2], h[2][0], h[2][1], h[2][2],
81 );
82 let svd = h_mat.svd(true, true);
83 let u = svd.u.unwrap();
84 let v_t = svd.v_t.unwrap();
85 let v = v_t.transpose();
86
87 let mut d = nalgebra::Matrix3::<f64>::identity();
89 if (v * u.transpose()).determinant() < 0.0 {
90 d[(2, 2)] = -1.0;
91 }
92 let r_mat = v * d * u.transpose();
93
94 let rotation = [
96 [r_mat[(0, 0)], r_mat[(0, 1)], r_mat[(0, 2)]],
97 [r_mat[(1, 0)], r_mat[(1, 1)], r_mat[(1, 2)]],
98 [r_mat[(2, 0)], r_mat[(2, 1)], r_mat[(2, 2)]],
99 ];
100
101 let translation = [c2[0] - c1[0], c2[1] - c1[1], c2[2] - c1[2]];
102
103 let mut aligned = vec![0.0f64; coords.len()];
105 let mut sum_sq = 0.0;
106 for i in 0..n {
107 let p = [
108 coords[i * 3] - c1[0],
109 coords[i * 3 + 1] - c1[1],
110 coords[i * 3 + 2] - c1[2],
111 ];
112 for k in 0..3 {
113 let rotated = r_mat[(k, 0)] * p[0] + r_mat[(k, 1)] * p[1] + r_mat[(k, 2)] * p[2];
114 aligned[i * 3 + k] = rotated + c2[k];
115 }
116 for k in 0..3 {
117 let diff = aligned[i * 3 + k] - reference[i * 3 + k];
118 sum_sq += diff * diff;
119 }
120 }
121 let rmsd = (sum_sq / n as f64).sqrt();
122
123 AlignmentResult {
124 rmsd,
125 rotation,
126 translation,
127 aligned_coords: aligned,
128 }
129}
130
131pub fn align_quaternion(coords: &[f64], reference: &[f64]) -> AlignmentResult {
137 assert_eq!(coords.len(), reference.len(), "coordinate length mismatch");
138 assert_eq!(coords.len() % 3, 0, "coordinates must be xyz triples");
139 let n = coords.len() / 3;
140
141 if n == 0 {
142 return AlignmentResult {
143 rmsd: 0.0,
144 rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
145 translation: [0.0; 3],
146 aligned_coords: Vec::new(),
147 };
148 }
149
150 let mut c1 = [0.0f64; 3];
152 let mut c2 = [0.0f64; 3];
153 for i in 0..n {
154 for k in 0..3 {
155 c1[k] += coords[i * 3 + k];
156 c2[k] += reference[i * 3 + k];
157 }
158 }
159 for k in 0..3 {
160 c1[k] /= n as f64;
161 c2[k] /= n as f64;
162 }
163
164 let mut r = [[0.0f64; 3]; 3];
166 for i in 0..n {
167 let p = [
168 coords[i * 3] - c1[0],
169 coords[i * 3 + 1] - c1[1],
170 coords[i * 3 + 2] - c1[2],
171 ];
172 let q = [
173 reference[i * 3] - c2[0],
174 reference[i * 3 + 1] - c2[1],
175 reference[i * 3 + 2] - c2[2],
176 ];
177 for a in 0..3 {
178 for b in 0..3 {
179 r[a][b] += p[a] * q[b];
180 }
181 }
182 }
183
184 let sxx = r[0][0];
190 let sxy = r[0][1];
191 let sxz = r[0][2];
192 let syx = r[1][0];
193 let syy = r[1][1];
194 let syz = r[1][2];
195 let szx = r[2][0];
196 let szy = r[2][1];
197 let szz = r[2][2];
198
199 let f = nalgebra::Matrix4::new(
200 sxx + syy + szz,
201 syz - szy,
202 szx - sxz,
203 sxy - syx,
204 syz - szy,
205 sxx - syy - szz,
206 sxy + syx,
207 szx + sxz,
208 szx - sxz,
209 sxy + syx,
210 -sxx + syy - szz,
211 syz + szy,
212 sxy - syx,
213 szx + sxz,
214 syz + szy,
215 -sxx - syy + szz,
216 );
217
218 let eig = f.symmetric_eigen();
220 let mut best_idx = 0;
221 let mut best_val = eig.eigenvalues[0];
222 for i in 1..4 {
223 if eig.eigenvalues[i] > best_val {
224 best_val = eig.eigenvalues[i];
225 best_idx = i;
226 }
227 }
228
229 let q0 = eig.eigenvectors[(0, best_idx)];
230 let q1 = eig.eigenvectors[(1, best_idx)];
231 let q2 = eig.eigenvectors[(2, best_idx)];
232 let q3 = eig.eigenvectors[(3, best_idx)];
233
234 let rotation = [
236 [
237 q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3,
238 2.0 * (q1 * q2 - q0 * q3),
239 2.0 * (q1 * q3 + q0 * q2),
240 ],
241 [
242 2.0 * (q1 * q2 + q0 * q3),
243 q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3,
244 2.0 * (q2 * q3 - q0 * q1),
245 ],
246 [
247 2.0 * (q1 * q3 - q0 * q2),
248 2.0 * (q2 * q3 + q0 * q1),
249 q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3,
250 ],
251 ];
252
253 let translation = [c2[0] - c1[0], c2[1] - c1[1], c2[2] - c1[2]];
254
255 let mut aligned = vec![0.0f64; coords.len()];
257 let mut sum_sq = 0.0;
258 for i in 0..n {
259 let p = [
260 coords[i * 3] - c1[0],
261 coords[i * 3 + 1] - c1[1],
262 coords[i * 3 + 2] - c1[2],
263 ];
264 for k in 0..3 {
265 let rotated = rotation[k][0] * p[0] + rotation[k][1] * p[1] + rotation[k][2] * p[2];
266 aligned[i * 3 + k] = rotated + c2[k];
267 }
268 for k in 0..3 {
269 let diff = aligned[i * 3 + k] - reference[i * 3 + k];
270 sum_sq += diff * diff;
271 }
272 }
273 let rmsd = (sum_sq / n as f64).sqrt();
274
275 AlignmentResult {
276 rmsd,
277 rotation,
278 translation,
279 aligned_coords: aligned,
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_identical_zero_rmsd() {
289 let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
290 let rmsd = compute_rmsd(&coords, &coords);
291 assert!(rmsd < 1e-10);
292 }
293
294 #[test]
295 fn test_translated_zero_rmsd() {
296 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
298 let coords: Vec<f64> = reference.iter().map(|x| x + 5.0).collect();
299 let rmsd = compute_rmsd(&coords, &reference);
300 assert!(rmsd < 1e-10, "got rmsd = {rmsd}");
301 }
302
303 #[test]
304 fn test_rotation_90deg_z() {
305 let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0];
307 let rotated = vec![0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0];
309 let rmsd = compute_rmsd(&rotated, &reference);
310 assert!(rmsd < 1e-10, "got rmsd = {rmsd}");
311 }
312
313 #[test]
314 fn test_known_rmsd() {
315 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
317 let perturbed = vec![0.1, 0.0, 0.0, 1.0, 0.1, 0.0, 0.0, 1.0, 0.1];
318 let rmsd = compute_rmsd(&perturbed, &reference);
319 assert!(rmsd > 0.01);
320 assert!(rmsd < 1.0);
321 }
322
323 #[test]
324 fn test_aligned_coords_returned() {
325 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
326 let coords: Vec<f64> = reference.iter().map(|x| x + 10.0).collect();
327 let result = align_coordinates(&coords, &reference);
328 assert_eq!(result.aligned_coords.len(), 9);
329 for i in 0..9 {
331 assert!(
332 (result.aligned_coords[i] - reference[i]).abs() < 1e-8,
333 "mismatch at index {i}"
334 );
335 }
336 }
337
338 #[test]
339 fn test_reflection_handling() {
340 let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
342 let reflected = vec![-1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
343 let result = align_coordinates(&reflected, &reference);
344 assert!(result.rmsd.is_finite());
346 }
347
348 #[test]
349 fn test_quaternion_identical() {
350 let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
351 let result = align_quaternion(&coords, &coords);
352 assert!(result.rmsd < 1e-10);
353 }
354
355 #[test]
356 fn test_quaternion_translated() {
357 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
358 let coords: Vec<f64> = reference.iter().map(|x| x + 5.0).collect();
359 let result = align_quaternion(&coords, &reference);
360 assert!(result.rmsd < 1e-10, "got rmsd = {}", result.rmsd);
361 }
362
363 #[test]
364 fn test_quaternion_rotated_90() {
365 let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0];
366 let rotated = vec![0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0];
367 let result = align_quaternion(&rotated, &reference);
368 assert!(result.rmsd < 1e-10, "got rmsd = {}", result.rmsd);
369 }
370
371 #[test]
372 fn test_quaternion_matches_kabsch() {
373 let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.5, 0.5, 1.0];
375 let perturbed = vec![
376 0.1, -0.05, 0.02, 1.1, 0.1, -0.05, -0.1, 0.9, 0.1, 0.6, 0.4, 1.1,
377 ];
378
379 let kabsch = align_coordinates(&perturbed, &reference);
380 let quat = align_quaternion(&perturbed, &reference);
381
382 assert!(
383 (kabsch.rmsd - quat.rmsd).abs() < 1e-8,
384 "Kabsch RMSD = {}, Quaternion RMSD = {}",
385 kabsch.rmsd,
386 quat.rmsd,
387 );
388
389 for i in 0..reference.len() {
391 assert!(
392 (kabsch.aligned_coords[i] - quat.aligned_coords[i]).abs() < 1e-6,
393 "aligned mismatch at {}: {:.8} vs {:.8}",
394 i,
395 kabsch.aligned_coords[i],
396 quat.aligned_coords[i],
397 );
398 }
399 }
400}