Skip to main content

sshash_lib/
encoding.rs

1//! DNA nucleotide encoding
2//!
3//! This module implements the 2-bit encoding scheme for DNA nucleotides,
4//! matching the C++ SSHash implementation exactly.
5//!
6//! Default encoding (SSHash custom):
7//! - A (65/97)  -> 00
8//! - C (67/99)  -> 01
9//! - G (71/103) -> 11
10//! - T (84/116) -> 10
11
12use thiserror::Error;
13
14/// Error type for encoding operations
15#[derive(Error, Debug, Clone, PartialEq, Eq)]
16pub enum EncodingError {
17    /// The input byte is not a valid DNA base (A/C/G/T)
18    #[error("Invalid DNA base: {0:?}")]
19    InvalidBase(u8),
20    /// The input string is not a valid k-mer
21    #[error("Invalid k-mer string: {0}")]
22    InvalidKmer(String),
23    /// The input string length does not match the expected k-mer length
24    #[error("K-mer length mismatch: expected {expected}, got {actual}")]
25    LengthMismatch {
26        /// Expected k-mer length
27        expected: usize,
28        /// Actual string length
29        actual: usize,
30    },
31}
32
33/// Encode a single DNA nucleotide to 2 bits
34///
35/// Uses the SSHash custom encoding by default:
36/// A -> 00, C -> 01, G -> 11, T -> 10
37#[inline]
38pub const fn encode_base(base: u8) -> Result<u8, EncodingError> {
39    match base {
40        b'A' | b'a' => Ok(0b00),
41        b'C' | b'c' => Ok(0b01),
42        b'G' | b'g' => Ok(0b11),
43        b'T' | b't' => Ok(0b10),
44        _ => Err(EncodingError::InvalidBase(base)),
45    }
46}
47
48/// Decode a 2-bit value to DNA nucleotide (uppercase)
49#[inline]
50pub const fn decode_base(bits: u8) -> u8 {
51    match bits & 0b11 {
52        0b00 => b'A',
53        0b01 => b'C',
54        0b11 => b'G',
55        0b10 => b'T',
56        _ => unreachable!(),
57    }
58}
59
60/// Get the complement of a DNA base (encoded)
61#[inline]
62pub const fn complement_base(bits: u8) -> u8 {
63    // For our encoding: A(00) <-> T(10), C(01) <-> G(11)
64    // XOR with 0b10 gives the complement
65    bits ^ 0b10
66}
67
68/// Encode a DNA string to a bit-packed representation
69///
70/// # Errors
71/// Returns an error if the string contains invalid bases
72pub fn encode_string(s: &str) -> Result<Vec<u8>, EncodingError> {
73    encode_sequence(s.as_bytes())
74}
75
76/// Encode a DNA sequence (byte slice) to a bit-packed representation
77///
78/// # Errors
79/// Returns an error if the sequence contains invalid bases
80pub fn encode_sequence(sequence: &[u8]) -> Result<Vec<u8>, EncodingError> {
81    let mut result = Vec::with_capacity(sequence.len().div_ceil(4));
82    let mut current = 0u8;
83    let mut bit_pos = 0;
84
85    for (i, &base) in sequence.iter().enumerate() {
86        let encoded = encode_base(base).map_err(|_| {
87            EncodingError::InvalidKmer(format!("Invalid base at position {}: {:?}", i, base as char))
88        })?;
89
90        current |= encoded << bit_pos;
91        bit_pos += 2;
92
93        if bit_pos == 8 {
94            result.push(current);
95            current = 0;
96            bit_pos = 0;
97        }
98    }
99
100    if bit_pos > 0 {
101        result.push(current);
102    }
103
104    Ok(result)
105}
106
107/// Decode a bit-packed representation back to a DNA string
108pub fn decode_string(data: &[u8], length: usize) -> String {
109    let mut result = String::with_capacity(length);
110    let mut bit_pos = 0;
111    let mut byte_idx = 0;
112
113    for _ in 0..length {
114        if byte_idx >= data.len() {
115            break;
116        }
117
118        let bits = (data[byte_idx] >> bit_pos) & 0b11;
119        result.push(decode_base(bits) as char);
120
121        bit_pos += 2;
122        if bit_pos == 8 {
123            bit_pos = 0;
124            byte_idx += 1;
125        }
126    }
127
128    result
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn test_encode_base() {
137        assert_eq!(encode_base(b'A').unwrap(), 0b00);
138        assert_eq!(encode_base(b'a').unwrap(), 0b00);
139        assert_eq!(encode_base(b'C').unwrap(), 0b01);
140        assert_eq!(encode_base(b'c').unwrap(), 0b01);
141        assert_eq!(encode_base(b'G').unwrap(), 0b11);
142        assert_eq!(encode_base(b'g').unwrap(), 0b11);
143        assert_eq!(encode_base(b'T').unwrap(), 0b10);
144        assert_eq!(encode_base(b't').unwrap(), 0b10);
145
146        // Invalid bases
147        assert!(encode_base(b'N').is_err());
148        assert!(encode_base(b'X').is_err());
149        assert!(encode_base(b'0').is_err());
150    }
151
152    #[test]
153    fn test_decode_base() {
154        assert_eq!(decode_base(0b00), b'A');
155        assert_eq!(decode_base(0b01), b'C');
156        assert_eq!(decode_base(0b11), b'G');
157        assert_eq!(decode_base(0b10), b'T');
158    }
159
160    #[test]
161    fn test_complement_base() {
162        assert_eq!(complement_base(0b00), 0b10); // A -> T
163        assert_eq!(complement_base(0b10), 0b00); // T -> A
164        assert_eq!(complement_base(0b01), 0b11); // C -> G
165        assert_eq!(complement_base(0b11), 0b01); // G -> C
166    }
167
168    #[test]
169    fn test_encode_decode_roundtrip() {
170        let sequences = vec!["ACGT", "AAAA", "TTTT", "ACGTACGT", "GATTACA"];
171
172        for seq in sequences {
173            let encoded = encode_string(seq).unwrap();
174            let decoded = decode_string(&encoded, seq.len());
175            assert_eq!(decoded, seq.to_uppercase());
176        }
177    }
178
179    #[test]
180    fn test_encode_mixed_case() {
181        let lower = encode_string("acgt").unwrap();
182        let upper = encode_string("ACGT").unwrap();
183        assert_eq!(lower, upper);
184    }
185
186    #[test]
187    fn test_encode_invalid() {
188        assert!(encode_string("ACGTN").is_err());
189        assert!(encode_string("ACGT X").is_err());
190    }
191}