rustalign_aligner/
sw_sse.rs1use crate::{Alignment, SwParams};
8use rustalign_common::{Nuc, Result as BwtResult, Score};
9use rustalign_simd::SseReg;
10
11const NWORDS_PER_REG: usize = 8;
13
14#[allow(dead_code)]
16const NBITS_PER_WORD: usize = 16;
17
18#[allow(dead_code)]
20const NBYTES_PER_WORD: usize = 2;
21
22#[allow(dead_code)]
24const NBYTES_PER_REG: usize = 16;
25
26const ALPHA_SIZE: usize = 5;
28
29#[derive(Debug)]
34pub struct QueryProfile {
35 profile: Vec<Vec<i16>>,
37
38 query_len: usize,
40
41 nseg: usize,
43
44 #[allow(dead_code)]
46 match_score: i16,
47 #[allow(dead_code)]
48 mismatch_penalty: i16,
49 #[allow(dead_code)]
50 gap_open: i16,
51 #[allow(dead_code)]
52 gap_extend: i16,
53}
54
55impl QueryProfile {
56 #[allow(clippy::manual_div_ceil)]
58 pub fn build(read: &[Nuc], match_score: i16, mismatch_penalty: i16) -> Self {
59 let query_len = read.len();
60 let nseg = (query_len + NWORDS_PER_REG - 1) / NWORDS_PER_REG;
61
62 let mut profile = vec![vec![mismatch_penalty; nseg * NWORDS_PER_REG]; ALPHA_SIZE];
64
65 for (i, &nuc) in read.iter().enumerate() {
67 let seg = i / NWORDS_PER_REG;
68 let word = i % NWORDS_PER_REG;
69
70 let ref_char = match nuc {
72 Nuc::A => 0,
73 Nuc::C => 1,
74 Nuc::G => 2,
75 Nuc::T => 3,
76 Nuc::N => 4,
77 };
78
79 profile[ref_char][seg * NWORDS_PER_REG + word] = match_score;
80 }
81
82 Self {
83 profile,
84 query_len,
85 nseg,
86 match_score,
87 mismatch_penalty,
88 gap_open: 0, gap_extend: 0,
90 }
91 }
92
93 #[inline]
95 fn get_profile(&self, ref_char: usize, seg: usize) -> &[i16] {
96 let start = seg * NWORDS_PER_REG;
97 &self.profile[ref_char][start..start + NWORDS_PER_REG]
98 }
99}
100
101pub struct SwSseAligner {
103 profile: QueryProfile,
105
106 params: SwParams,
108}
109
110impl SwSseAligner {
111 pub fn new(read: &[Nuc], params: SwParams) -> Self {
113 let match_score = params.match_score as i16;
114 let mismatch_penalty = params.mismatch_penalty as i16;
115
116 let profile = QueryProfile::build(read, match_score, mismatch_penalty);
117
118 Self { profile, params }
119 }
120
121 pub fn align(&mut self, reference: &[Nuc]) -> BwtResult<Alignment> {
126 let read_len = self.profile.query_len;
127 let ref_len = reference.len();
128
129 if read_len == 0 || ref_len == 0 {
130 return Ok(Alignment {
131 ref_start: 0,
132 ref_end: 0,
133 score: 0,
134 strand: rustalign_common::Strand::Forward,
135 edits: 0,
136 ..Default::default()
137 });
138 }
139
140 let mut best_score = Score::MIN;
143
144 let gap_open = self.params.gap_open as i16;
145 let gap_extend = self.params.gap_extend as i16;
146
147 #[allow(clippy::needless_range_loop)]
149 for ref_col in 0..ref_len {
150 let ref_nuc = reference[ref_col];
151 let ref_char = match ref_nuc {
152 Nuc::A => 0,
153 Nuc::C => 1,
154 Nuc::G => 2,
155 Nuc::T => 3,
156 Nuc::N => 4,
157 };
158
159 let mut h_prev = SseReg::zero();
161 let mut f_prev = SseReg::set1_epi16(gap_open);
162
163 for seg in 0..self.profile.nseg {
164 let _prof = self.profile.get_profile(ref_char, seg);
165
166 let e_gap = SseReg::set1_epi16(gap_extend);
168 let h_new = h_prev.max_epi16(&e_gap).max_epi16(&f_prev);
169
170 f_prev = f_prev.adds_epi16(&SseReg::set1_epi16(gap_extend));
172
173 let max_in_vec = h_new.hmax_epi16();
175 if max_in_vec as Score > best_score {
176 best_score = max_in_vec as Score;
177 }
178
179 h_prev = h_new;
180 }
181 }
182
183 Ok(Alignment {
186 ref_start: 0,
187 ref_end: ref_len,
188 score: best_score.max(0),
189 strand: rustalign_common::Strand::Forward,
190 edits: 0, ..Default::default()
192 })
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_query_profile_build() {
202 let read = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
203 let profile = QueryProfile::build(&read, 2, -3);
204
205 assert_eq!(profile.query_len, 4);
206 assert_eq!(profile.nseg, 1); assert_eq!(profile.match_score, 2);
208 assert_eq!(profile.mismatch_penalty, -3);
209 }
210
211 #[test]
212 fn test_query_profile_longer_read() {
213 let read: Vec<Nuc> = (0..20)
214 .map(|i| match i % 4 {
215 0 => Nuc::A,
216 1 => Nuc::C,
217 2 => Nuc::G,
218 _ => Nuc::T,
219 })
220 .collect();
221 let profile = QueryProfile::build(&read, 2, -3);
222
223 assert_eq!(profile.query_len, 20);
224 assert_eq!(profile.nseg, 3); }
226
227 #[test]
228 fn test_sw_sse_aligner_new() {
229 let read = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
230 let params = SwParams::default();
231 let aligner = SwSseAligner::new(&read, params);
232
233 assert_eq!(aligner.profile.query_len, 4);
234 }
235
236 #[test]
237 fn test_sw_sse_perfect_match() {
238 let read = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
239 let reference = read.clone();
240
241 let params = SwParams {
242 band_width: 10,
243 min_score: 0,
244 ..SwParams::default()
245 };
246
247 let mut aligner = SwSseAligner::new(&read, params);
248 let aln = aligner.align(&reference).unwrap();
249
250 assert!(aln.score >= 0);
252 }
253
254 #[test]
255 fn test_sw_sse_empty_read() {
256 let read: Vec<Nuc> = vec![];
257 let reference = vec![Nuc::A, Nuc::C, Nuc::G, Nuc::T];
258 let params = SwParams::default();
259 let mut aligner = SwSseAligner::new(&read, params);
260
261 let aln = aligner.align(&reference).unwrap();
262 assert_eq!(aln.score, 0);
263 assert_eq!(aln.edits, 0);
264 }
265}