1use crate::offsets::{EliasFanoOffsets, OffsetsVector};
7use crate::kmer::{Kmer, KmerBits};
8use std::io::{self, Read, Write};
9
10pub struct SpectrumPreservingStringSet {
18 strings: Vec<u8>,
20 offsets: EliasFanoOffsets,
23 k: usize,
25 m: usize,
27}
28
29impl SpectrumPreservingStringSet {
30 pub fn new(k: usize, m: usize) -> Self {
32 Self {
33 strings: Vec::new(),
34 offsets: EliasFanoOffsets::from_vec(&[0]),
35 k,
36 m,
37 }
38 }
39
40 pub fn from_parts(strings: Vec<u8>, offsets: OffsetsVector, k: usize, m: usize) -> Self {
50 Self {
51 strings,
52 offsets: EliasFanoOffsets::from_offsets_vector(offsets),
53 k,
54 m,
55 }
56 }
57
58 pub fn string_offsets(&self, string_id: u32) -> (u64, u64) {
60 let id = string_id as usize;
61 let begin = self.offsets.access(id);
62 let end = self.offsets.access(id + 1);
63 (begin, end)
64 }
65
66 pub fn num_strings(&self) -> u64 {
68 if !self.offsets.is_empty() {
69 (self.offsets.len() - 1) as u64
70 } else {
71 0
72 }
73 }
74
75 pub fn string_offset(&self, string_id: u64) -> u64 {
77 self.offsets.access(string_id as usize)
78 }
79
80 pub fn k(&self) -> usize {
82 self.k
83 }
84
85 pub fn m(&self) -> usize {
87 self.m
88 }
89
90 pub fn total_bases(&self) -> u64 {
92 if !self.offsets.is_empty() {
93 self.offsets.access(self.offsets.len() - 1)
94 } else {
95 0
96 }
97 }
98
99 #[inline]
102 pub fn locate(&self, absolute_pos: u64) -> Option<(u64, u64)> {
103 self.offsets.locate(absolute_pos)
104 }
105
106 #[inline]
110 pub fn locate_with_end(&self, absolute_pos: u64) -> Option<(u64, u64, u64)> {
111 self.offsets.locate_with_end(absolute_pos)
112 }
113
114 pub fn num_bits(&self) -> u64 {
116 (self.strings.len() as u64) * 8 + self.offsets.num_bits()
117 }
118
119 pub fn strings_bytes(&self) -> usize {
121 self.strings.len()
122 }
123
124 pub fn offsets_bytes(&self) -> usize {
126 self.offsets.num_bytes() as usize
127 }
128
129 pub fn string_length(&self, string_id: u64) -> usize {
131 let (begin, end) = self.string_offsets(string_id as u32);
132 (end - begin) as usize
133 }
134
135 #[inline]
139 pub fn decode_kmer<const K: usize>(&self, string_id: u64, kmer_pos: usize) -> Kmer<K>
140 where
141 Kmer<K>: KmerBits,
142 {
143 let (begin, _end) = self.string_offsets(string_id as u32);
144 let start_base = (begin as usize) + kmer_pos;
145
146 let byte_offset = start_base / 4;
147 let bit_shift = (start_base % 4) * 2;
148 let needed_bits = K * 2;
149
150 let needed_bytes = (needed_bits + bit_shift).div_ceil(8);
153
154 if needed_bytes <= 8 {
155 let mut buf = [0u8; 8];
157 let avail = self.strings.len().saturating_sub(byte_offset).min(8);
158 buf[..avail].copy_from_slice(&self.strings[byte_offset..byte_offset + avail]);
159 let raw = u64::from_le_bytes(buf);
160 let shifted = raw >> bit_shift;
161 let mask = if needed_bits >= 64 { u64::MAX } else { (1u64 << needed_bits) - 1 };
162 Kmer::<K>::new(<Kmer<K> as KmerBits>::from_u64(shifted & mask))
163 } else {
164 let mut buf = [0u8; 17];
169 let avail = self.strings.len().saturating_sub(byte_offset).min(17);
170 buf[..avail].copy_from_slice(&self.strings[byte_offset..byte_offset + avail]);
171 let raw = u128::from_le_bytes(buf[..16].try_into().unwrap());
172 let shifted = if bit_shift > 0 {
173 let extra = buf[16] as u128;
174 (raw >> bit_shift) | (extra << (128 - bit_shift))
175 } else {
176 raw
177 };
178 let mask = if needed_bits >= 128 { u128::MAX } else { (1u128 << needed_bits) - 1 };
179 Kmer::<K>::new(<Kmer<K> as KmerBits>::from_u128(shifted & mask))
180 }
181 }
182
183 #[inline]
187 pub fn decode_kmer_at<const K: usize>(&self, absolute_pos: usize) -> Kmer<K>
188 where
189 Kmer<K>: KmerBits,
190 {
191 let byte_offset = absolute_pos / 4;
192 let bit_shift = (absolute_pos % 4) * 2;
193 let needed_bits = K * 2;
194 let needed_bytes = (needed_bits + bit_shift).div_ceil(8);
195
196 if needed_bytes <= 8 {
197 let raw = if byte_offset + 8 <= self.strings.len() {
203 unsafe {
204 std::ptr::read_unaligned(
205 self.strings.as_ptr().add(byte_offset) as *const u64
206 )
207 }
208 } else {
209 let mut buf = [0u8; 8];
211 let avail = self.strings.len() - byte_offset;
212 buf[..avail].copy_from_slice(&self.strings[byte_offset..byte_offset + avail]);
213 u64::from_le_bytes(buf)
214 };
215 let shifted = raw >> bit_shift;
216 let mask = if needed_bits >= 64 { u64::MAX } else { (1u64 << needed_bits) - 1 };
217 Kmer::<K>::new(<Kmer<K> as KmerBits>::from_u64(shifted & mask))
218 } else {
219 let (raw, extra_byte) = if byte_offset + 17 <= self.strings.len() {
222 let r = unsafe {
223 std::ptr::read_unaligned(
224 self.strings.as_ptr().add(byte_offset) as *const u128
225 )
226 };
227 (r, self.strings[byte_offset + 16])
228 } else if byte_offset + 16 <= self.strings.len() {
229 let r = unsafe {
230 std::ptr::read_unaligned(
231 self.strings.as_ptr().add(byte_offset) as *const u128
232 )
233 };
234 let extra = if byte_offset + 16 < self.strings.len() {
235 self.strings[byte_offset + 16]
236 } else {
237 0u8
238 };
239 (r, extra)
240 } else {
241 let mut buf = [0u8; 17];
242 let avail = self.strings.len() - byte_offset;
243 buf[..avail].copy_from_slice(&self.strings[byte_offset..byte_offset + avail]);
244 (u128::from_le_bytes(buf[..16].try_into().unwrap()), buf[16])
245 };
246 let shifted = if bit_shift > 0 {
247 (raw >> bit_shift) | ((extra_byte as u128) << (128 - bit_shift))
248 } else {
249 raw
250 };
251 let mask = if needed_bits >= 128 { u128::MAX } else { (1u128 << needed_bits) - 1 };
252 Kmer::<K>::new(<Kmer<K> as KmerBits>::from_u128(shifted & mask))
253 }
254 }
255
256 pub fn serialize_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
265 writer.write_all(&(self.k as u64).to_le_bytes())?;
266 writer.write_all(&(self.m as u64).to_le_bytes())?;
267 writer.write_all(&(self.strings.len() as u64).to_le_bytes())?;
268 writer.write_all(&self.strings)?;
269 self.offsets.write_to(writer)?;
270 Ok(())
271 }
272
273 pub fn deserialize_from<R: Read>(reader: &mut R) -> io::Result<Self> {
275 let mut buf8 = [0u8; 8];
276 reader.read_exact(&mut buf8)?;
277 let k = u64::from_le_bytes(buf8) as usize;
278 reader.read_exact(&mut buf8)?;
279 let m = u64::from_le_bytes(buf8) as usize;
280 reader.read_exact(&mut buf8)?;
281 let strings_len = u64::from_le_bytes(buf8) as usize;
282 let mut strings = vec![0u8; strings_len];
283 reader.read_exact(&mut strings)?;
284 let offsets = EliasFanoOffsets::read_from(reader)?;
285 Ok(Self { strings, offsets, k, m })
286 }
287}
288
289impl std::fmt::Debug for SpectrumPreservingStringSet {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("SpectrumPreservingStringSet")
292 .field("k", &self.k)
293 .field("m", &self.m)
294 .field("num_strings", &self.num_strings())
295 .field("total_bases", &self.total_bases())
296 .field("num_bits", &self.num_bits())
297 .finish()
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::offsets::OffsetsVector;
305
306 fn build_test_spss(k: usize, m: usize, strings: &[&str]) -> SpectrumPreservingStringSet {
309 let mut packed = Vec::new();
310 let mut offsets = OffsetsVector::new();
311 let mut total_bases: u64 = 0;
312
313 for s in strings {
314 for &b in s.as_bytes() {
315 let bits = match b {
316 b'A' | b'a' => 0u8,
317 b'C' | b'c' => 1u8,
318 b'G' | b'g' => 3u8,
319 b'T' | b't' => 2u8,
320 _ => panic!("invalid base"),
321 };
322 let byte_idx = (total_bases as usize) / 4;
323 let bit_off = ((total_bases as usize) % 4) * 2;
324 if byte_idx >= packed.len() {
325 packed.push(0u8);
326 }
327 packed[byte_idx] |= bits << bit_off;
328 total_bases += 1;
329 }
330 offsets.push(total_bases);
331 }
332
333 SpectrumPreservingStringSet::from_parts(packed, offsets, k, m)
334 }
335
336 #[test]
337 fn test_spss_creation() {
338 let spss = SpectrumPreservingStringSet::new(31, 13);
339 assert_eq!(spss.k(), 31);
340 assert_eq!(spss.m(), 13);
341 assert_eq!(spss.num_strings(), 0);
342 }
343
344 #[test]
345 fn test_spss_two_strings() {
346 let spss = build_test_spss(31, 13, &[
347 "ACGTACGTACGTACGTACGTACGTACGTACG", "TGCATGCATGCATGCATGCATGCATGCATGCA", ]);
350 assert_eq!(spss.num_strings(), 2);
351 }
352
353 #[test]
354 fn test_spss_string_offsets() {
355 let spss = build_test_spss(31, 13, &[
356 "ACGTACGTACGTACGTACGTACGTACGTACG", "TGCATGCATGCATGCATGCATGCATGCATGC", ]);
359
360 let (begin1, end1) = spss.string_offsets(0);
361 let (begin2, end2) = spss.string_offsets(1);
362
363 assert_eq!(begin1, 0);
364 assert_eq!(end1 - begin1, 31);
365 assert_eq!(begin2, 31);
366 assert_eq!(end2 - begin2, 31);
367 }
368
369 #[test]
370 fn test_spss_total_bases() {
371 let spss = build_test_spss(31, 13, &[
372 "ACGTACGTACGTACGTACGTACGTACGTACG", "TGCATGCATGCATGCATGCATGCATGCATGC", ]);
375 assert_eq!(spss.total_bases(), 62);
376 }
377
378 #[test]
379 fn test_spss_debug() {
380 let spss = build_test_spss(31, 13, &[
381 "ACGTACGTACGTACGTACGTACGTACGTACG",
382 ]);
383 let debug_str = format!("{:?}", spss);
384 assert!(debug_str.contains("SpectrumPreservingStringSet"));
385 assert!(debug_str.contains("k: 31"));
386 }
387}