small_bwt/
lib.rs

1//! # BWT construction in small space
2//!
3//! Implementation of the BWT construction algorithm in small space,
4//! described in Algorithm 11.8 of the book:
5//! [Compact Data Structures - A Practical Approach](https://users.dcc.uchile.cl/~gnavarro/CDSbook/),
6//! Gonzalo Navarro, 2016.
7//!
8//! Given a typical text, it runs in `O(n log n loglog n)` time and `O(n)` additional bits of space,
9//! where `n` is the length of the input string and the alphabet size is much smaller than `n`.
10//! See the book for more details.
11//!
12//! ## Basic usage
13//!
14//! [`BwtBuilder`] provides a simple interface to build the BWT.
15//! It inputs a byte slice and outputs the BWT to a write stream.
16//!
17//! ```
18//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
19//! use small_bwt::BwtBuilder;
20//!
21//! // The text must end with a smallest terminal character.
22//! let text = "abracadabra$";
23//!
24//! // Build the BWT.
25//! let mut bwt = vec![];
26//! BwtBuilder::new(text.as_bytes())?.build(&mut bwt)?;
27//! let bwt_str = String::from_utf8_lossy(&bwt);
28//! assert_eq!(bwt_str, "ard$rcaaaabb");
29//! # Ok(())
30//! # }
31//! ```
32#![deny(missing_docs)]
33mod radixsort;
34
35use std::io::Write;
36
37use anyhow::{anyhow, Result};
38
39use radixsort::MsdRadixSorter;
40
41/// BWT builder in small space.
42///
43/// Given a typical text, it runs in `O(n log n loglog n)` time and `O(n)` additional bits of space,
44/// where `n` is the length of the input string and the alphabet size is much smaller than `n`.
45/// See the book for more details.
46///
47/// # Requirements
48///
49/// This assumes that the smallest character appears only at the end of the text.
50/// Given an unexpected text, the behavior is undefined.
51/// If you want to verify the text, use [`verify_terminator`].
52///
53/// # Examples
54///
55/// See [the top page](crate).
56pub struct BwtBuilder<'a> {
57    text: &'a [u8],
58    chunk_size: usize,
59    progress: Progress,
60}
61
62impl<'a> BwtBuilder<'a> {
63    /// Creates a new builder.
64    ///
65    /// # Arguments
66    ///
67    /// * `text` - The text to be transformed, which should satisfy [`verify_terminator`].
68    ///
69    /// # Errors
70    ///
71    /// An error is returned if `text` is empty.
72    pub fn new(text: &'a [u8]) -> Result<Self> {
73        if text.is_empty() {
74            return Err(anyhow!("text must not be empty."));
75        }
76        let n = text.len() as f64;
77        let chunk_size = (n / n.log2()).ceil() as usize;
78        let chunk_size = chunk_size.max(1);
79        Ok(Self {
80            text,
81            chunk_size,
82            progress: Progress::new(false),
83        })
84    }
85
86    /// Sets the chunk size (for experiments).
87    ///
88    /// # Arguments
89    ///
90    /// * `chunk_size` - The chunk size.
91    ///
92    /// # Default value
93    ///
94    /// `ceil(n / log2(n))`, where `n` is the text length.
95    ///
96    /// # Errors
97    ///
98    /// An error is returned if `chunk_size` is zero.
99    #[doc(hidden)]
100    pub fn chunk_size(mut self, chunk_size: usize) -> Result<Self> {
101        if chunk_size == 0 {
102            return Err(anyhow!("chunk_size must be positive."));
103        }
104        self.chunk_size = chunk_size;
105        Ok(self)
106    }
107
108    /// Sets the verbosity.
109    /// If `verbose` is `true`, the progress is printed to stderr.
110    ///
111    /// # Arguments
112    ///
113    /// * `verbose` - The verbosity.
114    ///
115    /// # Default value
116    ///
117    /// `false`
118    pub const fn verbose(mut self, verbose: bool) -> Self {
119        self.progress = Progress::new(verbose);
120        self
121    }
122
123    /// Builds the BWT and writes it to `wrt`.
124    ///
125    /// # Arguments
126    ///
127    /// * `wrt` - The writer to write the BWT.
128    ///
129    /// # Errors
130    ///
131    /// An error is returned if `wrt` returns an error.
132    pub fn build<W: Write>(&self, wrt: W) -> Result<()> {
133        assert!(!self.text.is_empty());
134        assert_ne!(self.chunk_size, 0);
135
136        let text = self.text;
137        let chunk_size = self.chunk_size;
138        let n_expected_cuts = text.len() / chunk_size;
139
140        self.progress
141            .print(&format!("Text length: {:?} MiB", to_mib(text.len())));
142        self.progress
143            .print(&format!("Chunk size: {:?} M", to_mb(chunk_size)));
144        self.progress
145            .print(&format!("Expected number of cuts: {:?}", n_expected_cuts));
146
147        self.progress.print("Generating cuts...");
148        let cuts = CutGenerator::generate(text, chunk_size);
149        self.progress
150            .print(&format!("Actual number of cuts: {:?}", cuts.len()));
151
152        bwt_from_cuts(text, &cuts, wrt, &self.progress)
153    }
154}
155
156fn bwt_from_cuts<W: Write>(
157    text: &[u8],
158    cuts: &[Vec<u8>],
159    mut wrt: W,
160    progress: &Progress,
161) -> Result<()> {
162    assert!(cuts[0].is_empty());
163    let mut chunks = vec![];
164    for q in 1..=cuts.len() {
165        progress.print(&format!("Generating BWT: {}/{}", q, cuts.len()));
166        progress.print(&format!("Length of the cut: {:?}", cuts[q - 1].len()));
167
168        let cut_p = cuts[q - 1].as_slice();
169        if q < cuts.len() {
170            let cut_q = cuts[q].as_slice();
171            for j in 0..text.len() {
172                let suffix = &text[j..];
173                if cut_p < suffix && suffix <= cut_q {
174                    chunks.push(j);
175                }
176            }
177        } else {
178            for j in 0..text.len() {
179                let suffix = &text[j..];
180                if cut_p < suffix {
181                    chunks.push(j);
182                }
183            }
184        }
185
186        progress.print(&format!("Length of the chunks: {:?}", chunks.len()));
187        chunks = MsdRadixSorter::sort(text, chunks, 256);
188
189        for &j in &chunks {
190            let c = if j == 0 {
191                *text.last().unwrap()
192            } else {
193                text[j - 1]
194            };
195            wrt.write_all(&[c])?;
196        }
197        chunks.clear();
198    }
199    Ok(())
200}
201
202struct CutGenerator<'a> {
203    text: &'a [u8],
204    chunk_size: usize,
205    cuts: Vec<Vec<u8>>,
206    lens: Vec<usize>,
207}
208
209impl<'a> CutGenerator<'a> {
210    fn generate(text: &'a [u8], chunk_size: usize) -> Vec<Vec<u8>> {
211        let mut builder = Self {
212            text,
213            chunk_size,
214            cuts: vec![vec![]],
215            lens: vec![],
216        };
217        builder.expand(vec![]);
218        builder.cuts
219    }
220
221    fn expand(&mut self, mut cut: Vec<u8>) {
222        let freqs = symbol_freqs(self.text, &cut);
223        cut.push(0); // dummy last symbol
224        for (symbol, &freq) in freqs.iter().enumerate() {
225            if freq == 0 {
226                continue;
227            }
228            *cut.last_mut().unwrap() = symbol as u8;
229            if freq <= self.chunk_size {
230                if self.lens.is_empty() || *self.lens.last().unwrap() + freq > self.chunk_size {
231                    self.cuts.push(vec![]);
232                    self.lens.push(0);
233                }
234                *self.cuts.last_mut().unwrap() = cut.clone();
235                *self.lens.last_mut().unwrap() += freq;
236            } else {
237                self.expand(cut.clone());
238            }
239        }
240    }
241}
242
243/// Computes the frequencies of symbols following cut in text.
244fn symbol_freqs(text: &[u8], cut: &[u8]) -> Vec<usize> {
245    let mut freqs = vec![0; 256];
246    for j in cut.len()..text.len() {
247        let i = j - cut.len();
248        if cut == &text[i..j] {
249            freqs[text[j] as usize] += 1;
250        }
251    }
252    freqs
253}
254
255struct Progress {
256    verbose: bool,
257}
258
259impl Progress {
260    const fn new(verbose: bool) -> Self {
261        Self { verbose }
262    }
263
264    fn print(&self, msg: &str) {
265        if self.verbose {
266            eprintln!("[INFO] {}", msg);
267        }
268    }
269}
270
271fn to_mb(bytes: usize) -> f64 {
272    bytes as f64 / 1000.0 / 1000.0
273}
274
275fn to_mib(bytes: usize) -> f64 {
276    bytes as f64 / 1024.0 / 1024.0
277}
278
279/// Verifies that the smallest character appears only at the end of the text.
280///
281/// # Arguments
282///
283/// * `text` - The text to be verified.
284///
285/// # Errors
286///
287/// An error is returned if the smallest character does not appear only at the end of the text.
288///
289/// # Examples
290///
291/// ```
292/// use small_bwt::verify_terminator;
293///
294/// let text = "abracadabra$";
295/// let result = verify_terminator(text.as_bytes());
296/// assert!(result.is_ok());
297///
298/// let text = "abrac$dabra$";
299/// let result = verify_terminator(text.as_bytes());
300/// assert!(result.is_err());
301/// ```
302pub fn verify_terminator(text: &[u8]) -> Result<()> {
303    if text.is_empty() {
304        return Err(anyhow!("text must not be empty."));
305    }
306    let smallest = *text.last().unwrap();
307    for (i, &c) in text[..text.len() - 1].iter().enumerate() {
308        if c <= smallest {
309            return Err(anyhow!(
310                "text must have the smallest special character only at the end, but found {c:?} at position {i}."
311            ));
312        }
313    }
314    Ok(())
315}
316
317/// Decodes the original text from a given BWT.
318///
319/// It runs in `O(n)` time and `O(n log n)` bits of space,
320/// where `n` is the length of the text.
321///
322/// # Arguments
323///
324/// * `bwt` - The Burrows-Wheeler transform of a text.
325///
326/// # Errors
327///
328/// An error is returned if the Burrows-Wheeler transform is invalid.
329///
330/// # Examples
331///
332/// ```
333/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
334/// use small_bwt::decode_bwt;
335///
336/// let bwt = "ard$rcaaaabb";
337/// let decoded = decode_bwt(bwt.as_bytes())?;
338/// assert_eq!(decoded, "abracadabra$".as_bytes());
339/// # Ok(())
340/// # }
341/// ```
342pub fn decode_bwt(bwt: &[u8]) -> Result<Vec<u8>> {
343    if bwt.is_empty() {
344        return Err(anyhow!("bwt must not be empty."));
345    }
346
347    let (counts, ranks) = {
348        let mut counts = vec![0; 256];
349        let mut ranks = vec![0; bwt.len()];
350        for (&c, r) in bwt.iter().zip(ranks.iter_mut()) {
351            *r = counts[c as usize];
352            counts[c as usize] += 1;
353        }
354        (counts, ranks)
355    };
356
357    let occ = {
358        let mut occ = vec![0; 256];
359        let mut rank = 0;
360        for i in 0..256 {
361            occ[i] = rank;
362            rank += counts[i];
363        }
364        occ
365    };
366
367    let terminator = counts.iter().position(|&c| c != 0).unwrap();
368    if counts[terminator] != 1 {
369        return Err(anyhow!(
370            "bwt must have exactly one terminator character, but found {:x} {} times.",
371            terminator,
372            counts[terminator]
373        ));
374    }
375    let terminator = terminator as u8;
376
377    let mut decoded = Vec::with_capacity(bwt.len());
378    decoded.push(terminator);
379
380    let mut i = 0;
381    while bwt[i] != terminator {
382        assert!(decoded.len() < bwt.len());
383        decoded.push(bwt[i]);
384        i = occ[bwt[i] as usize] + ranks[i];
385    }
386    decoded.reverse();
387
388    Ok(decoded)
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_bwt_builder() {
397        let text = "abracadabra$";
398        let mut bwt = vec![];
399        BwtBuilder::new(text.as_bytes())
400            .unwrap()
401            .build(&mut bwt)
402            .unwrap();
403        let bwt_str = String::from_utf8_lossy(&bwt);
404        assert_eq!(bwt_str, "ard$rcaaaabb");
405    }
406
407    #[test]
408    fn test_bwt_builder_3() {
409        let text = "abracadabra$";
410        let mut bwt = vec![];
411        BwtBuilder::new(text.as_bytes())
412            .unwrap()
413            .chunk_size(3)
414            .unwrap()
415            .build(&mut bwt)
416            .unwrap();
417        let bwt_str = String::from_utf8_lossy(&bwt);
418        assert_eq!(bwt_str, "ard$rcaaaabb");
419    }
420
421    #[test]
422    fn test_bwt_builder_4() {
423        let text = "abracadabra$";
424        let mut bwt = vec![];
425        BwtBuilder::new(text.as_bytes())
426            .unwrap()
427            .chunk_size(4)
428            .unwrap()
429            .build(&mut bwt)
430            .unwrap();
431        let bwt_str = String::from_utf8_lossy(&bwt);
432        assert_eq!(bwt_str, "ard$rcaaaabb");
433    }
434
435    #[test]
436    fn test_bwt_builder_empty() {
437        let text = "";
438        let e = BwtBuilder::new(text.as_bytes());
439        assert!(e.is_err());
440    }
441
442    #[test]
443    fn test_bwt_from_cuts_3() {
444        let text = b"abracadabra$";
445        let cuts = &[
446            b"".to_vec(),
447            b"a$".to_vec(),
448            b"ac".to_vec(),
449            b"b".to_vec(),
450            b"d".to_vec(),
451            b"r".to_vec(),
452        ];
453        let mut bwt = vec![];
454        bwt_from_cuts(text, cuts, &mut bwt, &Progress::new(false)).unwrap();
455        let bwt_str = String::from_utf8_lossy(&bwt);
456        assert_eq!(bwt_str, "ard$rcaaaabb");
457    }
458
459    #[test]
460    fn test_bwt_from_cuts_4() {
461        let text = b"abracadabra$";
462        let cuts = &[b"".to_vec(), b"ab".to_vec(), b"b".to_vec(), b"r".to_vec()];
463        let mut bwt = vec![];
464        bwt_from_cuts(text, cuts, &mut bwt, &Progress::new(false)).unwrap();
465        let bwt_str = String::from_utf8_lossy(&bwt);
466        assert_eq!(bwt_str, "ard$rcaaaabb");
467    }
468
469    #[test]
470    fn test_symbol_freqs() {
471        let text = b"abracadabra$";
472        let cut = b"ra";
473        let freqs = symbol_freqs(text, cut);
474        let mut expected = vec![0; 256];
475        expected[b'$' as usize] = 1;
476        expected[b'c' as usize] = 1;
477        assert_eq!(freqs, expected);
478    }
479
480    #[test]
481    fn test_symbol_freqs_empty() {
482        let text = b"abracadabra$";
483        let cut = b"";
484        let freqs = symbol_freqs(text, cut);
485        let mut expected = vec![0; 256];
486        expected[b'$' as usize] = 1;
487        expected[b'a' as usize] = 5;
488        expected[b'b' as usize] = 2;
489        expected[b'c' as usize] = 1;
490        expected[b'd' as usize] = 1;
491        expected[b'r' as usize] = 2;
492        assert_eq!(freqs, expected);
493    }
494
495    #[test]
496    fn test_verify_terminator_empty() {
497        let text = "";
498        let e = verify_terminator(text.as_bytes());
499        assert!(e.is_err());
500    }
501
502    #[test]
503    fn test_decode_bwt_single() {
504        let bwt = "$";
505        let decoded = decode_bwt(bwt.as_bytes()).unwrap();
506        assert_eq!(decoded, "$".as_bytes());
507    }
508
509    #[test]
510    fn test_decode_bwt_empty() {
511        let bwt = "";
512        let e = decode_bwt(bwt.as_bytes());
513        assert!(e.is_err());
514    }
515
516    #[test]
517    fn test_decode_bwt_invalid_terminator() {
518        let bwt = "ard$rcaaa$bb";
519        let e = decode_bwt(bwt.as_bytes());
520        assert!(e.is_err());
521    }
522}