Skip to main content

stackforge_core/layer/ipv4/
checksum.rs

1//! IPv4 header checksum calculation.
2//!
3//! Implements RFC 1071 Internet checksum algorithm used for IPv4 headers.
4//! The checksum is computed lazily - only when the packet is serialized.
5
6use crate::utils::internet_checksum;
7
8/// Compute the Internet checksum (RFC 1071) over a byte slice.
9///
10/// This is the standard one's complement sum used for IP, ICMP, TCP, and UDP.
11///
12/// # Algorithm
13///
14/// 1. Sum all 16-bit words in the data
15/// 2. Add any odd byte as the high byte of a word
16/// 3. Fold 32-bit sum to 16 bits by adding carry bits
17/// 4. Take one's complement of the result
18///
19/// # Example
20///
21/// ```rust
22/// use stackforge_core::layer::ipv4::checksum::ipv4_checksum;
23///
24/// let header = [
25///     0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00,
26///     0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a, 0x63,
27///     0xac, 0x10, 0x0a, 0x0c,
28/// ];
29/// let checksum = ipv4_checksum(&header);
30/// ```
31#[inline]
32#[must_use]
33pub fn ipv4_checksum(data: &[u8]) -> u16 {
34    internet_checksum(data)
35}
36
37// Note: internet_checksum, partial_checksum, and finalize_checksum
38// are now imported from crate::utils::checksum
39
40/// Verify that a checksum is valid.
41///
42/// When computed over data that includes a valid checksum, the result
43/// should be 0x0000 or 0xFFFF.
44#[inline]
45#[must_use]
46pub fn verify_ipv4_checksum(data: &[u8]) -> bool {
47    let sum = internet_checksum(data);
48    sum == 0 || sum == 0xFFFF
49}
50
51/// Incrementally update a checksum when a 16-bit field changes.
52///
53/// This is more efficient than recomputing the entire checksum when
54/// only a single field is modified.
55///
56/// # Arguments
57///
58/// * `old_checksum` - The existing checksum value
59/// * `old_value` - The old 16-bit field value
60/// * `new_value` - The new 16-bit field value
61///
62/// # Returns
63///
64/// The updated checksum value.
65#[inline]
66#[must_use]
67pub fn incremental_update_checksum(old_checksum: u16, old_value: u16, new_value: u16) -> u16 {
68    // RFC 1624: HC' = ~(~HC + ~m + m')
69    let hc = u32::from(!old_checksum);
70    let m = u32::from(!old_value);
71    let m_prime = u32::from(new_value);
72
73    let mut sum = hc + m + m_prime;
74
75    // Fold carry
76    while (sum >> 16) != 0 {
77        sum = (sum & 0xFFFF) + (sum >> 16);
78    }
79
80    !sum as u16
81}
82
83/// Incrementally update checksum when a 32-bit field changes.
84///
85/// Useful for updating checksum after IP address changes.
86#[inline]
87#[must_use]
88pub fn incremental_update_checksum_32(old_checksum: u16, old_value: u32, new_value: u32) -> u16 {
89    // Update for high 16 bits, then low 16 bits
90    let old_high = (old_value >> 16) as u16;
91    let old_low = (old_value & 0xFFFF) as u16;
92    let new_high = (new_value >> 16) as u16;
93    let new_low = (new_value & 0xFFFF) as u16;
94
95    let tmp = incremental_update_checksum(old_checksum, old_high, new_high);
96    incremental_update_checksum(tmp, old_low, new_low)
97}
98
99/// Compute the pseudo-header checksum for TCP/UDP.
100///
101/// The pseudo-header includes:
102/// - Source IP (4 bytes)
103/// - Destination IP (4 bytes)
104/// - Zero (1 byte)
105/// - Protocol (1 byte)
106/// - TCP/UDP length (2 bytes)
107///
108/// # Arguments
109///
110/// * `src_ip` - Source IP address (4 bytes)
111/// * `dst_ip` - Destination IP address (4 bytes)
112/// * `protocol` - IP protocol number
113/// * `transport_len` - Length of the transport layer data
114///
115/// # Returns
116///
117/// The partial checksum from the pseudo-header.
118#[must_use]
119pub fn pseudo_header_checksum(
120    src_ip: &[u8; 4],
121    dst_ip: &[u8; 4],
122    protocol: u8,
123    transport_len: u16,
124) -> u32 {
125    let mut sum: u32 = 0;
126
127    // Source IP
128    sum += u32::from(u16::from_be_bytes([src_ip[0], src_ip[1]]));
129    sum += u32::from(u16::from_be_bytes([src_ip[2], src_ip[3]]));
130
131    // Destination IP
132    sum += u32::from(u16::from_be_bytes([dst_ip[0], dst_ip[1]]));
133    sum += u32::from(u16::from_be_bytes([dst_ip[2], dst_ip[3]]));
134
135    // Zero + Protocol
136    sum += u32::from(protocol);
137
138    // Transport length
139    sum += u32::from(transport_len);
140
141    sum
142}
143
144/// Compute complete transport layer checksum (TCP or UDP).
145///
146/// # Arguments
147///
148/// * `src_ip` - Source IP address
149/// * `dst_ip` - Destination IP address
150/// * `protocol` - IP protocol number (6 for TCP, 17 for UDP)
151/// * `transport_data` - The complete transport layer header and payload
152///
153/// # Returns
154///
155/// The computed checksum value.
156#[must_use]
157pub fn transport_checksum(
158    src_ip: &[u8; 4],
159    dst_ip: &[u8; 4],
160    protocol: u8,
161    transport_data: &[u8],
162) -> u16 {
163    // Start with pseudo-header checksum
164    let mut sum = pseudo_header_checksum(src_ip, dst_ip, protocol, transport_data.len() as u16);
165
166    // Add transport data
167    let mut chunks = transport_data.chunks_exact(2);
168    for chunk in chunks.by_ref() {
169        sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
170    }
171
172    // Handle odd byte
173    if let Some(&last) = chunks.remainder().first() {
174        sum += u32::from(last) << 8;
175    }
176
177    // Fold and complement
178    while (sum >> 16) != 0 {
179        sum = (sum & 0xFFFF) + (sum >> 16);
180    }
181
182    // For UDP, if checksum is 0, use 0xFFFF instead (RFC 768)
183    let result = !sum as u16;
184    if result == 0 && protocol == 17 {
185        0xFFFF
186    } else {
187        result
188    }
189}
190
191// Note: partial_checksum and finalize_checksum moved to crate::utils::checksum
192// and are re-exported from this module for backwards compatibility.
193pub use crate::utils::{finalize_checksum, partial_checksum};
194
195/// Zero out checksum field in a buffer at the specified offset.
196#[inline]
197pub fn zero_checksum(buf: &mut [u8], offset: usize) {
198    if buf.len() >= offset + 2 {
199        buf[offset] = 0;
200        buf[offset + 1] = 0;
201    }
202}
203
204/// Write checksum to buffer at the specified offset.
205#[inline]
206pub fn write_checksum(buf: &mut [u8], offset: usize, checksum: u16) {
207    if buf.len() >= offset + 2 {
208        let bytes = checksum.to_be_bytes();
209        buf[offset] = bytes[0];
210        buf[offset + 1] = bytes[1];
211    }
212}
213
214/// Read checksum from buffer at the specified offset.
215#[inline]
216#[must_use]
217pub fn read_checksum(buf: &[u8], offset: usize) -> Option<u16> {
218    if buf.len() >= offset + 2 {
219        Some(u16::from_be_bytes([buf[offset], buf[offset + 1]]))
220    } else {
221        None
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_ipv4_checksum() {
231        // Sample IPv4 header from RFC 1071
232        let header = [
233            0x45, 0x00, 0x00, 0x3c, // Version, IHL, TOS, Total Length
234            0x1c, 0x46, 0x40, 0x00, // ID, Flags, Fragment Offset
235            0x40, 0x06, 0x00, 0x00, // TTL, Protocol, Checksum (zeroed)
236            0xac, 0x10, 0x0a, 0x63, // Source: 172.16.10.99
237            0xac, 0x10, 0x0a, 0x0c, // Dest: 172.16.10.12
238        ];
239
240        let checksum = ipv4_checksum(&header);
241
242        // Place checksum in header and verify
243        let mut header_with_cksum = header;
244        header_with_cksum[10] = (checksum >> 8) as u8;
245        header_with_cksum[11] = (checksum & 0xFF) as u8;
246
247        assert!(verify_ipv4_checksum(&header_with_cksum));
248    }
249
250    #[test]
251    fn test_verify_valid_checksum() {
252        // Header with pre-computed valid checksum
253        let header = [
254            0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xb1,
255            0xe6, // checksum
256            0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
257        ];
258
259        assert!(verify_ipv4_checksum(&header));
260    }
261
262    #[test]
263    fn test_verify_invalid_checksum() {
264        // Header with corrupted checksum
265        let header = [
266            0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xFF,
267            0xFF, // bad checksum
268            0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
269        ];
270
271        assert!(!verify_ipv4_checksum(&header));
272    }
273
274    #[test]
275    fn test_incremental_update() {
276        // Original header with valid checksum
277        let mut header = [
278            0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10,
279            0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
280        ];
281
282        // Compute initial checksum
283        let initial_checksum = ipv4_checksum(&header);
284        header[10] = (initial_checksum >> 8) as u8;
285        header[11] = (initial_checksum & 0xFF) as u8;
286
287        // Change TTL from 0x40 to 0x3F using incremental update
288        let old_ttl_word = u16::from_be_bytes([header[8], header[9]]);
289        header[8] = 0x3F;
290        let new_ttl_word = u16::from_be_bytes([header[8], header[9]]);
291
292        let new_checksum =
293            incremental_update_checksum(initial_checksum, old_ttl_word, new_ttl_word);
294        header[10] = (new_checksum >> 8) as u8;
295        header[11] = (new_checksum & 0xFF) as u8;
296
297        // Verify the incrementally updated checksum is valid
298        assert!(verify_ipv4_checksum(&header));
299    }
300
301    #[test]
302    fn test_pseudo_header_checksum() {
303        let src = [192, 168, 1, 1];
304        let dst = [192, 168, 1, 2];
305        let protocol = 6; // TCP
306        let length = 20; // TCP header only
307
308        let sum = pseudo_header_checksum(&src, &dst, protocol, length);
309
310        // Verify sum contains expected components
311        assert!(sum > 0);
312    }
313
314    #[test]
315    fn test_transport_checksum_tcp() {
316        let src_ip = [192, 168, 1, 1];
317        let dst_ip = [192, 168, 1, 2];
318
319        // Minimal TCP header with zeroed checksum
320        let tcp_header = [
321            0x00, 0x50, // Source port: 80
322            0x1F, 0x90, // Dest port: 8080
323            0x00, 0x00, 0x00, 0x01, // Seq number
324            0x00, 0x00, 0x00, 0x00, // Ack number
325            0x50, 0x02, // Data offset + flags (SYN)
326            0xFF, 0xFF, // Window
327            0x00, 0x00, // Checksum (zeroed)
328            0x00, 0x00, // Urgent pointer
329        ];
330
331        let checksum = transport_checksum(&src_ip, &dst_ip, 6, &tcp_header);
332        assert_ne!(checksum, 0);
333    }
334
335    #[test]
336    fn test_transport_checksum_udp_zero() {
337        let src_ip = [0, 0, 0, 0];
338        let dst_ip = [0, 0, 0, 0];
339
340        // UDP header that would result in zero checksum
341        // For UDP, zero should become 0xFFFF
342        let udp_header = [
343            0x00, 0x00, // Source port
344            0x00, 0x00, // Dest port
345            0x00, 0x08, // Length
346            0x00, 0x00, // Checksum (zeroed)
347        ];
348
349        let checksum = transport_checksum(&src_ip, &dst_ip, 17, &udp_header);
350        // UDP checksum should never be 0 (use 0xFFFF instead)
351        assert_ne!(checksum, 0);
352    }
353
354    #[test]
355    fn test_partial_checksum() {
356        let data1 = [0x01, 0x02, 0x03, 0x04];
357        let data2 = [0x05, 0x06, 0x07, 0x08];
358
359        // Compute separately and combine
360        let sum1 = partial_checksum(&data1, 0);
361        let sum2 = partial_checksum(&data2, sum1);
362        let checksum1 = finalize_checksum(sum2);
363
364        // Compute together
365        let combined = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
366        let checksum2 = internet_checksum(&combined);
367
368        assert_eq!(checksum1, checksum2);
369    }
370
371    #[test]
372    fn test_odd_length_data() {
373        // Test with odd number of bytes
374        let data = [0x45, 0x00, 0x00, 0x3c, 0x1c];
375        let checksum = internet_checksum(&data);
376
377        // Should handle odd byte correctly
378        assert_ne!(checksum, 0);
379    }
380
381    #[test]
382    fn test_zero_and_write_checksum() {
383        let mut buf = [0x45, 0x00, 0xAB, 0xCD, 0x00, 0x00];
384
385        zero_checksum(&mut buf, 2);
386        assert_eq!(buf[2], 0);
387        assert_eq!(buf[3], 0);
388
389        write_checksum(&mut buf, 2, 0x1234);
390        assert_eq!(buf[2], 0x12);
391        assert_eq!(buf[3], 0x34);
392
393        assert_eq!(read_checksum(&buf, 2), Some(0x1234));
394    }
395}