Skip to main content

stackforge_core/layer/udp/
checksum.rs

1//! UDP checksum calculation and verification.
2//!
3//! UDP uses the same pseudo-header as TCP for checksum calculation.
4//! For IPv4, the checksum is optional (can be 0), but for IPv6 it's mandatory.
5//! Per RFC 768, if the computed checksum is 0, it should be set to 0xFFFF.
6
7use std::net::{Ipv4Addr, Ipv6Addr};
8
9use crate::utils::{finalize_checksum, partial_checksum};
10
11/// UDP protocol number.
12const UDP_PROTOCOL: u8 = 17;
13
14/// Calculate UDP checksum with IPv4 pseudo-header.
15///
16/// # Arguments
17/// * `src_ip` - Source IPv4 address
18/// * `dst_ip` - Destination IPv4 address
19/// * `udp_data` - Complete UDP packet (header + payload)
20///
21/// # Returns
22/// The calculated checksum. Per RFC 768, if result is 0, returns 0xFFFF.
23#[must_use]
24pub fn udp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, udp_data: &[u8]) -> u16 {
25    let mut sum: u32 = 0;
26
27    // Add IPv4 pseudo-header
28    // Source IP (4 bytes)
29    for chunk in src_ip.octets().chunks(2) {
30        sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
31    }
32
33    // Destination IP (4 bytes)
34    for chunk in dst_ip.octets().chunks(2) {
35        sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
36    }
37
38    // Zero + Protocol (2 bytes): 0x00 (zero) + 0x11 (UDP) = 0x0011 as u16
39    sum += u32::from(UDP_PROTOCOL);
40
41    // UDP Length (2 bytes)
42    sum += udp_data.len() as u32;
43
44    // Add UDP data
45    sum = partial_checksum(udp_data, sum);
46
47    let checksum = finalize_checksum(sum);
48
49    // RFC 768: If computed checksum is 0, use 0xFFFF
50    if checksum == 0 { 0xFFFF } else { checksum }
51}
52
53/// Calculate UDP checksum with IPv6 pseudo-header.
54///
55/// # Arguments
56/// * `src_ip` - Source IPv6 address
57/// * `dst_ip` - Destination IPv6 address
58/// * `udp_data` - Complete UDP packet (header + payload)
59///
60/// # Returns
61/// The calculated checksum. Per RFC 2460, if result is 0, returns 0xFFFF.
62#[must_use]
63pub fn udp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, udp_data: &[u8]) -> u16 {
64    let mut sum: u32 = 0;
65
66    // Add IPv6 pseudo-header
67    // Source IP (16 bytes)
68    for chunk in src_ip.octets().chunks(2) {
69        sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
70    }
71
72    // Destination IP (16 bytes)
73    for chunk in dst_ip.octets().chunks(2) {
74        sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
75    }
76
77    // UDP Length (4 bytes, big-endian)
78    let udp_len = udp_data.len() as u32;
79    sum += udp_len >> 16;
80    sum += udp_len & 0xFFFF;
81
82    // Three zero bytes + Next Header (4 bytes total)
83    sum += u32::from(UDP_PROTOCOL);
84
85    // Add UDP data
86    sum = partial_checksum(udp_data, sum);
87
88    let checksum = finalize_checksum(sum);
89
90    // RFC 2460: If computed checksum is 0, use 0xFFFF
91    if checksum == 0 { 0xFFFF } else { checksum }
92}
93
94/// Verify UDP checksum with IPv4 pseudo-header.
95#[must_use]
96pub fn verify_udp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, udp_data: &[u8]) -> bool {
97    if udp_data.len() < 8 {
98        return false;
99    }
100
101    // Check if checksum is 0 (optional for IPv4)
102    let stored_checksum = u16::from_be_bytes([udp_data[6], udp_data[7]]);
103    if stored_checksum == 0 {
104        return true; // No checksum provided, considered valid
105    }
106
107    let calculated = udp_checksum_ipv4(src_ip, dst_ip, udp_data);
108
109    // The checksum should be 0xFFFF when calculated over data with valid checksum
110    calculated == 0xFFFF || calculated == stored_checksum
111}
112
113/// Verify UDP checksum with IPv6 pseudo-header.
114#[must_use]
115pub fn verify_udp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, udp_data: &[u8]) -> bool {
116    if udp_data.len() < 8 {
117        return false;
118    }
119
120    let stored_checksum = u16::from_be_bytes([udp_data[6], udp_data[7]]);
121
122    // For IPv6, checksum is mandatory (cannot be 0)
123    if stored_checksum == 0 {
124        return false;
125    }
126
127    let calculated = udp_checksum_ipv6(src_ip, dst_ip, udp_data);
128
129    // The checksum should be 0xFFFF when calculated over data with valid checksum
130    calculated == 0xFFFF || calculated == stored_checksum
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_udp_checksum_ipv4() {
139        // Create a simple UDP packet
140        let mut udp_data = vec![
141            0x04, 0xd2, // sport = 1234
142            0x00, 0x35, // dport = 53
143            0x00, 0x0c, // length = 12
144            0x00, 0x00, // checksum = 0 (to be calculated)
145            0x01, 0x02, 0x03, 0x04, // payload
146        ];
147
148        let src_ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
149        let dst_ip: Ipv4Addr = "8.8.8.8".parse().unwrap();
150
151        let checksum = udp_checksum_ipv4(src_ip, dst_ip, &udp_data);
152
153        // Set the checksum in the packet
154        udp_data[6] = (checksum >> 8) as u8;
155        udp_data[7] = (checksum & 0xFF) as u8;
156
157        // Verify the checksum
158        assert!(verify_udp_checksum_ipv4(src_ip, dst_ip, &udp_data));
159    }
160
161    #[test]
162    fn test_udp_checksum_ipv4_zero() {
163        // Test that computed checksum of 0 becomes 0xFFFF
164        // This is a contrived example; in practice, getting a 0 checksum is rare
165        let udp_data = vec![
166            0x00, 0x00, // sport
167            0x00, 0x00, // dport
168            0x00, 0x08, // length = 8 (header only)
169            0x00, 0x00, // checksum placeholder
170        ];
171
172        let src_ip: Ipv4Addr = "0.0.0.0".parse().unwrap();
173        let dst_ip: Ipv4Addr = "0.0.0.0".parse().unwrap();
174
175        let checksum = udp_checksum_ipv4(src_ip, dst_ip, &udp_data);
176
177        // The checksum should be 0xFFFF if raw calculation was 0
178        assert_ne!(checksum, 0);
179    }
180
181    #[test]
182    fn test_udp_checksum_ipv6() {
183        let mut udp_data = vec![
184            0x04, 0xd2, // sport = 1234
185            0x00, 0x35, // dport = 53
186            0x00, 0x0c, // length = 12
187            0x00, 0x00, // checksum = 0 (to be calculated)
188            0x01, 0x02, 0x03, 0x04, // payload
189        ];
190
191        let src_ip: Ipv6Addr = "2001:db8::1".parse().unwrap();
192        let dst_ip: Ipv6Addr = "2001:db8::2".parse().unwrap();
193
194        let checksum = udp_checksum_ipv6(src_ip, dst_ip, &udp_data);
195
196        // Checksum should not be 0 for IPv6
197        assert_ne!(checksum, 0);
198
199        // Set the checksum in the packet
200        udp_data[6] = (checksum >> 8) as u8;
201        udp_data[7] = (checksum & 0xFF) as u8;
202
203        // Verify the checksum
204        assert!(verify_udp_checksum_ipv6(src_ip, dst_ip, &udp_data));
205    }
206
207    #[test]
208    fn test_verify_udp_checksum_ipv4_optional() {
209        // IPv4 UDP checksum is optional; 0 means no checksum
210        let udp_data = vec![
211            0x04, 0xd2, // sport = 1234
212            0x00, 0x35, // dport = 53
213            0x00, 0x08, // length = 8
214            0x00, 0x00, // checksum = 0 (no checksum)
215        ];
216
217        let src_ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
218        let dst_ip: Ipv4Addr = "8.8.8.8".parse().unwrap();
219
220        // Should be valid even with 0 checksum
221        assert!(verify_udp_checksum_ipv4(src_ip, dst_ip, &udp_data));
222    }
223
224    #[test]
225    fn test_verify_udp_checksum_ipv6_mandatory() {
226        // IPv6 UDP checksum is mandatory; 0 is invalid
227        let udp_data = vec![
228            0x04, 0xd2, // sport = 1234
229            0x00, 0x35, // dport = 53
230            0x00, 0x08, // length = 8
231            0x00, 0x00, // checksum = 0 (invalid for IPv6)
232        ];
233
234        let src_ip: Ipv6Addr = "2001:db8::1".parse().unwrap();
235        let dst_ip: Ipv6Addr = "2001:db8::2".parse().unwrap();
236
237        // Should be invalid with 0 checksum
238        assert!(!verify_udp_checksum_ipv6(src_ip, dst_ip, &udp_data));
239    }
240}