1use thiserror::Error;
13
14#[derive(Error, Debug, Clone, PartialEq, Eq)]
16pub enum EncodingError {
17 #[error("Invalid DNA base: {0:?}")]
19 InvalidBase(u8),
20 #[error("Invalid k-mer string: {0}")]
22 InvalidKmer(String),
23 #[error("K-mer length mismatch: expected {expected}, got {actual}")]
25 LengthMismatch {
26 expected: usize,
28 actual: usize,
30 },
31}
32
33#[inline]
38pub const fn encode_base(base: u8) -> Result<u8, EncodingError> {
39 match base {
40 b'A' | b'a' => Ok(0b00),
41 b'C' | b'c' => Ok(0b01),
42 b'G' | b'g' => Ok(0b11),
43 b'T' | b't' => Ok(0b10),
44 _ => Err(EncodingError::InvalidBase(base)),
45 }
46}
47
48#[inline]
50pub const fn decode_base(bits: u8) -> u8 {
51 match bits & 0b11 {
52 0b00 => b'A',
53 0b01 => b'C',
54 0b11 => b'G',
55 0b10 => b'T',
56 _ => unreachable!(),
57 }
58}
59
60#[inline]
62pub const fn complement_base(bits: u8) -> u8 {
63 bits ^ 0b10
66}
67
68pub fn encode_string(s: &str) -> Result<Vec<u8>, EncodingError> {
73 encode_sequence(s.as_bytes())
74}
75
76pub fn encode_sequence(sequence: &[u8]) -> Result<Vec<u8>, EncodingError> {
81 let mut result = Vec::with_capacity(sequence.len().div_ceil(4));
82 let mut current = 0u8;
83 let mut bit_pos = 0;
84
85 for (i, &base) in sequence.iter().enumerate() {
86 let encoded = encode_base(base).map_err(|_| {
87 EncodingError::InvalidKmer(format!("Invalid base at position {}: {:?}", i, base as char))
88 })?;
89
90 current |= encoded << bit_pos;
91 bit_pos += 2;
92
93 if bit_pos == 8 {
94 result.push(current);
95 current = 0;
96 bit_pos = 0;
97 }
98 }
99
100 if bit_pos > 0 {
101 result.push(current);
102 }
103
104 Ok(result)
105}
106
107pub fn decode_string(data: &[u8], length: usize) -> String {
109 let mut result = String::with_capacity(length);
110 let mut bit_pos = 0;
111 let mut byte_idx = 0;
112
113 for _ in 0..length {
114 if byte_idx >= data.len() {
115 break;
116 }
117
118 let bits = (data[byte_idx] >> bit_pos) & 0b11;
119 result.push(decode_base(bits) as char);
120
121 bit_pos += 2;
122 if bit_pos == 8 {
123 bit_pos = 0;
124 byte_idx += 1;
125 }
126 }
127
128 result
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn test_encode_base() {
137 assert_eq!(encode_base(b'A').unwrap(), 0b00);
138 assert_eq!(encode_base(b'a').unwrap(), 0b00);
139 assert_eq!(encode_base(b'C').unwrap(), 0b01);
140 assert_eq!(encode_base(b'c').unwrap(), 0b01);
141 assert_eq!(encode_base(b'G').unwrap(), 0b11);
142 assert_eq!(encode_base(b'g').unwrap(), 0b11);
143 assert_eq!(encode_base(b'T').unwrap(), 0b10);
144 assert_eq!(encode_base(b't').unwrap(), 0b10);
145
146 assert!(encode_base(b'N').is_err());
148 assert!(encode_base(b'X').is_err());
149 assert!(encode_base(b'0').is_err());
150 }
151
152 #[test]
153 fn test_decode_base() {
154 assert_eq!(decode_base(0b00), b'A');
155 assert_eq!(decode_base(0b01), b'C');
156 assert_eq!(decode_base(0b11), b'G');
157 assert_eq!(decode_base(0b10), b'T');
158 }
159
160 #[test]
161 fn test_complement_base() {
162 assert_eq!(complement_base(0b00), 0b10); assert_eq!(complement_base(0b10), 0b00); assert_eq!(complement_base(0b01), 0b11); assert_eq!(complement_base(0b11), 0b01); }
167
168 #[test]
169 fn test_encode_decode_roundtrip() {
170 let sequences = vec!["ACGT", "AAAA", "TTTT", "ACGTACGT", "GATTACA"];
171
172 for seq in sequences {
173 let encoded = encode_string(seq).unwrap();
174 let decoded = decode_string(&encoded, seq.len());
175 assert_eq!(decoded, seq.to_uppercase());
176 }
177 }
178
179 #[test]
180 fn test_encode_mixed_case() {
181 let lower = encode_string("acgt").unwrap();
182 let upper = encode_string("ACGT").unwrap();
183 assert_eq!(lower, upper);
184 }
185
186 #[test]
187 fn test_encode_invalid() {
188 assert!(encode_string("ACGTN").is_err());
189 assert!(encode_string("ACGT X").is_err());
190 }
191}