1use std::net::{Ipv4Addr, Ipv6Addr};
45
46use crate::layer::ipv4::checksum::{pseudo_header_checksum, transport_checksum};
47use crate::utils::{finalize_checksum, partial_checksum};
48
49pub const TCP_PROTOCOL: u8 = 6;
51
52#[must_use]
64pub fn tcp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_data: &[u8]) -> u16 {
65 transport_checksum(&src_ip.octets(), &dst_ip.octets(), TCP_PROTOCOL, tcp_data)
66}
67
68#[must_use]
80pub fn tcp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_data: &[u8]) -> u16 {
81 let tcp_len = tcp_data.len() as u32;
82
83 let mut sum: u32 = 0;
85
86 let src_octets = src_ip.octets();
88 for chunk in src_octets.chunks(2) {
89 sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
90 }
91
92 let dst_octets = dst_ip.octets();
94 for chunk in dst_octets.chunks(2) {
95 sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
96 }
97
98 sum += (tcp_len >> 16) & 0xFFFF;
100 sum += tcp_len & 0xFFFF;
101
102 sum += u32::from(TCP_PROTOCOL);
104
105 sum = partial_checksum(tcp_data, sum);
107
108 finalize_checksum(sum)
110}
111
112#[must_use]
127pub fn tcp_checksum(src_ip: &[u8], dst_ip: &[u8], tcp_data: &[u8]) -> Option<u16> {
128 match (src_ip.len(), dst_ip.len()) {
129 (4, 4) => {
130 let src: [u8; 4] = src_ip.try_into().ok()?;
131 let dst: [u8; 4] = dst_ip.try_into().ok()?;
132 Some(transport_checksum(&src, &dst, TCP_PROTOCOL, tcp_data))
133 },
134 (16, 16) => {
135 let src: [u8; 16] = src_ip.try_into().ok()?;
136 let dst: [u8; 16] = dst_ip.try_into().ok()?;
137 Some(tcp_checksum_ipv6(
138 Ipv6Addr::from(src),
139 Ipv6Addr::from(dst),
140 tcp_data,
141 ))
142 },
143 _ => None,
144 }
145}
146
147#[must_use]
159pub fn verify_tcp_checksum(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_data: &[u8]) -> bool {
160 if tcp_data.len() < 20 {
161 return false;
162 }
163
164 let checksum = tcp_checksum_ipv4(src_ip, dst_ip, tcp_data);
165 checksum == 0 || checksum == 0xFFFF
166}
167
168#[must_use]
170pub fn verify_tcp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_data: &[u8]) -> bool {
171 if tcp_data.len() < 20 {
172 return false;
173 }
174
175 let checksum = tcp_checksum_ipv6(src_ip, dst_ip, tcp_data);
176 checksum == 0 || checksum == 0xFFFF
177}
178
179#[must_use]
191pub fn ipv4_pseudo_header(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_len: u16) -> [u8; 12] {
192 let mut header = [0u8; 12];
193
194 header[0..4].copy_from_slice(&src_ip.octets());
196
197 header[4..8].copy_from_slice(&dst_ip.octets());
199
200 header[8] = 0;
202
203 header[9] = TCP_PROTOCOL;
205
206 header[10..12].copy_from_slice(&tcp_len.to_be_bytes());
208
209 header
210}
211
212#[must_use]
224pub fn ipv6_pseudo_header(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_len: u32) -> [u8; 40] {
225 let mut header = [0u8; 40];
226
227 header[0..16].copy_from_slice(&src_ip.octets());
229
230 header[16..32].copy_from_slice(&dst_ip.octets());
232
233 header[32..36].copy_from_slice(&tcp_len.to_be_bytes());
235
236 header[36] = 0;
238 header[37] = 0;
239 header[38] = 0;
240 header[39] = TCP_PROTOCOL;
241
242 header
243}
244
245#[must_use]
249pub fn tcp_partial_checksum(src_ip: &[u8; 4], dst_ip: &[u8; 4], tcp_len: u16) -> u32 {
250 pseudo_header_checksum(src_ip, dst_ip, TCP_PROTOCOL, tcp_len)
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_tcp_checksum_ipv4() {
259 let src_ip = Ipv4Addr::new(192, 168, 1, 1);
260 let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
261
262 let tcp_header = [
264 0x00, 0x50, 0x1F, 0x90, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, ];
273
274 let checksum = tcp_checksum_ipv4(src_ip, dst_ip, &tcp_header);
275 assert_ne!(checksum, 0);
276
277 let mut tcp_with_checksum = tcp_header;
279 tcp_with_checksum[16] = (checksum >> 8) as u8;
280 tcp_with_checksum[17] = (checksum & 0xFF) as u8;
281
282 assert!(verify_tcp_checksum(src_ip, dst_ip, &tcp_with_checksum));
283 }
284
285 #[test]
286 fn test_tcp_checksum_ipv6() {
287 let src_ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1);
288 let dst_ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2);
289
290 let tcp_header = [
292 0x00, 0x50, 0x1F, 0x90, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, ];
301
302 let checksum = tcp_checksum_ipv6(src_ip, dst_ip, &tcp_header);
303 assert_ne!(checksum, 0);
304
305 let mut tcp_with_checksum = tcp_header;
307 tcp_with_checksum[16] = (checksum >> 8) as u8;
308 tcp_with_checksum[17] = (checksum & 0xFF) as u8;
309
310 assert!(verify_tcp_checksum_ipv6(src_ip, dst_ip, &tcp_with_checksum));
311 }
312
313 #[test]
314 fn test_tcp_checksum_generic() {
315 let src_ipv4 = [192, 168, 1, 1];
316 let dst_ipv4 = [192, 168, 1, 2];
317
318 let tcp_header = [
319 0x00, 0x50, 0x1F, 0x90, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, ];
328
329 let checksum = tcp_checksum(&src_ipv4, &dst_ipv4, &tcp_header);
330 assert!(checksum.is_some());
331
332 let checksum_direct = tcp_checksum_ipv4(
333 Ipv4Addr::from(src_ipv4),
334 Ipv4Addr::from(dst_ipv4),
335 &tcp_header,
336 );
337 assert_eq!(checksum.unwrap(), checksum_direct);
338 }
339
340 #[test]
341 fn test_pseudo_header() {
342 let src_ip = Ipv4Addr::new(192, 168, 1, 1);
343 let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
344 let tcp_len = 20u16;
345
346 let header = ipv4_pseudo_header(src_ip, dst_ip, tcp_len);
347
348 assert_eq!(&header[0..4], &[192, 168, 1, 1]);
349 assert_eq!(&header[4..8], &[192, 168, 1, 2]);
350 assert_eq!(header[8], 0);
351 assert_eq!(header[9], TCP_PROTOCOL);
352 assert_eq!(&header[10..12], &[0, 20]);
353 }
354
355 #[test]
356 fn test_invalid_checksum() {
357 let src_ip = Ipv4Addr::new(192, 168, 1, 1);
358 let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
359
360 let tcp_header = [
362 0x00, 0x50, 0x1F, 0x90, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02,
363 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00,
365 ];
366
367 assert!(!verify_tcp_checksum(src_ip, dst_ip, &tcp_header));
368 }
369
370 #[test]
371 fn test_checksum_with_payload() {
372 let src_ip = Ipv4Addr::new(10, 0, 0, 1);
373 let dst_ip = Ipv4Addr::new(10, 0, 0, 2);
374
375 let mut tcp_segment = vec![
377 0x00, 0x50, 0x00, 0x51, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x50, 0x18, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, ];
386 tcp_segment.extend_from_slice(b"Hello");
387
388 let checksum = tcp_checksum_ipv4(src_ip, dst_ip, &tcp_segment);
389 tcp_segment[16] = (checksum >> 8) as u8;
390 tcp_segment[17] = (checksum & 0xFF) as u8;
391
392 assert!(verify_tcp_checksum(src_ip, dst_ip, &tcp_segment));
393 }
394}