lib/
wavefront_alignment.rs

1//! This module exports the wavefront alignment functions.
2use super::alignment_lib::*;
3
4/// This function is exported and can be called to perform an alignment.
5/// The query cannot be longer than the text.
6pub fn wavefront_align(
7    query: &str,
8    text: &str,
9    pens: &Penalties,
10) -> Result<Alignment, AlignmentError> {
11    if query.is_empty() || text.is_empty() {
12        return Err(AlignmentError::ZeroLength(format!(
13            "At least one of the string slices passed to wavefront_align had a length of zero.
14                        Length of query: {}
15                        Length of text:  {}",
16            query.len(),
17            text.len()
18        )));
19    }
20    if query.len() > text.len() {
21        return Err(
22                   AlignmentError::QueryTooLong(
23                       "Query is longer than the reference string.
24                        The length of the first string must be <= to the the length of the second string".to_string()
25                      )
26                  );
27    }
28    let mut current_front = new_wavefront_state(query, text, pens);
29    loop {
30        current_front.extend();
31        if current_front.is_finished() {
32            break;
33        }
34        current_front.increment_score();
35        current_front.next();
36    }
37    current_front.backtrace()
38}
39
40/// Main struct, implementing the algorithm.
41#[derive(Debug, PartialEq, Eq)]
42struct WavefrontState<'a> {
43    query: &'a str,
44    text: &'a str,
45    pens: &'a Penalties,
46    q_chars: Vec<char>,
47    t_chars: Vec<char>,
48
49    /// Counter for looping and later backtracking.
50    current_score: u32,
51
52    grid: WavefrontGrid,
53
54    /// Number of diagonals in the query-text alignment
55    /// == to q_chars + t_chars - 1.
56    num_diags: i32,
57
58    /// The only diagonal on which we can align every char of query and
59    /// text.
60    final_diagonal: i32,
61
62    /// Highest and lowest possible diags.
63    highest_diag: i32,
64    lowest_diag: i32,
65}
66
67/// Initializes a WavefrontState with the correct fields, for 2 string
68/// slices and a penalties struct.
69fn new_wavefront_state<'a>(
70    query: &'a str,
71    text: &'a str,
72    pens: &'a Penalties,
73) -> WavefrontState<'a> {
74    let q_chars: Vec<char> = query.chars().collect();
75    let t_chars: Vec<char> = text.chars().collect();
76
77    let final_diagonal = (q_chars.len() as i32) - (t_chars.len() as i32); // A_k in the article
78    let num_diags = (q_chars.len() + t_chars.len() + 1) as i32;
79    let highest_diag = q_chars.len() as i32;
80    let lowest_diag = 0 - t_chars.len() as i32;
81
82    let mut matches = vec![vec![None; num_diags as usize]; 1];
83    matches[0][(0 - lowest_diag) as usize] = Some((0, AlignmentLayer::Matches)); // Initialize the starting cell.
84
85    let grid = new_wavefront_grid();
86
87    WavefrontState {
88        query,
89        text,
90        pens,
91        q_chars,
92        t_chars,
93        current_score: 0,
94        num_diags,
95        final_diagonal,
96        highest_diag,
97        lowest_diag,
98        grid,
99    }
100}
101
102impl Wavefront for WavefrontState<'_> {
103    fn extend(&mut self) {
104        //! Extends the matches wavefronts to the furthest reaching point
105        //! of the current score.
106        let diag_range = self
107            .grid
108            .get_diag_range(self.current_score)
109            .expect("get_diag_range returned None at wavefront_extend");
110
111        for diag in (diag_range.0)..=(diag_range.1) {
112            let text_pos = match self
113                .grid
114                .get(AlignmentLayer::Matches, self.current_score, diag)
115            {
116                Some((val, _)) => val,
117                _ => continue,
118            };
119            let mut query_pos = (text_pos as i32 + diag) as usize;
120            let mut text_pos = text_pos as usize;
121            // The furthest reaching point value stored is the number
122            // of matched chars in the Text string.
123            // For any diagonal on the dynamic programming alignment
124            // matrix, the number of chars matched for the Query is the
125            // number of Text chars matched + diagonal.
126
127            while query_pos < self.q_chars.len() && text_pos < self.t_chars.len() {
128                match (
129                    self.q_chars.get(query_pos as usize),
130                    self.t_chars.get(text_pos as usize),
131                ) {
132                    (Some(q), Some(t)) => {
133                        if q == t {
134                            self.grid.increment(self.current_score, diag);
135                            query_pos += 1;
136                            text_pos += 1;
137                        } else {
138                            break;
139                        }
140                    }
141                    _ => break,
142                }
143            }
144        }
145    }
146
147    fn increment_score(&mut self) {
148        //! Increments the current score by 1.
149        self.current_score += 1;
150    }
151
152    fn is_finished(&self) -> bool {
153        //! Checks if the alignment is completed: for the current score,
154        //! on the final diagonal, the furthest reaching point matches every
155        //! char of Text and Query.
156        match self.grid.get(
157            AlignmentLayer::Matches,
158            self.current_score,
159            self.final_diagonal,
160        ) {
161            Some((score, _)) => score as usize >= self.t_chars.len(),
162            _ => false,
163        }
164    }
165
166    fn next(&mut self) {
167        //! Equivalent of WAVEFRONT_NEXT
168
169        // Calculating the next highest diagonal of the wavefront
170        let mut hi = 1 + vec![
171            self.current_score.checked_sub(self.pens.mismatch_pen),
172            self.current_score.checked_sub(self.pens.open_pen + self.pens.extd_pen),
173            self.current_score.checked_sub(self.pens.extd_pen),
174        ]
175        .into_iter()
176        .filter(|x| x.is_some())
177        .map(|x| x.unwrap())
178        .map(|x| self.grid.get_diag_range(x).unwrap().1)
179        .max()
180        .unwrap_or(-1);
181
182        if hi > self.highest_diag {
183            hi = self.highest_diag;
184        }
185
186        let mut lo = vec![
187            self.current_score.checked_sub(self.pens.mismatch_pen),
188            self.current_score.checked_sub(self.pens.open_pen + self.pens.extd_pen),
189            self.current_score.checked_sub(self.pens.extd_pen),
190        ]
191        .into_iter()
192        .filter(|x| x.is_some())
193        .map(|x| x.unwrap())
194        .map(|x| self.grid.get_diag_range(x).unwrap().0)
195        .min()
196        .unwrap_or(1)
197            - 1;
198
199        if lo < self.lowest_diag {
200            lo = self.lowest_diag;
201        }
202
203        self.grid.add_layer(lo, hi);
204
205        for diag in lo..=hi {
206            self.update_ins(diag);
207            self.update_del(diag);
208            self.update_mat(diag);
209        }
210    }
211
212    fn backtrace(&self) -> Result<Alignment, AlignmentError> {
213        let mut curr_score = self.current_score;
214        let mut curr_diag = self.final_diagonal;
215        let mut curr_layer = AlignmentLayer::Matches;
216
217        let mut query_aligned = String::new();
218        let mut text_aligned = String::new();
219
220        while curr_score > 0 {
221            match &mut curr_layer {
222                // If we're on a match
223                AlignmentLayer::Matches => {
224                    match self
225                        .grid
226                        .get(AlignmentLayer::Matches, curr_score, curr_diag)
227                    {
228                        Some((score, AlignmentLayer::Inserts)) => {
229                            curr_layer = AlignmentLayer::Inserts;
230                            let mut current_char = score;
231                            while current_char
232                                > self
233                                    .grid
234                                    .get(AlignmentLayer::Inserts, curr_score, curr_diag)
235                                    .unwrap()
236                                    .0
237                            {
238                                query_aligned
239                                    .push(self.q_chars[(current_char as i32 + curr_diag - 1) as usize]);
240                                text_aligned.push(self.t_chars[(current_char - 1) as usize]);
241                                current_char -= 1;
242                            }
243                        }
244                        Some((score, AlignmentLayer::Deletes)) => {
245                            curr_layer = AlignmentLayer::Deletes;
246                            let mut current_char = score;
247                            while current_char
248                                > self
249                                    .grid
250                                    .get(AlignmentLayer::Deletes, curr_score, curr_diag)
251                                    .unwrap()
252                                    .0
253                            {
254                                query_aligned
255                                    .push(self.q_chars[(current_char as i32 + curr_diag - 1) as usize]);
256                                text_aligned.push(self.t_chars[(current_char - 1) as usize]);
257                                current_char -= 1;
258                            }
259                        }
260                        Some((score, AlignmentLayer::Matches)) => {
261                            let mut current_char = score;
262                            curr_score -= self.pens.mismatch_pen;
263                            while current_char
264                                > self
265                                    .grid
266                                    .get(AlignmentLayer::Matches, curr_score, curr_diag)
267                                    .unwrap()
268                                    .0
269                            {
270                                query_aligned
271                                    .push(self.q_chars[(current_char as i32 + curr_diag - 1) as usize]);
272                                text_aligned.push(self.t_chars[(current_char - 1) as usize]);
273                                current_char -= 1;
274                            }
275                        }
276                        _ => panic!(),
277                    };
278                }
279                // If we're on the Inserts layer.
280                AlignmentLayer::Inserts => {
281                    match self
282                        .grid
283                        .get(AlignmentLayer::Inserts, curr_score, curr_diag)
284                    {
285                        Some((_, AlignmentLayer::Matches)) => {
286                            let previous = self
287                                .grid
288                                .get(
289                                    AlignmentLayer::Matches,
290                                    curr_score - self.pens.extd_pen - self.pens.open_pen,
291                                    curr_diag - 1,
292                                )
293                                .unwrap();
294                            query_aligned.push(self.q_chars[(previous.0 as i32 + curr_diag - 1) as usize]);
295                            text_aligned.push('-');
296                            curr_diag -= 1;
297                            curr_score -= self.pens.extd_pen + self.pens.open_pen;
298                            curr_layer = AlignmentLayer::Matches;
299                        }
300                        Some((_, AlignmentLayer::Inserts)) => {
301                            let previous = self
302                                .grid
303                                .get(
304                                    AlignmentLayer::Inserts,
305                                    curr_score - self.pens.extd_pen,
306                                    curr_diag - 1,
307                                )
308                                .unwrap();
309                            query_aligned.push(self.q_chars[(previous.0 as i32 + curr_diag - 1) as usize]);
310                            text_aligned.push('-');
311                            curr_diag -= 1;
312                            curr_score -= self.pens.extd_pen;
313                        }
314                        _ => panic!(),
315                    };
316                }
317                AlignmentLayer::Deletes => {
318                    match self
319                        .grid
320                        .get(AlignmentLayer::Deletes, curr_score, curr_diag)
321                    {
322                        Some((_, AlignmentLayer::Matches)) => {
323                            let previous = self
324                                .grid
325                                .get(
326                                    AlignmentLayer::Matches,
327                                    curr_score - self.pens.extd_pen - self.pens.open_pen,
328                                    curr_diag + 1,
329                                )
330                                .unwrap();
331                            query_aligned.push('-');
332                            text_aligned.push(self.t_chars[(previous.0) as usize]);
333                            curr_diag += 1;
334                            curr_score -= self.pens.extd_pen + self.pens.open_pen;
335                            curr_layer = AlignmentLayer::Matches;
336                        }
337
338                        Some((_, AlignmentLayer::Deletes)) => {
339                            let previous = self
340                                .grid
341                                .get(
342                                    AlignmentLayer::Deletes,
343                                    curr_score - self.pens.extd_pen,
344                                    curr_diag + 1,
345                                )
346                                .unwrap();
347                            query_aligned.push('-');
348                            text_aligned.push(self.t_chars[(previous.0) as usize]);
349                            curr_diag += 1;
350                            curr_score -= self.pens.extd_pen;
351                        }
352                        _ => panic!(),
353                    };
354                }
355            };
356        }
357        if let AlignmentLayer::Matches = curr_layer {
358            if curr_score == 0 {
359                let remaining = self.grid.get(AlignmentLayer::Matches, 0, 0).unwrap().0 as usize;
360                if remaining > 0 {
361                    query_aligned =
362                        query_aligned + &self.q_chars[..remaining].iter().rev().collect::<String>();
363                    text_aligned =
364                        text_aligned + &self.t_chars[..remaining].iter().rev().collect::<String>();
365                }
366            }
367        }
368
369        let q = query_aligned.chars().rev().collect();
370        let t = text_aligned.chars().rev().collect();
371
372        Ok(Alignment {
373            score: self.current_score,
374            query_aligned: q,
375            text_aligned: t,
376        })
377    }
378}
379
380impl<'a> WavefrontState<'a> {
381    fn update_ins(&mut self, diag: i32) {
382        let from_open = if self.current_score >= (self.pens.open_pen + self.pens.extd_pen)
383        {
384            self.grid.get(
385                AlignmentLayer::Matches,
386                self.current_score - (self.pens.open_pen + self.pens.extd_pen),
387                diag - 1,
388            )
389        } else {
390            None
391        };
392        let from_extd = if self.current_score >= self.pens.extd_pen {
393            self.grid.get(
394                AlignmentLayer::Inserts,
395                self.current_score - self.pens.extd_pen,
396                diag - 1,
397            )
398        } else {
399            None
400        };
401        match (from_open, from_extd) {
402            (None, None) => (),
403            (Some(x), None) => {
404                self.grid.set(
405                    AlignmentLayer::Inserts,
406                    self.current_score,
407                    diag,
408                    Some((x.0, AlignmentLayer::Matches)),
409                );
410            }
411            (None, Some(x)) => {
412                self.grid.set(
413                    AlignmentLayer::Inserts,
414                    self.current_score,
415                    diag,
416                    Some((x.0, AlignmentLayer::Inserts)),
417                );
418            }
419            (Some(x), Some(y)) => {
420                if x.0 > y.0 {
421                    self.grid.set(
422                        AlignmentLayer::Inserts,
423                        self.current_score,
424                        diag,
425                        Some((x.0, AlignmentLayer::Matches)),
426                    );
427                } else {
428                    self.grid.set(
429                        AlignmentLayer::Inserts,
430                        self.current_score,
431                        diag,
432                        Some((y.0, AlignmentLayer::Inserts)),
433                    );
434                }
435            }
436        }
437    }
438
439    fn update_del(&mut self, diag: i32) {
440        let from_open = if self.current_score >= self.pens.open_pen + self.pens.extd_pen
441        {
442            self.grid.get(
443                AlignmentLayer::Matches,
444                self.current_score - (self.pens.open_pen + self.pens.extd_pen),
445                diag + 1,
446            )
447        } else {
448            None
449        };
450        let from_extd = if self.current_score >= self.pens.extd_pen {
451            self.grid.get(
452                AlignmentLayer::Deletes,
453                self.current_score - self.pens.extd_pen,
454                diag + 1,
455            )
456        } else {
457            None
458        };
459
460        match (from_open, from_extd) {
461            (None, None) => (),
462            (Some(x), None) => {
463                self.grid.set(
464                    AlignmentLayer::Deletes,
465                    self.current_score,
466                    diag,
467                    Some((x.0 + 1, AlignmentLayer::Matches)),
468                );
469            }
470            (None, Some(x)) => {
471                self.grid.set(
472                    AlignmentLayer::Deletes,
473                    self.current_score,
474                    diag,
475                    Some((x.0 + 1, AlignmentLayer::Deletes)),
476                );
477            }
478            (Some(x), Some(y)) => {
479                if x.0 >= y.0 {
480                    self.grid.set(
481                        AlignmentLayer::Deletes,
482                        self.current_score,
483                        diag,
484                        Some((x.0 + 1, AlignmentLayer::Matches)),
485                    );
486                } else {
487                    self.grid.set(
488                        AlignmentLayer::Deletes,
489                        self.current_score,
490                        diag,
491                        Some((y.0 + 1, AlignmentLayer::Deletes)),
492                    );
493                }
494            }
495        }
496    }
497
498    fn update_mat(&mut self, diag: i32) {
499        let from_mismatch = if self.current_score >= self.pens.mismatch_pen {
500            self.grid.get(
501                AlignmentLayer::Matches,
502                self.current_score - self.pens.mismatch_pen,
503                diag,
504            )
505        } else {
506            None
507        };
508
509        self.grid.set(
510            AlignmentLayer::Matches,
511            self.current_score,
512            diag,
513            match (
514                from_mismatch,
515                self.grid
516                    .get(AlignmentLayer::Inserts, self.current_score, diag),
517                self.grid
518                    .get(AlignmentLayer::Deletes, self.current_score, diag),
519            ) {
520                (None, None, None) => None,
521                (Some(x), None, None) => Some((x.0 + 1, AlignmentLayer::Matches)),
522                (None, Some(x), None) => Some((x.0, AlignmentLayer::Inserts)),
523                (None, None, Some(x)) => Some((x.0, AlignmentLayer::Deletes)),
524                (Some(x), Some(y), None) => Some(if x.0 + 1 >= y.0 {
525                    (x.0 + 1, AlignmentLayer::Matches)
526                } else {
527                    (y.0, AlignmentLayer::Inserts)
528                }),
529
530                (Some(x), None, Some(y)) => Some(if x.0 + 1 >= y.0 {
531                    (x.0 + 1, AlignmentLayer::Matches)
532                } else {
533                    (y.0, AlignmentLayer::Deletes)
534                }),
535
536                (None, Some(x), Some(y)) => Some(if x.0 > y.0 {
537                    (x.0, AlignmentLayer::Inserts)
538                } else {
539                    (y.0, AlignmentLayer::Deletes)
540                }),
541
542                (Some(x), Some(y), Some(z)) => Some(if x.0 + 1 >= y.0 {
543                    if x.0 + 1 >= z.0 {
544                        (x.0 + 1, AlignmentLayer::Matches)
545                    } else {
546                        (z.0, AlignmentLayer::Deletes)
547                    }
548                } else if y.0 > z.0 {
549                    (y.0, AlignmentLayer::Inserts)
550                } else {
551                    (z.0, AlignmentLayer::Deletes)
552                }),
553            },
554        )
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn test_new_wavefront_state() {
564        // Doesn't do much currently but at least if we accidently
565        // change the behaviour/meaning of the wavefront state structs,
566        // we'll notice.
567        let state = new_wavefront_state(
568            "GATA",
569            "TAGAC",
570            &Penalties {
571                mismatch_pen: 1,
572                open_pen: 2,
573                extd_pen: 3,
574            },
575        );
576
577        let mut manual_matches = vec![vec![None; 10]; 1];
578        manual_matches[0][5] = Some((0, AlignmentLayer::Matches));
579        let manual = WavefrontState {
580            query: "GATA",
581            text: "TAGAC",
582            pens: &Penalties {
583                mismatch_pen: 1,
584                open_pen: 2,
585                extd_pen: 3,
586            },
587            q_chars: "GATA".chars().collect(),
588            t_chars: "TAGAC".chars().collect(),
589            current_score: 0,
590            num_diags: 10,
591            final_diagonal: -1,
592            highest_diag: 4,
593            lowest_diag: -5,
594            grid: new_wavefront_grid(),
595        };
596
597        assert_eq!(state, manual);
598    }
599    #[test]
600    fn test_wavefront_update_ins() {
601        // TODO
602    }
603
604    #[test]
605    fn test_wavefront_update_mat() {
606        // TODO
607    }
608
609    #[test]
610    fn test_wavefront_backtrace() {
611        // TODO
612    }
613
614    #[test]
615    fn test_align_avd() {
616        assert_eq!(
617            wavefront_align(
618                "AViidI",
619                "ViidIM",
620                &Penalties {
621                    mismatch_pen: 3,
622                    extd_pen: 1,
623                    open_pen: 1,
624                }
625            ),
626            Ok(Alignment {
627                query_aligned: "AViidI-".to_string(),
628                text_aligned: "-ViidIM".to_string(),
629                score: 4,
630            })
631        );
632
633        assert_eq!(
634            wavefront_align(
635                "AVD",
636                "VDM",
637                &Penalties {
638                    mismatch_pen: 2,
639                    extd_pen: 1,
640                    open_pen: 1,
641                }
642            ),
643            Ok(Alignment {
644                query_aligned: "AVD-".to_string(),
645                text_aligned: "-VDM".to_string(),
646                score: 4,
647            })
648        );
649
650        assert_eq!(
651            wavefront_align(
652                "AV",
653                "VM",
654                &Penalties {
655                    mismatch_pen: 2,
656                    extd_pen: 1,
657                    open_pen: 1,
658                }
659            ),
660            Ok(Alignment {
661                query_aligned: "AV".to_string(),
662                text_aligned: "VM".to_string(),
663                score: 4,
664            })
665        );
666    }
667
668    #[test]
669    fn test_wavefront_align() {
670        assert_eq!(
671            wavefront_align(
672                "CAT",
673                "CAT",
674                &Penalties {
675                    mismatch_pen: 1,
676                    extd_pen: 1,
677                    open_pen: 1,
678                }
679            ),
680            Ok(Alignment {
681                query_aligned: "CAT".to_string(),
682                text_aligned: "CAT".to_string(),
683                score: 0,
684            })
685        );
686        assert_eq!(
687            wavefront_align(
688                "CAT",
689                "CATS",
690                &Penalties {
691                    mismatch_pen: 1,
692                    extd_pen: 1,
693                    open_pen: 1,
694                }
695            ),
696            Ok(Alignment {
697                query_aligned: "CAT-".to_string(),
698                text_aligned: "CATS".to_string(),
699                score: 2,
700            })
701        );
702        assert_eq!(
703            wavefront_align(
704                "XX",
705                "YY",
706                &Penalties {
707                    mismatch_pen: 1,
708                    extd_pen: 100,
709                    open_pen: 100,
710                }
711            ),
712            Ok(Alignment {
713                query_aligned: "XX".to_string(),
714                text_aligned: "YY".to_string(),
715                score: 2,
716            })
717        );
718        assert_eq!(
719            wavefront_align(
720                "XX",
721                "YY",
722                &Penalties {
723                    mismatch_pen: 100,
724                    extd_pen: 1,
725                    open_pen: 1,
726                }
727            ),
728            Ok(Alignment {
729                query_aligned: "XX--".to_string(),
730                text_aligned: "--YY".to_string(),
731                score: 6,
732            })
733        );
734        assert_eq!(
735            wavefront_align(
736                "XX",
737                "YYYYYYYY",
738                &Penalties {
739                    mismatch_pen: 100,
740                    extd_pen: 1,
741                    open_pen: 1,
742                }
743            ),
744            Ok(Alignment {
745                query_aligned: "XX--------".to_string(),
746                text_aligned: "--YYYYYYYY".to_string(),
747                score: 12,
748            })
749        );
750        assert_eq!(
751            wavefront_align(
752                "XXZZ",
753                "XXYZ",
754                &Penalties {
755                    mismatch_pen: 100,
756                    extd_pen: 1,
757                    open_pen: 1,
758                }
759            ),
760            Ok(Alignment {
761                query_aligned: "XX-ZZ".to_string(),
762                text_aligned: "XXYZ-".to_string(),
763                score: 4,
764            })
765        );
766    }
767
768    #[test]
769    fn assert_align_score() {
770        assert_eq!(
771            match wavefront_align(
772                "TCTTTACTCGCGCGTTGGAGAAATACAATAGT",
773                "TCTATACTGCGCGTTTGGAGAAATAAAATAGT",
774                &Penalties {
775                    mismatch_pen: 1,
776                    extd_pen: 1,
777                    open_pen: 1,
778                }
779            ) {
780                Ok(s) => s.score,
781                _ => 1,
782            },
783            6
784        );
785
786        assert_eq!(
787            match wavefront_align(
788                "TCTTTACTCGCGCGTTGGAGAAATACAATAGT",
789                "TCTATACTGCGCGTTTGGAGAAATAAAATAGT",
790                &Penalties {
791                    mismatch_pen: 135,
792                    extd_pen: 19,
793                    open_pen: 82,
794                }
795            ) {
796                Ok(s) => s.score,
797                _ => 1,
798            },
799            472
800        );
801    }
802}