Skip to main content

sshash_lib/
streaming_query.rs

1//! Streaming query for efficient k-mer lookups
2//!
3//! This module implements streaming queries, which optimize lookup performance
4//! when querying consecutive k-mers (sliding window over a sequence).
5//!
6//! Key optimizations:
7//! - Incremental k-mer updates (drop first base, add last base)
8//! - Reuse minimizer state across adjacent k-mers
9//! - Extend within the same string when possible (avoiding MPHF lookups)
10//! - Skip searches when minimizer unchanged and previous lookup failed
11
12use crate::kmer::{Kmer, KmerBits};
13use crate::minimizer::{MinimizerInfo, MinimizerIterator};
14use crate::encoding::encode_base;
15
16/// Result of a k-mer lookup
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct LookupResult {
19    /// Absolute k-mer ID (global across all strings)
20    pub kmer_id: u64,
21    /// Relative k-mer ID within the string (0 <= kmer_id_in_string < string_size)
22    pub kmer_id_in_string: u64,
23    /// Bit offset into the string data
24    pub kmer_offset: u64,
25    /// Orientation: +1 for forward, -1 for reverse complement
26    pub kmer_orientation: i8,
27    
28    /// String ID containing this k-mer
29    pub string_id: u64,
30    /// Start position of the string (in bases)
31    pub string_begin: u64,
32    /// End position of the string (in bases)
33    pub string_end: u64,
34    
35    /// Whether the minimizer was found in the index
36    pub minimizer_found: bool,
37}
38
39impl LookupResult {
40    /// Create a new lookup result indicating "not found"
41    pub fn not_found() -> Self {
42        Self {
43            kmer_id: u64::MAX,
44            kmer_id_in_string: u64::MAX,
45            kmer_offset: u64::MAX,
46            kmer_orientation: 1, // Forward by default
47            string_id: u64::MAX,
48            string_begin: u64::MAX,
49            string_end: u64::MAX,
50            minimizer_found: true,
51        }
52    }
53
54    /// Check if this result represents a found k-mer
55    #[inline]
56    pub fn is_found(&self) -> bool {
57        self.kmer_id != u64::MAX
58    }
59
60    /// Get the string length
61    #[inline]
62    pub fn string_length(&self) -> u64 {
63        if self.is_found() {
64            self.string_end - self.string_begin
65        } else {
66            0
67        }
68    }
69}
70
71impl Default for LookupResult {
72    fn default() -> Self {
73        Self::not_found()
74    }
75}
76
77/// Streaming query engine for efficient consecutive k-mer lookups
78///
79/// This struct maintains state across multiple lookups to optimize
80/// queries for sliding windows over sequences.
81///
82/// # Example
83/// ```no_run
84/// use sshash_lib::streaming_query::StreamingQuery;
85/// // Assuming we have a dictionary...
86/// // let mut query = StreamingQuery::new(&dict, true); // canonical=true
87/// // 
88/// // Process consecutive k-mers efficiently
89/// // let result1 = query.lookup("ACGTACGTACGTACGTACGTACGTACGTACG");
90/// // let result2 = query.lookup("CGTACGTACGTACGTACGTACGTACGTACGT"); // Sliding by 1
91/// ```
92pub struct StreamingQuery<const K: usize>
93where
94    Kmer<K>: KmerBits,
95{
96    k: usize,
97    _m: usize, // Will be used in full Dictionary lookup
98    _canonical: bool, // Will be used in full Dictionary lookup
99    
100    // K-mer state
101    start: bool,
102    kmer: Option<Kmer<K>>,
103    kmer_rc: Option<Kmer<K>>,
104    
105    // Minimizer state
106    minimizer_it: MinimizerIterator,
107    minimizer_it_rc: MinimizerIterator,
108    curr_mini_info: MinimizerInfo,
109    prev_mini_info: MinimizerInfo,
110    curr_mini_info_rc: MinimizerInfo,
111    prev_mini_info_rc: MinimizerInfo,
112    
113    // String extension state
114    remaining_string_bases: u64,
115    
116    // Result state
117    result: LookupResult,
118    
119    // Performance counters
120    num_searches: u64,
121    num_extensions: u64,
122    num_invalid: u64,
123    num_negative: u64,
124}
125
126impl<const K: usize> StreamingQuery<K>
127where
128    Kmer<K>: KmerBits,
129{
130    /// Create a new streaming query engine
131    ///
132    /// # Arguments
133    /// * `k` - K-mer size
134    /// * `m` - Minimizer size
135    /// * `canonical` - Whether to use canonical k-mers (min of forward/RC)
136    pub fn new(k: usize, m: usize, canonical: bool) -> Self {
137        assert_eq!(k, K, "k parameter must match const generic K");
138        
139        let dummy_mini = MinimizerInfo::new(u64::MAX, 0, 0);
140        
141        Self {
142            k,
143            _m: m,
144            _canonical: canonical,
145            start: true,
146            kmer: None,
147            kmer_rc: None,
148            minimizer_it: MinimizerIterator::with_seed(k, m, 1),
149            minimizer_it_rc: MinimizerIterator::with_seed(k, m, 1),
150            curr_mini_info: dummy_mini,
151            prev_mini_info: dummy_mini,
152            curr_mini_info_rc: dummy_mini,
153            prev_mini_info_rc: dummy_mini,
154            remaining_string_bases: 0,
155            result: LookupResult::not_found(),
156            num_searches: 0,
157            num_extensions: 0,
158            num_invalid: 0,
159            num_negative: 0,
160        }
161    }
162
163    /// Reset the query state (call this when starting a new sequence)
164    pub fn reset(&mut self) {
165        self.start = true;
166        self.remaining_string_bases = 0;
167        self.result = LookupResult::not_found();
168        self.minimizer_it.set_position(0);
169        self.minimizer_it_rc.set_position(0);
170    }
171
172    /// Perform a streaming lookup for a k-mer
173    ///
174    /// This is the main entry point for queries. For optimal performance,
175    /// call this with consecutive k-mers (sliding by 1 base at a time).
176    ///
177    /// # Arguments
178    /// * `kmer_str` - DNA string of length K
179    ///
180    /// # Returns
181    /// A LookupResult indicating whether the k-mer was found and its location
182    pub fn lookup(&mut self, kmer_bytes: &[u8]) -> LookupResult {
183        // MVP version without Dictionary integration (always seeds)
184        self.lookup_internal(kmer_bytes, None)
185    }
186
187    /// Perform a streaming lookup with dictionary integration
188    ///
189    /// Internal method used by StreamingQueryEngine
190    pub(crate) fn lookup_with_dict(&mut self, kmer_bytes: &[u8], dict: &crate::dictionary::Dictionary) -> LookupResult {
191        self.lookup_internal(kmer_bytes, Some(dict))
192    }
193
194    fn lookup_internal(&mut self, kmer_bytes: &[u8], dict_opt: Option<&crate::dictionary::Dictionary>) -> LookupResult {
195        // 1. Validation
196        let is_valid = if self.start {
197            self.is_valid_kmer_bytes(kmer_bytes)
198        } else {
199            self.is_valid_base(kmer_bytes[self.k - 1])
200        };
201
202        if !is_valid {
203            self.num_invalid += 1;
204            self.reset();
205            return self.result.clone();
206        }
207
208        // 2. Compute k-mer and reverse complement, update minimizers
209        if self.start {
210            // First k-mer: parse from scratch using fast byte encoding
211            let km = Kmer::<K>::from_ascii_unchecked(kmer_bytes);
212            self.kmer = Some(km);
213            let rc = km.reverse_complement();
214            self.kmer_rc = Some(rc);
215
216            self.curr_mini_info = self.minimizer_it.next(km);
217            self.curr_mini_info_rc = self.minimizer_it_rc.next(rc);
218        } else {
219            // Update incrementally: drop first base, add new last base
220            if let Some(mut km) = self.kmer {
221                // Drop first base (shift left)
222                for i in 0..(self.k - 1) {
223                    let base = km.get_base(i + 1);
224                    km.set_base(i, base);
225                }
226
227                // Add new last base
228                let new_base = kmer_bytes[self.k - 1];
229                if let Ok(encoded) = encode_base(new_base) {
230                    km.set_base(self.k - 1, encoded);
231
232                    self.kmer = Some(km);
233
234                    // Update RC: pad (shift right), set first base to complement
235                    if let Some(mut km_rc) = self.kmer_rc {
236                        for i in (1..self.k).rev() {
237                            let base = km_rc.get_base(i - 1);
238                            km_rc.set_base(i, base);
239                        }
240                        
241                        // Complement of new base at position 0
242                        let complement = crate::encoding::complement_base(encoded);
243                        km_rc.set_base(0, complement);
244
245                        self.kmer_rc = Some(km_rc);
246
247                        self.curr_mini_info = self.minimizer_it.next(km);
248                        self.curr_mini_info_rc = self.minimizer_it_rc.next(km_rc);
249                    }
250                }
251            }
252        }
253
254        // 3. Compute result (either extend or search)
255        if self.remaining_string_bases == 0 {
256            self.seed(dict_opt);
257        } else {
258            // Try to extend within current string
259            if let Some(dict) = dict_opt {
260                self.try_extend(dict);
261            } else {
262                // No dictionary, can't extend
263                self.seed(dict_opt);
264            }
265        }
266
267        // 4. Update state
268        self.prev_mini_info = self.curr_mini_info;
269        self.prev_mini_info_rc = self.curr_mini_info_rc;
270        self.start = false;
271
272        self.result.clone()
273    }
274
275    /// Validate a full k-mer byte slice
276    fn is_valid_kmer_bytes(&self, bytes: &[u8]) -> bool {
277        if bytes.len() != self.k {
278            return false;
279        }
280        for &b in bytes {
281            if !matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't') {
282                return false;
283            }
284        }
285        true
286    }
287
288    /// Validate a single base
289    fn is_valid_base(&self, b: u8) -> bool {
290        matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't')
291    }
292
293    /// Perform a full search (seed) for the current k-mer
294    ///
295    /// This is called when we can't extend within the current string.
296    fn seed(&mut self, dict_opt: Option<&crate::dictionary::Dictionary>) {
297        self.remaining_string_bases = 0;
298
299        // Optimization: if minimizer unchanged and previous was not found, skip
300        if !self.start
301            && self.curr_mini_info.value == self.prev_mini_info.value
302            && self.curr_mini_info_rc.value == self.prev_mini_info_rc.value
303            && !self.result.minimizer_found
304        {
305            assert_eq!(self.result.kmer_id, u64::MAX);
306            self.num_negative += 1;
307            return;
308        }
309
310        if let (Some(dict), Some(kmer)) = (dict_opt, self.kmer) {
311            if self._canonical {
312                // Canonical mode: matching C++ lookup_canonical logic in seed.
313                //
314                // Use freshly extracted minimizer info for the lookup because
315                // the streaming MinimizerIterator's pos_in_kmer can be wrong
316                // for RC k-mers (it doesn't account for the reverse sliding
317                // direction, matching C++'s reverse template parameter).
318                // The streaming values are still correct for the negative
319                // optimization check above (which only uses .value).
320                let kmer_rc = kmer.reverse_complement();
321                let mini_fwd = dict.extract_minimizer::<K>(&kmer);
322                let mini_rc = dict.extract_minimizer::<K>(&kmer_rc);
323
324                if mini_fwd.value < mini_rc.value {
325                    self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_fwd);
326                } else if mini_rc.value < mini_fwd.value {
327                    self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_rc);
328                } else {
329                    self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_fwd);
330                    if self.result.kmer_id == u64::MAX {
331                        self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_rc);
332                    }
333                }
334            } else {
335                // Regular mode: try forward, then RC with backward orientation.
336                // Also use fresh minimizer extraction for correct pos_in_kmer.
337                let mini_fwd = dict.extract_minimizer::<K>(&kmer);
338                self.result = dict.lookup_regular_streaming::<K>(&kmer, mini_fwd);
339                let minimizer_found = self.result.minimizer_found;
340                if self.result.kmer_id == u64::MAX {
341                    assert_eq!(self.result.kmer_orientation, 1); // forward
342                    let kmer_rc = kmer.reverse_complement();
343                    let mini_rc = dict.extract_minimizer::<K>(&kmer_rc);
344                    self.result = dict.lookup_regular_streaming::<K>(&kmer_rc, mini_rc);
345                    self.result.kmer_orientation = -1; // backward
346                    let minimizer_rc_found = self.result.minimizer_found;
347                    self.result.minimizer_found = minimizer_rc_found || minimizer_found;
348                }
349            }
350
351            if self.result.kmer_id == u64::MAX {
352                self.num_negative += 1;
353                return;
354            }
355
356            assert!(self.result.minimizer_found);
357            self.num_searches += 1;
358
359            // Calculate remaining bases for extension, matching C++ exactly:
360            //   forward:  (string_end - string_begin - k) - kmer_id_in_string
361            //   backward: kmer_id_in_string
362            let string_size = self.result.string_end - self.result.string_begin;
363            if self.result.kmer_orientation > 0 {
364                self.remaining_string_bases =
365                    (string_size - self.k as u64) - self.result.kmer_id_in_string;
366            } else {
367                self.remaining_string_bases = self.result.kmer_id_in_string;
368            }
369        } else {
370            // No dictionary available
371            self.result = LookupResult::not_found();
372            self.num_negative += 1;
373        }
374    }
375    
376    /// Try to extend within the current string
377    ///
378    /// Matches C++ streaming_query extension logic:
379    /// - Read the expected next k-mer from the string data
380    /// - If it matches the current k-mer (or its RC), update result fields
381    fn try_extend(&mut self, dict: &crate::dictionary::Dictionary) {
382        if let (Some(kmer), Some(kmer_rc)) = (self.kmer, self.kmer_rc) {
383            // Compute the absolute position of the expected next k-mer
384            // C++: kmer_offset = 2 * (kmer_id + string_id * (k-1))
385            // The absolute base position in the concatenated strings
386            let abs_pos = self.result.kmer_id_in_string as usize
387                + self.result.string_begin as usize;
388
389            let next_abs_pos = if self.result.kmer_orientation > 0 {
390                abs_pos + 1
391            } else {
392                abs_pos.wrapping_sub(1)
393            };
394
395            // Read expected k-mer from string data at the next position
396            let expected_kmer: Kmer<K> = dict.spss().decode_kmer_at(next_abs_pos);
397
398            if expected_kmer.bits() == kmer.bits()
399                || expected_kmer.bits() == kmer_rc.bits()
400            {
401                // Successfully extended!
402                self.num_extensions += 1;
403                let delta = self.result.kmer_orientation as i64;
404                self.result.kmer_id = (self.result.kmer_id as i64 + delta) as u64;
405                self.result.kmer_id_in_string =
406                    (self.result.kmer_id_in_string as i64 + delta) as u64;
407                self.result.kmer_offset =
408                    (self.result.kmer_offset as i64 + delta) as u64;
409                self.remaining_string_bases -= 1;
410                return;
411            }
412        }
413        
414        // Extension failed, do a full search
415        self.seed(Some(dict));
416    }
417
418    /// Get the number of full searches performed
419    pub fn num_searches(&self) -> u64 {
420        self.num_searches
421    }
422
423    /// Get the number of extensions (no search needed)
424    pub fn num_extensions(&self) -> u64 {
425        self.num_extensions
426    }
427
428    /// Get the number of positive lookups (found)
429    pub fn num_positive_lookups(&self) -> u64 {
430        self.num_searches + self.num_extensions
431    }
432
433    /// Get the number of negative lookups (not found)
434    pub fn num_negative_lookups(&self) -> u64 {
435        self.num_negative
436    }
437
438    /// Get the number of invalid lookups (malformed input)
439    pub fn num_invalid_lookups(&self) -> u64 {
440        self.num_invalid
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_lookup_result_creation() {
450        let result = LookupResult::not_found();
451        assert!(!result.is_found());
452        assert_eq!(result.kmer_id, u64::MAX);
453    }
454
455    #[test]
456    fn test_lookup_result_string_length() {
457        let mut result = LookupResult::not_found();
458        result.string_begin = 100;
459        result.string_end = 200;
460        result.kmer_id = 42; // Mark as found
461        
462        assert_eq!(result.string_length(), 100);
463    }
464
465    #[test]
466    fn test_streaming_query_creation() {
467        let query: StreamingQuery<31> = StreamingQuery::new(31, 13, true);
468        assert_eq!(query.k, 31);
469        assert_eq!(query._m, 13);
470        assert!(query._canonical);
471        assert_eq!(query.num_searches(), 0);
472    }
473
474    #[test]
475    fn test_streaming_query_reset() {
476        let mut query: StreamingQuery<31> = StreamingQuery::new(31, 13, false);
477        query.num_searches = 10;
478        query.num_extensions = 5;
479        
480        query.reset();
481        
482        assert!(query.start);
483        assert_eq!(query.remaining_string_bases, 0);
484    }
485
486    #[test]
487    fn test_streaming_query_validation() {
488        let query: StreamingQuery<31> = StreamingQuery::new(31, 13, true);
489        
490        assert!(query.is_valid_kmer_bytes(b"ACGTACGTACGTACGTACGTACGTACGTACG")); // 31 bases
491        assert!(!query.is_valid_kmer_bytes(b"ACGT")); // Too short
492        assert!(!query.is_valid_kmer_bytes(b"ACGTACGTACGTACGTACGTACGTACGTACGN")); // Invalid base
493        
494        assert!(query.is_valid_base(b'A'));
495        assert!(query.is_valid_base(b'a'));
496        assert!(!query.is_valid_base(b'N'));
497    }
498
499    #[test]
500    fn test_streaming_query_lookup_invalid() {
501        let mut query: StreamingQuery<15> = StreamingQuery::new(15, 7, true);
502        
503        // Invalid: too short
504        let result = query.lookup(b"ACGT");
505        assert!(!result.is_found());
506        assert_eq!(query.num_invalid_lookups(), 1);
507
508        // Invalid: has 'N'
509        query.reset();
510        let result = query.lookup(b"ACGTACGTACGTACN");
511        assert!(!result.is_found());
512        assert_eq!(query.num_invalid_lookups(), 2);
513    }
514
515    #[test]
516    fn test_streaming_query_incremental_update() {
517        let mut query: StreamingQuery<9> = StreamingQuery::new(9, 5, false);
518
519        // First lookup
520        let _result1 = query.lookup(b"ACGTACGTA");
521        assert!(!query.start); // No longer in start state
522
523        // Second lookup (sliding by 1)
524        let _result2 = query.lookup(b"CGTACGTAC");
525        
526        // Even though lookups fail (no dictionary), state should update
527        assert!(!query.start);
528    }
529}
530
531/// Streaming query engine integrated with Dictionary
532///
533/// This provides the full streaming query functionality by connecting
534/// to a Dictionary instance for actual k-mer lookups.
535pub struct StreamingQueryEngine<'a, const K: usize>
536where
537    Kmer<K>: KmerBits,
538{
539    dict: &'a crate::dictionary::Dictionary,
540    query: StreamingQuery<K>,
541}
542
543impl<'a, const K: usize> StreamingQueryEngine<'a, K>
544where
545    Kmer<K>: KmerBits,
546{
547    /// Create a new streaming query engine for a dictionary
548    pub fn new(dict: &'a crate::dictionary::Dictionary) -> Self {
549        let canonical = dict.canonical();
550        Self {
551            dict,
552            query: StreamingQuery::new(dict.k(), dict.m(), canonical),
553        }
554    }
555    
556    /// Reset the query state
557    pub fn reset(&mut self) {
558        self.query.reset();
559    }
560    
561    /// Perform a streaming lookup
562    pub fn lookup(&mut self, kmer_bytes: &[u8]) -> LookupResult {
563        // Perform streaming lookup with dictionary integration
564        self.query.lookup_with_dict(kmer_bytes, self.dict)
565    }
566    
567    /// Get the number of full searches performed
568    pub fn num_searches(&self) -> u64 {
569        self.query.num_searches()
570    }
571    
572    /// Get the number of extensions (no search needed)
573    pub fn num_extensions(&self) -> u64 {
574        self.query.num_extensions()
575    }
576    
577    /// Get statistics
578    pub fn stats(&self) -> StreamingQueryStats {
579        StreamingQueryStats {
580            num_searches: self.query.num_searches(),
581            num_extensions: self.query.num_extensions(),
582            num_invalid: self.query.num_invalid_lookups(),
583            num_negative: self.query.num_negative_lookups(),
584        }
585    }
586}
587
588/// Statistics from streaming queries
589#[derive(Debug, Clone)]
590pub struct StreamingQueryStats {
591    /// Number of full MPHF lookups performed
592    pub num_searches: u64,
593    /// Number of k-mers resolved by extending from a previous result
594    pub num_extensions: u64,
595    /// Number of lookups that failed validation (hash collision)
596    pub num_invalid: u64,
597    /// Number of k-mers not found in the dictionary
598    pub num_negative: u64,
599}