1use crate::utils::internet_checksum;
7
8#[inline]
32#[must_use]
33pub fn ipv4_checksum(data: &[u8]) -> u16 {
34 internet_checksum(data)
35}
36
37#[inline]
45#[must_use]
46pub fn verify_ipv4_checksum(data: &[u8]) -> bool {
47 let sum = internet_checksum(data);
48 sum == 0 || sum == 0xFFFF
49}
50
51#[inline]
66#[must_use]
67pub fn incremental_update_checksum(old_checksum: u16, old_value: u16, new_value: u16) -> u16 {
68 let hc = u32::from(!old_checksum);
70 let m = u32::from(!old_value);
71 let m_prime = u32::from(new_value);
72
73 let mut sum = hc + m + m_prime;
74
75 while (sum >> 16) != 0 {
77 sum = (sum & 0xFFFF) + (sum >> 16);
78 }
79
80 !sum as u16
81}
82
83#[inline]
87#[must_use]
88pub fn incremental_update_checksum_32(old_checksum: u16, old_value: u32, new_value: u32) -> u16 {
89 let old_high = (old_value >> 16) as u16;
91 let old_low = (old_value & 0xFFFF) as u16;
92 let new_high = (new_value >> 16) as u16;
93 let new_low = (new_value & 0xFFFF) as u16;
94
95 let tmp = incremental_update_checksum(old_checksum, old_high, new_high);
96 incremental_update_checksum(tmp, old_low, new_low)
97}
98
99#[must_use]
119pub fn pseudo_header_checksum(
120 src_ip: &[u8; 4],
121 dst_ip: &[u8; 4],
122 protocol: u8,
123 transport_len: u16,
124) -> u32 {
125 let mut sum: u32 = 0;
126
127 sum += u32::from(u16::from_be_bytes([src_ip[0], src_ip[1]]));
129 sum += u32::from(u16::from_be_bytes([src_ip[2], src_ip[3]]));
130
131 sum += u32::from(u16::from_be_bytes([dst_ip[0], dst_ip[1]]));
133 sum += u32::from(u16::from_be_bytes([dst_ip[2], dst_ip[3]]));
134
135 sum += u32::from(protocol);
137
138 sum += u32::from(transport_len);
140
141 sum
142}
143
144#[must_use]
157pub fn transport_checksum(
158 src_ip: &[u8; 4],
159 dst_ip: &[u8; 4],
160 protocol: u8,
161 transport_data: &[u8],
162) -> u16 {
163 let mut sum = pseudo_header_checksum(src_ip, dst_ip, protocol, transport_data.len() as u16);
165
166 let mut chunks = transport_data.chunks_exact(2);
168 for chunk in chunks.by_ref() {
169 sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
170 }
171
172 if let Some(&last) = chunks.remainder().first() {
174 sum += u32::from(last) << 8;
175 }
176
177 while (sum >> 16) != 0 {
179 sum = (sum & 0xFFFF) + (sum >> 16);
180 }
181
182 let result = !sum as u16;
184 if result == 0 && protocol == 17 {
185 0xFFFF
186 } else {
187 result
188 }
189}
190
191pub use crate::utils::{finalize_checksum, partial_checksum};
194
195#[inline]
197pub fn zero_checksum(buf: &mut [u8], offset: usize) {
198 if buf.len() >= offset + 2 {
199 buf[offset] = 0;
200 buf[offset + 1] = 0;
201 }
202}
203
204#[inline]
206pub fn write_checksum(buf: &mut [u8], offset: usize, checksum: u16) {
207 if buf.len() >= offset + 2 {
208 let bytes = checksum.to_be_bytes();
209 buf[offset] = bytes[0];
210 buf[offset + 1] = bytes[1];
211 }
212}
213
214#[inline]
216#[must_use]
217pub fn read_checksum(buf: &[u8], offset: usize) -> Option<u16> {
218 if buf.len() >= offset + 2 {
219 Some(u16::from_be_bytes([buf[offset], buf[offset + 1]]))
220 } else {
221 None
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_ipv4_checksum() {
231 let header = [
233 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c, ];
239
240 let checksum = ipv4_checksum(&header);
241
242 let mut header_with_cksum = header;
244 header_with_cksum[10] = (checksum >> 8) as u8;
245 header_with_cksum[11] = (checksum & 0xFF) as u8;
246
247 assert!(verify_ipv4_checksum(&header_with_cksum));
248 }
249
250 #[test]
251 fn test_verify_valid_checksum() {
252 let header = [
254 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xb1,
255 0xe6, 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
257 ];
258
259 assert!(verify_ipv4_checksum(&header));
260 }
261
262 #[test]
263 fn test_verify_invalid_checksum() {
264 let header = [
266 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xFF,
267 0xFF, 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
269 ];
270
271 assert!(!verify_ipv4_checksum(&header));
272 }
273
274 #[test]
275 fn test_incremental_update() {
276 let mut header = [
278 0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10,
279 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c,
280 ];
281
282 let initial_checksum = ipv4_checksum(&header);
284 header[10] = (initial_checksum >> 8) as u8;
285 header[11] = (initial_checksum & 0xFF) as u8;
286
287 let old_ttl_word = u16::from_be_bytes([header[8], header[9]]);
289 header[8] = 0x3F;
290 let new_ttl_word = u16::from_be_bytes([header[8], header[9]]);
291
292 let new_checksum =
293 incremental_update_checksum(initial_checksum, old_ttl_word, new_ttl_word);
294 header[10] = (new_checksum >> 8) as u8;
295 header[11] = (new_checksum & 0xFF) as u8;
296
297 assert!(verify_ipv4_checksum(&header));
299 }
300
301 #[test]
302 fn test_pseudo_header_checksum() {
303 let src = [192, 168, 1, 1];
304 let dst = [192, 168, 1, 2];
305 let protocol = 6; let length = 20; let sum = pseudo_header_checksum(&src, &dst, protocol, length);
309
310 assert!(sum > 0);
312 }
313
314 #[test]
315 fn test_transport_checksum_tcp() {
316 let src_ip = [192, 168, 1, 1];
317 let dst_ip = [192, 168, 1, 2];
318
319 let tcp_header = [
321 0x00, 0x50, 0x1F, 0x90, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, ];
330
331 let checksum = transport_checksum(&src_ip, &dst_ip, 6, &tcp_header);
332 assert_ne!(checksum, 0);
333 }
334
335 #[test]
336 fn test_transport_checksum_udp_zero() {
337 let src_ip = [0, 0, 0, 0];
338 let dst_ip = [0, 0, 0, 0];
339
340 let udp_header = [
343 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, ];
348
349 let checksum = transport_checksum(&src_ip, &dst_ip, 17, &udp_header);
350 assert_ne!(checksum, 0);
352 }
353
354 #[test]
355 fn test_partial_checksum() {
356 let data1 = [0x01, 0x02, 0x03, 0x04];
357 let data2 = [0x05, 0x06, 0x07, 0x08];
358
359 let sum1 = partial_checksum(&data1, 0);
361 let sum2 = partial_checksum(&data2, sum1);
362 let checksum1 = finalize_checksum(sum2);
363
364 let combined = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
366 let checksum2 = internet_checksum(&combined);
367
368 assert_eq!(checksum1, checksum2);
369 }
370
371 #[test]
372 fn test_odd_length_data() {
373 let data = [0x45, 0x00, 0x00, 0x3c, 0x1c];
375 let checksum = internet_checksum(&data);
376
377 assert_ne!(checksum, 0);
379 }
380
381 #[test]
382 fn test_zero_and_write_checksum() {
383 let mut buf = [0x45, 0x00, 0xAB, 0xCD, 0x00, 0x00];
384
385 zero_checksum(&mut buf, 2);
386 assert_eq!(buf[2], 0);
387 assert_eq!(buf[3], 0);
388
389 write_checksum(&mut buf, 2, 0x1234);
390 assert_eq!(buf[2], 0x12);
391 assert_eq!(buf[3], 0x34);
392
393 assert_eq!(read_checksum(&buf, 2), Some(0x1234));
394 }
395}