1use super::alignment_lib::*;
3
4pub fn wavefront_align(
7 query: &str,
8 text: &str,
9 pens: &Penalties,
10) -> Result<Alignment, AlignmentError> {
11 if query.is_empty() || text.is_empty() {
12 return Err(AlignmentError::ZeroLength(format!(
13 "At least one of the string slices passed to wavefront_align had a length of zero.
14 Length of query: {}
15 Length of text: {}",
16 query.len(),
17 text.len()
18 )));
19 }
20 if query.len() > text.len() {
21 return Err(
22 AlignmentError::QueryTooLong(
23 "Query is longer than the reference string.
24 The length of the first string must be <= to the the length of the second string".to_string()
25 )
26 );
27 }
28 let mut current_front = new_wavefront_state(query, text, pens);
29 loop {
30 current_front.extend();
31 if current_front.is_finished() {
32 break;
33 }
34 current_front.increment_score();
35 current_front.next();
36 }
37 current_front.backtrace()
38}
39
40#[derive(Debug, PartialEq, Eq)]
42struct WavefrontState<'a> {
43 query: &'a str,
44 text: &'a str,
45 pens: &'a Penalties,
46 q_chars: Vec<char>,
47 t_chars: Vec<char>,
48
49 current_score: u32,
51
52 grid: WavefrontGrid,
53
54 num_diags: i32,
57
58 final_diagonal: i32,
61
62 highest_diag: i32,
64 lowest_diag: i32,
65}
66
67fn new_wavefront_state<'a>(
70 query: &'a str,
71 text: &'a str,
72 pens: &'a Penalties,
73) -> WavefrontState<'a> {
74 let q_chars: Vec<char> = query.chars().collect();
75 let t_chars: Vec<char> = text.chars().collect();
76
77 let final_diagonal = (q_chars.len() as i32) - (t_chars.len() as i32); let num_diags = (q_chars.len() + t_chars.len() + 1) as i32;
79 let highest_diag = q_chars.len() as i32;
80 let lowest_diag = 0 - t_chars.len() as i32;
81
82 let mut matches = vec![vec![None; num_diags as usize]; 1];
83 matches[0][(0 - lowest_diag) as usize] = Some((0, AlignmentLayer::Matches)); let grid = new_wavefront_grid();
86
87 WavefrontState {
88 query,
89 text,
90 pens,
91 q_chars,
92 t_chars,
93 current_score: 0,
94 num_diags,
95 final_diagonal,
96 highest_diag,
97 lowest_diag,
98 grid,
99 }
100}
101
102impl Wavefront for WavefrontState<'_> {
103 fn extend(&mut self) {
104 let diag_range = self
107 .grid
108 .get_diag_range(self.current_score)
109 .expect("get_diag_range returned None at wavefront_extend");
110
111 for diag in (diag_range.0)..=(diag_range.1) {
112 let text_pos = match self
113 .grid
114 .get(AlignmentLayer::Matches, self.current_score, diag)
115 {
116 Some((val, _)) => val,
117 _ => continue,
118 };
119 let mut query_pos = (text_pos as i32 + diag) as usize;
120 let mut text_pos = text_pos as usize;
121 while query_pos < self.q_chars.len() && text_pos < self.t_chars.len() {
128 match (
129 self.q_chars.get(query_pos as usize),
130 self.t_chars.get(text_pos as usize),
131 ) {
132 (Some(q), Some(t)) => {
133 if q == t {
134 self.grid.increment(self.current_score, diag);
135 query_pos += 1;
136 text_pos += 1;
137 } else {
138 break;
139 }
140 }
141 _ => break,
142 }
143 }
144 }
145 }
146
147 fn increment_score(&mut self) {
148 self.current_score += 1;
150 }
151
152 fn is_finished(&self) -> bool {
153 match self.grid.get(
157 AlignmentLayer::Matches,
158 self.current_score,
159 self.final_diagonal,
160 ) {
161 Some((score, _)) => score as usize >= self.t_chars.len(),
162 _ => false,
163 }
164 }
165
166 fn next(&mut self) {
167 let mut hi = 1 + vec![
171 self.current_score.checked_sub(self.pens.mismatch_pen),
172 self.current_score.checked_sub(self.pens.open_pen + self.pens.extd_pen),
173 self.current_score.checked_sub(self.pens.extd_pen),
174 ]
175 .into_iter()
176 .filter(|x| x.is_some())
177 .map(|x| x.unwrap())
178 .map(|x| self.grid.get_diag_range(x).unwrap().1)
179 .max()
180 .unwrap_or(-1);
181
182 if hi > self.highest_diag {
183 hi = self.highest_diag;
184 }
185
186 let mut lo = vec![
187 self.current_score.checked_sub(self.pens.mismatch_pen),
188 self.current_score.checked_sub(self.pens.open_pen + self.pens.extd_pen),
189 self.current_score.checked_sub(self.pens.extd_pen),
190 ]
191 .into_iter()
192 .filter(|x| x.is_some())
193 .map(|x| x.unwrap())
194 .map(|x| self.grid.get_diag_range(x).unwrap().0)
195 .min()
196 .unwrap_or(1)
197 - 1;
198
199 if lo < self.lowest_diag {
200 lo = self.lowest_diag;
201 }
202
203 self.grid.add_layer(lo, hi);
204
205 for diag in lo..=hi {
206 self.update_ins(diag);
207 self.update_del(diag);
208 self.update_mat(diag);
209 }
210 }
211
212 fn backtrace(&self) -> Result<Alignment, AlignmentError> {
213 let mut curr_score = self.current_score;
214 let mut curr_diag = self.final_diagonal;
215 let mut curr_layer = AlignmentLayer::Matches;
216
217 let mut query_aligned = String::new();
218 let mut text_aligned = String::new();
219
220 while curr_score > 0 {
221 match &mut curr_layer {
222 AlignmentLayer::Matches => {
224 match self
225 .grid
226 .get(AlignmentLayer::Matches, curr_score, curr_diag)
227 {
228 Some((score, AlignmentLayer::Inserts)) => {
229 curr_layer = AlignmentLayer::Inserts;
230 let mut current_char = score;
231 while current_char
232 > self
233 .grid
234 .get(AlignmentLayer::Inserts, curr_score, curr_diag)
235 .unwrap()
236 .0
237 {
238 query_aligned
239 .push(self.q_chars[(current_char as i32 + curr_diag - 1) as usize]);
240 text_aligned.push(self.t_chars[(current_char - 1) as usize]);
241 current_char -= 1;
242 }
243 }
244 Some((score, AlignmentLayer::Deletes)) => {
245 curr_layer = AlignmentLayer::Deletes;
246 let mut current_char = score;
247 while current_char
248 > self
249 .grid
250 .get(AlignmentLayer::Deletes, curr_score, curr_diag)
251 .unwrap()
252 .0
253 {
254 query_aligned
255 .push(self.q_chars[(current_char as i32 + curr_diag - 1) as usize]);
256 text_aligned.push(self.t_chars[(current_char - 1) as usize]);
257 current_char -= 1;
258 }
259 }
260 Some((score, AlignmentLayer::Matches)) => {
261 let mut current_char = score;
262 curr_score -= self.pens.mismatch_pen;
263 while current_char
264 > self
265 .grid
266 .get(AlignmentLayer::Matches, curr_score, curr_diag)
267 .unwrap()
268 .0
269 {
270 query_aligned
271 .push(self.q_chars[(current_char as i32 + curr_diag - 1) as usize]);
272 text_aligned.push(self.t_chars[(current_char - 1) as usize]);
273 current_char -= 1;
274 }
275 }
276 _ => panic!(),
277 };
278 }
279 AlignmentLayer::Inserts => {
281 match self
282 .grid
283 .get(AlignmentLayer::Inserts, curr_score, curr_diag)
284 {
285 Some((_, AlignmentLayer::Matches)) => {
286 let previous = self
287 .grid
288 .get(
289 AlignmentLayer::Matches,
290 curr_score - self.pens.extd_pen - self.pens.open_pen,
291 curr_diag - 1,
292 )
293 .unwrap();
294 query_aligned.push(self.q_chars[(previous.0 as i32 + curr_diag - 1) as usize]);
295 text_aligned.push('-');
296 curr_diag -= 1;
297 curr_score -= self.pens.extd_pen + self.pens.open_pen;
298 curr_layer = AlignmentLayer::Matches;
299 }
300 Some((_, AlignmentLayer::Inserts)) => {
301 let previous = self
302 .grid
303 .get(
304 AlignmentLayer::Inserts,
305 curr_score - self.pens.extd_pen,
306 curr_diag - 1,
307 )
308 .unwrap();
309 query_aligned.push(self.q_chars[(previous.0 as i32 + curr_diag - 1) as usize]);
310 text_aligned.push('-');
311 curr_diag -= 1;
312 curr_score -= self.pens.extd_pen;
313 }
314 _ => panic!(),
315 };
316 }
317 AlignmentLayer::Deletes => {
318 match self
319 .grid
320 .get(AlignmentLayer::Deletes, curr_score, curr_diag)
321 {
322 Some((_, AlignmentLayer::Matches)) => {
323 let previous = self
324 .grid
325 .get(
326 AlignmentLayer::Matches,
327 curr_score - self.pens.extd_pen - self.pens.open_pen,
328 curr_diag + 1,
329 )
330 .unwrap();
331 query_aligned.push('-');
332 text_aligned.push(self.t_chars[(previous.0) as usize]);
333 curr_diag += 1;
334 curr_score -= self.pens.extd_pen + self.pens.open_pen;
335 curr_layer = AlignmentLayer::Matches;
336 }
337
338 Some((_, AlignmentLayer::Deletes)) => {
339 let previous = self
340 .grid
341 .get(
342 AlignmentLayer::Deletes,
343 curr_score - self.pens.extd_pen,
344 curr_diag + 1,
345 )
346 .unwrap();
347 query_aligned.push('-');
348 text_aligned.push(self.t_chars[(previous.0) as usize]);
349 curr_diag += 1;
350 curr_score -= self.pens.extd_pen;
351 }
352 _ => panic!(),
353 };
354 }
355 };
356 }
357 if let AlignmentLayer::Matches = curr_layer {
358 if curr_score == 0 {
359 let remaining = self.grid.get(AlignmentLayer::Matches, 0, 0).unwrap().0 as usize;
360 if remaining > 0 {
361 query_aligned =
362 query_aligned + &self.q_chars[..remaining].iter().rev().collect::<String>();
363 text_aligned =
364 text_aligned + &self.t_chars[..remaining].iter().rev().collect::<String>();
365 }
366 }
367 }
368
369 let q = query_aligned.chars().rev().collect();
370 let t = text_aligned.chars().rev().collect();
371
372 Ok(Alignment {
373 score: self.current_score,
374 query_aligned: q,
375 text_aligned: t,
376 })
377 }
378}
379
380impl<'a> WavefrontState<'a> {
381 fn update_ins(&mut self, diag: i32) {
382 let from_open = if self.current_score >= (self.pens.open_pen + self.pens.extd_pen)
383 {
384 self.grid.get(
385 AlignmentLayer::Matches,
386 self.current_score - (self.pens.open_pen + self.pens.extd_pen),
387 diag - 1,
388 )
389 } else {
390 None
391 };
392 let from_extd = if self.current_score >= self.pens.extd_pen {
393 self.grid.get(
394 AlignmentLayer::Inserts,
395 self.current_score - self.pens.extd_pen,
396 diag - 1,
397 )
398 } else {
399 None
400 };
401 match (from_open, from_extd) {
402 (None, None) => (),
403 (Some(x), None) => {
404 self.grid.set(
405 AlignmentLayer::Inserts,
406 self.current_score,
407 diag,
408 Some((x.0, AlignmentLayer::Matches)),
409 );
410 }
411 (None, Some(x)) => {
412 self.grid.set(
413 AlignmentLayer::Inserts,
414 self.current_score,
415 diag,
416 Some((x.0, AlignmentLayer::Inserts)),
417 );
418 }
419 (Some(x), Some(y)) => {
420 if x.0 > y.0 {
421 self.grid.set(
422 AlignmentLayer::Inserts,
423 self.current_score,
424 diag,
425 Some((x.0, AlignmentLayer::Matches)),
426 );
427 } else {
428 self.grid.set(
429 AlignmentLayer::Inserts,
430 self.current_score,
431 diag,
432 Some((y.0, AlignmentLayer::Inserts)),
433 );
434 }
435 }
436 }
437 }
438
439 fn update_del(&mut self, diag: i32) {
440 let from_open = if self.current_score >= self.pens.open_pen + self.pens.extd_pen
441 {
442 self.grid.get(
443 AlignmentLayer::Matches,
444 self.current_score - (self.pens.open_pen + self.pens.extd_pen),
445 diag + 1,
446 )
447 } else {
448 None
449 };
450 let from_extd = if self.current_score >= self.pens.extd_pen {
451 self.grid.get(
452 AlignmentLayer::Deletes,
453 self.current_score - self.pens.extd_pen,
454 diag + 1,
455 )
456 } else {
457 None
458 };
459
460 match (from_open, from_extd) {
461 (None, None) => (),
462 (Some(x), None) => {
463 self.grid.set(
464 AlignmentLayer::Deletes,
465 self.current_score,
466 diag,
467 Some((x.0 + 1, AlignmentLayer::Matches)),
468 );
469 }
470 (None, Some(x)) => {
471 self.grid.set(
472 AlignmentLayer::Deletes,
473 self.current_score,
474 diag,
475 Some((x.0 + 1, AlignmentLayer::Deletes)),
476 );
477 }
478 (Some(x), Some(y)) => {
479 if x.0 >= y.0 {
480 self.grid.set(
481 AlignmentLayer::Deletes,
482 self.current_score,
483 diag,
484 Some((x.0 + 1, AlignmentLayer::Matches)),
485 );
486 } else {
487 self.grid.set(
488 AlignmentLayer::Deletes,
489 self.current_score,
490 diag,
491 Some((y.0 + 1, AlignmentLayer::Deletes)),
492 );
493 }
494 }
495 }
496 }
497
498 fn update_mat(&mut self, diag: i32) {
499 let from_mismatch = if self.current_score >= self.pens.mismatch_pen {
500 self.grid.get(
501 AlignmentLayer::Matches,
502 self.current_score - self.pens.mismatch_pen,
503 diag,
504 )
505 } else {
506 None
507 };
508
509 self.grid.set(
510 AlignmentLayer::Matches,
511 self.current_score,
512 diag,
513 match (
514 from_mismatch,
515 self.grid
516 .get(AlignmentLayer::Inserts, self.current_score, diag),
517 self.grid
518 .get(AlignmentLayer::Deletes, self.current_score, diag),
519 ) {
520 (None, None, None) => None,
521 (Some(x), None, None) => Some((x.0 + 1, AlignmentLayer::Matches)),
522 (None, Some(x), None) => Some((x.0, AlignmentLayer::Inserts)),
523 (None, None, Some(x)) => Some((x.0, AlignmentLayer::Deletes)),
524 (Some(x), Some(y), None) => Some(if x.0 + 1 >= y.0 {
525 (x.0 + 1, AlignmentLayer::Matches)
526 } else {
527 (y.0, AlignmentLayer::Inserts)
528 }),
529
530 (Some(x), None, Some(y)) => Some(if x.0 + 1 >= y.0 {
531 (x.0 + 1, AlignmentLayer::Matches)
532 } else {
533 (y.0, AlignmentLayer::Deletes)
534 }),
535
536 (None, Some(x), Some(y)) => Some(if x.0 > y.0 {
537 (x.0, AlignmentLayer::Inserts)
538 } else {
539 (y.0, AlignmentLayer::Deletes)
540 }),
541
542 (Some(x), Some(y), Some(z)) => Some(if x.0 + 1 >= y.0 {
543 if x.0 + 1 >= z.0 {
544 (x.0 + 1, AlignmentLayer::Matches)
545 } else {
546 (z.0, AlignmentLayer::Deletes)
547 }
548 } else if y.0 > z.0 {
549 (y.0, AlignmentLayer::Inserts)
550 } else {
551 (z.0, AlignmentLayer::Deletes)
552 }),
553 },
554 )
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
563 fn test_new_wavefront_state() {
564 let state = new_wavefront_state(
568 "GATA",
569 "TAGAC",
570 &Penalties {
571 mismatch_pen: 1,
572 open_pen: 2,
573 extd_pen: 3,
574 },
575 );
576
577 let mut manual_matches = vec![vec![None; 10]; 1];
578 manual_matches[0][5] = Some((0, AlignmentLayer::Matches));
579 let manual = WavefrontState {
580 query: "GATA",
581 text: "TAGAC",
582 pens: &Penalties {
583 mismatch_pen: 1,
584 open_pen: 2,
585 extd_pen: 3,
586 },
587 q_chars: "GATA".chars().collect(),
588 t_chars: "TAGAC".chars().collect(),
589 current_score: 0,
590 num_diags: 10,
591 final_diagonal: -1,
592 highest_diag: 4,
593 lowest_diag: -5,
594 grid: new_wavefront_grid(),
595 };
596
597 assert_eq!(state, manual);
598 }
599 #[test]
600 fn test_wavefront_update_ins() {
601 }
603
604 #[test]
605 fn test_wavefront_update_mat() {
606 }
608
609 #[test]
610 fn test_wavefront_backtrace() {
611 }
613
614 #[test]
615 fn test_align_avd() {
616 assert_eq!(
617 wavefront_align(
618 "AViidI",
619 "ViidIM",
620 &Penalties {
621 mismatch_pen: 3,
622 extd_pen: 1,
623 open_pen: 1,
624 }
625 ),
626 Ok(Alignment {
627 query_aligned: "AViidI-".to_string(),
628 text_aligned: "-ViidIM".to_string(),
629 score: 4,
630 })
631 );
632
633 assert_eq!(
634 wavefront_align(
635 "AVD",
636 "VDM",
637 &Penalties {
638 mismatch_pen: 2,
639 extd_pen: 1,
640 open_pen: 1,
641 }
642 ),
643 Ok(Alignment {
644 query_aligned: "AVD-".to_string(),
645 text_aligned: "-VDM".to_string(),
646 score: 4,
647 })
648 );
649
650 assert_eq!(
651 wavefront_align(
652 "AV",
653 "VM",
654 &Penalties {
655 mismatch_pen: 2,
656 extd_pen: 1,
657 open_pen: 1,
658 }
659 ),
660 Ok(Alignment {
661 query_aligned: "AV".to_string(),
662 text_aligned: "VM".to_string(),
663 score: 4,
664 })
665 );
666 }
667
668 #[test]
669 fn test_wavefront_align() {
670 assert_eq!(
671 wavefront_align(
672 "CAT",
673 "CAT",
674 &Penalties {
675 mismatch_pen: 1,
676 extd_pen: 1,
677 open_pen: 1,
678 }
679 ),
680 Ok(Alignment {
681 query_aligned: "CAT".to_string(),
682 text_aligned: "CAT".to_string(),
683 score: 0,
684 })
685 );
686 assert_eq!(
687 wavefront_align(
688 "CAT",
689 "CATS",
690 &Penalties {
691 mismatch_pen: 1,
692 extd_pen: 1,
693 open_pen: 1,
694 }
695 ),
696 Ok(Alignment {
697 query_aligned: "CAT-".to_string(),
698 text_aligned: "CATS".to_string(),
699 score: 2,
700 })
701 );
702 assert_eq!(
703 wavefront_align(
704 "XX",
705 "YY",
706 &Penalties {
707 mismatch_pen: 1,
708 extd_pen: 100,
709 open_pen: 100,
710 }
711 ),
712 Ok(Alignment {
713 query_aligned: "XX".to_string(),
714 text_aligned: "YY".to_string(),
715 score: 2,
716 })
717 );
718 assert_eq!(
719 wavefront_align(
720 "XX",
721 "YY",
722 &Penalties {
723 mismatch_pen: 100,
724 extd_pen: 1,
725 open_pen: 1,
726 }
727 ),
728 Ok(Alignment {
729 query_aligned: "XX--".to_string(),
730 text_aligned: "--YY".to_string(),
731 score: 6,
732 })
733 );
734 assert_eq!(
735 wavefront_align(
736 "XX",
737 "YYYYYYYY",
738 &Penalties {
739 mismatch_pen: 100,
740 extd_pen: 1,
741 open_pen: 1,
742 }
743 ),
744 Ok(Alignment {
745 query_aligned: "XX--------".to_string(),
746 text_aligned: "--YYYYYYYY".to_string(),
747 score: 12,
748 })
749 );
750 assert_eq!(
751 wavefront_align(
752 "XXZZ",
753 "XXYZ",
754 &Penalties {
755 mismatch_pen: 100,
756 extd_pen: 1,
757 open_pen: 1,
758 }
759 ),
760 Ok(Alignment {
761 query_aligned: "XX-ZZ".to_string(),
762 text_aligned: "XXYZ-".to_string(),
763 score: 4,
764 })
765 );
766 }
767
768 #[test]
769 fn assert_align_score() {
770 assert_eq!(
771 match wavefront_align(
772 "TCTTTACTCGCGCGTTGGAGAAATACAATAGT",
773 "TCTATACTGCGCGTTTGGAGAAATAAAATAGT",
774 &Penalties {
775 mismatch_pen: 1,
776 extd_pen: 1,
777 open_pen: 1,
778 }
779 ) {
780 Ok(s) => s.score,
781 _ => 1,
782 },
783 6
784 );
785
786 assert_eq!(
787 match wavefront_align(
788 "TCTTTACTCGCGCGTTGGAGAAATACAATAGT",
789 "TCTATACTGCGCGTTTGGAGAAATAAAATAGT",
790 &Penalties {
791 mismatch_pen: 135,
792 extd_pen: 19,
793 open_pen: 82,
794 }
795 ) {
796 Ok(s) => s.score,
797 _ => 1,
798 },
799 472
800 );
801 }
802}