1use crate::kmer::{Kmer, KmerBits};
13use crate::minimizer::{MinimizerInfo, MinimizerIterator};
14use crate::encoding::encode_base;
15
16#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct LookupResult {
19 pub kmer_id: u64,
21 pub kmer_id_in_string: u64,
23 pub kmer_offset: u64,
25 pub kmer_orientation: i8,
27
28 pub string_id: u64,
30 pub string_begin: u64,
32 pub string_end: u64,
34
35 pub minimizer_found: bool,
37}
38
39impl LookupResult {
40 pub fn not_found() -> Self {
42 Self {
43 kmer_id: u64::MAX,
44 kmer_id_in_string: u64::MAX,
45 kmer_offset: u64::MAX,
46 kmer_orientation: 1, string_id: u64::MAX,
48 string_begin: u64::MAX,
49 string_end: u64::MAX,
50 minimizer_found: true,
51 }
52 }
53
54 #[inline]
56 pub fn is_found(&self) -> bool {
57 self.kmer_id != u64::MAX
58 }
59
60 #[inline]
62 pub fn string_length(&self) -> u64 {
63 if self.is_found() {
64 self.string_end - self.string_begin
65 } else {
66 0
67 }
68 }
69}
70
71impl Default for LookupResult {
72 fn default() -> Self {
73 Self::not_found()
74 }
75}
76
77pub struct StreamingQuery<const K: usize>
93where
94 Kmer<K>: KmerBits,
95{
96 k: usize,
97 _m: usize, _canonical: bool, start: bool,
102 kmer: Option<Kmer<K>>,
103 kmer_rc: Option<Kmer<K>>,
104
105 minimizer_it: MinimizerIterator,
107 minimizer_it_rc: MinimizerIterator,
108 curr_mini_info: MinimizerInfo,
109 prev_mini_info: MinimizerInfo,
110 curr_mini_info_rc: MinimizerInfo,
111 prev_mini_info_rc: MinimizerInfo,
112
113 remaining_string_bases: u64,
115
116 result: LookupResult,
118
119 num_searches: u64,
121 num_extensions: u64,
122 num_invalid: u64,
123 num_negative: u64,
124}
125
126impl<const K: usize> StreamingQuery<K>
127where
128 Kmer<K>: KmerBits,
129{
130 pub fn new(k: usize, m: usize, canonical: bool) -> Self {
137 assert_eq!(k, K, "k parameter must match const generic K");
138
139 let dummy_mini = MinimizerInfo::new(u64::MAX, 0, 0);
140
141 Self {
142 k,
143 _m: m,
144 _canonical: canonical,
145 start: true,
146 kmer: None,
147 kmer_rc: None,
148 minimizer_it: MinimizerIterator::with_seed(k, m, 1),
149 minimizer_it_rc: MinimizerIterator::with_seed(k, m, 1),
150 curr_mini_info: dummy_mini,
151 prev_mini_info: dummy_mini,
152 curr_mini_info_rc: dummy_mini,
153 prev_mini_info_rc: dummy_mini,
154 remaining_string_bases: 0,
155 result: LookupResult::not_found(),
156 num_searches: 0,
157 num_extensions: 0,
158 num_invalid: 0,
159 num_negative: 0,
160 }
161 }
162
163 pub fn reset(&mut self) {
165 self.start = true;
166 self.remaining_string_bases = 0;
167 self.result = LookupResult::not_found();
168 self.minimizer_it.set_position(0);
169 self.minimizer_it_rc.set_position(0);
170 }
171
172 pub fn lookup(&mut self, kmer_bytes: &[u8]) -> LookupResult {
183 self.lookup_internal(kmer_bytes, None)
185 }
186
187 pub fn lookup_with_dict(&mut self, kmer_bytes: &[u8], dict: &crate::dictionary::Dictionary) -> LookupResult {
192 self.lookup_internal(kmer_bytes, Some(dict))
193 }
194
195 fn lookup_internal(&mut self, kmer_bytes: &[u8], dict_opt: Option<&crate::dictionary::Dictionary>) -> LookupResult {
196 let is_valid = if self.start {
198 self.is_valid_kmer_bytes(kmer_bytes)
199 } else {
200 self.is_valid_base(kmer_bytes[self.k - 1])
201 };
202
203 if !is_valid {
204 self.num_invalid += 1;
205 self.reset();
206 return self.result.clone();
207 }
208
209 if self.start {
211 let km = Kmer::<K>::from_ascii_unchecked(kmer_bytes);
213 self.kmer = Some(km);
214 let rc = km.reverse_complement();
215 self.kmer_rc = Some(rc);
216
217 self.curr_mini_info = self.minimizer_it.next(km);
218 self.curr_mini_info_rc = self.minimizer_it_rc.next(rc);
219 } else {
220 if let Some(mut km) = self.kmer {
222 for i in 0..(self.k - 1) {
224 let base = km.get_base(i + 1);
225 km.set_base(i, base);
226 }
227
228 let new_base = kmer_bytes[self.k - 1];
230 if let Ok(encoded) = encode_base(new_base) {
231 km.set_base(self.k - 1, encoded);
232
233 self.kmer = Some(km);
234
235 if let Some(mut km_rc) = self.kmer_rc {
237 for i in (1..self.k).rev() {
238 let base = km_rc.get_base(i - 1);
239 km_rc.set_base(i, base);
240 }
241
242 let complement = crate::encoding::complement_base(encoded);
244 km_rc.set_base(0, complement);
245
246 self.kmer_rc = Some(km_rc);
247
248 self.curr_mini_info = self.minimizer_it.next(km);
249 self.curr_mini_info_rc = self.minimizer_it_rc.next(km_rc);
250 }
251 }
252 }
253 }
254
255 if self.remaining_string_bases == 0 {
257 self.seed(dict_opt);
258 } else {
259 if let Some(dict) = dict_opt {
261 self.try_extend(dict);
262 } else {
263 self.seed(dict_opt);
265 }
266 }
267
268 self.prev_mini_info = self.curr_mini_info;
270 self.prev_mini_info_rc = self.curr_mini_info_rc;
271 self.start = false;
272
273 self.result.clone()
274 }
275
276 fn is_valid_kmer_bytes(&self, bytes: &[u8]) -> bool {
278 if bytes.len() != self.k {
279 return false;
280 }
281 for &b in bytes {
282 if !matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't') {
283 return false;
284 }
285 }
286 true
287 }
288
289 fn is_valid_base(&self, b: u8) -> bool {
291 matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't')
292 }
293
294 fn seed(&mut self, dict_opt: Option<&crate::dictionary::Dictionary>) {
298 self.remaining_string_bases = 0;
299
300 if !self.start
302 && self.curr_mini_info.value == self.prev_mini_info.value
303 && self.curr_mini_info_rc.value == self.prev_mini_info_rc.value
304 && !self.result.minimizer_found
305 {
306 assert_eq!(self.result.kmer_id, u64::MAX);
307 self.num_negative += 1;
308 return;
309 }
310
311 if let (Some(dict), Some(kmer)) = (dict_opt, self.kmer) {
312 if self._canonical {
313 let kmer_rc = kmer.reverse_complement();
322 let mini_fwd = dict.extract_minimizer::<K>(&kmer);
323 let mini_rc = dict.extract_minimizer::<K>(&kmer_rc);
324
325 if mini_fwd.value < mini_rc.value {
326 self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_fwd);
327 } else if mini_rc.value < mini_fwd.value {
328 self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_rc);
329 } else {
330 self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_fwd);
331 if self.result.kmer_id == u64::MAX {
332 self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_rc);
333 }
334 }
335 } else {
336 let mini_fwd = dict.extract_minimizer::<K>(&kmer);
339 self.result = dict.lookup_regular_streaming::<K>(&kmer, mini_fwd);
340 let minimizer_found = self.result.minimizer_found;
341 if self.result.kmer_id == u64::MAX {
342 assert_eq!(self.result.kmer_orientation, 1); let kmer_rc = kmer.reverse_complement();
344 let mini_rc = dict.extract_minimizer::<K>(&kmer_rc);
345 self.result = dict.lookup_regular_streaming::<K>(&kmer_rc, mini_rc);
346 self.result.kmer_orientation = -1; let minimizer_rc_found = self.result.minimizer_found;
348 self.result.minimizer_found = minimizer_rc_found || minimizer_found;
349 }
350 }
351
352 if self.result.kmer_id == u64::MAX {
353 self.num_negative += 1;
354 return;
355 }
356
357 assert!(self.result.minimizer_found);
358 self.num_searches += 1;
359
360 let string_size = self.result.string_end - self.result.string_begin;
364 if self.result.kmer_orientation > 0 {
365 self.remaining_string_bases =
366 (string_size - self.k as u64) - self.result.kmer_id_in_string;
367 } else {
368 self.remaining_string_bases = self.result.kmer_id_in_string;
369 }
370 } else {
371 self.result = LookupResult::not_found();
373 self.num_negative += 1;
374 }
375 }
376
377 fn try_extend(&mut self, dict: &crate::dictionary::Dictionary) {
383 if let (Some(kmer), Some(kmer_rc)) = (self.kmer, self.kmer_rc) {
384 let abs_pos = self.result.kmer_id_in_string as usize
388 + self.result.string_begin as usize;
389
390 let next_abs_pos = if self.result.kmer_orientation > 0 {
391 abs_pos + 1
392 } else {
393 abs_pos.wrapping_sub(1)
394 };
395
396 let expected_kmer: Kmer<K> = dict.spss().decode_kmer_at(next_abs_pos);
398
399 if expected_kmer.bits() == kmer.bits()
400 || expected_kmer.bits() == kmer_rc.bits()
401 {
402 self.num_extensions += 1;
404 let delta = self.result.kmer_orientation as i64;
405 self.result.kmer_id = (self.result.kmer_id as i64 + delta) as u64;
406 self.result.kmer_id_in_string =
407 (self.result.kmer_id_in_string as i64 + delta) as u64;
408 self.result.kmer_offset =
409 (self.result.kmer_offset as i64 + delta) as u64;
410 self.remaining_string_bases -= 1;
411 return;
412 }
413 }
414
415 self.seed(Some(dict));
417 }
418
419 pub fn num_searches(&self) -> u64 {
421 self.num_searches
422 }
423
424 pub fn num_extensions(&self) -> u64 {
426 self.num_extensions
427 }
428
429 pub fn num_positive_lookups(&self) -> u64 {
431 self.num_searches + self.num_extensions
432 }
433
434 pub fn num_negative_lookups(&self) -> u64 {
436 self.num_negative
437 }
438
439 pub fn num_invalid_lookups(&self) -> u64 {
441 self.num_invalid
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_lookup_result_creation() {
451 let result = LookupResult::not_found();
452 assert!(!result.is_found());
453 assert_eq!(result.kmer_id, u64::MAX);
454 }
455
456 #[test]
457 fn test_lookup_result_string_length() {
458 let mut result = LookupResult::not_found();
459 result.string_begin = 100;
460 result.string_end = 200;
461 result.kmer_id = 42; assert_eq!(result.string_length(), 100);
464 }
465
466 #[test]
467 fn test_streaming_query_creation() {
468 let query: StreamingQuery<31> = StreamingQuery::new(31, 13, true);
469 assert_eq!(query.k, 31);
470 assert_eq!(query._m, 13);
471 assert!(query._canonical);
472 assert_eq!(query.num_searches(), 0);
473 }
474
475 #[test]
476 fn test_streaming_query_reset() {
477 let mut query: StreamingQuery<31> = StreamingQuery::new(31, 13, false);
478 query.num_searches = 10;
479 query.num_extensions = 5;
480
481 query.reset();
482
483 assert!(query.start);
484 assert_eq!(query.remaining_string_bases, 0);
485 }
486
487 #[test]
488 fn test_streaming_query_validation() {
489 let query: StreamingQuery<31> = StreamingQuery::new(31, 13, true);
490
491 assert!(query.is_valid_kmer_bytes(b"ACGTACGTACGTACGTACGTACGTACGTACG")); assert!(!query.is_valid_kmer_bytes(b"ACGT")); assert!(!query.is_valid_kmer_bytes(b"ACGTACGTACGTACGTACGTACGTACGTACGN")); assert!(query.is_valid_base(b'A'));
496 assert!(query.is_valid_base(b'a'));
497 assert!(!query.is_valid_base(b'N'));
498 }
499
500 #[test]
501 fn test_streaming_query_lookup_invalid() {
502 let mut query: StreamingQuery<15> = StreamingQuery::new(15, 7, true);
503
504 let result = query.lookup(b"ACGT");
506 assert!(!result.is_found());
507 assert_eq!(query.num_invalid_lookups(), 1);
508
509 query.reset();
511 let result = query.lookup(b"ACGTACGTACGTACN");
512 assert!(!result.is_found());
513 assert_eq!(query.num_invalid_lookups(), 2);
514 }
515
516 #[test]
517 fn test_streaming_query_incremental_update() {
518 let mut query: StreamingQuery<9> = StreamingQuery::new(9, 5, false);
519
520 let _result1 = query.lookup(b"ACGTACGTA");
522 assert!(!query.start); let _result2 = query.lookup(b"CGTACGTAC");
526
527 assert!(!query.start);
529 }
530}
531
532pub struct StreamingQueryEngine<'a, const K: usize>
537where
538 Kmer<K>: KmerBits,
539{
540 dict: &'a crate::dictionary::Dictionary,
541 query: StreamingQuery<K>,
542}
543
544impl<'a, const K: usize> StreamingQueryEngine<'a, K>
545where
546 Kmer<K>: KmerBits,
547{
548 pub fn new(dict: &'a crate::dictionary::Dictionary) -> Self {
550 let canonical = dict.canonical();
551 Self {
552 dict,
553 query: StreamingQuery::new(dict.k(), dict.m(), canonical),
554 }
555 }
556
557 pub fn reset(&mut self) {
559 self.query.reset();
560 }
561
562 pub fn lookup(&mut self, kmer_bytes: &[u8]) -> LookupResult {
564 self.query.lookup_with_dict(kmer_bytes, self.dict)
566 }
567
568 pub fn num_searches(&self) -> u64 {
570 self.query.num_searches()
571 }
572
573 pub fn num_extensions(&self) -> u64 {
575 self.query.num_extensions()
576 }
577
578 pub fn stats(&self) -> StreamingQueryStats {
580 StreamingQueryStats {
581 num_searches: self.query.num_searches(),
582 num_extensions: self.query.num_extensions(),
583 num_invalid: self.query.num_invalid_lookups(),
584 num_negative: self.query.num_negative_lookups(),
585 }
586 }
587}
588
589#[derive(Debug, Clone)]
591pub struct StreamingQueryStats {
592 pub num_searches: u64,
594 pub num_extensions: u64,
596 pub num_invalid: u64,
598 pub num_negative: u64,
600}