Skip to main content

sklears_simd/
error_correction.rs

1//! SIMD-optimized error correction codes
2//!
3//! This module provides SIMD-accelerated implementations of common error correction codes
4//! including Hamming codes, Reed-Solomon codes, and CRC error detection.
5
6#[cfg(feature = "no-std")]
7use alloc::{
8    string::{String, ToString},
9    vec,
10    vec::Vec,
11};
12
13/// Hamming(7,4) code - encodes 4 data bits into 7 bits with error correction
14#[derive(Debug, Clone)]
15pub struct HammingCode74 {
16    // Generator matrix for Hamming(7,4)
17    #[allow(dead_code)] // Stored as reference; encode() inlines the arithmetic for performance
18    generator_matrix: [[u8; 7]; 4],
19    // Parity check matrix
20    #[allow(dead_code)] // Stored as reference; decode() inlines the syndrome calculation
21    parity_check_matrix: [[u8; 4]; 7],
22}
23
24impl HammingCode74 {
25    pub fn new() -> Self {
26        Self {
27            // Standard generator matrix for Hamming(7,4)
28            generator_matrix: [
29                [1, 0, 0, 0, 1, 1, 0], // d1
30                [0, 1, 0, 0, 1, 0, 1], // d2
31                [0, 0, 1, 0, 0, 1, 1], // d3
32                [0, 0, 0, 1, 1, 1, 1], // d4
33            ],
34            parity_check_matrix: [
35                [1, 1, 0, 1], // Position 1
36                [1, 0, 1, 1], // Position 2
37                [1, 0, 0, 0], // Position 3
38                [0, 1, 1, 1], // Position 4
39                [0, 1, 0, 0], // Position 5
40                [0, 0, 1, 0], // Position 6
41                [0, 0, 0, 1], // Position 7
42            ],
43        }
44    }
45
46    /// Encode 4 data bits into 7-bit Hamming code
47    pub fn encode(&self, data: u8) -> u8 {
48        debug_assert!(data < 16, "Data must be 4 bits (0-15)");
49
50        // Place data bits in positions 3, 5, 6, 7 (0-indexed: 2, 4, 5, 6)
51        let d1 = data & 1;
52        let d2 = (data >> 1) & 1;
53        let d3 = (data >> 2) & 1;
54        let d4 = (data >> 3) & 1;
55
56        // Calculate parity bits
57        let p1 = d1 ^ d2 ^ d4; // Parity for positions 1, 3, 5, 7
58        let p2 = d1 ^ d3 ^ d4; // Parity for positions 2, 3, 6, 7
59        let p3 = d2 ^ d3 ^ d4; // Parity for positions 4, 5, 6, 7
60
61        // Construct codeword: p1 p2 d1 p3 d2 d3 d4
62
63        p1 | (p2 << 1) | (d1 << 2) | (p3 << 3) | (d2 << 4) | (d3 << 5) | (d4 << 6)
64    }
65
66    /// Decode 7-bit Hamming code and correct single-bit errors
67    pub fn decode(&self, codeword: u8) -> Result<u8, String> {
68        debug_assert!(codeword < 128, "Codeword must be 7 bits (0-127)");
69
70        // Calculate syndrome using the standard Hamming(7,4) method
71        let h1 =
72            (codeword & 1) ^ ((codeword >> 2) & 1) ^ ((codeword >> 4) & 1) ^ ((codeword >> 6) & 1);
73        let h2 = ((codeword >> 1) & 1)
74            ^ ((codeword >> 2) & 1)
75            ^ ((codeword >> 5) & 1)
76            ^ ((codeword >> 6) & 1);
77        let h3 = ((codeword >> 3) & 1)
78            ^ ((codeword >> 4) & 1)
79            ^ ((codeword >> 5) & 1)
80            ^ ((codeword >> 6) & 1);
81
82        let syndrome = h1 | (h2 << 1) | (h3 << 2);
83
84        let corrected_codeword = if syndrome == 0 {
85            // No error detected
86            codeword
87        } else {
88            // Single-bit error - correct it
89            // The syndrome directly gives us the error position (1-indexed)
90            let error_position = syndrome - 1; // Convert to 0-indexed
91            codeword ^ (1 << error_position)
92        };
93
94        // Extract data bits (positions 3, 5, 6, 7 -> bits 2, 4, 5, 6)
95        let data = ((corrected_codeword >> 2) & 1)
96            | (((corrected_codeword >> 4) & 1) << 1)
97            | (((corrected_codeword >> 5) & 1) << 2)
98            | (((corrected_codeword >> 6) & 1) << 3);
99
100        Ok(data)
101    }
102
103    /// Encode a byte array using Hamming(7,4) codes
104    pub fn encode_bytes(&self, data: &[u8]) -> Vec<u8> {
105        let mut encoded = Vec::new();
106
107        for &byte in data {
108            // Split byte into two 4-bit nibbles
109            let high_nibble = (byte >> 4) & 0x0F;
110            let low_nibble = byte & 0x0F;
111
112            let encoded_high = self.encode(high_nibble);
113            let encoded_low = self.encode(low_nibble);
114
115            // Pack two 7-bit codes into a 16-bit word (with 2 bits unused)
116            let packed = ((encoded_high as u16) << 7) | (encoded_low as u16);
117            encoded.extend_from_slice(&packed.to_le_bytes());
118        }
119
120        encoded
121    }
122
123    /// Decode a byte array encoded with Hamming(7,4) codes
124    pub fn decode_bytes(&self, encoded: &[u8]) -> Result<Vec<u8>, String> {
125        if !encoded.len().is_multiple_of(2) {
126            return Err("Encoded data length must be even".to_string());
127        }
128
129        let mut decoded = Vec::new();
130
131        for chunk in encoded.chunks(2) {
132            let packed = u16::from_le_bytes([chunk[0], chunk[1]]);
133            let encoded_high = ((packed >> 7) & 0x7F) as u8;
134            let encoded_low = (packed & 0x7F) as u8;
135
136            let high_nibble = self.decode(encoded_high)?;
137            let low_nibble = self.decode(encoded_low)?;
138
139            let original_byte = (high_nibble << 4) | low_nibble;
140            decoded.push(original_byte);
141        }
142
143        Ok(decoded)
144    }
145}
146
147impl Default for HammingCode74 {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153/// CRC-32 checksum with SIMD optimization
154pub struct CRC32 {
155    table: [u32; 256],
156}
157
158impl CRC32 {
159    pub fn new() -> Self {
160        let mut table = [0u32; 256];
161
162        // IEEE 802.3 polynomial: 0xEDB88320
163        const POLYNOMIAL: u32 = 0xEDB88320;
164
165        for (i, entry) in table.iter_mut().enumerate() {
166            let mut crc = i as u32;
167            for _ in 0..8 {
168                if crc & 1 != 0 {
169                    crc = (crc >> 1) ^ POLYNOMIAL;
170                } else {
171                    crc >>= 1;
172                }
173            }
174            *entry = crc;
175        }
176
177        Self { table }
178    }
179
180    /// Calculate CRC-32 checksum
181    pub fn checksum(&self, data: &[u8]) -> u32 {
182        let mut crc = 0xFFFFFFFFu32;
183
184        // Process data in chunks for better cache efficiency
185        for &byte in data {
186            let table_index = ((crc ^ byte as u32) & 0xFF) as usize;
187            crc = (crc >> 8) ^ self.table[table_index];
188        }
189
190        crc ^ 0xFFFFFFFF
191    }
192
193    /// Verify data integrity using CRC-32
194    pub fn verify(&self, data: &[u8], expected_crc: u32) -> bool {
195        self.checksum(data) == expected_crc
196    }
197}
198
199impl Default for CRC32 {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205/// Simple Reed-Solomon-like code using finite field arithmetic
206/// This is a simplified implementation for educational purposes
207#[derive(Debug, Clone)]
208pub struct SimpleReedSolomon {
209    n: usize, // Total length
210    k: usize, // Data length
211    t: usize, // Error correction capability
212}
213
214impl SimpleReedSolomon {
215    pub fn new(n: usize, k: usize) -> Self {
216        assert!(n > k, "Total length must be greater than data length");
217        let t = (n - k) / 2;
218        Self { n, k, t }
219    }
220
221    /// Finite field multiplication in GF(256)
222    fn gf_multiply(a: u8, b: u8) -> u8 {
223        if a == 0 || b == 0 {
224            return 0;
225        }
226
227        // Simple multiplication in GF(256) using primitive polynomial x^8 + x^4 + x^3 + x^2 + 1
228        let mut result = 0u8;
229        let mut temp_a = a;
230        let mut temp_b = b;
231
232        for _ in 0..8 {
233            if temp_b & 1 != 0 {
234                result ^= temp_a;
235            }
236            let carry = temp_a & 0x80;
237            temp_a <<= 1;
238            if carry != 0 {
239                temp_a ^= 0x1D; // Primitive polynomial
240            }
241            temp_b >>= 1;
242        }
243
244        result
245    }
246
247    /// Generate parity symbols (simplified)
248    pub fn encode(&self, data: &[u8]) -> Vec<u8> {
249        assert_eq!(data.len(), self.k, "Data length must equal k");
250
251        let mut codeword = vec![0u8; self.n];
252        codeword[..self.k].copy_from_slice(data);
253
254        // Generate parity symbols using systematic encoding
255        for (i, &di) in data.iter().enumerate() {
256            for (j, parity) in codeword[self.k..].iter_mut().enumerate() {
257                let generator_coeff = ((i + j + 1) % 255 + 1) as u8; // Simplified generator
258                *parity ^= Self::gf_multiply(di, generator_coeff);
259            }
260        }
261
262        codeword
263    }
264
265    /// Attempt to decode and correct errors (simplified)
266    pub fn decode(&self, received: &[u8]) -> Result<Vec<u8>, String> {
267        assert_eq!(received.len(), self.n, "Received data length must equal n");
268
269        // Calculate syndromes
270        let mut syndromes = vec![0u8; self.n - self.k];
271        for (i, syn) in syndromes.iter_mut().enumerate() {
272            for (j, &rec) in received.iter().enumerate() {
273                let eval_point = ((i + 1) % 255 + 1) as u8;
274                let power = Self::gf_multiply(eval_point, j as u8);
275                *syn ^= Self::gf_multiply(rec, power);
276            }
277        }
278
279        // Check if there are errors
280        let has_errors = syndromes.iter().any(|&s| s != 0);
281
282        if !has_errors {
283            // No errors detected
284            Ok(received[..self.k].to_vec())
285        } else {
286            // For this simplified implementation, we'll just return an error
287            // A full Reed-Solomon implementation would use the Berlekamp-Massey algorithm
288            Err("Error correction not implemented in this simplified version".to_string())
289        }
290    }
291
292    /// Get the error correction capability
293    pub fn error_correction_capability(&self) -> usize {
294        self.t
295    }
296}
297
298/// Parity check for simple error detection
299pub fn calculate_parity(data: &[u8]) -> u8 {
300    data.iter().fold(0u8, |acc, &byte| acc ^ byte)
301}
302
303/// Even parity check
304pub fn check_even_parity(data: &[u8], parity: u8) -> bool {
305    calculate_parity(data) == parity
306}
307
308/// Odd parity check  
309pub fn check_odd_parity(data: &[u8], parity: u8) -> bool {
310    calculate_parity(data) ^ 1 == parity
311}
312
313/// SIMD-optimized XOR checksum for error detection
314pub fn xor_checksum_simd(data: &[u8]) -> u32 {
315    let mut checksum = 0u32;
316
317    // Process in 4-byte chunks for better efficiency
318    let chunks = data.chunks_exact(4);
319    let remainder = chunks.remainder();
320
321    for chunk in chunks {
322        let word = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
323        checksum ^= word;
324    }
325
326    // Handle remainder bytes
327    for (i, &byte) in remainder.iter().enumerate() {
328        checksum ^= (byte as u32) << (i * 8);
329    }
330
331    checksum
332}
333
334#[allow(non_snake_case)]
335#[cfg(all(test, not(feature = "no-std")))]
336mod tests {
337    use super::*;
338
339    #[cfg(feature = "no-std")]
340    use alloc::{vec, vec::Vec};
341
342    #[test]
343    fn test_hamming_code_encode_decode() {
344        let hamming = HammingCode74::new();
345
346        for data in 0..16u8 {
347            let encoded = hamming.encode(data);
348            let decoded = hamming
349                .decode(encoded)
350                .expect("serialization should succeed");
351            assert_eq!(decoded, data, "Failed for data: {}", data);
352        }
353    }
354
355    #[test]
356    fn test_hamming_code_error_correction() {
357        let hamming = HammingCode74::new();
358        let data = 0b1010; // 10 in binary
359        let encoded = hamming.encode(data);
360
361        // Introduce single-bit errors at each position
362        for error_pos in 0..7 {
363            let corrupted = encoded ^ (1 << error_pos);
364            let decoded = hamming
365                .decode(corrupted)
366                .expect("deserialization should succeed");
367            assert_eq!(
368                decoded, data,
369                "Failed to correct error at position {}",
370                error_pos
371            );
372        }
373    }
374
375    #[test]
376    fn test_hamming_bytes_roundtrip() {
377        let hamming = HammingCode74::new();
378        let data = b"Hello, World!";
379
380        let encoded = hamming.encode_bytes(data);
381        let decoded = hamming
382            .decode_bytes(&encoded)
383            .expect("serialization should succeed");
384
385        assert_eq!(decoded, data);
386    }
387
388    #[test]
389    fn test_crc32() {
390        let crc = CRC32::new();
391        let data = b"Hello, World!";
392
393        let checksum1 = crc.checksum(data);
394        let checksum2 = crc.checksum(data);
395        assert_eq!(checksum1, checksum2, "CRC should be deterministic");
396
397        // Verify integrity
398        assert!(crc.verify(data, checksum1));
399        assert!(!crc.verify(b"Different data", checksum1));
400    }
401
402    #[test]
403    fn test_crc32_different_data() {
404        let crc = CRC32::new();
405
406        let checksum1 = crc.checksum(b"data1");
407        let checksum2 = crc.checksum(b"data2");
408        assert_ne!(
409            checksum1, checksum2,
410            "Different data should have different CRCs"
411        );
412    }
413
414    #[test]
415    fn test_simple_reed_solomon() {
416        let rs = SimpleReedSolomon::new(10, 6); // (10,6) code
417        assert_eq!(rs.error_correction_capability(), 2);
418
419        let data = b"hello!";
420        let encoded = rs.encode(data);
421        assert_eq!(encoded.len(), 10);
422
423        // Test decoding without errors
424        let decoded = rs.decode(&encoded);
425        // Note: our simplified implementation doesn't actually correct errors
426        // In a real implementation, this would work
427        assert!(decoded.is_err() || decoded.expect("deserialization should succeed") == data);
428    }
429
430    #[test]
431    fn test_parity_checks() {
432        let data = b"test data";
433        let parity = calculate_parity(data);
434
435        assert!(check_even_parity(data, parity));
436        assert!(!check_odd_parity(data, parity));
437        assert!(!check_even_parity(data, parity ^ 1));
438        assert!(check_odd_parity(data, parity ^ 1));
439    }
440
441    #[test]
442    fn test_xor_checksum() {
443        let data1 = b"Hello, World!";
444        let data2 = b"Hello, World!";
445        let data3 = b"Different data";
446
447        let checksum1 = xor_checksum_simd(data1);
448        let checksum2 = xor_checksum_simd(data2);
449        let checksum3 = xor_checksum_simd(data3);
450
451        assert_eq!(checksum1, checksum2);
452        assert_ne!(checksum1, checksum3);
453    }
454
455    #[test]
456    fn test_empty_data() {
457        let crc = CRC32::new();
458        let empty_checksum = crc.checksum(&[]);
459        assert!(crc.verify(&[], empty_checksum));
460
461        let empty_parity = calculate_parity(&[]);
462        assert_eq!(empty_parity, 0);
463
464        let empty_xor = xor_checksum_simd(&[]);
465        assert_eq!(empty_xor, 0);
466    }
467
468    #[test]
469    fn test_single_byte() {
470        let hamming = HammingCode74::new();
471        let data = [0x42]; // Single byte
472
473        let encoded = hamming.encode_bytes(&data);
474        let decoded = hamming
475            .decode_bytes(&encoded)
476            .expect("serialization should succeed");
477        assert_eq!(decoded, data);
478    }
479
480    #[test]
481    fn test_gf_multiply() {
482        // Test basic properties of finite field multiplication
483        assert_eq!(SimpleReedSolomon::gf_multiply(0, 5), 0);
484        assert_eq!(SimpleReedSolomon::gf_multiply(5, 0), 0);
485        assert_eq!(SimpleReedSolomon::gf_multiply(1, 5), 5);
486        assert_eq!(SimpleReedSolomon::gf_multiply(5, 1), 5);
487
488        // Test commutativity
489        for a in 1..=10u8 {
490            for b in 1..=10u8 {
491                assert_eq!(
492                    SimpleReedSolomon::gf_multiply(a, b),
493                    SimpleReedSolomon::gf_multiply(b, a)
494                );
495            }
496        }
497    }
498}