Skip to main content

stackforge_core/layer/ipv4/
checksum.rs

1//! IPv4 header checksum calculation.
2//!
3//! Implements RFC 1071 Internet checksum algorithm used for IPv4 headers.
4//! The checksum is computed lazily - only when the packet is serialized.
5
6use crate::utils::internet_checksum;
7
8/// Compute the Internet checksum (RFC 1071) over a byte slice.
9///
10/// This is the standard one's complement sum used for IP, ICMP, TCP, and UDP.
11///
12/// # Algorithm
13///
14/// 1. Sum all 16-bit words in the data
15/// 2. Add any odd byte as the high byte of a word
16/// 3. Fold 32-bit sum to 16 bits by adding carry bits
17/// 4. Take one's complement of the result
18///
19/// # Example
20///
21/// ```rust
22/// use stackforge_core::layer::ipv4::checksum::ipv4_checksum;
23///
24/// let header = [
25///     0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00,
26///     0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a, 0x63,
27///     0xac, 0x10, 0x0a, 0x0c,
28/// ];
29/// let checksum = ipv4_checksum(&header);
30/// ```
31#[inline]
32pub fn ipv4_checksum(data: &[u8]) -> u16 {
33    internet_checksum(data)
34}
35
36// Note: internet_checksum, partial_checksum, and finalize_checksum
37// are now imported from crate::utils::checksum
38
39/// Verify that a checksum is valid.
40///
41/// When computed over data that includes a valid checksum, the result
42/// should be 0x0000 or 0xFFFF.
43#[inline]
44pub fn verify_ipv4_checksum(data: &[u8]) -> bool {
45    let sum = internet_checksum(data);
46    sum == 0 || sum == 0xFFFF
47}
48
49/// Incrementally update a checksum when a 16-bit field changes.
50///
51/// This is more efficient than recomputing the entire checksum when
52/// only a single field is modified.
53///
54/// # Arguments
55///
56/// * `old_checksum` - The existing checksum value
57/// * `old_value` - The old 16-bit field value
58/// * `new_value` - The new 16-bit field value
59///
60/// # Returns
61///
62/// The updated checksum value.
63#[inline]
64pub fn incremental_update_checksum(old_checksum: u16, old_value: u16, new_value: u16) -> u16 {
65    // RFC 1624: HC' = ~(~HC + ~m + m')
66    let hc = !old_checksum as u32;
67    let m = !old_value as u32;
68    let m_prime = new_value as u32;
69
70    let mut sum = hc + m + m_prime;
71
72    // Fold carry
73    while (sum >> 16) != 0 {
74        sum = (sum & 0xFFFF) + (sum >> 16);
75    }
76
77    !sum as u16
78}
79
80/// Incrementally update checksum when a 32-bit field changes.
81///
82/// Useful for updating checksum after IP address changes.
83#[inline]
84pub fn incremental_update_checksum_32(old_checksum: u16, old_value: u32, new_value: u32) -> u16 {
85    // Update for high 16 bits, then low 16 bits
86    let old_high = (old_value >> 16) as u16;
87    let old_low = (old_value & 0xFFFF) as u16;
88    let new_high = (new_value >> 16) as u16;
89    let new_low = (new_value & 0xFFFF) as u16;
90
91    let tmp = incremental_update_checksum(old_checksum, old_high, new_high);
92    incremental_update_checksum(tmp, old_low, new_low)
93}
94
95/// Compute the pseudo-header checksum for TCP/UDP.
96///
97/// The pseudo-header includes:
98/// - Source IP (4 bytes)
99/// - Destination IP (4 bytes)
100/// - Zero (1 byte)
101/// - Protocol (1 byte)
102/// - TCP/UDP length (2 bytes)
103///
104/// # Arguments
105///
106/// * `src_ip` - Source IP address (4 bytes)
107/// * `dst_ip` - Destination IP address (4 bytes)
108/// * `protocol` - IP protocol number
109/// * `transport_len` - Length of the transport layer data
110///
111/// # Returns
112///
113/// The partial checksum from the pseudo-header.
114pub fn pseudo_header_checksum(
115    src_ip: &[u8; 4],
116    dst_ip: &[u8; 4],
117    protocol: u8,
118    transport_len: u16,
119) -> u32 {
120    let mut sum: u32 = 0;
121
122    // Source IP
123    sum += u16::from_be_bytes([src_ip[0], src_ip[1]]) as u32;
124    sum += u16::from_be_bytes([src_ip[2], src_ip[3]]) as u32;
125
126    // Destination IP
127    sum += u16::from_be_bytes([dst_ip[0], dst_ip[1]]) as u32;
128    sum += u16::from_be_bytes([dst_ip[2], dst_ip[3]]) as u32;
129
130    // Zero + Protocol
131    sum += protocol as u32;
132
133    // Transport length
134    sum += transport_len as u32;
135
136    sum
137}
138
139/// Compute complete transport layer checksum (TCP or UDP).
140///
141/// # Arguments
142///
143/// * `src_ip` - Source IP address
144/// * `dst_ip` - Destination IP address
145/// * `protocol` - IP protocol number (6 for TCP, 17 for UDP)
146/// * `transport_data` - The complete transport layer header and payload
147///
148/// # Returns
149///
150/// The computed checksum value.
151pub fn transport_checksum(
152    src_ip: &[u8; 4],
153    dst_ip: &[u8; 4],
154    protocol: u8,
155    transport_data: &[u8],
156) -> u16 {
157    // Start with pseudo-header checksum
158    let mut sum = pseudo_header_checksum(src_ip, dst_ip, protocol, transport_data.len() as u16);
159
160    // Add transport data
161    let mut chunks = transport_data.chunks_exact(2);
162    for chunk in chunks.by_ref() {
163        sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
164    }
165
166    // Handle odd byte
167    if let Some(&last) = chunks.remainder().first() {
168        sum += (last as u32) << 8;
169    }
170
171    // Fold and complement
172    while (sum >> 16) != 0 {
173        sum = (sum & 0xFFFF) + (sum >> 16);
174    }
175
176    // For UDP, if checksum is 0, use 0xFFFF instead (RFC 768)
177    let result = !sum as u16;
178    if result == 0 && protocol == 17 {
179        0xFFFF
180    } else {
181        result
182    }
183}
184
185// Note: partial_checksum and finalize_checksum moved to crate::utils::checksum
186// and are re-exported from this module for backwards compatibility.
187pub use crate::utils::{finalize_checksum, partial_checksum};
188
189/// Zero out checksum field in a buffer at the specified offset.
190#[inline]
191pub fn zero_checksum(buf: &mut [u8], offset: usize) {
192    if buf.len() >= offset + 2 {
193        buf[offset] = 0;
194        buf[offset + 1] = 0;
195    }
196}
197
198/// Write checksum to buffer at the specified offset.
199#[inline]
200pub fn write_checksum(buf: &mut [u8], offset: usize, checksum: u16) {
201    if buf.len() >= offset + 2 {
202        let bytes = checksum.to_be_bytes();
203        buf[offset] = bytes[0];
204        buf[offset + 1] = bytes[1];
205    }
206}
207
208/// Read checksum from buffer at the specified offset.
209#[inline]
210pub fn read_checksum(buf: &[u8], offset: usize) -> Option<u16> {
211    if buf.len() >= offset + 2 {
212        Some(u16::from_be_bytes([buf[offset], buf[offset + 1]]))
213    } else {
214        None
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_ipv4_checksum() {
224        // Sample IPv4 header from RFC 1071
225        let header = [
226            0x45, 0x00, 0x00, 0x3c, // Version, IHL, TOS, Total Length
227            0x1c, 0x46, 0x40, 0x00, // ID, Flags, Fragment Offset
228            0x40, 0x06, 0x00, 0x00, // TTL, Protocol, Checksum (zeroed)
229            0xac, 0x10, 0x0a, 0x63, // Source: 172.16.10.99
230            0xac, 0x10, 0x0a, 0x0c, // Dest: 172.16.10.12
231        ];
232
233        let checksum = ipv4_checksum(&header);
234
235        // Place checksum in header and verify
236        let mut header_with_cksum = header;
237        header_with_cksum[10] = (checksum >> 8) as u8;
238        header_with_cksum[11] = (checksum & 0xFF) as u8;
239
240        assert!(verify_ipv4_checksum(&header_with_cksum));
241    }
242
243    #[test]
244    fn test_verify_valid_checksum() {
245        // Header with pre-computed valid checksum
246        let header = [
247            0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xb1,
248            0xe6, // checksum
249            0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
250        ];
251
252        assert!(verify_ipv4_checksum(&header));
253    }
254
255    #[test]
256    fn test_verify_invalid_checksum() {
257        // Header with corrupted checksum
258        let header = [
259            0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xFF,
260            0xFF, // bad checksum
261            0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
262        ];
263
264        assert!(!verify_ipv4_checksum(&header));
265    }
266
267    #[test]
268    fn test_incremental_update() {
269        // Original header with valid checksum
270        let mut header = [
271            0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10,
272            0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
273        ];
274
275        // Compute initial checksum
276        let initial_checksum = ipv4_checksum(&header);
277        header[10] = (initial_checksum >> 8) as u8;
278        header[11] = (initial_checksum & 0xFF) as u8;
279
280        // Change TTL from 0x40 to 0x3F using incremental update
281        let old_ttl_word = u16::from_be_bytes([header[8], header[9]]);
282        header[8] = 0x3F;
283        let new_ttl_word = u16::from_be_bytes([header[8], header[9]]);
284
285        let new_checksum =
286            incremental_update_checksum(initial_checksum, old_ttl_word, new_ttl_word);
287        header[10] = (new_checksum >> 8) as u8;
288        header[11] = (new_checksum & 0xFF) as u8;
289
290        // Verify the incrementally updated checksum is valid
291        assert!(verify_ipv4_checksum(&header));
292    }
293
294    #[test]
295    fn test_pseudo_header_checksum() {
296        let src = [192, 168, 1, 1];
297        let dst = [192, 168, 1, 2];
298        let protocol = 6; // TCP
299        let length = 20; // TCP header only
300
301        let sum = pseudo_header_checksum(&src, &dst, protocol, length);
302
303        // Verify sum contains expected components
304        assert!(sum > 0);
305    }
306
307    #[test]
308    fn test_transport_checksum_tcp() {
309        let src_ip = [192, 168, 1, 1];
310        let dst_ip = [192, 168, 1, 2];
311
312        // Minimal TCP header with zeroed checksum
313        let tcp_header = [
314            0x00, 0x50, // Source port: 80
315            0x1F, 0x90, // Dest port: 8080
316            0x00, 0x00, 0x00, 0x01, // Seq number
317            0x00, 0x00, 0x00, 0x00, // Ack number
318            0x50, 0x02, // Data offset + flags (SYN)
319            0xFF, 0xFF, // Window
320            0x00, 0x00, // Checksum (zeroed)
321            0x00, 0x00, // Urgent pointer
322        ];
323
324        let checksum = transport_checksum(&src_ip, &dst_ip, 6, &tcp_header);
325        assert_ne!(checksum, 0);
326    }
327
328    #[test]
329    fn test_transport_checksum_udp_zero() {
330        let src_ip = [0, 0, 0, 0];
331        let dst_ip = [0, 0, 0, 0];
332
333        // UDP header that would result in zero checksum
334        // For UDP, zero should become 0xFFFF
335        let udp_header = [
336            0x00, 0x00, // Source port
337            0x00, 0x00, // Dest port
338            0x00, 0x08, // Length
339            0x00, 0x00, // Checksum (zeroed)
340        ];
341
342        let checksum = transport_checksum(&src_ip, &dst_ip, 17, &udp_header);
343        // UDP checksum should never be 0 (use 0xFFFF instead)
344        assert_ne!(checksum, 0);
345    }
346
347    #[test]
348    fn test_partial_checksum() {
349        let data1 = [0x01, 0x02, 0x03, 0x04];
350        let data2 = [0x05, 0x06, 0x07, 0x08];
351
352        // Compute separately and combine
353        let sum1 = partial_checksum(&data1, 0);
354        let sum2 = partial_checksum(&data2, sum1);
355        let checksum1 = finalize_checksum(sum2);
356
357        // Compute together
358        let combined = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
359        let checksum2 = internet_checksum(&combined);
360
361        assert_eq!(checksum1, checksum2);
362    }
363
364    #[test]
365    fn test_odd_length_data() {
366        // Test with odd number of bytes
367        let data = [0x45, 0x00, 0x00, 0x3c, 0x1c];
368        let checksum = internet_checksum(&data);
369
370        // Should handle odd byte correctly
371        assert_ne!(checksum, 0);
372    }
373
374    #[test]
375    fn test_zero_and_write_checksum() {
376        let mut buf = [0x45, 0x00, 0xAB, 0xCD, 0x00, 0x00];
377
378        zero_checksum(&mut buf, 2);
379        assert_eq!(buf[2], 0);
380        assert_eq!(buf[3], 0);
381
382        write_checksum(&mut buf, 2, 0x1234);
383        assert_eq!(buf[2], 0x12);
384        assert_eq!(buf[3], 0x34);
385
386        assert_eq!(read_checksum(&buf, 2), Some(0x1234));
387    }
388}