Skip to main content

rustalign_aligner/
smith_waterman.rs

1//! Smith-Waterman dynamic programming alignment
2//!
3//! This implements the banded Smith-Waterman algorithm for local alignment,
4//! matching the C++ aligner_sw.h implementation.
5
6use super::Alignment;
7use rustalign_common::{Nuc, Result as BwtResult, Score, Strand};
8
9/// Parameters for Smith-Waterman alignment
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct SwParams {
12    /// Match score
13    pub match_score: Score,
14
15    /// Mismatch penalty
16    pub mismatch_penalty: Score,
17
18    /// Gap open penalty
19    pub gap_open: Score,
20
21    /// Gap extend penalty
22    pub gap_extend: Score,
23
24    /// Band width for banded alignment (0 = no banding)
25    pub band_width: usize,
26
27    /// Minimum score threshold
28    pub min_score: Score,
29}
30
31impl Default for SwParams {
32    fn default() -> Self {
33        Self {
34            match_score: 2,
35            mismatch_penalty: -3,
36            gap_open: -5,
37            gap_extend: -2,
38            band_width: 15,
39            min_score: 30,
40        }
41    }
42}
43
44/// DP cell for Smith-Waterman alignment
45#[derive(Debug, Clone, Copy, Default)]
46struct DpCell {
47    /// Score from diagonal (match/mismatch)
48    h: Score,
49    /// Score from above (insertion)
50    e: Score,
51    /// Score from left (deletion)
52    f: Score,
53    /// Best score at this cell
54    #[allow(dead_code)]
55    best: Score,
56}
57
58/// Smith-Waterman aligner with banded DP
59pub struct SwAligner {
60    /// Alignment parameters
61    params: SwParams,
62}
63
64impl SwAligner {
65    /// Create a new Smith-Waterman aligner
66    pub fn new(params: SwParams) -> Self {
67        Self { params }
68    }
69
70    /// Align a read to a reference sequence
71    ///
72    /// Uses Smith-Waterman DP with optional banding for optimization.
73    /// When band_width > 0, only computes cells within band_width/2 of the main diagonal.
74    pub fn align(&self, read: &[Nuc], reference: &[Nuc]) -> BwtResult<Alignment> {
75        let read_len = read.len();
76        let ref_len = reference.len();
77
78        if read_len == 0 || ref_len == 0 {
79            return Ok(Alignment {
80                ref_start: 0,
81                ref_end: 0,
82                score: 0,
83                strand: Strand::Forward,
84                edits: 0,
85                ..Default::default()
86            });
87        }
88
89        // Use banded alignment if band_width is set
90        if self.params.band_width > 0 && self.params.band_width < read_len {
91            return self.align_banded(read, reference);
92        }
93
94        // Use full DP matrix for correctness or small sequences
95        self.align_full(read, reference)
96    }
97
98    /// Full (non-banded) Smith-Waterman alignment
99    fn align_full(&self, read: &[Nuc], reference: &[Nuc]) -> BwtResult<Alignment> {
100        let read_len = read.len();
101        let ref_len = reference.len();
102
103        let mut dp_matrix = vec![vec![DpCell::default(); ref_len + 1]; read_len + 1];
104        let mut best_score = Score::MIN;
105        let mut best_pos = (0, 0);
106
107        // Initialize first row and column
108        #[allow(clippy::needless_range_loop)]
109        for i in 0..=read_len {
110            dp_matrix[i][0].h = 0;
111            dp_matrix[i][0].e = Score::MIN / 2;
112            dp_matrix[i][0].f = Score::MIN / 2;
113        }
114        #[allow(clippy::needless_range_loop)]
115        for j in 0..=ref_len {
116            dp_matrix[0][j].h = 0;
117            dp_matrix[0][j].e = Score::MIN / 2;
118            dp_matrix[0][j].f = Score::MIN / 2;
119        }
120
121        // Fill DP matrix
122        for i in 1..=read_len {
123            for j in 1..=ref_len {
124                // Calculate H score (diagonal)
125                let match_score = if read[i - 1] == reference[j - 1] {
126                    self.params.match_score
127                } else {
128                    self.params.mismatch_penalty
129                };
130                let h_diag = dp_matrix[i - 1][j - 1].h + match_score;
131
132                // Calculate E score (from above/insertion)
133                let e_gap = dp_matrix[i - 1][j].e + self.params.gap_extend;
134                let h_gap = dp_matrix[i - 1][j].h + self.params.gap_open + self.params.gap_extend;
135                let e_new = e_gap.max(h_gap).max(Score::MIN / 2);
136
137                // Calculate F score (from left/deletion)
138                let f_gap = dp_matrix[i][j - 1].f + self.params.gap_extend;
139                let h_prev = dp_matrix[i][j - 1].h;
140                let f_new = f_gap
141                    .max(h_prev + self.params.gap_open + self.params.gap_extend)
142                    .max(Score::MIN / 2);
143
144                // H cell: max of diagonal, E, F, and 0 (for local alignment)
145                let h_new = h_diag.max(e_new).max(f_new).max(0);
146
147                dp_matrix[i][j].h = h_new;
148                dp_matrix[i][j].e = e_new;
149                dp_matrix[i][j].f = f_new;
150
151                // Track best score
152                if h_new > best_score {
153                    best_score = h_new;
154                    best_pos = (i, j);
155                }
156            }
157        }
158
159        // Count edits from traceback
160        let (edits, ref_start) = self.traceback_simple(read, reference, best_pos);
161
162        Ok(Alignment {
163            ref_start,
164            ref_end: best_pos.1,
165            score: best_score,
166            strand: Strand::Forward,
167            edits,
168            ..Default::default()
169        })
170    }
171
172    /// Banded Smith-Waterman alignment
173    ///
174    /// Only computes cells within band_width/2 of the main diagonal.
175    /// This reduces complexity from O(m*n) to O(band_width * min(m,n)).
176    #[allow(clippy::manual_div_ceil)]
177    fn align_banded(&self, read: &[Nuc], reference: &[Nuc]) -> BwtResult<Alignment> {
178        let read_len = read.len();
179        let ref_len = reference.len();
180        let band_half = (self.params.band_width + 1) / 2; // Round up
181
182        // Banded DP matrix: only store cells within the band
183        // For each row i, we store columns from max(0, i-band_half) to min(ref_len, i+band_half)
184        let mut best_score = Score::MIN;
185        let mut best_pos = (0, 0);
186
187        // Use a 2D array for the banded region to simplify diagonal access
188        // Band size needs to account for the offset
189        let band_size = self.params.band_width + 2;
190        let mut dp_band = vec![vec![DpCell::default(); band_size]; read_len + 1];
191
192        for i in 0..=read_len {
193            // Determine column range for this row
194            let col_start = i.saturating_sub(band_half);
195            let col_end = (i + band_half + 1).min(ref_len + 1);
196
197            for j in col_start..col_end {
198                let band_idx = j.saturating_sub(col_start);
199                if band_idx >= band_size {
200                    continue;
201                }
202
203                // Calculate all needed scores first to avoid borrow conflicts
204                let (h_diag, e_above, h_above, f_left, h_left, _nuc_match) = if i == 0 || j == 0 {
205                    (Score::MIN, Score::MIN / 2, 0, Score::MIN / 2, 0, false)
206                } else {
207                    // Calculate diagonal score
208                    let prev_col_start = (i - 1).saturating_sub(band_half);
209                    let diag_band_idx = (j - 1).saturating_sub(prev_col_start);
210                    let h_diag = if diag_band_idx < band_size && i > 0 && j > 0 {
211                        let match_score = if read[i - 1] == reference[j - 1] {
212                            self.params.match_score
213                        } else {
214                            self.params.mismatch_penalty
215                        };
216                        dp_band[i - 1][diag_band_idx].h + match_score
217                    } else {
218                        Score::MIN
219                    };
220
221                    // Calculate score from above
222                    let above_band_idx = j.saturating_sub(col_start);
223                    let (e_above, h_above) = if i > 0 && above_band_idx < band_size {
224                        (
225                            dp_band[i - 1][above_band_idx].e,
226                            dp_band[i - 1][above_band_idx].h,
227                        )
228                    } else {
229                        (Score::MIN / 2, 0)
230                    };
231
232                    // Calculate score from left
233                    let left_band_idx = (j - 1).saturating_sub(col_start);
234                    let (f_left, h_left) = if left_band_idx < band_size {
235                        (dp_band[i][left_band_idx].f, dp_band[i][left_band_idx].h)
236                    } else {
237                        (Score::MIN / 2, 0)
238                    };
239
240                    let nuc_match = read[i - 1] == reference[j - 1];
241                    (h_diag, e_above, h_above, f_left, h_left, nuc_match)
242                };
243
244                let cell = &mut dp_band[i][band_idx];
245
246                if i == 0 || j == 0 {
247                    // Boundary conditions
248                    cell.h = 0;
249                    cell.e = Score::MIN / 2;
250                    cell.f = Score::MIN / 2;
251                } else {
252                    // Calculate E score (from above/insertion)
253                    let e_gap = e_above + self.params.gap_extend;
254                    let h_gap = h_above + self.params.gap_open + self.params.gap_extend;
255                    cell.e = e_gap.max(h_gap).max(Score::MIN / 2);
256
257                    // Calculate F score (from left/deletion)
258                    let f_gap = f_left + self.params.gap_extend;
259                    let h_prev = h_left;
260                    cell.f = f_gap
261                        .max(h_prev + self.params.gap_open + self.params.gap_extend)
262                        .max(Score::MIN / 2);
263
264                    // H cell: max of diagonal, E, F, and 0
265                    // We already computed h_diag earlier, use it directly
266                    cell.h = h_diag.max(cell.e).max(cell.f).max(0);
267                }
268
269                // Track best score
270                if cell.h > best_score {
271                    best_score = cell.h;
272                    best_pos = (i, j);
273                }
274            }
275        }
276
277        // For banded alignment, estimate edits based on score
278        let matches = (best_score as usize).div_ceil(self.params.match_score.max(1) as usize);
279        let edits = best_pos.0.saturating_sub(matches);
280
281        Ok(Alignment {
282            ref_start: best_pos.1.saturating_sub(best_pos.0.min(best_pos.1)),
283            ref_end: best_pos.1,
284            score: best_score,
285            strand: Strand::Forward,
286            edits,
287            ..Default::default()
288        })
289    }
290
291    /// Simple traceback to count edits
292    fn traceback_simple(
293        &self,
294        read: &[Nuc],
295        reference: &[Nuc],
296        best_pos: (usize, usize),
297    ) -> (usize, usize) {
298        let (mut i, mut j) = best_pos;
299        let mut edits = 0;
300        let mut ref_start = j;
301
302        while i > 0 && j > 0 {
303            let _match_score = if read[i - 1] == reference[j - 1] {
304                self.params.match_score
305            } else {
306                self.params.mismatch_penalty
307            };
308
309            // For simplicity, just count mismatches along diagonal
310            if read[i - 1] != reference[j - 1] {
311                edits += 1;
312            }
313
314            ref_start = j;
315            i -= 1;
316            j -= 1;
317        }
318
319        (edits, ref_start)
320    }
321
322    /// Traceback from best position to reconstruct alignment
323    #[allow(dead_code)]
324    fn traceback_internal(
325        dp_matrix: &[DpCell],
326        read: &[Nuc],
327        reference: &[Nuc],
328        best_pos: (usize, usize),
329        band_half: usize,
330        params: &SwParams,
331    ) -> (usize, usize, usize) {
332        let (mut i, mut j) = best_pos;
333        let mut edits = 0;
334        let mut ref_start = j;
335
336        // Simple traceback - count edits
337        #[allow(clippy::implicit_saturating_sub)]
338        while i > 0 && j > 0 {
339            let row_start = if i > band_half { i - band_half } else { 0 };
340            let band_idx = j - row_start;
341
342            if band_idx >= dp_matrix.len() {
343                break;
344            }
345
346            let cell = dp_matrix[band_idx];
347
348            // Check if we came from diagonal
349            let prev_band_idx = if j > row_start && band_idx > 0 {
350                band_idx - 1
351            } else {
352                0
353            };
354
355            let match_score = if read[i - 1] == reference[j - 1] {
356                params.match_score
357            } else {
358                params.mismatch_penalty
359            };
360
361            let h_diag = if prev_band_idx < dp_matrix.len() {
362                dp_matrix[prev_band_idx].h + match_score
363            } else {
364                Score::MIN
365            };
366
367            if cell.h == h_diag && cell.h > 0 {
368                // Came from diagonal
369                if read[i - 1] != reference[j - 1] {
370                    edits += 1;
371                }
372                i -= 1;
373                j -= 1;
374                ref_start = j;
375            } else if cell.h == cell.e && cell.h > 0 {
376                // Came from above (insertion in reference)
377                edits += 1;
378                i -= 1;
379            } else if cell.h == cell.f && cell.h > 0 {
380                // Came from left (deletion in reference)
381                edits += 1;
382                j -= 1;
383                ref_start = j;
384            } else {
385                // Reached boundary
386                break;
387            }
388        }
389
390        (edits, ref_start, best_pos.1)
391    }
392}
393
394impl Default for SwAligner {
395    fn default() -> Self {
396        Self::new(SwParams::default())
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_params_default() {
406        let params = SwParams::default();
407        assert_eq!(params.match_score, 2);
408        assert_eq!(params.mismatch_penalty, -3);
409        assert_eq!(params.band_width, 15);
410    }
411
412    #[test]
413    fn test_aligner_new() {
414        let aligner = SwAligner::new(SwParams::default());
415        assert_eq!(aligner.params.match_score, 2);
416    }
417
418    #[test]
419    fn test_perfect_match() {
420        let read = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
421        let reference = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
422        // Use larger band width for short sequences
423        let params = SwParams {
424            band_width: 10,
425            min_score: 0, // Accept any positive score
426            ..SwParams::default()
427        };
428        let aligner = SwAligner::new(params);
429
430        let aln = aligner.align(&read, &reference).unwrap();
431        // For local alignment, we should get at least 8
432        assert!(aln.score >= 8);
433        assert_eq!(aln.edits, 0);
434    }
435
436    #[test]
437    fn test_one_mismatch() {
438        let read = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
439        let reference = vec![Nuc::A, Nuc::C, Nuc::A, Nuc::T]; // G->A mismatch
440        let params = SwParams {
441            band_width: 10,
442            min_score: 0, // Accept any positive score
443            ..SwParams::default()
444        };
445        let aligner = SwAligner::new(params);
446
447        let aln = aligner.align(&read, &reference).unwrap();
448        // Should find alignment with at least 3 matches (score >= 6)
449        assert!(aln.score >= 3);
450    }
451
452    #[test]
453    fn test_empty_sequence() {
454        let read = vec![];
455        let reference = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
456        let aligner = SwAligner::new(SwParams::default());
457
458        let aln = aligner.align(&read, &reference).unwrap();
459        assert_eq!(aln.score, 0);
460        assert_eq!(aln.edits, 0);
461    }
462
463    #[test]
464    fn test_banded_alignment() {
465        let read: Vec<Nuc> = (0..50)
466            .map(|i| match i % 4 {
467                0 => Nuc::A,
468                1 => Nuc::C,
469                2 => Nuc::G,
470                _ => Nuc::T,
471            })
472            .collect();
473        let mut reference = read.clone();
474        // Add a few mutations
475        reference[10] = Nuc::T; // was G
476        reference[20] = Nuc::A; // was C
477        reference[30] = Nuc::G; // was A
478
479        let params = SwParams {
480            band_width: 15, // Enable banding
481            min_score: 0,
482            ..SwParams::default()
483        };
484
485        let aligner = SwAligner::new(params);
486        let aln = aligner.align(&read, &reference).unwrap();
487
488        // Should find good alignment despite mutations
489        assert!(aln.score > 80);
490        assert!(aln.edits <= 5);
491    }
492
493    #[test]
494    fn test_banded_vs_full_consistency() {
495        let read: Vec<Nuc> = (0..30)
496            .map(|i| match i % 4 {
497                0 => Nuc::A,
498                1 => Nuc::C,
499                2 => Nuc::G,
500                _ => Nuc::T,
501            })
502            .collect();
503        let mut reference = read.clone();
504        reference[15] = Nuc::T;
505
506        // Banded alignment
507        let params_banded = SwParams {
508            band_width: 10,
509            min_score: 0,
510            ..SwParams::default()
511        };
512        let aligner_banded = SwAligner::new(params_banded);
513        let aln_banded = aligner_banded.align(&read, &reference).unwrap();
514
515        // Full alignment (no banding)
516        let params_full = SwParams {
517            band_width: 0, // Disable banding
518            min_score: 0,
519            ..SwParams::default()
520        };
521        let aligner_full = SwAligner::new(params_full);
522        let aln_full = aligner_full.align(&read, &reference).unwrap();
523
524        // Both should find similar alignments
525        assert_eq!(aln_banded.score, aln_full.score);
526    }
527
528    #[test]
529    fn test_banded_narrow_band() {
530        let read: Vec<Nuc> = (0..100)
531            .map(|i| match i % 4 {
532                0 => Nuc::A,
533                1 => Nuc::C,
534                2 => Nuc::G,
535                _ => Nuc::T,
536            })
537            .collect();
538        let reference = read.clone();
539
540        // Very narrow band - should still work for perfect match
541        let params = SwParams {
542            band_width: 5,
543            min_score: 0,
544            ..SwParams::default()
545        };
546
547        let aligner = SwAligner::new(params);
548        let aln = aligner.align(&read, &reference).unwrap();
549
550        // Perfect match should get max score
551        assert_eq!(aln.score, 200); // 100 * 2
552    }
553
554    #[test]
555    fn test_banded_disabled() {
556        let read: Vec<Nuc> = (0..30)
557            .map(|i| match i % 4 {
558                0 => Nuc::A,
559                1 => Nuc::C,
560                2 => Nuc::G,
561                _ => Nuc::T,
562            })
563            .collect();
564        let reference = read.clone();
565
566        // band_width = 0 means no banding
567        let params = SwParams {
568            band_width: 0,
569            min_score: 0,
570            ..SwParams::default()
571        };
572
573        let aligner = SwAligner::new(params);
574        let aln = aligner.align(&read, &reference).unwrap();
575
576        assert_eq!(aln.score, 60); // 30 * 2
577    }
578}