Skip to main content

tdb_succinct/
wavelettree.rs

1//! A succinct data structure for quick lookup of entry positions in a sequence.
2
3use bitvec::vec::BitVec;
4use futures::Stream;
5use futures::TryStreamExt;
6
7use crate::storage::{FileLoad, FileStore};
8
9use super::bitarray::*;
10use super::bitindex::*;
11use super::logarray::*;
12use super::util;
13
14use std::convert::TryInto;
15use std::io;
16
17/// A wavelet tree, encoding a u64 array for fast lookup of number positions.
18///
19/// A wavelet tree consists of a layer of bitarrays (stored as one big
20/// bitarray). The amount of layers is the log2 of the alphabet size,
21/// rounded up to make it an integer. Since we're encoding u64 values,
22/// the number of layers can never be larger than 64.
23#[derive(Clone)]
24pub struct WaveletTree {
25    bits: BitIndex,
26    num_layers: u8,
27}
28
29/// A lookup for all positions of a particular entry.
30///
31/// This struct caches part of the calculation required to get
32/// positions out of a wavelet tree, allowing for quick iteration over
33/// all positions for a given entry.
34#[derive(Clone)]
35pub struct WaveletLookup {
36    /// the entry this lookup was created for.
37    pub entry: u64,
38    tree: WaveletTree,
39    slices: Vec<(bool, u64, u64)>,
40}
41
42impl WaveletLookup {
43    /// Returns the amount of positions found in this lookup.
44    pub fn len(&self) -> usize {
45        let (b, start, end) = *self.slices.last().unwrap();
46
47        if b {
48            self.tree.bits.rank1_from_range(start, end) as usize
49        } else {
50            self.tree.bits.rank0_from_range(start, end) as usize
51        }
52    }
53
54    /// Returns the position of the index'th entry of this lookup
55    pub fn entry(&self, index: usize) -> u64 {
56        if index >= self.len() {
57            panic!("entry is out of bounds");
58        }
59
60        let mut result = (index + 1) as u64;
61        for &(b, start_index, end_index) in self.slices.iter().rev() {
62            if b {
63                result = self
64                    .tree
65                    .bits
66                    .select1_from_range(result, start_index, end_index)
67                    .unwrap()
68                    - start_index
69                    + 1;
70            } else {
71                result = self
72                    .tree
73                    .bits
74                    .select0_from_range(result, start_index, end_index)
75                    .unwrap()
76                    - start_index
77                    + 1;
78            }
79        }
80
81        result - 1
82    }
83
84    /// Returns an Iterator over all positions for the entry of this lookup
85    pub fn iter(&self) -> impl Iterator<Item = u64> {
86        let cloned = self.clone();
87        (0..self.len()).map(move |i| cloned.entry(i))
88    }
89}
90
91impl WaveletTree {
92    /// Construct a wavelet tree from a bitindex and a layer count.
93    pub fn from_parts(bits: BitIndex, num_layers: u8) -> WaveletTree {
94        if num_layers != 0 && bits.len() % num_layers as usize != 0 {
95            panic!("the bitarray length is not a multiple of the number of layers");
96        }
97
98        WaveletTree { bits, num_layers }
99    }
100
101    /// Returns the length of the encoded array.
102    pub fn len(&self) -> usize {
103        if self.num_layers == 0 {
104            0
105        } else {
106            self.bits.len() / self.num_layers as usize
107        }
108    }
109
110    /// Returns the amount of layers.
111    pub fn num_layers(&self) -> usize {
112        self.num_layers as usize
113    }
114
115    /// Decode the wavelet tree to the original u64 sequence. This returns an iterator.
116    pub fn decode(&self) -> impl Iterator<Item = u64> {
117        let owned = self.clone();
118        (0..self.len()).map(move |i| owned.decode_one(i))
119    }
120
121    /// Decode a single position of the original u64 sequence.
122    pub fn decode_one(&self, index: usize) -> u64 {
123        let len = self.len() as u64;
124        let mut offset = index as u64;
125        let mut alphabet_start = 0;
126        let mut alphabet_end = 2_u64.pow(self.num_layers as u32) as u64;
127        let mut range_start = 0;
128        let mut range_end = len;
129        for i in 0..self.num_layers as u64 {
130            let index = i * len + range_start + offset;
131            if index as usize >= self.bits.len() {
132                panic!("inner loop reached an index that is too high");
133            }
134            let bit = self.bits.get(index);
135
136            let range_start_index = i * len + range_start;
137            let range_end_index = i * len + range_end;
138            if bit {
139                alphabet_start = (alphabet_start + alphabet_end) / 2;
140                offset = self.bits.rank1_from_range(range_start_index, index + 1) - 1;
141
142                let zeros_in_range = self
143                    .bits
144                    .rank0_from_range(range_start_index, range_end_index);
145                range_start += zeros_in_range;
146            } else {
147                alphabet_end = (alphabet_start + alphabet_end) / 2;
148                offset = self.bits.rank0_from_range(range_start_index, index + 1) - 1;
149
150                let ones_in_range = self
151                    .bits
152                    .rank1_from_range(range_start_index, range_end_index);
153                range_end -= ones_in_range;
154            }
155        }
156
157        assert!(alphabet_start == alphabet_end - 1);
158
159        alphabet_start
160    }
161
162    /// Lookup the given entry. This returns a `WaveletLookup` which can then be used to find all positions.
163    pub fn lookup(&self, entry: u64) -> Option<WaveletLookup> {
164        if self.num_layers == 0 {
165            // without any layers, there's not going to be any elements
166            return None;
167        }
168
169        let width = self.len() as u64;
170        let mut slices = Vec::with_capacity(self.num_layers as usize);
171        let mut alphabet_start = 0;
172        let mut alphabet_end = 2_u64.pow(self.num_layers as u32) as u64;
173
174        if entry >= alphabet_end {
175            return None;
176        }
177
178        let mut start_index = 0_u64;
179        let mut end_index = self.len() as u64;
180        for i in 0..self.num_layers {
181            let full_start_index = (i as u64) * width + start_index;
182            let full_end_index = (i as u64) * width + end_index;
183            let b = entry >= (alphabet_start + alphabet_end) / 2;
184            slices.push((b, full_start_index, full_end_index));
185            if b {
186                alphabet_start += 2_u64.pow((self.num_layers - i - 1) as u32);
187                start_index += self.bits.rank0_from_range(full_start_index, full_end_index);
188            } else {
189                alphabet_end -= 2_u64.pow((self.num_layers - i - 1) as u32);
190                end_index -= self.bits.rank1_from_range(full_start_index, full_end_index);
191            }
192
193            if start_index == end_index {
194                return None;
195            }
196        }
197
198        Some(WaveletLookup {
199            entry,
200            slices,
201            tree: self.clone(),
202        })
203    }
204
205    /// Lookup the given entry. This returns a single result, even if there's multiple.
206    pub fn lookup_one(&self, entry: u64) -> Option<u64> {
207        self.lookup(entry).map(|l| l.entry(0))
208    }
209}
210
211#[derive(Debug)]
212struct FragmentBuilder {
213    fragment_start: u64,
214    fragment_half: u64,
215    fragment_end: u64,
216    bits: BitVec,
217}
218
219impl FragmentBuilder {
220    fn new(fragment_start: u64, fragment_end: u64) -> Self {
221        let fragment_half = (fragment_start + fragment_end) / 2;
222
223        Self {
224            fragment_start,
225            fragment_half,
226            fragment_end,
227            bits: BitVec::new(),
228        }
229    }
230
231    fn push(&mut self, num: u64) {
232        if num < self.fragment_start || num >= self.fragment_end {
233            // this number doesn't fit in this fragment so ignore
234            return;
235        }
236
237        self.bits.push(num >= self.fragment_half);
238    }
239}
240
241impl IntoIterator for FragmentBuilder {
242    type Item = bool;
243    type IntoIter = bitvec::boxed::IntoIter<usize>;
244
245    fn into_iter(self) -> Self::IntoIter {
246        self.bits.into_iter()
247    }
248}
249
250fn create_fragments(width: u8) -> Vec<FragmentBuilder> {
251    let upper = 2_u64.pow(width as u32);
252
253    let len = 2_usize.pow(width as u32) - 1;
254    let mut result = Vec::with_capacity(len);
255
256    for i in 0..width {
257        let increment = upper >> i;
258        let num = 2_u64.pow(i as u32);
259        for j in 0..num {
260            result.push(FragmentBuilder::new(j * increment, (j + 1) * increment));
261        }
262    }
263
264    result
265}
266
267fn push_to_fragments(num: u64, width: u8, fragments: &mut Vec<FragmentBuilder>) {
268    let mut num_it: usize = num.try_into().unwrap(); // this will ensure that we get some sort of error on 32 bit for large nums
269    for i in 0..width {
270        num_it >>= 1;
271        let index = num_it + 2_usize.pow((width - i - 1) as u32) - 1;
272        fragments[index].push(num);
273    }
274}
275
276/// Build a wavelet tree from a stream
277pub async fn build_wavelet_tree_from_stream<
278    S: Stream<Item = io::Result<u64>> + Unpin,
279    F: 'static + FileLoad + FileStore,
280>(
281    width: u8,
282    mut source: S,
283    destination_bits: F,
284    destination_blocks: F,
285    destination_sblocks: F,
286) -> io::Result<()> {
287    let mut bits = BitArrayFileBuilder::new(destination_bits.open_write().await?);
288    let mut fragments = create_fragments(width);
289
290    while let Some(num) = source.try_next().await? {
291        push_to_fragments(num, width, &mut fragments);
292    }
293
294    let iter = fragments.into_iter().flat_map(|f| f.into_iter());
295
296    bits.push_all(util::stream_iter_ok(iter)).await?;
297    bits.finalize().await?;
298
299    build_bitindex(
300        destination_bits.open_read().await?,
301        destination_blocks.open_write().await?,
302        destination_sblocks.open_write().await?,
303    )
304    .await?;
305
306    Ok(())
307}
308
309/// Build a wavelet tree from an iterator
310pub async fn build_wavelet_tree_from_iter<
311    I: Iterator<Item = u64>,
312    F: 'static + FileLoad + FileStore,
313>(
314    width: u8,
315    source: I,
316    destination_bits: F,
317    destination_blocks: F,
318    destination_sblocks: F,
319) -> io::Result<()> {
320    let mut bits = BitArrayFileBuilder::new(destination_bits.open_write().await?);
321    let mut fragments = create_fragments(width);
322
323    for num in source {
324        push_to_fragments(num, width, &mut fragments);
325    }
326
327    let iter = fragments.into_iter().flat_map(|f| f.into_iter());
328
329    bits.push_all(util::stream_iter_ok(iter)).await?;
330    bits.finalize().await?;
331
332    build_bitindex(
333        destination_bits.open_read().await?,
334        destination_blocks.open_write().await?,
335        destination_sblocks.open_write().await?,
336    )
337    .await?;
338
339    Ok(())
340}
341
342/// Build a wavelet tree from a file storing a logarray.
343pub async fn build_wavelet_tree_from_logarray<
344    FLoad: 'static + FileLoad,
345    F: 'static + FileLoad + FileStore,
346>(
347    source: FLoad,
348    destination_bits: F,
349    destination_blocks: F,
350    destination_sblocks: F,
351) -> io::Result<()> {
352    let (_, width) = logarray_file_get_length_and_width(source.clone()).await?;
353    let stream = logarray_stream_entries(source).await?;
354
355    build_wavelet_tree_from_stream(
356        width,
357        stream,
358        destination_bits,
359        destination_blocks,
360        destination_sblocks,
361    )
362    .await?;
363
364    Ok(())
365}
366
367#[cfg(test)]
368mod tests {
369    use crate::storage::memory::MemoryBackedStore;
370
371    use super::*;
372    use futures::executor::block_on;
373
374    #[test]
375    fn generate_and_decode_wavelet_tree_from_vec() {
376        let contents = vec![21, 1, 30, 13, 23, 21, 3, 0, 21, 21, 12, 11];
377        let contents_closure = contents.clone();
378        let contents_len = contents.len();
379
380        let wavelet_bits_file = MemoryBackedStore::new();
381        let wavelet_blocks_file = MemoryBackedStore::new();
382        let wavelet_sblocks_file = MemoryBackedStore::new();
383
384        block_on(build_wavelet_tree_from_iter(
385            5,
386            contents_closure.into_iter(),
387            wavelet_bits_file.clone(),
388            wavelet_blocks_file.clone(),
389            wavelet_sblocks_file.clone(),
390        ))
391        .unwrap();
392
393        let wavelet_bits = block_on(wavelet_bits_file.map()).unwrap();
394        let wavelet_blocks = block_on(wavelet_blocks_file.map()).unwrap();
395        let wavelet_sblocks = block_on(wavelet_sblocks_file.map()).unwrap();
396
397        let wavelet_bitindex = BitIndex::from_maps(wavelet_bits, wavelet_blocks, wavelet_sblocks);
398        let wavelet_tree = WaveletTree::from_parts(wavelet_bitindex, 5);
399
400        assert_eq!(contents_len, wavelet_tree.len());
401
402        assert_eq!(contents, wavelet_tree.decode().collect::<Vec<_>>());
403    }
404
405    #[tokio::test]
406    async fn generate_and_decode_wavelet_tree_from_logarray() {
407        let logarray_file = MemoryBackedStore::new();
408        let mut logarray_builder =
409            LogArrayFileBuilder::new(logarray_file.open_write().await.unwrap(), 5);
410        let contents = vec![21, 1, 30, 13, 23, 21, 3, 0, 21, 21, 12, 11];
411        let contents_len = contents.len();
412        block_on(async {
413            logarray_builder
414                .push_all(util::stream_iter_ok(contents.clone()))
415                .await?;
416            logarray_builder.finalize().await?;
417
418            Ok::<_, io::Error>(())
419        })
420        .unwrap();
421
422        let wavelet_bits_file = MemoryBackedStore::new();
423        let wavelet_blocks_file = MemoryBackedStore::new();
424        let wavelet_sblocks_file = MemoryBackedStore::new();
425
426        block_on(build_wavelet_tree_from_logarray(
427            logarray_file,
428            wavelet_bits_file.clone(),
429            wavelet_blocks_file.clone(),
430            wavelet_sblocks_file.clone(),
431        ))
432        .unwrap();
433
434        let wavelet_bits = block_on(wavelet_bits_file.map()).unwrap();
435        let wavelet_blocks = block_on(wavelet_blocks_file.map()).unwrap();
436        let wavelet_sblocks = block_on(wavelet_sblocks_file.map()).unwrap();
437
438        let wavelet_bitindex = BitIndex::from_maps(wavelet_bits, wavelet_blocks, wavelet_sblocks);
439        let wavelet_tree = WaveletTree::from_parts(wavelet_bitindex, 5);
440
441        assert_eq!(contents_len, wavelet_tree.len());
442
443        assert_eq!(contents, wavelet_tree.decode().collect::<Vec<_>>());
444    }
445
446    #[test]
447    fn slice_wavelet_tree() {
448        let contents = vec![8, 3, 8, 8, 1, 2, 3, 2, 8, 9, 3, 3, 6, 7, 0, 4, 8, 7, 3];
449        let contents_closure = contents.clone();
450
451        let wavelet_bits_file = MemoryBackedStore::new();
452        let wavelet_blocks_file = MemoryBackedStore::new();
453        let wavelet_sblocks_file = MemoryBackedStore::new();
454
455        block_on(build_wavelet_tree_from_iter(
456            4,
457            contents_closure.into_iter(),
458            wavelet_bits_file.clone(),
459            wavelet_blocks_file.clone(),
460            wavelet_sblocks_file.clone(),
461        ))
462        .unwrap();
463
464        let wavelet_bits = block_on(wavelet_bits_file.map()).unwrap();
465        let wavelet_blocks = block_on(wavelet_blocks_file.map()).unwrap();
466        let wavelet_sblocks = block_on(wavelet_sblocks_file.map()).unwrap();
467
468        let wavelet_bitindex = BitIndex::from_maps(wavelet_bits, wavelet_blocks, wavelet_sblocks);
469        let wavelet_tree = WaveletTree::from_parts(wavelet_bitindex, 4);
470
471        let slice = wavelet_tree.lookup(8).unwrap();
472        assert_eq!(vec![0, 2, 3, 8, 16], slice.iter().collect::<Vec<_>>());
473        let slice = wavelet_tree.lookup(3).unwrap();
474        assert_eq!(vec![1, 6, 10, 11, 18], slice.iter().collect::<Vec<_>>());
475        let slice = wavelet_tree.lookup(0).unwrap();
476        assert_eq!(vec![14], slice.iter().collect::<Vec<_>>());
477        let slice = wavelet_tree.lookup(5);
478        assert!(slice.is_none());
479    }
480
481    #[test]
482    fn empty_wavelet_tree() {
483        let contents = Vec::new();
484        let contents_closure = contents.clone();
485
486        let wavelet_bits_file = MemoryBackedStore::new();
487        let wavelet_blocks_file = MemoryBackedStore::new();
488        let wavelet_sblocks_file = MemoryBackedStore::new();
489
490        block_on(build_wavelet_tree_from_iter(
491            4,
492            contents_closure.into_iter(),
493            wavelet_bits_file.clone(),
494            wavelet_blocks_file.clone(),
495            wavelet_sblocks_file.clone(),
496        ))
497        .unwrap();
498
499        let wavelet_bits = block_on(wavelet_bits_file.map()).unwrap();
500        let wavelet_blocks = block_on(wavelet_blocks_file.map()).unwrap();
501        let wavelet_sblocks = block_on(wavelet_sblocks_file.map()).unwrap();
502
503        let wavelet_bitindex = BitIndex::from_maps(wavelet_bits, wavelet_blocks, wavelet_sblocks);
504        let wavelet_tree = WaveletTree::from_parts(wavelet_bitindex, 4);
505
506        assert!(wavelet_tree.lookup(3).is_none());
507    }
508
509    #[test]
510    fn lookup_wavelet_beyond_end() {
511        let contents = vec![8, 3, 8, 8, 1, 2, 3, 2, 8, 9, 3, 3, 6, 7, 0, 4, 8, 7, 3];
512        let contents_closure = contents.clone();
513
514        let wavelet_bits_file = MemoryBackedStore::new();
515        let wavelet_blocks_file = MemoryBackedStore::new();
516        let wavelet_sblocks_file = MemoryBackedStore::new();
517
518        block_on(build_wavelet_tree_from_iter(
519            4,
520            contents_closure.into_iter(),
521            wavelet_bits_file.clone(),
522            wavelet_blocks_file.clone(),
523            wavelet_sblocks_file.clone(),
524        ))
525        .unwrap();
526
527        let wavelet_bits = block_on(wavelet_bits_file.map()).unwrap();
528        let wavelet_blocks = block_on(wavelet_blocks_file.map()).unwrap();
529        let wavelet_sblocks = block_on(wavelet_sblocks_file.map()).unwrap();
530
531        let wavelet_bitindex = BitIndex::from_maps(wavelet_bits, wavelet_blocks, wavelet_sblocks);
532        let wavelet_tree = WaveletTree::from_parts(wavelet_bitindex, 4);
533
534        assert!(wavelet_tree.lookup(100).is_none());
535    }
536
537    #[test]
538    fn lookup_wavelet_with_just_one_char_type() {
539        let contents = vec![5, 5, 5, 5, 5, 5, 5, 5, 5, 5];
540        let contents_closure = contents.clone();
541
542        let wavelet_bits_file = MemoryBackedStore::new();
543        let wavelet_blocks_file = MemoryBackedStore::new();
544        let wavelet_sblocks_file = MemoryBackedStore::new();
545
546        block_on(build_wavelet_tree_from_iter(
547            4,
548            contents_closure.into_iter(),
549            wavelet_bits_file.clone(),
550            wavelet_blocks_file.clone(),
551            wavelet_sblocks_file.clone(),
552        ))
553        .unwrap();
554
555        let wavelet_bits = block_on(wavelet_bits_file.map()).unwrap();
556        let wavelet_blocks = block_on(wavelet_blocks_file.map()).unwrap();
557        let wavelet_sblocks = block_on(wavelet_sblocks_file.map()).unwrap();
558
559        let wavelet_bitindex = BitIndex::from_maps(wavelet_bits, wavelet_blocks, wavelet_sblocks);
560        let wavelet_tree = WaveletTree::from_parts(wavelet_bitindex, 4);
561
562        assert_eq!(
563            vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
564            wavelet_tree.lookup(5).unwrap().iter().collect::<Vec<_>>()
565        );
566        assert!(wavelet_tree.lookup(4).is_none());
567        assert!(wavelet_tree.lookup(6).is_none());
568    }
569
570    #[test]
571    fn wavelet_lookup_one() {
572        let contents = vec![3, 6, 2, 1, 8, 5, 4, 7];
573        let contents_closure = contents.clone();
574
575        let wavelet_bits_file = MemoryBackedStore::new();
576        let wavelet_blocks_file = MemoryBackedStore::new();
577        let wavelet_sblocks_file = MemoryBackedStore::new();
578
579        block_on(build_wavelet_tree_from_iter(
580            4,
581            contents_closure.into_iter(),
582            wavelet_bits_file.clone(),
583            wavelet_blocks_file.clone(),
584            wavelet_sblocks_file.clone(),
585        ))
586        .unwrap();
587
588        let wavelet_bits = block_on(wavelet_bits_file.map()).unwrap();
589        let wavelet_blocks = block_on(wavelet_blocks_file.map()).unwrap();
590        let wavelet_sblocks = block_on(wavelet_sblocks_file.map()).unwrap();
591
592        let wavelet_bitindex = BitIndex::from_maps(wavelet_bits, wavelet_blocks, wavelet_sblocks);
593        let wavelet_tree = WaveletTree::from_parts(wavelet_bitindex, 4);
594
595        assert_eq!(Some(3), wavelet_tree.lookup_one(1));
596        assert_eq!(Some(2), wavelet_tree.lookup_one(2));
597        assert_eq!(Some(6), wavelet_tree.lookup_one(4));
598        assert_eq!(Some(5), wavelet_tree.lookup_one(5));
599        assert_eq!(Some(1), wavelet_tree.lookup_one(6));
600        assert_eq!(Some(7), wavelet_tree.lookup_one(7));
601        assert_eq!(Some(4), wavelet_tree.lookup_one(8));
602    }
603}