Skip to main content

sshash_lib/
streaming_query.rs

1//! Streaming query for efficient k-mer lookups
2//!
3//! This module implements streaming queries, which optimize lookup performance
4//! when querying consecutive k-mers (sliding window over a sequence).
5//!
6//! Key optimizations:
7//! - Incremental k-mer updates (drop first base, add last base)
8//! - Reuse minimizer state across adjacent k-mers
9//! - Extend within the same string when possible (avoiding MPHF lookups)
10//! - Skip searches when minimizer unchanged and previous lookup failed
11
12use crate::kmer::{Kmer, KmerBits};
13use crate::minimizer::{MinimizerInfo, MinimizerIterator};
14use crate::encoding::encode_base;
15
16/// Result of a k-mer lookup
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct LookupResult {
19    /// Absolute k-mer ID (global across all strings)
20    pub kmer_id: u64,
21    /// Relative k-mer ID within the string (0 <= kmer_id_in_string < string_size)
22    pub kmer_id_in_string: u64,
23    /// Bit offset into the string data
24    pub kmer_offset: u64,
25    /// Orientation: +1 for forward, -1 for reverse complement
26    pub kmer_orientation: i8,
27    
28    /// String ID containing this k-mer
29    pub string_id: u64,
30    /// Start position of the string (in bases)
31    pub string_begin: u64,
32    /// End position of the string (in bases)
33    pub string_end: u64,
34    
35    /// Whether the minimizer was found in the index
36    pub minimizer_found: bool,
37}
38
39impl LookupResult {
40    /// Create a new lookup result indicating "not found"
41    pub fn not_found() -> Self {
42        Self {
43            kmer_id: u64::MAX,
44            kmer_id_in_string: u64::MAX,
45            kmer_offset: u64::MAX,
46            kmer_orientation: 1, // Forward by default
47            string_id: u64::MAX,
48            string_begin: u64::MAX,
49            string_end: u64::MAX,
50            minimizer_found: true,
51        }
52    }
53
54    /// Check if this result represents a found k-mer
55    #[inline]
56    pub fn is_found(&self) -> bool {
57        self.kmer_id != u64::MAX
58    }
59
60    /// Get the string length
61    #[inline]
62    pub fn string_length(&self) -> u64 {
63        if self.is_found() {
64            self.string_end - self.string_begin
65        } else {
66            0
67        }
68    }
69}
70
71impl Default for LookupResult {
72    fn default() -> Self {
73        Self::not_found()
74    }
75}
76
77/// Streaming query engine for efficient consecutive k-mer lookups
78///
79/// This struct maintains state across multiple lookups to optimize
80/// queries for sliding windows over sequences.
81///
82/// # Example
83/// ```no_run
84/// use sshash_lib::streaming_query::StreamingQuery;
85/// // Assuming we have a dictionary...
86/// // let mut query = StreamingQuery::new(&dict, true); // canonical=true
87/// // 
88/// // Process consecutive k-mers efficiently
89/// // let result1 = query.lookup("ACGTACGTACGTACGTACGTACGTACGTACG");
90/// // let result2 = query.lookup("CGTACGTACGTACGTACGTACGTACGTACGT"); // Sliding by 1
91/// ```
92pub struct StreamingQuery<const K: usize>
93where
94    Kmer<K>: KmerBits,
95{
96    k: usize,
97    _m: usize, // Will be used in full Dictionary lookup
98    _canonical: bool, // Will be used in full Dictionary lookup
99    
100    // K-mer state
101    start: bool,
102    kmer: Option<Kmer<K>>,
103    kmer_rc: Option<Kmer<K>>,
104    
105    // Minimizer state
106    minimizer_it: MinimizerIterator,
107    minimizer_it_rc: MinimizerIterator,
108    curr_mini_info: MinimizerInfo,
109    prev_mini_info: MinimizerInfo,
110    curr_mini_info_rc: MinimizerInfo,
111    prev_mini_info_rc: MinimizerInfo,
112    
113    // String extension state
114    remaining_string_bases: u64,
115    
116    // Result state
117    result: LookupResult,
118    
119    // Performance counters
120    num_searches: u64,
121    num_extensions: u64,
122    num_invalid: u64,
123    num_negative: u64,
124}
125
126impl<const K: usize> StreamingQuery<K>
127where
128    Kmer<K>: KmerBits,
129{
130    /// Create a new streaming query engine
131    ///
132    /// # Arguments
133    /// * `k` - K-mer size
134    /// * `m` - Minimizer size
135    /// * `canonical` - Whether to use canonical k-mers (min of forward/RC)
136    pub fn new(k: usize, m: usize, canonical: bool) -> Self {
137        assert_eq!(k, K, "k parameter must match const generic K");
138        
139        let dummy_mini = MinimizerInfo::new(u64::MAX, 0, 0);
140        
141        Self {
142            k,
143            _m: m,
144            _canonical: canonical,
145            start: true,
146            kmer: None,
147            kmer_rc: None,
148            minimizer_it: MinimizerIterator::with_seed(k, m, 1),
149            minimizer_it_rc: MinimizerIterator::with_seed(k, m, 1),
150            curr_mini_info: dummy_mini,
151            prev_mini_info: dummy_mini,
152            curr_mini_info_rc: dummy_mini,
153            prev_mini_info_rc: dummy_mini,
154            remaining_string_bases: 0,
155            result: LookupResult::not_found(),
156            num_searches: 0,
157            num_extensions: 0,
158            num_invalid: 0,
159            num_negative: 0,
160        }
161    }
162
163    /// Reset the query state (call this when starting a new sequence)
164    pub fn reset(&mut self) {
165        self.start = true;
166        self.remaining_string_bases = 0;
167        self.result = LookupResult::not_found();
168        self.minimizer_it.set_position(0);
169        self.minimizer_it_rc.set_position(0);
170    }
171
172    /// Perform a streaming lookup for a k-mer
173    ///
174    /// This is the main entry point for queries. For optimal performance,
175    /// call this with consecutive k-mers (sliding by 1 base at a time).
176    ///
177    /// # Arguments
178    /// * `kmer_str` - DNA string of length K
179    ///
180    /// # Returns
181    /// A LookupResult indicating whether the k-mer was found and its location
182    pub fn lookup(&mut self, kmer_bytes: &[u8]) -> LookupResult {
183        // MVP version without Dictionary integration (always seeds)
184        self.lookup_internal(kmer_bytes, None)
185    }
186
187    /// Perform a streaming lookup with dictionary integration.
188    ///
189    /// Accepts a `&Dictionary` at call time rather than storing a reference,
190    /// so callers can manage the dictionary's lifetime independently (e.g. via `Arc`).
191    pub fn lookup_with_dict(&mut self, kmer_bytes: &[u8], dict: &crate::dictionary::Dictionary) -> LookupResult {
192        self.lookup_internal(kmer_bytes, Some(dict))
193    }
194
195    fn lookup_internal(&mut self, kmer_bytes: &[u8], dict_opt: Option<&crate::dictionary::Dictionary>) -> LookupResult {
196        // 1. Validation
197        let is_valid = if self.start {
198            self.is_valid_kmer_bytes(kmer_bytes)
199        } else {
200            self.is_valid_base(kmer_bytes[self.k - 1])
201        };
202
203        if !is_valid {
204            self.num_invalid += 1;
205            self.reset();
206            return self.result.clone();
207        }
208
209        // 2. Compute k-mer and reverse complement, update minimizers
210        if self.start {
211            // First k-mer: parse from scratch using fast byte encoding
212            let km = Kmer::<K>::from_ascii_unchecked(kmer_bytes);
213            self.kmer = Some(km);
214            let rc = km.reverse_complement();
215            self.kmer_rc = Some(rc);
216
217            self.curr_mini_info = self.minimizer_it.next(km);
218            self.curr_mini_info_rc = self.minimizer_it_rc.next(rc);
219        } else {
220            // Update incrementally: drop first base, add new last base
221            if let Some(mut km) = self.kmer {
222                // Drop first base (shift left)
223                for i in 0..(self.k - 1) {
224                    let base = km.get_base(i + 1);
225                    km.set_base(i, base);
226                }
227
228                // Add new last base
229                let new_base = kmer_bytes[self.k - 1];
230                if let Ok(encoded) = encode_base(new_base) {
231                    km.set_base(self.k - 1, encoded);
232
233                    self.kmer = Some(km);
234
235                    // Update RC: pad (shift right), set first base to complement
236                    if let Some(mut km_rc) = self.kmer_rc {
237                        for i in (1..self.k).rev() {
238                            let base = km_rc.get_base(i - 1);
239                            km_rc.set_base(i, base);
240                        }
241                        
242                        // Complement of new base at position 0
243                        let complement = crate::encoding::complement_base(encoded);
244                        km_rc.set_base(0, complement);
245
246                        self.kmer_rc = Some(km_rc);
247
248                        self.curr_mini_info = self.minimizer_it.next(km);
249                        self.curr_mini_info_rc = self.minimizer_it_rc.next(km_rc);
250                    }
251                }
252            }
253        }
254
255        // 3. Compute result (either extend or search)
256        if self.remaining_string_bases == 0 {
257            self.seed(dict_opt);
258        } else {
259            // Try to extend within current string
260            if let Some(dict) = dict_opt {
261                self.try_extend(dict);
262            } else {
263                // No dictionary, can't extend
264                self.seed(dict_opt);
265            }
266        }
267
268        // 4. Update state
269        self.prev_mini_info = self.curr_mini_info;
270        self.prev_mini_info_rc = self.curr_mini_info_rc;
271        self.start = false;
272
273        self.result.clone()
274    }
275
276    /// Validate a full k-mer byte slice
277    fn is_valid_kmer_bytes(&self, bytes: &[u8]) -> bool {
278        if bytes.len() != self.k {
279            return false;
280        }
281        for &b in bytes {
282            if !matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't') {
283                return false;
284            }
285        }
286        true
287    }
288
289    /// Validate a single base
290    fn is_valid_base(&self, b: u8) -> bool {
291        matches!(b, b'A' | b'C' | b'G' | b'T' | b'a' | b'c' | b'g' | b't')
292    }
293
294    /// Perform a full search (seed) for the current k-mer
295    ///
296    /// This is called when we can't extend within the current string.
297    fn seed(&mut self, dict_opt: Option<&crate::dictionary::Dictionary>) {
298        self.remaining_string_bases = 0;
299
300        // Optimization: if minimizer unchanged and previous was not found, skip
301        if !self.start
302            && self.curr_mini_info.value == self.prev_mini_info.value
303            && self.curr_mini_info_rc.value == self.prev_mini_info_rc.value
304            && !self.result.minimizer_found
305        {
306            assert_eq!(self.result.kmer_id, u64::MAX);
307            self.num_negative += 1;
308            return;
309        }
310
311        if let (Some(dict), Some(kmer)) = (dict_opt, self.kmer) {
312            if self._canonical {
313                // Canonical mode: matching C++ lookup_canonical logic in seed.
314                //
315                // Use freshly extracted minimizer info for the lookup because
316                // the streaming MinimizerIterator's pos_in_kmer can be wrong
317                // for RC k-mers (it doesn't account for the reverse sliding
318                // direction, matching C++'s reverse template parameter).
319                // The streaming values are still correct for the negative
320                // optimization check above (which only uses .value).
321                let kmer_rc = kmer.reverse_complement();
322                let mini_fwd = dict.extract_minimizer::<K>(&kmer);
323                let mini_rc = dict.extract_minimizer::<K>(&kmer_rc);
324
325                if mini_fwd.value < mini_rc.value {
326                    self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_fwd);
327                } else if mini_rc.value < mini_fwd.value {
328                    self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_rc);
329                } else {
330                    self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_fwd);
331                    if self.result.kmer_id == u64::MAX {
332                        self.result = dict.lookup_canonical_streaming::<K>(&kmer, &kmer_rc, mini_rc);
333                    }
334                }
335            } else {
336                // Regular mode: try forward, then RC with backward orientation.
337                // Also use fresh minimizer extraction for correct pos_in_kmer.
338                let mini_fwd = dict.extract_minimizer::<K>(&kmer);
339                self.result = dict.lookup_regular_streaming::<K>(&kmer, mini_fwd);
340                let minimizer_found = self.result.minimizer_found;
341                if self.result.kmer_id == u64::MAX {
342                    assert_eq!(self.result.kmer_orientation, 1); // forward
343                    let kmer_rc = kmer.reverse_complement();
344                    let mini_rc = dict.extract_minimizer::<K>(&kmer_rc);
345                    self.result = dict.lookup_regular_streaming::<K>(&kmer_rc, mini_rc);
346                    self.result.kmer_orientation = -1; // backward
347                    let minimizer_rc_found = self.result.minimizer_found;
348                    self.result.minimizer_found = minimizer_rc_found || minimizer_found;
349                }
350            }
351
352            if self.result.kmer_id == u64::MAX {
353                self.num_negative += 1;
354                return;
355            }
356
357            assert!(self.result.minimizer_found);
358            self.num_searches += 1;
359
360            // Calculate remaining bases for extension, matching C++ exactly:
361            //   forward:  (string_end - string_begin - k) - kmer_id_in_string
362            //   backward: kmer_id_in_string
363            let string_size = self.result.string_end - self.result.string_begin;
364            if self.result.kmer_orientation > 0 {
365                self.remaining_string_bases =
366                    (string_size - self.k as u64) - self.result.kmer_id_in_string;
367            } else {
368                self.remaining_string_bases = self.result.kmer_id_in_string;
369            }
370        } else {
371            // No dictionary available
372            self.result = LookupResult::not_found();
373            self.num_negative += 1;
374        }
375    }
376    
377    /// Try to extend within the current string
378    ///
379    /// Matches C++ streaming_query extension logic:
380    /// - Read the expected next k-mer from the string data
381    /// - If it matches the current k-mer (or its RC), update result fields
382    fn try_extend(&mut self, dict: &crate::dictionary::Dictionary) {
383        if let (Some(kmer), Some(kmer_rc)) = (self.kmer, self.kmer_rc) {
384            // Compute the absolute position of the expected next k-mer
385            // C++: kmer_offset = 2 * (kmer_id + string_id * (k-1))
386            // The absolute base position in the concatenated strings
387            let abs_pos = self.result.kmer_id_in_string as usize
388                + self.result.string_begin as usize;
389
390            let next_abs_pos = if self.result.kmer_orientation > 0 {
391                abs_pos + 1
392            } else {
393                abs_pos.wrapping_sub(1)
394            };
395
396            // Read expected k-mer from string data at the next position
397            let expected_kmer: Kmer<K> = dict.spss().decode_kmer_at(next_abs_pos);
398
399            if expected_kmer.bits() == kmer.bits()
400                || expected_kmer.bits() == kmer_rc.bits()
401            {
402                // Successfully extended!
403                self.num_extensions += 1;
404                let delta = self.result.kmer_orientation as i64;
405                self.result.kmer_id = (self.result.kmer_id as i64 + delta) as u64;
406                self.result.kmer_id_in_string =
407                    (self.result.kmer_id_in_string as i64 + delta) as u64;
408                self.result.kmer_offset =
409                    (self.result.kmer_offset as i64 + delta) as u64;
410                self.remaining_string_bases -= 1;
411                return;
412            }
413        }
414        
415        // Extension failed, do a full search
416        self.seed(Some(dict));
417    }
418
419    /// Get the number of full searches performed
420    pub fn num_searches(&self) -> u64 {
421        self.num_searches
422    }
423
424    /// Get the number of extensions (no search needed)
425    pub fn num_extensions(&self) -> u64 {
426        self.num_extensions
427    }
428
429    /// Get the number of positive lookups (found)
430    pub fn num_positive_lookups(&self) -> u64 {
431        self.num_searches + self.num_extensions
432    }
433
434    /// Get the number of negative lookups (not found)
435    pub fn num_negative_lookups(&self) -> u64 {
436        self.num_negative
437    }
438
439    /// Get the number of invalid lookups (malformed input)
440    pub fn num_invalid_lookups(&self) -> u64 {
441        self.num_invalid
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_lookup_result_creation() {
451        let result = LookupResult::not_found();
452        assert!(!result.is_found());
453        assert_eq!(result.kmer_id, u64::MAX);
454    }
455
456    #[test]
457    fn test_lookup_result_string_length() {
458        let mut result = LookupResult::not_found();
459        result.string_begin = 100;
460        result.string_end = 200;
461        result.kmer_id = 42; // Mark as found
462        
463        assert_eq!(result.string_length(), 100);
464    }
465
466    #[test]
467    fn test_streaming_query_creation() {
468        let query: StreamingQuery<31> = StreamingQuery::new(31, 13, true);
469        assert_eq!(query.k, 31);
470        assert_eq!(query._m, 13);
471        assert!(query._canonical);
472        assert_eq!(query.num_searches(), 0);
473    }
474
475    #[test]
476    fn test_streaming_query_reset() {
477        let mut query: StreamingQuery<31> = StreamingQuery::new(31, 13, false);
478        query.num_searches = 10;
479        query.num_extensions = 5;
480        
481        query.reset();
482        
483        assert!(query.start);
484        assert_eq!(query.remaining_string_bases, 0);
485    }
486
487    #[test]
488    fn test_streaming_query_validation() {
489        let query: StreamingQuery<31> = StreamingQuery::new(31, 13, true);
490        
491        assert!(query.is_valid_kmer_bytes(b"ACGTACGTACGTACGTACGTACGTACGTACG")); // 31 bases
492        assert!(!query.is_valid_kmer_bytes(b"ACGT")); // Too short
493        assert!(!query.is_valid_kmer_bytes(b"ACGTACGTACGTACGTACGTACGTACGTACGN")); // Invalid base
494        
495        assert!(query.is_valid_base(b'A'));
496        assert!(query.is_valid_base(b'a'));
497        assert!(!query.is_valid_base(b'N'));
498    }
499
500    #[test]
501    fn test_streaming_query_lookup_invalid() {
502        let mut query: StreamingQuery<15> = StreamingQuery::new(15, 7, true);
503        
504        // Invalid: too short
505        let result = query.lookup(b"ACGT");
506        assert!(!result.is_found());
507        assert_eq!(query.num_invalid_lookups(), 1);
508
509        // Invalid: has 'N'
510        query.reset();
511        let result = query.lookup(b"ACGTACGTACGTACN");
512        assert!(!result.is_found());
513        assert_eq!(query.num_invalid_lookups(), 2);
514    }
515
516    #[test]
517    fn test_streaming_query_incremental_update() {
518        let mut query: StreamingQuery<9> = StreamingQuery::new(9, 5, false);
519
520        // First lookup
521        let _result1 = query.lookup(b"ACGTACGTA");
522        assert!(!query.start); // No longer in start state
523
524        // Second lookup (sliding by 1)
525        let _result2 = query.lookup(b"CGTACGTAC");
526        
527        // Even though lookups fail (no dictionary), state should update
528        assert!(!query.start);
529    }
530}
531
532/// Streaming query engine integrated with Dictionary
533///
534/// This provides the full streaming query functionality by connecting
535/// to a Dictionary instance for actual k-mer lookups.
536pub struct StreamingQueryEngine<'a, const K: usize>
537where
538    Kmer<K>: KmerBits,
539{
540    dict: &'a crate::dictionary::Dictionary,
541    query: StreamingQuery<K>,
542}
543
544impl<'a, const K: usize> StreamingQueryEngine<'a, K>
545where
546    Kmer<K>: KmerBits,
547{
548    /// Create a new streaming query engine for a dictionary
549    pub fn new(dict: &'a crate::dictionary::Dictionary) -> Self {
550        let canonical = dict.canonical();
551        Self {
552            dict,
553            query: StreamingQuery::new(dict.k(), dict.m(), canonical),
554        }
555    }
556    
557    /// Reset the query state
558    pub fn reset(&mut self) {
559        self.query.reset();
560    }
561    
562    /// Perform a streaming lookup
563    pub fn lookup(&mut self, kmer_bytes: &[u8]) -> LookupResult {
564        // Perform streaming lookup with dictionary integration
565        self.query.lookup_with_dict(kmer_bytes, self.dict)
566    }
567    
568    /// Get the number of full searches performed
569    pub fn num_searches(&self) -> u64 {
570        self.query.num_searches()
571    }
572    
573    /// Get the number of extensions (no search needed)
574    pub fn num_extensions(&self) -> u64 {
575        self.query.num_extensions()
576    }
577    
578    /// Get statistics
579    pub fn stats(&self) -> StreamingQueryStats {
580        StreamingQueryStats {
581            num_searches: self.query.num_searches(),
582            num_extensions: self.query.num_extensions(),
583            num_invalid: self.query.num_invalid_lookups(),
584            num_negative: self.query.num_negative_lookups(),
585        }
586    }
587}
588
589/// Statistics from streaming queries
590#[derive(Debug, Clone)]
591pub struct StreamingQueryStats {
592    /// Number of full MPHF lookups performed
593    pub num_searches: u64,
594    /// Number of k-mers resolved by extending from a previous result
595    pub num_extensions: u64,
596    /// Number of lookups that failed validation (hash collision)
597    pub num_invalid: u64,
598    /// Number of k-mers not found in the dictionary
599    pub num_negative: u64,
600}