Skip to main content

stackforge_core/layer/udp/
checksum.rs

1//! UDP checksum calculation and verification.
2//!
3//! UDP uses the same pseudo-header as TCP for checksum calculation.
4//! For IPv4, the checksum is optional (can be 0), but for IPv6 it's mandatory.
5//! Per RFC 768, if the computed checksum is 0, it should be set to 0xFFFF.
6
7use std::net::{Ipv4Addr, Ipv6Addr};
8
9use crate::utils::{finalize_checksum, partial_checksum};
10
11/// UDP protocol number.
12const UDP_PROTOCOL: u8 = 17;
13
14/// Calculate UDP checksum with IPv4 pseudo-header.
15///
16/// # Arguments
17/// * `src_ip` - Source IPv4 address
18/// * `dst_ip` - Destination IPv4 address
19/// * `udp_data` - Complete UDP packet (header + payload)
20///
21/// # Returns
22/// The calculated checksum. Per RFC 768, if result is 0, returns 0xFFFF.
23pub fn udp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, udp_data: &[u8]) -> u16 {
24    let mut sum: u32 = 0;
25
26    // Add IPv4 pseudo-header
27    // Source IP (4 bytes)
28    for chunk in src_ip.octets().chunks(2) {
29        sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
30    }
31
32    // Destination IP (4 bytes)
33    for chunk in dst_ip.octets().chunks(2) {
34        sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
35    }
36
37    // Zero + Protocol (2 bytes): 0x00 (zero) + 0x11 (UDP) = 0x0011 as u16
38    sum += UDP_PROTOCOL as u32;
39
40    // UDP Length (2 bytes)
41    sum += udp_data.len() as u32;
42
43    // Add UDP data
44    sum = partial_checksum(udp_data, sum);
45
46    let checksum = finalize_checksum(sum);
47
48    // RFC 768: If computed checksum is 0, use 0xFFFF
49    if checksum == 0 { 0xFFFF } else { checksum }
50}
51
52/// Calculate UDP checksum with IPv6 pseudo-header.
53///
54/// # Arguments
55/// * `src_ip` - Source IPv6 address
56/// * `dst_ip` - Destination IPv6 address
57/// * `udp_data` - Complete UDP packet (header + payload)
58///
59/// # Returns
60/// The calculated checksum. Per RFC 2460, if result is 0, returns 0xFFFF.
61pub fn udp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, udp_data: &[u8]) -> u16 {
62    let mut sum: u32 = 0;
63
64    // Add IPv6 pseudo-header
65    // Source IP (16 bytes)
66    for chunk in src_ip.octets().chunks(2) {
67        sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
68    }
69
70    // Destination IP (16 bytes)
71    for chunk in dst_ip.octets().chunks(2) {
72        sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
73    }
74
75    // UDP Length (4 bytes, big-endian)
76    let udp_len = udp_data.len() as u32;
77    sum += (udp_len >> 16) as u32;
78    sum += (udp_len & 0xFFFF) as u32;
79
80    // Three zero bytes + Next Header (4 bytes total)
81    sum += UDP_PROTOCOL as u32;
82
83    // Add UDP data
84    sum = partial_checksum(udp_data, sum);
85
86    let checksum = finalize_checksum(sum);
87
88    // RFC 2460: If computed checksum is 0, use 0xFFFF
89    if checksum == 0 { 0xFFFF } else { checksum }
90}
91
92/// Verify UDP checksum with IPv4 pseudo-header.
93pub fn verify_udp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, udp_data: &[u8]) -> bool {
94    if udp_data.len() < 8 {
95        return false;
96    }
97
98    // Check if checksum is 0 (optional for IPv4)
99    let stored_checksum = u16::from_be_bytes([udp_data[6], udp_data[7]]);
100    if stored_checksum == 0 {
101        return true; // No checksum provided, considered valid
102    }
103
104    let calculated = udp_checksum_ipv4(src_ip, dst_ip, udp_data);
105
106    // The checksum should be 0xFFFF when calculated over data with valid checksum
107    calculated == 0xFFFF || calculated == stored_checksum
108}
109
110/// Verify UDP checksum with IPv6 pseudo-header.
111pub fn verify_udp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, udp_data: &[u8]) -> bool {
112    if udp_data.len() < 8 {
113        return false;
114    }
115
116    let stored_checksum = u16::from_be_bytes([udp_data[6], udp_data[7]]);
117
118    // For IPv6, checksum is mandatory (cannot be 0)
119    if stored_checksum == 0 {
120        return false;
121    }
122
123    let calculated = udp_checksum_ipv6(src_ip, dst_ip, udp_data);
124
125    // The checksum should be 0xFFFF when calculated over data with valid checksum
126    calculated == 0xFFFF || calculated == stored_checksum
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_udp_checksum_ipv4() {
135        // Create a simple UDP packet
136        let mut udp_data = vec![
137            0x04, 0xd2, // sport = 1234
138            0x00, 0x35, // dport = 53
139            0x00, 0x0c, // length = 12
140            0x00, 0x00, // checksum = 0 (to be calculated)
141            0x01, 0x02, 0x03, 0x04, // payload
142        ];
143
144        let src_ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
145        let dst_ip: Ipv4Addr = "8.8.8.8".parse().unwrap();
146
147        let checksum = udp_checksum_ipv4(src_ip, dst_ip, &udp_data);
148
149        // Set the checksum in the packet
150        udp_data[6] = (checksum >> 8) as u8;
151        udp_data[7] = (checksum & 0xFF) as u8;
152
153        // Verify the checksum
154        assert!(verify_udp_checksum_ipv4(src_ip, dst_ip, &udp_data));
155    }
156
157    #[test]
158    fn test_udp_checksum_ipv4_zero() {
159        // Test that computed checksum of 0 becomes 0xFFFF
160        // This is a contrived example; in practice, getting a 0 checksum is rare
161        let udp_data = vec![
162            0x00, 0x00, // sport
163            0x00, 0x00, // dport
164            0x00, 0x08, // length = 8 (header only)
165            0x00, 0x00, // checksum placeholder
166        ];
167
168        let src_ip: Ipv4Addr = "0.0.0.0".parse().unwrap();
169        let dst_ip: Ipv4Addr = "0.0.0.0".parse().unwrap();
170
171        let checksum = udp_checksum_ipv4(src_ip, dst_ip, &udp_data);
172
173        // The checksum should be 0xFFFF if raw calculation was 0
174        assert_ne!(checksum, 0);
175    }
176
177    #[test]
178    fn test_udp_checksum_ipv6() {
179        let mut udp_data = vec![
180            0x04, 0xd2, // sport = 1234
181            0x00, 0x35, // dport = 53
182            0x00, 0x0c, // length = 12
183            0x00, 0x00, // checksum = 0 (to be calculated)
184            0x01, 0x02, 0x03, 0x04, // payload
185        ];
186
187        let src_ip: Ipv6Addr = "2001:db8::1".parse().unwrap();
188        let dst_ip: Ipv6Addr = "2001:db8::2".parse().unwrap();
189
190        let checksum = udp_checksum_ipv6(src_ip, dst_ip, &udp_data);
191
192        // Checksum should not be 0 for IPv6
193        assert_ne!(checksum, 0);
194
195        // Set the checksum in the packet
196        udp_data[6] = (checksum >> 8) as u8;
197        udp_data[7] = (checksum & 0xFF) as u8;
198
199        // Verify the checksum
200        assert!(verify_udp_checksum_ipv6(src_ip, dst_ip, &udp_data));
201    }
202
203    #[test]
204    fn test_verify_udp_checksum_ipv4_optional() {
205        // IPv4 UDP checksum is optional; 0 means no checksum
206        let udp_data = vec![
207            0x04, 0xd2, // sport = 1234
208            0x00, 0x35, // dport = 53
209            0x00, 0x08, // length = 8
210            0x00, 0x00, // checksum = 0 (no checksum)
211        ];
212
213        let src_ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
214        let dst_ip: Ipv4Addr = "8.8.8.8".parse().unwrap();
215
216        // Should be valid even with 0 checksum
217        assert!(verify_udp_checksum_ipv4(src_ip, dst_ip, &udp_data));
218    }
219
220    #[test]
221    fn test_verify_udp_checksum_ipv6_mandatory() {
222        // IPv6 UDP checksum is mandatory; 0 is invalid
223        let udp_data = vec![
224            0x04, 0xd2, // sport = 1234
225            0x00, 0x35, // dport = 53
226            0x00, 0x08, // length = 8
227            0x00, 0x00, // checksum = 0 (invalid for IPv6)
228        ];
229
230        let src_ip: Ipv6Addr = "2001:db8::1".parse().unwrap();
231        let dst_ip: Ipv6Addr = "2001:db8::2".parse().unwrap();
232
233        // Should be invalid with 0 checksum
234        assert!(!verify_udp_checksum_ipv6(src_ip, dst_ip, &udp_data));
235    }
236}