stackforge_core/layer/udp/
checksum.rs1use std::net::{Ipv4Addr, Ipv6Addr};
8
9use crate::utils::{finalize_checksum, partial_checksum};
10
11const UDP_PROTOCOL: u8 = 17;
13
14pub fn udp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, udp_data: &[u8]) -> u16 {
24 let mut sum: u32 = 0;
25
26 for chunk in src_ip.octets().chunks(2) {
29 sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
30 }
31
32 for chunk in dst_ip.octets().chunks(2) {
34 sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
35 }
36
37 sum += UDP_PROTOCOL as u32;
39
40 sum += udp_data.len() as u32;
42
43 sum = partial_checksum(udp_data, sum);
45
46 let checksum = finalize_checksum(sum);
47
48 if checksum == 0 { 0xFFFF } else { checksum }
50}
51
52pub fn udp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, udp_data: &[u8]) -> u16 {
62 let mut sum: u32 = 0;
63
64 for chunk in src_ip.octets().chunks(2) {
67 sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
68 }
69
70 for chunk in dst_ip.octets().chunks(2) {
72 sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
73 }
74
75 let udp_len = udp_data.len() as u32;
77 sum += (udp_len >> 16) as u32;
78 sum += (udp_len & 0xFFFF) as u32;
79
80 sum += UDP_PROTOCOL as u32;
82
83 sum = partial_checksum(udp_data, sum);
85
86 let checksum = finalize_checksum(sum);
87
88 if checksum == 0 { 0xFFFF } else { checksum }
90}
91
92pub fn verify_udp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, udp_data: &[u8]) -> bool {
94 if udp_data.len() < 8 {
95 return false;
96 }
97
98 let stored_checksum = u16::from_be_bytes([udp_data[6], udp_data[7]]);
100 if stored_checksum == 0 {
101 return true; }
103
104 let calculated = udp_checksum_ipv4(src_ip, dst_ip, udp_data);
105
106 calculated == 0xFFFF || calculated == stored_checksum
108}
109
110pub fn verify_udp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, udp_data: &[u8]) -> bool {
112 if udp_data.len() < 8 {
113 return false;
114 }
115
116 let stored_checksum = u16::from_be_bytes([udp_data[6], udp_data[7]]);
117
118 if stored_checksum == 0 {
120 return false;
121 }
122
123 let calculated = udp_checksum_ipv6(src_ip, dst_ip, udp_data);
124
125 calculated == 0xFFFF || calculated == stored_checksum
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_udp_checksum_ipv4() {
135 let mut udp_data = vec![
137 0x04, 0xd2, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, ];
143
144 let src_ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
145 let dst_ip: Ipv4Addr = "8.8.8.8".parse().unwrap();
146
147 let checksum = udp_checksum_ipv4(src_ip, dst_ip, &udp_data);
148
149 udp_data[6] = (checksum >> 8) as u8;
151 udp_data[7] = (checksum & 0xFF) as u8;
152
153 assert!(verify_udp_checksum_ipv4(src_ip, dst_ip, &udp_data));
155 }
156
157 #[test]
158 fn test_udp_checksum_ipv4_zero() {
159 let udp_data = vec![
162 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, ];
167
168 let src_ip: Ipv4Addr = "0.0.0.0".parse().unwrap();
169 let dst_ip: Ipv4Addr = "0.0.0.0".parse().unwrap();
170
171 let checksum = udp_checksum_ipv4(src_ip, dst_ip, &udp_data);
172
173 assert_ne!(checksum, 0);
175 }
176
177 #[test]
178 fn test_udp_checksum_ipv6() {
179 let mut udp_data = vec![
180 0x04, 0xd2, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, ];
186
187 let src_ip: Ipv6Addr = "2001:db8::1".parse().unwrap();
188 let dst_ip: Ipv6Addr = "2001:db8::2".parse().unwrap();
189
190 let checksum = udp_checksum_ipv6(src_ip, dst_ip, &udp_data);
191
192 assert_ne!(checksum, 0);
194
195 udp_data[6] = (checksum >> 8) as u8;
197 udp_data[7] = (checksum & 0xFF) as u8;
198
199 assert!(verify_udp_checksum_ipv6(src_ip, dst_ip, &udp_data));
201 }
202
203 #[test]
204 fn test_verify_udp_checksum_ipv4_optional() {
205 let udp_data = vec![
207 0x04, 0xd2, 0x00, 0x35, 0x00, 0x08, 0x00, 0x00, ];
212
213 let src_ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
214 let dst_ip: Ipv4Addr = "8.8.8.8".parse().unwrap();
215
216 assert!(verify_udp_checksum_ipv4(src_ip, dst_ip, &udp_data));
218 }
219
220 #[test]
221 fn test_verify_udp_checksum_ipv6_mandatory() {
222 let udp_data = vec![
224 0x04, 0xd2, 0x00, 0x35, 0x00, 0x08, 0x00, 0x00, ];
229
230 let src_ip: Ipv6Addr = "2001:db8::1".parse().unwrap();
231 let dst_ip: Ipv6Addr = "2001:db8::2".parse().unwrap();
232
233 assert!(!verify_udp_checksum_ipv6(src_ip, dst_ip, &udp_data));
235 }
236}