Skip to main content

stackforge_core/utils/
checksum.rs

1//! Checksum calculation and verification utilities.
2//!
3//! Provides generic RFC 1071 Internet checksum implementation that can be used
4//! by all protocol layers (IP, ICMP, TCP, UDP, etc.).
5
6/// Calculate the Internet checksum (RFC 1071).
7///
8/// This is used for IP, ICMP, TCP, and UDP checksums.
9///
10/// # Algorithm
11///
12/// 1. Sum all 16-bit words in the data
13/// 2. Add any odd byte as the high byte of a word
14/// 3. Fold 32-bit sum to 16 bits by adding carry bits
15/// 4. Take one's complement of the result
16pub fn internet_checksum(data: &[u8]) -> u16 {
17    let sum = partial_checksum(data, 0);
18    finalize_checksum(sum)
19}
20
21/// Compute partial checksum (before folding and complement).
22///
23/// Useful for computing checksum across multiple data segments or when
24/// combining pseudo-header with data.
25///
26/// # Arguments
27///
28/// * `data` - The data to checksum
29/// * `initial` - Initial checksum value (use 0 for starting fresh)
30///
31/// # Returns
32///
33/// The 32-bit partial sum (not yet folded or complemented).
34#[inline]
35pub fn partial_checksum(data: &[u8], initial: u32) -> u32 {
36    let mut sum = initial;
37
38    // Process 16-bit words
39    let mut chunks = data.chunks_exact(2);
40    for chunk in chunks.by_ref() {
41        sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
42    }
43
44    // Handle odd byte (pad with zero on the right)
45    if let Some(&last) = chunks.remainder().first() {
46        sum += (last as u32) << 8;
47    }
48
49    sum
50}
51
52/// Finalize a partial checksum.
53///
54/// Folds the 32-bit sum to 16 bits by adding carry bits, then takes
55/// one's complement.
56///
57/// # Arguments
58///
59/// * `sum` - The 32-bit partial sum from `partial_checksum()`
60///
61/// # Returns
62///
63/// The final 16-bit checksum value.
64#[inline]
65pub fn finalize_checksum(sum: u32) -> u16 {
66    let mut s = sum;
67    // Fold 32-bit sum to 16 bits (add carry)
68    while (s >> 16) != 0 {
69        s = (s & 0xFFFF) + (s >> 16);
70    }
71    // One's complement
72    !s as u16
73}
74
75/// Calculate checksum with pseudo-header (for TCP/UDP).
76pub fn transport_checksum(src_ip: &[u8], dst_ip: &[u8], protocol: u8, data: &[u8]) -> u16 {
77    let mut pseudo_header = Vec::with_capacity(12 + data.len());
78
79    // Source IP
80    pseudo_header.extend_from_slice(src_ip);
81    // Destination IP
82    pseudo_header.extend_from_slice(dst_ip);
83    // Zero
84    pseudo_header.push(0);
85    // Protocol
86    pseudo_header.push(protocol);
87    // Length (big-endian)
88    let len = data.len() as u16;
89    pseudo_header.extend_from_slice(&len.to_be_bytes());
90    // Data
91    pseudo_header.extend_from_slice(data);
92
93    internet_checksum(&pseudo_header)
94}
95
96/// Verify a checksum is valid (should be 0 or 0xFFFF when calculated over data with checksum).
97pub fn verify_checksum(data: &[u8]) -> bool {
98    let sum = internet_checksum(data);
99    sum == 0 || sum == 0xFFFF
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn test_internet_checksum() {
108        // Test with known values from RFC 1071
109        let data = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
110        let checksum = internet_checksum(&data);
111        // The result should fold correctly
112        assert_ne!(checksum, 0); // Non-zero for this data
113    }
114
115    #[test]
116    fn test_checksum_verify() {
117        // Create data with valid checksum
118        let mut data = vec![0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00];
119        data.extend_from_slice(&[0x40, 0x06, 0x00, 0x00]); // checksum = 0 initially
120        data.extend_from_slice(&[0xac, 0x10, 0x0a, 0x63]); // src IP
121        data.extend_from_slice(&[0xac, 0x10, 0x0a, 0x0c]); // dst IP
122
123        let checksum = internet_checksum(&data);
124        // Set the checksum
125        data[10] = (checksum >> 8) as u8;
126        data[11] = checksum as u8;
127
128        // Now verification should pass
129        assert!(verify_checksum(&data));
130    }
131}