Skip to main content

sci_form/alignment/
kabsch.rs

1//! Kabsch algorithm for optimal rotation alignment and RMSD computation.
2//!
3//! Provides a clean f64 public API for molecular alignment.
4
5/// Result of a Kabsch alignment.
6#[derive(Debug, Clone)]
7pub struct AlignmentResult {
8    /// RMSD after optimal alignment (Å).
9    pub rmsd: f64,
10    /// 3×3 rotation matrix (row-major).
11    pub rotation: [[f64; 3]; 3],
12    /// Translation applied to `coords` centroid.
13    pub translation: [f64; 3],
14    /// Aligned coordinates: flat [x0,y0,z0, x1,...].
15    pub aligned_coords: Vec<f64>,
16}
17
18/// Compute RMSD between two conformers after Kabsch alignment.
19///
20/// `coords`: flat [x0,y0,z0, x1,y1,z1,...] mobile structure.
21/// `reference`: flat [x0,y0,z0,...] reference structure.
22/// Both must have the same number of atoms.
23pub fn compute_rmsd(coords: &[f64], reference: &[f64]) -> f64 {
24    align_coordinates(coords, reference).rmsd
25}
26
27/// Kabsch alignment: find optimal rotation mapping `coords` onto `reference`.
28///
29/// `coords`, `reference`: flat [x0,y0,z0, x1,y1,z1,...] in Å.
30pub fn align_coordinates(coords: &[f64], reference: &[f64]) -> AlignmentResult {
31    assert_eq!(coords.len(), reference.len(), "coordinate length mismatch");
32    assert_eq!(coords.len() % 3, 0, "coordinates must be xyz triples");
33    let n = coords.len() / 3;
34
35    if n == 0 {
36        return AlignmentResult {
37            rmsd: 0.0,
38            rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
39            translation: [0.0; 3],
40            aligned_coords: Vec::new(),
41        };
42    }
43
44    // Compute centroids
45    let mut c1 = [0.0f64; 3];
46    let mut c2 = [0.0f64; 3];
47    for i in 0..n {
48        for k in 0..3 {
49            c1[k] += coords[i * 3 + k];
50            c2[k] += reference[i * 3 + k];
51        }
52    }
53    for k in 0..3 {
54        c1[k] /= n as f64;
55        c2[k] /= n as f64;
56    }
57
58    // Build H = Σ (p_i - c1)(q_i - c2)^T  (3×3)
59    let mut h = [[0.0f64; 3]; 3];
60    for i in 0..n {
61        let p = [
62            coords[i * 3] - c1[0],
63            coords[i * 3 + 1] - c1[1],
64            coords[i * 3 + 2] - c1[2],
65        ];
66        let q = [
67            reference[i * 3] - c2[0],
68            reference[i * 3 + 1] - c2[1],
69            reference[i * 3 + 2] - c2[2],
70        ];
71        for r in 0..3 {
72            for c in 0..3 {
73                h[r][c] += p[r] * q[c];
74            }
75        }
76    }
77
78    // SVD via nalgebra
79    let h_mat = nalgebra::Matrix3::new(
80        h[0][0], h[0][1], h[0][2], h[1][0], h[1][1], h[1][2], h[2][0], h[2][1], h[2][2],
81    );
82    let svd = h_mat.svd(true, true);
83    let u = svd.u.unwrap();
84    let v_t = svd.v_t.unwrap();
85    let v = v_t.transpose();
86
87    // Handle reflection
88    let mut d = nalgebra::Matrix3::<f64>::identity();
89    if (v * u.transpose()).determinant() < 0.0 {
90        d[(2, 2)] = -1.0;
91    }
92    let r_mat = v * d * u.transpose();
93
94    // Build rotation as row-major array
95    let rotation = [
96        [r_mat[(0, 0)], r_mat[(0, 1)], r_mat[(0, 2)]],
97        [r_mat[(1, 0)], r_mat[(1, 1)], r_mat[(1, 2)]],
98        [r_mat[(2, 0)], r_mat[(2, 1)], r_mat[(2, 2)]],
99    ];
100
101    let translation = [c2[0] - c1[0], c2[1] - c1[1], c2[2] - c1[2]];
102
103    // Apply rotation and compute RMSD
104    let mut aligned = vec![0.0f64; coords.len()];
105    let mut sum_sq = 0.0;
106    for i in 0..n {
107        let p = [
108            coords[i * 3] - c1[0],
109            coords[i * 3 + 1] - c1[1],
110            coords[i * 3 + 2] - c1[2],
111        ];
112        for k in 0..3 {
113            let rotated = r_mat[(k, 0)] * p[0] + r_mat[(k, 1)] * p[1] + r_mat[(k, 2)] * p[2];
114            aligned[i * 3 + k] = rotated + c2[k];
115        }
116        for k in 0..3 {
117            let diff = aligned[i * 3 + k] - reference[i * 3 + k];
118            sum_sq += diff * diff;
119        }
120    }
121    let rmsd = (sum_sq / n as f64).sqrt();
122
123    AlignmentResult {
124        rmsd,
125        rotation,
126        translation,
127        aligned_coords: aligned,
128    }
129}
130
131/// Quaternion-based optimal rotation alignment (Coutsias et al. 2004).
132///
133/// Uses the quaternion eigenvector method which is numerically more stable
134/// than SVD for near-degenerate cases.  Produces the same result as Kabsch
135/// but avoids explicit SVD decomposition.
136pub fn align_quaternion(coords: &[f64], reference: &[f64]) -> AlignmentResult {
137    assert_eq!(coords.len(), reference.len(), "coordinate length mismatch");
138    assert_eq!(coords.len() % 3, 0, "coordinates must be xyz triples");
139    let n = coords.len() / 3;
140
141    if n == 0 {
142        return AlignmentResult {
143            rmsd: 0.0,
144            rotation: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
145            translation: [0.0; 3],
146            aligned_coords: Vec::new(),
147        };
148    }
149
150    // Centroids
151    let mut c1 = [0.0f64; 3];
152    let mut c2 = [0.0f64; 3];
153    for i in 0..n {
154        for k in 0..3 {
155            c1[k] += coords[i * 3 + k];
156            c2[k] += reference[i * 3 + k];
157        }
158    }
159    for k in 0..3 {
160        c1[k] /= n as f64;
161        c2[k] /= n as f64;
162    }
163
164    // Build cross-covariance elements: R_ij = Σ (p_i - c1_i)(q_j - c2_j)
165    let mut r = [[0.0f64; 3]; 3];
166    for i in 0..n {
167        let p = [
168            coords[i * 3] - c1[0],
169            coords[i * 3 + 1] - c1[1],
170            coords[i * 3 + 2] - c1[2],
171        ];
172        let q = [
173            reference[i * 3] - c2[0],
174            reference[i * 3 + 1] - c2[1],
175            reference[i * 3 + 2] - c2[2],
176        ];
177        for a in 0..3 {
178            for b in 0..3 {
179                r[a][b] += p[a] * q[b];
180            }
181        }
182    }
183
184    // Build the 4×4 symmetric key matrix (Davenport/Coutsias):
185    //   F = [[ Sxx+Syy+Szz, Syz-Szy, Szx-Sxz, Sxy-Syx ],
186    //        [ Syz-Szy, Sxx-Syy-Szz, Sxy+Syx, Szx+Sxz ],
187    //        [ Szx-Sxz, Sxy+Syx, -Sxx+Syy-Szz, Syz+Szy ],
188    //        [ Sxy-Syx, Szx+Sxz, Syz+Szy, -Sxx-Syy+Szz ]]
189    let sxx = r[0][0];
190    let sxy = r[0][1];
191    let sxz = r[0][2];
192    let syx = r[1][0];
193    let syy = r[1][1];
194    let syz = r[1][2];
195    let szx = r[2][0];
196    let szy = r[2][1];
197    let szz = r[2][2];
198
199    let f = nalgebra::Matrix4::new(
200        sxx + syy + szz,
201        syz - szy,
202        szx - sxz,
203        sxy - syx,
204        syz - szy,
205        sxx - syy - szz,
206        sxy + syx,
207        szx + sxz,
208        szx - sxz,
209        sxy + syx,
210        -sxx + syy - szz,
211        syz + szy,
212        sxy - syx,
213        szx + sxz,
214        syz + szy,
215        -sxx - syy + szz,
216    );
217
218    // The optimal rotation quaternion is the eigenvector of F with the largest eigenvalue
219    let eig = f.symmetric_eigen();
220    let mut best_idx = 0;
221    let mut best_val = eig.eigenvalues[0];
222    for i in 1..4 {
223        if eig.eigenvalues[i] > best_val {
224            best_val = eig.eigenvalues[i];
225            best_idx = i;
226        }
227    }
228
229    let q0 = eig.eigenvectors[(0, best_idx)];
230    let q1 = eig.eigenvectors[(1, best_idx)];
231    let q2 = eig.eigenvectors[(2, best_idx)];
232    let q3 = eig.eigenvectors[(3, best_idx)];
233
234    // quaternion → rotation matrix
235    let rotation = [
236        [
237            q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3,
238            2.0 * (q1 * q2 - q0 * q3),
239            2.0 * (q1 * q3 + q0 * q2),
240        ],
241        [
242            2.0 * (q1 * q2 + q0 * q3),
243            q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3,
244            2.0 * (q2 * q3 - q0 * q1),
245        ],
246        [
247            2.0 * (q1 * q3 - q0 * q2),
248            2.0 * (q2 * q3 + q0 * q1),
249            q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3,
250        ],
251    ];
252
253    let translation = [c2[0] - c1[0], c2[1] - c1[1], c2[2] - c1[2]];
254
255    // Apply rotation and compute RMSD
256    let mut aligned = vec![0.0f64; coords.len()];
257    let mut sum_sq = 0.0;
258    for i in 0..n {
259        let p = [
260            coords[i * 3] - c1[0],
261            coords[i * 3 + 1] - c1[1],
262            coords[i * 3 + 2] - c1[2],
263        ];
264        for k in 0..3 {
265            let rotated = rotation[k][0] * p[0] + rotation[k][1] * p[1] + rotation[k][2] * p[2];
266            aligned[i * 3 + k] = rotated + c2[k];
267        }
268        for k in 0..3 {
269            let diff = aligned[i * 3 + k] - reference[i * 3 + k];
270            sum_sq += diff * diff;
271        }
272    }
273    let rmsd = (sum_sq / n as f64).sqrt();
274
275    AlignmentResult {
276        rmsd,
277        rotation,
278        translation,
279        aligned_coords: aligned,
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_identical_zero_rmsd() {
289        let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
290        let rmsd = compute_rmsd(&coords, &coords);
291        assert!(rmsd < 1e-10);
292    }
293
294    #[test]
295    fn test_translated_zero_rmsd() {
296        // Pure translation → RMSD should be ~0 after alignment.
297        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
298        let coords: Vec<f64> = reference.iter().map(|x| x + 5.0).collect();
299        let rmsd = compute_rmsd(&coords, &reference);
300        assert!(rmsd < 1e-10, "got rmsd = {rmsd}");
301    }
302
303    #[test]
304    fn test_rotation_90deg_z() {
305        // 90° rotation around Z-axis → RMSD should be ~0 after alignment.
306        let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0];
307        // Rotate 90° around Z: (x,y) -> (-y,x)
308        let rotated = vec![0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0];
309        let rmsd = compute_rmsd(&rotated, &reference);
310        assert!(rmsd < 1e-10, "got rmsd = {rmsd}");
311    }
312
313    #[test]
314    fn test_known_rmsd() {
315        // Slightly perturbed structure → nonzero RMSD.
316        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
317        let perturbed = vec![0.1, 0.0, 0.0, 1.0, 0.1, 0.0, 0.0, 1.0, 0.1];
318        let rmsd = compute_rmsd(&perturbed, &reference);
319        assert!(rmsd > 0.01);
320        assert!(rmsd < 1.0);
321    }
322
323    #[test]
324    fn test_aligned_coords_returned() {
325        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
326        let coords: Vec<f64> = reference.iter().map(|x| x + 10.0).collect();
327        let result = align_coordinates(&coords, &reference);
328        assert_eq!(result.aligned_coords.len(), 9);
329        // Aligned should be close to reference
330        for i in 0..9 {
331            assert!(
332                (result.aligned_coords[i] - reference[i]).abs() < 1e-8,
333                "mismatch at index {i}"
334            );
335        }
336    }
337
338    #[test]
339    fn test_reflection_handling() {
340        // Mirror image (reflection): Kabsch should handle determinant < 0.
341        let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
342        let reflected = vec![-1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
343        let result = align_coordinates(&reflected, &reference);
344        // Should still give a valid RMSD (may not be zero for true reflection)
345        assert!(result.rmsd.is_finite());
346    }
347
348    #[test]
349    fn test_quaternion_identical() {
350        let coords = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
351        let result = align_quaternion(&coords, &coords);
352        assert!(result.rmsd < 1e-10);
353    }
354
355    #[test]
356    fn test_quaternion_translated() {
357        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
358        let coords: Vec<f64> = reference.iter().map(|x| x + 5.0).collect();
359        let result = align_quaternion(&coords, &reference);
360        assert!(result.rmsd < 1e-10, "got rmsd = {}", result.rmsd);
361    }
362
363    #[test]
364    fn test_quaternion_rotated_90() {
365        let reference = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0];
366        let rotated = vec![0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0];
367        let result = align_quaternion(&rotated, &reference);
368        assert!(result.rmsd < 1e-10, "got rmsd = {}", result.rmsd);
369    }
370
371    #[test]
372    fn test_quaternion_matches_kabsch() {
373        // Both methods should give the same RMSD.
374        let reference = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.5, 0.5, 1.0];
375        let perturbed = vec![
376            0.1, -0.05, 0.02, 1.1, 0.1, -0.05, -0.1, 0.9, 0.1, 0.6, 0.4, 1.1,
377        ];
378
379        let kabsch = align_coordinates(&perturbed, &reference);
380        let quat = align_quaternion(&perturbed, &reference);
381
382        assert!(
383            (kabsch.rmsd - quat.rmsd).abs() < 1e-8,
384            "Kabsch RMSD = {}, Quaternion RMSD = {}",
385            kabsch.rmsd,
386            quat.rmsd,
387        );
388
389        // Aligned coords should also match
390        for i in 0..reference.len() {
391            assert!(
392                (kabsch.aligned_coords[i] - quat.aligned_coords[i]).abs() < 1e-6,
393                "aligned mismatch at {}: {:.8} vs {:.8}",
394                i,
395                kabsch.aligned_coords[i],
396                quat.aligned_coords[i],
397            );
398        }
399    }
400}