Skip to main content

stackforge_core/utils/
checksum.rs

1//! Checksum calculation and verification utilities.
2//!
3//! Provides generic RFC 1071 Internet checksum implementation that can be used
4//! by all protocol layers (IP, ICMP, TCP, UDP, etc.).
5
6/// Calculate the Internet checksum (RFC 1071).
7///
8/// This is used for IP, ICMP, TCP, and UDP checksums.
9///
10/// # Algorithm
11///
12/// 1. Sum all 16-bit words in the data
13/// 2. Add any odd byte as the high byte of a word
14/// 3. Fold 32-bit sum to 16 bits by adding carry bits
15/// 4. Take one's complement of the result
16#[must_use]
17pub fn internet_checksum(data: &[u8]) -> u16 {
18    let sum = partial_checksum(data, 0);
19    finalize_checksum(sum)
20}
21
22/// Compute partial checksum (before folding and complement).
23///
24/// Useful for computing checksum across multiple data segments or when
25/// combining pseudo-header with data.
26///
27/// # Arguments
28///
29/// * `data` - The data to checksum
30/// * `initial` - Initial checksum value (use 0 for starting fresh)
31///
32/// # Returns
33///
34/// The 32-bit partial sum (not yet folded or complemented).
35#[inline]
36#[must_use]
37pub fn partial_checksum(data: &[u8], initial: u32) -> u32 {
38    let mut sum = initial;
39
40    // Process 16-bit words
41    let mut chunks = data.chunks_exact(2);
42    for chunk in chunks.by_ref() {
43        sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
44    }
45
46    // Handle odd byte (pad with zero on the right)
47    if let Some(&last) = chunks.remainder().first() {
48        sum += u32::from(last) << 8;
49    }
50
51    sum
52}
53
54/// Finalize a partial checksum.
55///
56/// Folds the 32-bit sum to 16 bits by adding carry bits, then takes
57/// one's complement.
58///
59/// # Arguments
60///
61/// * `sum` - The 32-bit partial sum from `partial_checksum()`
62///
63/// # Returns
64///
65/// The final 16-bit checksum value.
66#[inline]
67#[must_use]
68pub fn finalize_checksum(sum: u32) -> u16 {
69    let mut s = sum;
70    // Fold 32-bit sum to 16 bits (add carry)
71    while (s >> 16) != 0 {
72        s = (s & 0xFFFF) + (s >> 16);
73    }
74    // One's complement
75    !s as u16
76}
77
78/// Calculate checksum with pseudo-header (for TCP/UDP).
79#[must_use]
80pub fn transport_checksum(src_ip: &[u8], dst_ip: &[u8], protocol: u8, data: &[u8]) -> u16 {
81    let mut pseudo_header = Vec::with_capacity(12 + data.len());
82
83    // Source IP
84    pseudo_header.extend_from_slice(src_ip);
85    // Destination IP
86    pseudo_header.extend_from_slice(dst_ip);
87    // Zero
88    pseudo_header.push(0);
89    // Protocol
90    pseudo_header.push(protocol);
91    // Length (big-endian)
92    let len = data.len() as u16;
93    pseudo_header.extend_from_slice(&len.to_be_bytes());
94    // Data
95    pseudo_header.extend_from_slice(data);
96
97    internet_checksum(&pseudo_header)
98}
99
100/// Verify a checksum is valid (should be 0 or 0xFFFF when calculated over data with checksum).
101#[must_use]
102pub fn verify_checksum(data: &[u8]) -> bool {
103    let sum = internet_checksum(data);
104    sum == 0 || sum == 0xFFFF
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_internet_checksum() {
113        // Test with known values from RFC 1071
114        let data = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
115        let checksum = internet_checksum(&data);
116        // The result should fold correctly
117        assert_ne!(checksum, 0); // Non-zero for this data
118    }
119
120    #[test]
121    fn test_checksum_verify() {
122        // Create data with valid checksum
123        let mut data = vec![0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00];
124        data.extend_from_slice(&[0x40, 0x06, 0x00, 0x00]); // checksum = 0 initially
125        data.extend_from_slice(&[0xac, 0x10, 0x0a, 0x63]); // src IP
126        data.extend_from_slice(&[0xac, 0x10, 0x0a, 0x0c]); // dst IP
127
128        let checksum = internet_checksum(&data);
129        // Set the checksum
130        data[10] = (checksum >> 8) as u8;
131        data[11] = checksum as u8;
132
133        // Now verification should pass
134        assert!(verify_checksum(&data));
135    }
136}