1use crate::utils::internet_checksum;
7
8#[inline]
32pub fn ipv4_checksum(data: &[u8]) -> u16 {
33 internet_checksum(data)
34}
35
36#[inline]
44pub fn verify_ipv4_checksum(data: &[u8]) -> bool {
45 let sum = internet_checksum(data);
46 sum == 0 || sum == 0xFFFF
47}
48
49#[inline]
64pub fn incremental_update_checksum(old_checksum: u16, old_value: u16, new_value: u16) -> u16 {
65 let hc = !old_checksum as u32;
67 let m = !old_value as u32;
68 let m_prime = new_value as u32;
69
70 let mut sum = hc + m + m_prime;
71
72 while (sum >> 16) != 0 {
74 sum = (sum & 0xFFFF) + (sum >> 16);
75 }
76
77 !sum as u16
78}
79
80#[inline]
84pub fn incremental_update_checksum_32(old_checksum: u16, old_value: u32, new_value: u32) -> u16 {
85 let old_high = (old_value >> 16) as u16;
87 let old_low = (old_value & 0xFFFF) as u16;
88 let new_high = (new_value >> 16) as u16;
89 let new_low = (new_value & 0xFFFF) as u16;
90
91 let tmp = incremental_update_checksum(old_checksum, old_high, new_high);
92 incremental_update_checksum(tmp, old_low, new_low)
93}
94
95pub fn pseudo_header_checksum(
115 src_ip: &[u8; 4],
116 dst_ip: &[u8; 4],
117 protocol: u8,
118 transport_len: u16,
119) -> u32 {
120 let mut sum: u32 = 0;
121
122 sum += u16::from_be_bytes([src_ip[0], src_ip[1]]) as u32;
124 sum += u16::from_be_bytes([src_ip[2], src_ip[3]]) as u32;
125
126 sum += u16::from_be_bytes([dst_ip[0], dst_ip[1]]) as u32;
128 sum += u16::from_be_bytes([dst_ip[2], dst_ip[3]]) as u32;
129
130 sum += protocol as u32;
132
133 sum += transport_len as u32;
135
136 sum
137}
138
139pub fn transport_checksum(
152 src_ip: &[u8; 4],
153 dst_ip: &[u8; 4],
154 protocol: u8,
155 transport_data: &[u8],
156) -> u16 {
157 let mut sum = pseudo_header_checksum(src_ip, dst_ip, protocol, transport_data.len() as u16);
159
160 let mut chunks = transport_data.chunks_exact(2);
162 for chunk in chunks.by_ref() {
163 sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
164 }
165
166 if let Some(&last) = chunks.remainder().first() {
168 sum += (last as u32) << 8;
169 }
170
171 while (sum >> 16) != 0 {
173 sum = (sum & 0xFFFF) + (sum >> 16);
174 }
175
176 let result = !sum as u16;
178 if result == 0 && protocol == 17 {
179 0xFFFF
180 } else {
181 result
182 }
183}
184
185pub use crate::utils::{finalize_checksum, partial_checksum};
188
189#[inline]
191pub fn zero_checksum(buf: &mut [u8], offset: usize) {
192 if buf.len() >= offset + 2 {
193 buf[offset] = 0;
194 buf[offset + 1] = 0;
195 }
196}
197
198#[inline]
200pub fn write_checksum(buf: &mut [u8], offset: usize, checksum: u16) {
201 if buf.len() >= offset + 2 {
202 let bytes = checksum.to_be_bytes();
203 buf[offset] = bytes[0];
204 buf[offset + 1] = bytes[1];
205 }
206}
207
208#[inline]
210pub fn read_checksum(buf: &[u8], offset: usize) -> Option<u16> {
211 if buf.len() >= offset + 2 {
212 Some(u16::from_be_bytes([buf[offset], buf[offset + 1]]))
213 } else {
214 None
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_ipv4_checksum() {
224 let header = [
226 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c, ];
232
233 let checksum = ipv4_checksum(&header);
234
235 let mut header_with_cksum = header;
237 header_with_cksum[10] = (checksum >> 8) as u8;
238 header_with_cksum[11] = (checksum & 0xFF) as u8;
239
240 assert!(verify_ipv4_checksum(&header_with_cksum));
241 }
242
243 #[test]
244 fn test_verify_valid_checksum() {
245 let header = [
247 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xb1,
248 0xe6, 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
250 ];
251
252 assert!(verify_ipv4_checksum(&header));
253 }
254
255 #[test]
256 fn test_verify_invalid_checksum() {
257 let header = [
259 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xFF,
260 0xFF, 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
262 ];
263
264 assert!(!verify_ipv4_checksum(&header));
265 }
266
267 #[test]
268 fn test_incremental_update() {
269 let mut header = [
271 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10,
272 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
273 ];
274
275 let initial_checksum = ipv4_checksum(&header);
277 header[10] = (initial_checksum >> 8) as u8;
278 header[11] = (initial_checksum & 0xFF) as u8;
279
280 let old_ttl_word = u16::from_be_bytes([header[8], header[9]]);
282 header[8] = 0x3F;
283 let new_ttl_word = u16::from_be_bytes([header[8], header[9]]);
284
285 let new_checksum =
286 incremental_update_checksum(initial_checksum, old_ttl_word, new_ttl_word);
287 header[10] = (new_checksum >> 8) as u8;
288 header[11] = (new_checksum & 0xFF) as u8;
289
290 assert!(verify_ipv4_checksum(&header));
292 }
293
294 #[test]
295 fn test_pseudo_header_checksum() {
296 let src = [192, 168, 1, 1];
297 let dst = [192, 168, 1, 2];
298 let protocol = 6; let length = 20; let sum = pseudo_header_checksum(&src, &dst, protocol, length);
302
303 assert!(sum > 0);
305 }
306
307 #[test]
308 fn test_transport_checksum_tcp() {
309 let src_ip = [192, 168, 1, 1];
310 let dst_ip = [192, 168, 1, 2];
311
312 let tcp_header = [
314 0x00, 0x50, 0x1F, 0x90, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, ];
323
324 let checksum = transport_checksum(&src_ip, &dst_ip, 6, &tcp_header);
325 assert_ne!(checksum, 0);
326 }
327
328 #[test]
329 fn test_transport_checksum_udp_zero() {
330 let src_ip = [0, 0, 0, 0];
331 let dst_ip = [0, 0, 0, 0];
332
333 let udp_header = [
336 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, ];
341
342 let checksum = transport_checksum(&src_ip, &dst_ip, 17, &udp_header);
343 assert_ne!(checksum, 0);
345 }
346
347 #[test]
348 fn test_partial_checksum() {
349 let data1 = [0x01, 0x02, 0x03, 0x04];
350 let data2 = [0x05, 0x06, 0x07, 0x08];
351
352 let sum1 = partial_checksum(&data1, 0);
354 let sum2 = partial_checksum(&data2, sum1);
355 let checksum1 = finalize_checksum(sum2);
356
357 let combined = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
359 let checksum2 = internet_checksum(&combined);
360
361 assert_eq!(checksum1, checksum2);
362 }
363
364 #[test]
365 fn test_odd_length_data() {
366 let data = [0x45, 0x00, 0x00, 0x3c, 0x1c];
368 let checksum = internet_checksum(&data);
369
370 assert_ne!(checksum, 0);
372 }
373
374 #[test]
375 fn test_zero_and_write_checksum() {
376 let mut buf = [0x45, 0x00, 0xAB, 0xCD, 0x00, 0x00];
377
378 zero_checksum(&mut buf, 2);
379 assert_eq!(buf[2], 0);
380 assert_eq!(buf[3], 0);
381
382 write_checksum(&mut buf, 2, 0x1234);
383 assert_eq!(buf[2], 0x12);
384 assert_eq!(buf[3], 0x34);
385
386 assert_eq!(read_checksum(&buf, 2), Some(0x1234));
387 }
388}