1#![deny(missing_docs)]
33mod radixsort;
34
35use std::io::Write;
36
37use anyhow::{anyhow, Result};
38
39use radixsort::MsdRadixSorter;
40
41pub struct BwtBuilder<'a> {
57 text: &'a [u8],
58 chunk_size: usize,
59 progress: Progress,
60}
61
62impl<'a> BwtBuilder<'a> {
63 pub fn new(text: &'a [u8]) -> Result<Self> {
73 if text.is_empty() {
74 return Err(anyhow!("text must not be empty."));
75 }
76 let n = text.len() as f64;
77 let chunk_size = (n / n.log2()).ceil() as usize;
78 let chunk_size = chunk_size.max(1);
79 Ok(Self {
80 text,
81 chunk_size,
82 progress: Progress::new(false),
83 })
84 }
85
86 #[doc(hidden)]
100 pub fn chunk_size(mut self, chunk_size: usize) -> Result<Self> {
101 if chunk_size == 0 {
102 return Err(anyhow!("chunk_size must be positive."));
103 }
104 self.chunk_size = chunk_size;
105 Ok(self)
106 }
107
108 pub const fn verbose(mut self, verbose: bool) -> Self {
119 self.progress = Progress::new(verbose);
120 self
121 }
122
123 pub fn build<W: Write>(&self, wrt: W) -> Result<()> {
133 assert!(!self.text.is_empty());
134 assert_ne!(self.chunk_size, 0);
135
136 let text = self.text;
137 let chunk_size = self.chunk_size;
138 let n_expected_cuts = text.len() / chunk_size;
139
140 self.progress
141 .print(&format!("Text length: {:?} MiB", to_mib(text.len())));
142 self.progress
143 .print(&format!("Chunk size: {:?} M", to_mb(chunk_size)));
144 self.progress
145 .print(&format!("Expected number of cuts: {:?}", n_expected_cuts));
146
147 self.progress.print("Generating cuts...");
148 let cuts = CutGenerator::generate(text, chunk_size);
149 self.progress
150 .print(&format!("Actual number of cuts: {:?}", cuts.len()));
151
152 bwt_from_cuts(text, &cuts, wrt, &self.progress)
153 }
154}
155
156fn bwt_from_cuts<W: Write>(
157 text: &[u8],
158 cuts: &[Vec<u8>],
159 mut wrt: W,
160 progress: &Progress,
161) -> Result<()> {
162 assert!(cuts[0].is_empty());
163 let mut chunks = vec![];
164 for q in 1..=cuts.len() {
165 progress.print(&format!("Generating BWT: {}/{}", q, cuts.len()));
166 progress.print(&format!("Length of the cut: {:?}", cuts[q - 1].len()));
167
168 let cut_p = cuts[q - 1].as_slice();
169 if q < cuts.len() {
170 let cut_q = cuts[q].as_slice();
171 for j in 0..text.len() {
172 let suffix = &text[j..];
173 if cut_p < suffix && suffix <= cut_q {
174 chunks.push(j);
175 }
176 }
177 } else {
178 for j in 0..text.len() {
179 let suffix = &text[j..];
180 if cut_p < suffix {
181 chunks.push(j);
182 }
183 }
184 }
185
186 progress.print(&format!("Length of the chunks: {:?}", chunks.len()));
187 chunks = MsdRadixSorter::sort(text, chunks, 256);
188
189 for &j in &chunks {
190 let c = if j == 0 {
191 *text.last().unwrap()
192 } else {
193 text[j - 1]
194 };
195 wrt.write_all(&[c])?;
196 }
197 chunks.clear();
198 }
199 Ok(())
200}
201
202struct CutGenerator<'a> {
203 text: &'a [u8],
204 chunk_size: usize,
205 cuts: Vec<Vec<u8>>,
206 lens: Vec<usize>,
207}
208
209impl<'a> CutGenerator<'a> {
210 fn generate(text: &'a [u8], chunk_size: usize) -> Vec<Vec<u8>> {
211 let mut builder = Self {
212 text,
213 chunk_size,
214 cuts: vec![vec![]],
215 lens: vec![],
216 };
217 builder.expand(vec![]);
218 builder.cuts
219 }
220
221 fn expand(&mut self, mut cut: Vec<u8>) {
222 let freqs = symbol_freqs(self.text, &cut);
223 cut.push(0); for (symbol, &freq) in freqs.iter().enumerate() {
225 if freq == 0 {
226 continue;
227 }
228 *cut.last_mut().unwrap() = symbol as u8;
229 if freq <= self.chunk_size {
230 if self.lens.is_empty() || *self.lens.last().unwrap() + freq > self.chunk_size {
231 self.cuts.push(vec![]);
232 self.lens.push(0);
233 }
234 *self.cuts.last_mut().unwrap() = cut.clone();
235 *self.lens.last_mut().unwrap() += freq;
236 } else {
237 self.expand(cut.clone());
238 }
239 }
240 }
241}
242
243fn symbol_freqs(text: &[u8], cut: &[u8]) -> Vec<usize> {
245 let mut freqs = vec![0; 256];
246 for j in cut.len()..text.len() {
247 let i = j - cut.len();
248 if cut == &text[i..j] {
249 freqs[text[j] as usize] += 1;
250 }
251 }
252 freqs
253}
254
255struct Progress {
256 verbose: bool,
257}
258
259impl Progress {
260 const fn new(verbose: bool) -> Self {
261 Self { verbose }
262 }
263
264 fn print(&self, msg: &str) {
265 if self.verbose {
266 eprintln!("[INFO] {}", msg);
267 }
268 }
269}
270
271fn to_mb(bytes: usize) -> f64 {
272 bytes as f64 / 1000.0 / 1000.0
273}
274
275fn to_mib(bytes: usize) -> f64 {
276 bytes as f64 / 1024.0 / 1024.0
277}
278
279pub fn verify_terminator(text: &[u8]) -> Result<()> {
303 if text.is_empty() {
304 return Err(anyhow!("text must not be empty."));
305 }
306 let smallest = *text.last().unwrap();
307 for (i, &c) in text[..text.len() - 1].iter().enumerate() {
308 if c <= smallest {
309 return Err(anyhow!(
310 "text must have the smallest special character only at the end, but found {c:?} at position {i}."
311 ));
312 }
313 }
314 Ok(())
315}
316
317pub fn decode_bwt(bwt: &[u8]) -> Result<Vec<u8>> {
343 if bwt.is_empty() {
344 return Err(anyhow!("bwt must not be empty."));
345 }
346
347 let (counts, ranks) = {
348 let mut counts = vec![0; 256];
349 let mut ranks = vec![0; bwt.len()];
350 for (&c, r) in bwt.iter().zip(ranks.iter_mut()) {
351 *r = counts[c as usize];
352 counts[c as usize] += 1;
353 }
354 (counts, ranks)
355 };
356
357 let occ = {
358 let mut occ = vec![0; 256];
359 let mut rank = 0;
360 for i in 0..256 {
361 occ[i] = rank;
362 rank += counts[i];
363 }
364 occ
365 };
366
367 let terminator = counts.iter().position(|&c| c != 0).unwrap();
368 if counts[terminator] != 1 {
369 return Err(anyhow!(
370 "bwt must have exactly one terminator character, but found {:x} {} times.",
371 terminator,
372 counts[terminator]
373 ));
374 }
375 let terminator = terminator as u8;
376
377 let mut decoded = Vec::with_capacity(bwt.len());
378 decoded.push(terminator);
379
380 let mut i = 0;
381 while bwt[i] != terminator {
382 assert!(decoded.len() < bwt.len());
383 decoded.push(bwt[i]);
384 i = occ[bwt[i] as usize] + ranks[i];
385 }
386 decoded.reverse();
387
388 Ok(decoded)
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_bwt_builder() {
397 let text = "abracadabra$";
398 let mut bwt = vec![];
399 BwtBuilder::new(text.as_bytes())
400 .unwrap()
401 .build(&mut bwt)
402 .unwrap();
403 let bwt_str = String::from_utf8_lossy(&bwt);
404 assert_eq!(bwt_str, "ard$rcaaaabb");
405 }
406
407 #[test]
408 fn test_bwt_builder_3() {
409 let text = "abracadabra$";
410 let mut bwt = vec![];
411 BwtBuilder::new(text.as_bytes())
412 .unwrap()
413 .chunk_size(3)
414 .unwrap()
415 .build(&mut bwt)
416 .unwrap();
417 let bwt_str = String::from_utf8_lossy(&bwt);
418 assert_eq!(bwt_str, "ard$rcaaaabb");
419 }
420
421 #[test]
422 fn test_bwt_builder_4() {
423 let text = "abracadabra$";
424 let mut bwt = vec![];
425 BwtBuilder::new(text.as_bytes())
426 .unwrap()
427 .chunk_size(4)
428 .unwrap()
429 .build(&mut bwt)
430 .unwrap();
431 let bwt_str = String::from_utf8_lossy(&bwt);
432 assert_eq!(bwt_str, "ard$rcaaaabb");
433 }
434
435 #[test]
436 fn test_bwt_builder_empty() {
437 let text = "";
438 let e = BwtBuilder::new(text.as_bytes());
439 assert!(e.is_err());
440 }
441
442 #[test]
443 fn test_bwt_from_cuts_3() {
444 let text = b"abracadabra$";
445 let cuts = &[
446 b"".to_vec(),
447 b"a$".to_vec(),
448 b"ac".to_vec(),
449 b"b".to_vec(),
450 b"d".to_vec(),
451 b"r".to_vec(),
452 ];
453 let mut bwt = vec![];
454 bwt_from_cuts(text, cuts, &mut bwt, &Progress::new(false)).unwrap();
455 let bwt_str = String::from_utf8_lossy(&bwt);
456 assert_eq!(bwt_str, "ard$rcaaaabb");
457 }
458
459 #[test]
460 fn test_bwt_from_cuts_4() {
461 let text = b"abracadabra$";
462 let cuts = &[b"".to_vec(), b"ab".to_vec(), b"b".to_vec(), b"r".to_vec()];
463 let mut bwt = vec![];
464 bwt_from_cuts(text, cuts, &mut bwt, &Progress::new(false)).unwrap();
465 let bwt_str = String::from_utf8_lossy(&bwt);
466 assert_eq!(bwt_str, "ard$rcaaaabb");
467 }
468
469 #[test]
470 fn test_symbol_freqs() {
471 let text = b"abracadabra$";
472 let cut = b"ra";
473 let freqs = symbol_freqs(text, cut);
474 let mut expected = vec![0; 256];
475 expected[b'$' as usize] = 1;
476 expected[b'c' as usize] = 1;
477 assert_eq!(freqs, expected);
478 }
479
480 #[test]
481 fn test_symbol_freqs_empty() {
482 let text = b"abracadabra$";
483 let cut = b"";
484 let freqs = symbol_freqs(text, cut);
485 let mut expected = vec![0; 256];
486 expected[b'$' as usize] = 1;
487 expected[b'a' as usize] = 5;
488 expected[b'b' as usize] = 2;
489 expected[b'c' as usize] = 1;
490 expected[b'd' as usize] = 1;
491 expected[b'r' as usize] = 2;
492 assert_eq!(freqs, expected);
493 }
494
495 #[test]
496 fn test_verify_terminator_empty() {
497 let text = "";
498 let e = verify_terminator(text.as_bytes());
499 assert!(e.is_err());
500 }
501
502 #[test]
503 fn test_decode_bwt_single() {
504 let bwt = "$";
505 let decoded = decode_bwt(bwt.as_bytes()).unwrap();
506 assert_eq!(decoded, "$".as_bytes());
507 }
508
509 #[test]
510 fn test_decode_bwt_empty() {
511 let bwt = "";
512 let e = decode_bwt(bwt.as_bytes());
513 assert!(e.is_err());
514 }
515
516 #[test]
517 fn test_decode_bwt_invalid_terminator() {
518 let bwt = "ard$rcaaa$bb";
519 let e = decode_bwt(bwt.as_bytes());
520 assert!(e.is_err());
521 }
522}