1use super::Alignment;
7use rustalign_common::{Nuc, Result as BwtResult, Score, Strand};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct SwParams {
12 pub match_score: Score,
14
15 pub mismatch_penalty: Score,
17
18 pub gap_open: Score,
20
21 pub gap_extend: Score,
23
24 pub band_width: usize,
26
27 pub min_score: Score,
29}
30
31impl Default for SwParams {
32 fn default() -> Self {
33 Self {
34 match_score: 2,
35 mismatch_penalty: -3,
36 gap_open: -5,
37 gap_extend: -2,
38 band_width: 15,
39 min_score: 30,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, Default)]
46struct DpCell {
47 h: Score,
49 e: Score,
51 f: Score,
53 #[allow(dead_code)]
55 best: Score,
56}
57
58pub struct SwAligner {
60 params: SwParams,
62}
63
64impl SwAligner {
65 pub fn new(params: SwParams) -> Self {
67 Self { params }
68 }
69
70 pub fn align(&self, read: &[Nuc], reference: &[Nuc]) -> BwtResult<Alignment> {
75 let read_len = read.len();
76 let ref_len = reference.len();
77
78 if read_len == 0 || ref_len == 0 {
79 return Ok(Alignment {
80 ref_start: 0,
81 ref_end: 0,
82 score: 0,
83 strand: Strand::Forward,
84 edits: 0,
85 ..Default::default()
86 });
87 }
88
89 if self.params.band_width > 0 && self.params.band_width < read_len {
91 return self.align_banded(read, reference);
92 }
93
94 self.align_full(read, reference)
96 }
97
98 fn align_full(&self, read: &[Nuc], reference: &[Nuc]) -> BwtResult<Alignment> {
100 let read_len = read.len();
101 let ref_len = reference.len();
102
103 let mut dp_matrix = vec![vec![DpCell::default(); ref_len + 1]; read_len + 1];
104 let mut best_score = Score::MIN;
105 let mut best_pos = (0, 0);
106
107 #[allow(clippy::needless_range_loop)]
109 for i in 0..=read_len {
110 dp_matrix[i][0].h = 0;
111 dp_matrix[i][0].e = Score::MIN / 2;
112 dp_matrix[i][0].f = Score::MIN / 2;
113 }
114 #[allow(clippy::needless_range_loop)]
115 for j in 0..=ref_len {
116 dp_matrix[0][j].h = 0;
117 dp_matrix[0][j].e = Score::MIN / 2;
118 dp_matrix[0][j].f = Score::MIN / 2;
119 }
120
121 for i in 1..=read_len {
123 for j in 1..=ref_len {
124 let match_score = if read[i - 1] == reference[j - 1] {
126 self.params.match_score
127 } else {
128 self.params.mismatch_penalty
129 };
130 let h_diag = dp_matrix[i - 1][j - 1].h + match_score;
131
132 let e_gap = dp_matrix[i - 1][j].e + self.params.gap_extend;
134 let h_gap = dp_matrix[i - 1][j].h + self.params.gap_open + self.params.gap_extend;
135 let e_new = e_gap.max(h_gap).max(Score::MIN / 2);
136
137 let f_gap = dp_matrix[i][j - 1].f + self.params.gap_extend;
139 let h_prev = dp_matrix[i][j - 1].h;
140 let f_new = f_gap
141 .max(h_prev + self.params.gap_open + self.params.gap_extend)
142 .max(Score::MIN / 2);
143
144 let h_new = h_diag.max(e_new).max(f_new).max(0);
146
147 dp_matrix[i][j].h = h_new;
148 dp_matrix[i][j].e = e_new;
149 dp_matrix[i][j].f = f_new;
150
151 if h_new > best_score {
153 best_score = h_new;
154 best_pos = (i, j);
155 }
156 }
157 }
158
159 let (edits, ref_start) = self.traceback_simple(read, reference, best_pos);
161
162 Ok(Alignment {
163 ref_start,
164 ref_end: best_pos.1,
165 score: best_score,
166 strand: Strand::Forward,
167 edits,
168 ..Default::default()
169 })
170 }
171
172 #[allow(clippy::manual_div_ceil)]
177 fn align_banded(&self, read: &[Nuc], reference: &[Nuc]) -> BwtResult<Alignment> {
178 let read_len = read.len();
179 let ref_len = reference.len();
180 let band_half = (self.params.band_width + 1) / 2; let mut best_score = Score::MIN;
185 let mut best_pos = (0, 0);
186
187 let band_size = self.params.band_width + 2;
190 let mut dp_band = vec![vec![DpCell::default(); band_size]; read_len + 1];
191
192 for i in 0..=read_len {
193 let col_start = i.saturating_sub(band_half);
195 let col_end = (i + band_half + 1).min(ref_len + 1);
196
197 for j in col_start..col_end {
198 let band_idx = j.saturating_sub(col_start);
199 if band_idx >= band_size {
200 continue;
201 }
202
203 let (h_diag, e_above, h_above, f_left, h_left, _nuc_match) = if i == 0 || j == 0 {
205 (Score::MIN, Score::MIN / 2, 0, Score::MIN / 2, 0, false)
206 } else {
207 let prev_col_start = (i - 1).saturating_sub(band_half);
209 let diag_band_idx = (j - 1).saturating_sub(prev_col_start);
210 let h_diag = if diag_band_idx < band_size && i > 0 && j > 0 {
211 let match_score = if read[i - 1] == reference[j - 1] {
212 self.params.match_score
213 } else {
214 self.params.mismatch_penalty
215 };
216 dp_band[i - 1][diag_band_idx].h + match_score
217 } else {
218 Score::MIN
219 };
220
221 let above_band_idx = j.saturating_sub(col_start);
223 let (e_above, h_above) = if i > 0 && above_band_idx < band_size {
224 (
225 dp_band[i - 1][above_band_idx].e,
226 dp_band[i - 1][above_band_idx].h,
227 )
228 } else {
229 (Score::MIN / 2, 0)
230 };
231
232 let left_band_idx = (j - 1).saturating_sub(col_start);
234 let (f_left, h_left) = if left_band_idx < band_size {
235 (dp_band[i][left_band_idx].f, dp_band[i][left_band_idx].h)
236 } else {
237 (Score::MIN / 2, 0)
238 };
239
240 let nuc_match = read[i - 1] == reference[j - 1];
241 (h_diag, e_above, h_above, f_left, h_left, nuc_match)
242 };
243
244 let cell = &mut dp_band[i][band_idx];
245
246 if i == 0 || j == 0 {
247 cell.h = 0;
249 cell.e = Score::MIN / 2;
250 cell.f = Score::MIN / 2;
251 } else {
252 let e_gap = e_above + self.params.gap_extend;
254 let h_gap = h_above + self.params.gap_open + self.params.gap_extend;
255 cell.e = e_gap.max(h_gap).max(Score::MIN / 2);
256
257 let f_gap = f_left + self.params.gap_extend;
259 let h_prev = h_left;
260 cell.f = f_gap
261 .max(h_prev + self.params.gap_open + self.params.gap_extend)
262 .max(Score::MIN / 2);
263
264 cell.h = h_diag.max(cell.e).max(cell.f).max(0);
267 }
268
269 if cell.h > best_score {
271 best_score = cell.h;
272 best_pos = (i, j);
273 }
274 }
275 }
276
277 let matches = (best_score as usize).div_ceil(self.params.match_score.max(1) as usize);
279 let edits = best_pos.0.saturating_sub(matches);
280
281 Ok(Alignment {
282 ref_start: best_pos.1.saturating_sub(best_pos.0.min(best_pos.1)),
283 ref_end: best_pos.1,
284 score: best_score,
285 strand: Strand::Forward,
286 edits,
287 ..Default::default()
288 })
289 }
290
291 fn traceback_simple(
293 &self,
294 read: &[Nuc],
295 reference: &[Nuc],
296 best_pos: (usize, usize),
297 ) -> (usize, usize) {
298 let (mut i, mut j) = best_pos;
299 let mut edits = 0;
300 let mut ref_start = j;
301
302 while i > 0 && j > 0 {
303 let _match_score = if read[i - 1] == reference[j - 1] {
304 self.params.match_score
305 } else {
306 self.params.mismatch_penalty
307 };
308
309 if read[i - 1] != reference[j - 1] {
311 edits += 1;
312 }
313
314 ref_start = j;
315 i -= 1;
316 j -= 1;
317 }
318
319 (edits, ref_start)
320 }
321
322 #[allow(dead_code)]
324 fn traceback_internal(
325 dp_matrix: &[DpCell],
326 read: &[Nuc],
327 reference: &[Nuc],
328 best_pos: (usize, usize),
329 band_half: usize,
330 params: &SwParams,
331 ) -> (usize, usize, usize) {
332 let (mut i, mut j) = best_pos;
333 let mut edits = 0;
334 let mut ref_start = j;
335
336 #[allow(clippy::implicit_saturating_sub)]
338 while i > 0 && j > 0 {
339 let row_start = if i > band_half { i - band_half } else { 0 };
340 let band_idx = j - row_start;
341
342 if band_idx >= dp_matrix.len() {
343 break;
344 }
345
346 let cell = dp_matrix[band_idx];
347
348 let prev_band_idx = if j > row_start && band_idx > 0 {
350 band_idx - 1
351 } else {
352 0
353 };
354
355 let match_score = if read[i - 1] == reference[j - 1] {
356 params.match_score
357 } else {
358 params.mismatch_penalty
359 };
360
361 let h_diag = if prev_band_idx < dp_matrix.len() {
362 dp_matrix[prev_band_idx].h + match_score
363 } else {
364 Score::MIN
365 };
366
367 if cell.h == h_diag && cell.h > 0 {
368 if read[i - 1] != reference[j - 1] {
370 edits += 1;
371 }
372 i -= 1;
373 j -= 1;
374 ref_start = j;
375 } else if cell.h == cell.e && cell.h > 0 {
376 edits += 1;
378 i -= 1;
379 } else if cell.h == cell.f && cell.h > 0 {
380 edits += 1;
382 j -= 1;
383 ref_start = j;
384 } else {
385 break;
387 }
388 }
389
390 (edits, ref_start, best_pos.1)
391 }
392}
393
394impl Default for SwAligner {
395 fn default() -> Self {
396 Self::new(SwParams::default())
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_params_default() {
406 let params = SwParams::default();
407 assert_eq!(params.match_score, 2);
408 assert_eq!(params.mismatch_penalty, -3);
409 assert_eq!(params.band_width, 15);
410 }
411
412 #[test]
413 fn test_aligner_new() {
414 let aligner = SwAligner::new(SwParams::default());
415 assert_eq!(aligner.params.match_score, 2);
416 }
417
418 #[test]
419 fn test_perfect_match() {
420 let read = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
421 let reference = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
422 let params = SwParams {
424 band_width: 10,
425 min_score: 0, ..SwParams::default()
427 };
428 let aligner = SwAligner::new(params);
429
430 let aln = aligner.align(&read, &reference).unwrap();
431 assert!(aln.score >= 8);
433 assert_eq!(aln.edits, 0);
434 }
435
436 #[test]
437 fn test_one_mismatch() {
438 let read = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
439 let reference = vec![Nuc::A, Nuc::C, Nuc::A, Nuc::T]; let params = SwParams {
441 band_width: 10,
442 min_score: 0, ..SwParams::default()
444 };
445 let aligner = SwAligner::new(params);
446
447 let aln = aligner.align(&read, &reference).unwrap();
448 assert!(aln.score >= 3);
450 }
451
452 #[test]
453 fn test_empty_sequence() {
454 let read = vec![];
455 let reference = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
456 let aligner = SwAligner::new(SwParams::default());
457
458 let aln = aligner.align(&read, &reference).unwrap();
459 assert_eq!(aln.score, 0);
460 assert_eq!(aln.edits, 0);
461 }
462
463 #[test]
464 fn test_banded_alignment() {
465 let read: Vec<Nuc> = (0..50)
466 .map(|i| match i % 4 {
467 0 => Nuc::A,
468 1 => Nuc::C,
469 2 => Nuc::G,
470 _ => Nuc::T,
471 })
472 .collect();
473 let mut reference = read.clone();
474 reference[10] = Nuc::T; reference[20] = Nuc::A; reference[30] = Nuc::G; let params = SwParams {
480 band_width: 15, min_score: 0,
482 ..SwParams::default()
483 };
484
485 let aligner = SwAligner::new(params);
486 let aln = aligner.align(&read, &reference).unwrap();
487
488 assert!(aln.score > 80);
490 assert!(aln.edits <= 5);
491 }
492
493 #[test]
494 fn test_banded_vs_full_consistency() {
495 let read: Vec<Nuc> = (0..30)
496 .map(|i| match i % 4 {
497 0 => Nuc::A,
498 1 => Nuc::C,
499 2 => Nuc::G,
500 _ => Nuc::T,
501 })
502 .collect();
503 let mut reference = read.clone();
504 reference[15] = Nuc::T;
505
506 let params_banded = SwParams {
508 band_width: 10,
509 min_score: 0,
510 ..SwParams::default()
511 };
512 let aligner_banded = SwAligner::new(params_banded);
513 let aln_banded = aligner_banded.align(&read, &reference).unwrap();
514
515 let params_full = SwParams {
517 band_width: 0, min_score: 0,
519 ..SwParams::default()
520 };
521 let aligner_full = SwAligner::new(params_full);
522 let aln_full = aligner_full.align(&read, &reference).unwrap();
523
524 assert_eq!(aln_banded.score, aln_full.score);
526 }
527
528 #[test]
529 fn test_banded_narrow_band() {
530 let read: Vec<Nuc> = (0..100)
531 .map(|i| match i % 4 {
532 0 => Nuc::A,
533 1 => Nuc::C,
534 2 => Nuc::G,
535 _ => Nuc::T,
536 })
537 .collect();
538 let reference = read.clone();
539
540 let params = SwParams {
542 band_width: 5,
543 min_score: 0,
544 ..SwParams::default()
545 };
546
547 let aligner = SwAligner::new(params);
548 let aln = aligner.align(&read, &reference).unwrap();
549
550 assert_eq!(aln.score, 200); }
553
554 #[test]
555 fn test_banded_disabled() {
556 let read: Vec<Nuc> = (0..30)
557 .map(|i| match i % 4 {
558 0 => Nuc::A,
559 1 => Nuc::C,
560 2 => Nuc::G,
561 _ => Nuc::T,
562 })
563 .collect();
564 let reference = read.clone();
565
566 let params = SwParams {
568 band_width: 0,
569 min_score: 0,
570 ..SwParams::default()
571 };
572
573 let aligner = SwAligner::new(params);
574 let aln = aligner.align(&read, &reference).unwrap();
575
576 assert_eq!(aln.score, 60); }
578}