stackforge_core/layer/udp/
checksum.rs1use std::net::{Ipv4Addr, Ipv6Addr};
8
9use crate::utils::{finalize_checksum, partial_checksum};
10
11const UDP_PROTOCOL: u8 = 17;
13
14#[must_use]
24pub fn udp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, udp_data: &[u8]) -> u16 {
25 let mut sum: u32 = 0;
26
27 for chunk in src_ip.octets().chunks(2) {
30 sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
31 }
32
33 for chunk in dst_ip.octets().chunks(2) {
35 sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
36 }
37
38 sum += u32::from(UDP_PROTOCOL);
40
41 sum += udp_data.len() as u32;
43
44 sum = partial_checksum(udp_data, sum);
46
47 let checksum = finalize_checksum(sum);
48
49 if checksum == 0 { 0xFFFF } else { checksum }
51}
52
53#[must_use]
63pub fn udp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, udp_data: &[u8]) -> u16 {
64 let mut sum: u32 = 0;
65
66 for chunk in src_ip.octets().chunks(2) {
69 sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
70 }
71
72 for chunk in dst_ip.octets().chunks(2) {
74 sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
75 }
76
77 let udp_len = udp_data.len() as u32;
79 sum += udp_len >> 16;
80 sum += udp_len & 0xFFFF;
81
82 sum += u32::from(UDP_PROTOCOL);
84
85 sum = partial_checksum(udp_data, sum);
87
88 let checksum = finalize_checksum(sum);
89
90 if checksum == 0 { 0xFFFF } else { checksum }
92}
93
94#[must_use]
96pub fn verify_udp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, udp_data: &[u8]) -> bool {
97 if udp_data.len() < 8 {
98 return false;
99 }
100
101 let stored_checksum = u16::from_be_bytes([udp_data[6], udp_data[7]]);
103 if stored_checksum == 0 {
104 return true; }
106
107 let calculated = udp_checksum_ipv4(src_ip, dst_ip, udp_data);
108
109 calculated == 0xFFFF || calculated == stored_checksum
111}
112
113#[must_use]
115pub fn verify_udp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, udp_data: &[u8]) -> bool {
116 if udp_data.len() < 8 {
117 return false;
118 }
119
120 let stored_checksum = u16::from_be_bytes([udp_data[6], udp_data[7]]);
121
122 if stored_checksum == 0 {
124 return false;
125 }
126
127 let calculated = udp_checksum_ipv6(src_ip, dst_ip, udp_data);
128
129 calculated == 0xFFFF || calculated == stored_checksum
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_udp_checksum_ipv4() {
139 let mut udp_data = vec![
141 0x04, 0xd2, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, ];
147
148 let src_ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
149 let dst_ip: Ipv4Addr = "8.8.8.8".parse().unwrap();
150
151 let checksum = udp_checksum_ipv4(src_ip, dst_ip, &udp_data);
152
153 udp_data[6] = (checksum >> 8) as u8;
155 udp_data[7] = (checksum & 0xFF) as u8;
156
157 assert!(verify_udp_checksum_ipv4(src_ip, dst_ip, &udp_data));
159 }
160
161 #[test]
162 fn test_udp_checksum_ipv4_zero() {
163 let udp_data = vec![
166 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, ];
171
172 let src_ip: Ipv4Addr = "0.0.0.0".parse().unwrap();
173 let dst_ip: Ipv4Addr = "0.0.0.0".parse().unwrap();
174
175 let checksum = udp_checksum_ipv4(src_ip, dst_ip, &udp_data);
176
177 assert_ne!(checksum, 0);
179 }
180
181 #[test]
182 fn test_udp_checksum_ipv6() {
183 let mut udp_data = vec![
184 0x04, 0xd2, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, ];
190
191 let src_ip: Ipv6Addr = "2001:db8::1".parse().unwrap();
192 let dst_ip: Ipv6Addr = "2001:db8::2".parse().unwrap();
193
194 let checksum = udp_checksum_ipv6(src_ip, dst_ip, &udp_data);
195
196 assert_ne!(checksum, 0);
198
199 udp_data[6] = (checksum >> 8) as u8;
201 udp_data[7] = (checksum & 0xFF) as u8;
202
203 assert!(verify_udp_checksum_ipv6(src_ip, dst_ip, &udp_data));
205 }
206
207 #[test]
208 fn test_verify_udp_checksum_ipv4_optional() {
209 let udp_data = vec![
211 0x04, 0xd2, 0x00, 0x35, 0x00, 0x08, 0x00, 0x00, ];
216
217 let src_ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
218 let dst_ip: Ipv4Addr = "8.8.8.8".parse().unwrap();
219
220 assert!(verify_udp_checksum_ipv4(src_ip, dst_ip, &udp_data));
222 }
223
224 #[test]
225 fn test_verify_udp_checksum_ipv6_mandatory() {
226 let udp_data = vec![
228 0x04, 0xd2, 0x00, 0x35, 0x00, 0x08, 0x00, 0x00, ];
233
234 let src_ip: Ipv6Addr = "2001:db8::1".parse().unwrap();
235 let dst_ip: Ipv6Addr = "2001:db8::2".parse().unwrap();
236
237 assert!(!verify_udp_checksum_ipv6(src_ip, dst_ip, &udp_data));
239 }
240}