Skip to main content

sci_form/alignment/
kabsch.rs

1//! Kabsch algorithm for optimal rotation alignment and RMSD computation.
2//!
3//! Provides a clean f64 public API for molecular alignment.
4
5/// Result of a Kabsch alignment.
6#[derive(Debug, Clone)]
7pub struct AlignmentResult {
8    /// RMSD after optimal alignment (Å).
9    pub rmsd: f64,
10    /// 3×3 rotation matrix (row-major).
11    pub rotation: [[f64; 3]; 3],
12    /// Translation applied to `coords` centroid.
13    pub translation: [f64; 3],
14    /// Aligned coordinates: flat [x0,y0,z0, x1,...].
15    pub aligned_coords: Vec<f64>,
16}
17
18/// Compute RMSD between two conformers after Kabsch alignment.
19///
20/// `coords`: flat [x0,y0,z0, x1,y1,z1,...] mobile structure.
21/// `reference`: flat [x0,y0,z0,...] reference structure.
22/// Both must have the same number of atoms.
23pub fn compute_rmsd(coords: &[f64], reference: &[f64]) -> f64 {
24    align_coordinates(coords, reference).rmsd
25}
26
27/// Kabsch alignment: find optimal rotation mapping `coords` onto `reference`.
28///
29/// `coords`, `reference`: flat [x0,y0,z0, x1,y1,z1,...] in Å.
30pub fn align_coordinates(coords: &[f64], reference: &[f64]) -> AlignmentResult {
31    if coords.len() != reference.len() || !coords.len().is_multiple_of(3) {
32        return AlignmentResult {
33            rmsd: f64::NAN,
34            rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
35            translation: [0.0; 3],
36            aligned_coords: coords.to_vec(),
37        };
38    }
39    let n = coords.len() / 3;
40
41    if n == 0 {
42        return AlignmentResult {
43            rmsd: 0.0,
44            rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
45            translation: [0.0; 3],
46            aligned_coords: Vec::new(),
47        };
48    }
49
50    // Compute centroids
51    let mut c1 = [0.0f64; 3];
52    let mut c2 = [0.0f64; 3];
53    for i in 0..n {
54        for k in 0..3 {
55            c1[k] += coords[i * 3 + k];
56            c2[k] += reference[i * 3 + k];
57        }
58    }
59    for k in 0..3 {
60        c1[k] /= n as f64;
61        c2[k] /= n as f64;
62    }
63
64    // Build H = Σ (p_i - c1)(q_i - c2)^T  (3×3)
65    let mut h = [[0.0f64; 3]; 3];
66    for i in 0..n {
67        let p = [
68            coords[i * 3] - c1[0],
69            coords[i * 3 + 1] - c1[1],
70            coords[i * 3 + 2] - c1[2],
71        ];
72        let q = [
73            reference[i * 3] - c2[0],
74            reference[i * 3 + 1] - c2[1],
75            reference[i * 3 + 2] - c2[2],
76        ];
77        for r in 0..3 {
78            for c in 0..3 {
79                h[r][c] += p[r] * q[c];
80            }
81        }
82    }
83
84    // SVD via nalgebra
85    let h_mat = nalgebra::Matrix3::new(
86        h[0][0], h[0][1], h[0][2], h[1][0], h[1][1], h[1][2], h[2][0], h[2][1], h[2][2],
87    );
88    let svd = h_mat.svd(true, true);
89    let (u, v_t) = match (svd.u, svd.v_t) {
90        (Some(u), Some(v_t)) => (u, v_t),
91        _ => {
92            // SVD failed — return identity alignment with raw RMSD
93            let mut sum_sq = 0.0;
94            for i in 0..coords.len() {
95                let diff = coords[i] - reference[i];
96                sum_sq += diff * diff;
97            }
98            return AlignmentResult {
99                rmsd: (sum_sq / n as f64).sqrt(),
100                rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
101                translation: [0.0; 3],
102                aligned_coords: coords.to_vec(),
103            };
104        }
105    };
106    let v = v_t.transpose();
107
108    // Handle reflection: correct the sign of the smallest singular value
109    // to ensure a proper rotation (det(R) = +1).
110    // For coplanar molecules (one singular value ≈ 0), this prevents
111    // an improper rotation (reflection) from being selected.
112    let mut d = nalgebra::Matrix3::<f64>::identity();
113    if (v * u.transpose()).determinant() < 0.0 {
114        d[(2, 2)] = -1.0;
115    }
116    let r_mat = v * d * u.transpose();
117
118    // Build rotation as row-major array
119    let rotation = [
120        [r_mat[(0, 0)], r_mat[(0, 1)], r_mat[(0, 2)]],
121        [r_mat[(1, 0)], r_mat[(1, 1)], r_mat[(1, 2)]],
122        [r_mat[(2, 0)], r_mat[(2, 1)], r_mat[(2, 2)]],
123    ];
124
125    let translation = [c2[0] - c1[0], c2[1] - c1[1], c2[2] - c1[2]];
126
127    // Apply rotation and compute RMSD
128    let mut aligned = vec![0.0f64; coords.len()];
129    let mut sum_sq = 0.0;
130    for i in 0..n {
131        let p = [
132            coords[i * 3] - c1[0],
133            coords[i * 3 + 1] - c1[1],
134            coords[i * 3 + 2] - c1[2],
135        ];
136        for k in 0..3 {
137            let rotated = r_mat[(k, 0)] * p[0] + r_mat[(k, 1)] * p[1] + r_mat[(k, 2)] * p[2];
138            aligned[i * 3 + k] = rotated + c2[k];
139        }
140        for k in 0..3 {
141            let diff = aligned[i * 3 + k] - reference[i * 3 + k];
142            sum_sq += diff * diff;
143        }
144    }
145    let rmsd = (sum_sq / n as f64).sqrt();
146
147    AlignmentResult {
148        rmsd,
149        rotation,
150        translation,
151        aligned_coords: aligned,
152    }
153}
154
155/// Quaternion-based optimal rotation alignment (Coutsias et al. 2004).
156///
157/// Uses the quaternion eigenvector method which is numerically more stable
158/// than SVD for near-degenerate cases.  Produces the same result as Kabsch
159/// but avoids explicit SVD decomposition.
160pub fn align_quaternion(coords: &[f64], reference: &[f64]) -> AlignmentResult {
161    if coords.len() != reference.len() || !coords.len().is_multiple_of(3) {
162        return AlignmentResult {
163            rmsd: f64::NAN,
164            rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
165            translation: [0.0; 3],
166            aligned_coords: coords.to_vec(),
167        };
168    }
169    let n = coords.len() / 3;
170
171    if n == 0 {
172        return AlignmentResult {
173            rmsd: 0.0,
174            rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
175            translation: [0.0; 3],
176            aligned_coords: Vec::new(),
177        };
178    }
179
180    // Centroids
181    let mut c1 = [0.0f64; 3];
182    let mut c2 = [0.0f64; 3];
183    for i in 0..n {
184        for k in 0..3 {
185            c1[k] += coords[i * 3 + k];
186            c2[k] += reference[i * 3 + k];
187        }
188    }
189    for k in 0..3 {
190        c1[k] /= n as f64;
191        c2[k] /= n as f64;
192    }
193
194    // Build cross-covariance elements: R_ij = Σ (p_i - c1_i)(q_j - c2_j)
195    let mut r = [[0.0f64; 3]; 3];
196    for i in 0..n {
197        let p = [
198            coords[i * 3] - c1[0],
199            coords[i * 3 + 1] - c1[1],
200            coords[i * 3 + 2] - c1[2],
201        ];
202        let q = [
203            reference[i * 3] - c2[0],
204            reference[i * 3 + 1] - c2[1],
205            reference[i * 3 + 2] - c2[2],
206        ];
207        for a in 0..3 {
208            for b in 0..3 {
209                r[a][b] += p[a] * q[b];
210            }
211        }
212    }
213
214    // Build the 4×4 symmetric key matrix (Davenport/Coutsias):
215    //   F = [[ Sxx+Syy+Szz, Syz-Szy, Szx-Sxz, Sxy-Syx ],
216    //        [ Syz-Szy, Sxx-Syy-Szz, Sxy+Syx, Szx+Sxz ],
217    //        [ Szx-Sxz, Sxy+Syx, -Sxx+Syy-Szz, Syz+Szy ],
218    //        [ Sxy-Syx, Szx+Sxz, Syz+Szy, -Sxx-Syy+Szz ]]
219    let sxx = r[0][0];
220    let sxy = r[0][1];
221    let sxz = r[0][2];
222    let syx = r[1][0];
223    let syy = r[1][1];
224    let syz = r[1][2];
225    let szx = r[2][0];
226    let szy = r[2][1];
227    let szz = r[2][2];
228
229    let f = nalgebra::Matrix4::new(
230        sxx + syy + szz,
231        syz - szy,
232        szx - sxz,
233        sxy - syx,
234        syz - szy,
235        sxx - syy - szz,
236        sxy + syx,
237        szx + sxz,
238        szx - sxz,
239        sxy + syx,
240        -sxx + syy - szz,
241        syz + szy,
242        sxy - syx,
243        szx + sxz,
244        syz + szy,
245        -sxx - syy + szz,
246    );
247
248    // The optimal rotation quaternion is the eigenvector of F with the largest eigenvalue
249    let eig = f.symmetric_eigen();
250    let mut best_idx = 0;
251    let mut best_val = eig.eigenvalues[0];
252    for i in 1..4 {
253        if eig.eigenvalues[i] > best_val {
254            best_val = eig.eigenvalues[i];
255            best_idx = i;
256        }
257    }
258
259    let q0 = eig.eigenvectors[(0, best_idx)];
260    let q1 = eig.eigenvectors[(1, best_idx)];
261    let q2 = eig.eigenvectors[(2, best_idx)];
262    let q3 = eig.eigenvectors[(3, best_idx)];
263
264    // quaternion → rotation matrix
265    let rotation = [
266        [
267            q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3,
268            2.0 * (q1 * q2 - q0 * q3),
269            2.0 * (q1 * q3 + q0 * q2),
270        ],
271        [
272            2.0 * (q1 * q2 + q0 * q3),
273            q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3,
274            2.0 * (q2 * q3 - q0 * q1),
275        ],
276        [
277            2.0 * (q1 * q3 - q0 * q2),
278            2.0 * (q2 * q3 + q0 * q1),
279            q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3,
280        ],
281    ];
282
283    let translation = [c2[0] - c1[0], c2[1] - c1[1], c2[2] - c1[2]];
284
285    // Apply rotation and compute RMSD
286    let mut aligned = vec![0.0f64; coords.len()];
287    let mut sum_sq = 0.0;
288    for i in 0..n {
289        let p = [
290            coords[i * 3] - c1[0],
291            coords[i * 3 + 1] - c1[1],
292            coords[i * 3 + 2] - c1[2],
293        ];
294        for k in 0..3 {
295            let rotated = rotation[k][0] * p[0] + rotation[k][1] * p[1] + rotation[k][2] * p[2];
296            aligned[i * 3 + k] = rotated + c2[k];
297        }
298        for k in 0..3 {
299            let diff = aligned[i * 3 + k] - reference[i * 3 + k];
300            sum_sq += diff * diff;
301        }
302    }
303    let rmsd = (sum_sq / n as f64).sqrt();
304
305    AlignmentResult {
306        rmsd,
307        rotation,
308        translation,
309        aligned_coords: aligned,
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_identical_zero_rmsd() {
319        let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
320        let rmsd = compute_rmsd(&coords, &coords);
321        assert!(rmsd < 1e-10);
322    }
323
324    #[test]
325    fn test_translated_zero_rmsd() {
326        // Pure translation → RMSD should be ~0 after alignment.
327        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
328        let coords: Vec<f64> = reference.iter().map(|x| x + 5.0).collect();
329        let rmsd = compute_rmsd(&coords, &reference);
330        assert!(rmsd < 1e-10, "got rmsd = {rmsd}");
331    }
332
333    #[test]
334    fn test_rotation_90deg_z() {
335        // 90° rotation around Z-axis → RMSD should be ~0 after alignment.
336        let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0];
337        // Rotate 90° around Z: (x,y) -> (-y,x)
338        let rotated = vec![0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0];
339        let rmsd = compute_rmsd(&rotated, &reference);
340        assert!(rmsd < 1e-10, "got rmsd = {rmsd}");
341    }
342
343    #[test]
344    fn test_known_rmsd() {
345        // Slightly perturbed structure → nonzero RMSD.
346        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
347        let perturbed = vec![0.1, 0.0, 0.0, 1.0, 0.1, 0.0, 0.0, 1.0, 0.1];
348        let rmsd = compute_rmsd(&perturbed, &reference);
349        assert!(rmsd > 0.01);
350        assert!(rmsd < 1.0);
351    }
352
353    #[test]
354    fn test_aligned_coords_returned() {
355        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
356        let coords: Vec<f64> = reference.iter().map(|x| x + 10.0).collect();
357        let result = align_coordinates(&coords, &reference);
358        assert_eq!(result.aligned_coords.len(), 9);
359        // Aligned should be close to reference
360        for i in 0..9 {
361            assert!(
362                (result.aligned_coords[i] - reference[i]).abs() < 1e-8,
363                "mismatch at index {i}"
364            );
365        }
366    }
367
368    #[test]
369    fn test_reflection_handling() {
370        // Mirror image (reflection): Kabsch should handle determinant < 0.
371        let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
372        let reflected = vec![-1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
373        let result = align_coordinates(&reflected, &reference);
374        // Should still give a valid RMSD (may not be zero for true reflection)
375        assert!(result.rmsd.is_finite());
376    }
377
378    #[test]
379    fn test_quaternion_identical() {
380        let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
381        let result = align_quaternion(&coords, &coords);
382        assert!(result.rmsd < 1e-10);
383    }
384
385    #[test]
386    fn test_quaternion_translated() {
387        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
388        let coords: Vec<f64> = reference.iter().map(|x| x + 5.0).collect();
389        let result = align_quaternion(&coords, &reference);
390        assert!(result.rmsd < 1e-10, "got rmsd = {}", result.rmsd);
391    }
392
393    #[test]
394    fn test_quaternion_rotated_90() {
395        let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0];
396        let rotated = vec![0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0];
397        let result = align_quaternion(&rotated, &reference);
398        assert!(result.rmsd < 1e-10, "got rmsd = {}", result.rmsd);
399    }
400
401    #[test]
402    fn test_quaternion_matches_kabsch() {
403        // Both methods should give the same RMSD.
404        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.5, 0.5, 1.0];
405        let perturbed = vec![
406            0.1, -0.05, 0.02, 1.1, 0.1, -0.05, -0.1, 0.9, 0.1, 0.6, 0.4, 1.1,
407        ];
408
409        let kabsch = align_coordinates(&perturbed, &reference);
410        let quat = align_quaternion(&perturbed, &reference);
411
412        assert!(
413            (kabsch.rmsd - quat.rmsd).abs() < 1e-8,
414            "Kabsch RMSD = {}, Quaternion RMSD = {}",
415            kabsch.rmsd,
416            quat.rmsd,
417        );
418
419        // Aligned coords should also match
420        for i in 0..reference.len() {
421            assert!(
422                (kabsch.aligned_coords[i] - quat.aligned_coords[i]).abs() < 1e-6,
423                "aligned mismatch at {}: {:.8} vs {:.8}",
424                i,
425                kabsch.aligned_coords[i],
426                quat.aligned_coords[i],
427            );
428        }
429    }
430}