stackforge_core/layer/tcp/checksum.rs
1//! TCP checksum calculation.
2//!
3//! TCP uses the Internet checksum (RFC 1071) computed over a pseudo-header
4//! followed by the TCP header and payload.
5//!
6//! # IPv4 Pseudo-header
7//!
8//! ```text
9//! +--------+--------+--------+--------+
10//! | Source Address |
11//! +--------+--------+--------+--------+
12//! | Destination Address |
13//! +--------+--------+--------+--------+
14//! | zero | PTCL | TCP Length |
15//! +--------+--------+--------+--------+
16//! ```
17//!
18//! # IPv6 Pseudo-header
19//!
20//! ```text
21//! +--------+--------+--------+--------+
22//! | |
23//! + +
24//! | Source Address |
25//! + +
26//! | |
27//! + +
28//! | |
29//! +--------+--------+--------+--------+
30//! | |
31//! + +
32//! | Destination Address |
33//! + +
34//! | |
35//! + +
36//! | |
37//! +--------+--------+--------+--------+
38//! | Upper-Layer Packet Length |
39//! +--------+--------+--------+--------+
40//! | zero |Next Hdr|
41//! +--------+--------+--------+--------+
42//! ```
43
44use std::net::{Ipv4Addr, Ipv6Addr};
45
46use crate::layer::ipv4::checksum::{pseudo_header_checksum, transport_checksum};
47use crate::utils::{finalize_checksum, partial_checksum};
48
49/// TCP protocol number for pseudo-header.
50pub const TCP_PROTOCOL: u8 = 6;
51
52/// Compute TCP checksum with IPv4 pseudo-header.
53///
54/// # Arguments
55///
56/// * `src_ip` - Source IPv4 address
57/// * `dst_ip` - Destination IPv4 address
58/// * `tcp_data` - Complete TCP segment (header + payload)
59///
60/// # Returns
61///
62/// The computed checksum value.
63pub fn tcp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_data: &[u8]) -> u16 {
64 transport_checksum(&src_ip.octets(), &dst_ip.octets(), TCP_PROTOCOL, tcp_data)
65}
66
67/// Compute TCP checksum with IPv6 pseudo-header.
68///
69/// # Arguments
70///
71/// * `src_ip` - Source IPv6 address
72/// * `dst_ip` - Destination IPv6 address
73/// * `tcp_data` - Complete TCP segment (header + payload)
74///
75/// # Returns
76///
77/// The computed checksum value.
78pub fn tcp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_data: &[u8]) -> u16 {
79 let tcp_len = tcp_data.len() as u32;
80
81 // IPv6 pseudo-header checksum
82 let mut sum: u32 = 0;
83
84 // Source address (16 bytes)
85 let src_octets = src_ip.octets();
86 for chunk in src_octets.chunks(2) {
87 sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
88 }
89
90 // Destination address (16 bytes)
91 let dst_octets = dst_ip.octets();
92 for chunk in dst_octets.chunks(2) {
93 sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
94 }
95
96 // Upper-layer packet length (4 bytes)
97 sum += ((tcp_len >> 16) & 0xFFFF) as u32;
98 sum += (tcp_len & 0xFFFF) as u32;
99
100 // Zero + Next Header (4 bytes, but only last byte is non-zero)
101 sum += TCP_PROTOCOL as u32;
102
103 // Add TCP data
104 sum = partial_checksum(tcp_data, sum);
105
106 // Finalize
107 finalize_checksum(sum)
108}
109
110/// Compute TCP checksum (generic version).
111///
112/// Automatically handles IPv4 or IPv6 based on address type.
113/// Uses raw byte slices for maximum flexibility.
114///
115/// # Arguments
116///
117/// * `src_ip` - Source IP address bytes (4 for IPv4, 16 for IPv6)
118/// * `dst_ip` - Destination IP address bytes
119/// * `tcp_data` - Complete TCP segment (header + payload)
120///
121/// # Returns
122///
123/// The computed checksum value, or None if address lengths are invalid.
124pub fn tcp_checksum(src_ip: &[u8], dst_ip: &[u8], tcp_data: &[u8]) -> Option<u16> {
125 match (src_ip.len(), dst_ip.len()) {
126 (4, 4) => {
127 let src: [u8; 4] = src_ip.try_into().ok()?;
128 let dst: [u8; 4] = dst_ip.try_into().ok()?;
129 Some(transport_checksum(&src, &dst, TCP_PROTOCOL, tcp_data))
130 },
131 (16, 16) => {
132 let src: [u8; 16] = src_ip.try_into().ok()?;
133 let dst: [u8; 16] = dst_ip.try_into().ok()?;
134 Some(tcp_checksum_ipv6(
135 Ipv6Addr::from(src),
136 Ipv6Addr::from(dst),
137 tcp_data,
138 ))
139 },
140 _ => None,
141 }
142}
143
144/// Verify TCP checksum with IPv4 pseudo-header.
145///
146/// # Arguments
147///
148/// * `src_ip` - Source IPv4 address
149/// * `dst_ip` - Destination IPv4 address
150/// * `tcp_data` - Complete TCP segment (header + payload) with checksum
151///
152/// # Returns
153///
154/// `true` if the checksum is valid.
155pub fn verify_tcp_checksum(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_data: &[u8]) -> bool {
156 if tcp_data.len() < 20 {
157 return false;
158 }
159
160 let checksum = tcp_checksum_ipv4(src_ip, dst_ip, tcp_data);
161 checksum == 0 || checksum == 0xFFFF
162}
163
164/// Verify TCP checksum with IPv6 pseudo-header.
165pub fn verify_tcp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_data: &[u8]) -> bool {
166 if tcp_data.len() < 20 {
167 return false;
168 }
169
170 let checksum = tcp_checksum_ipv6(src_ip, dst_ip, tcp_data);
171 checksum == 0 || checksum == 0xFFFF
172}
173
174/// Build the IPv4 pseudo-header bytes for TCP.
175///
176/// # Arguments
177///
178/// * `src_ip` - Source IPv4 address
179/// * `dst_ip` - Destination IPv4 address
180/// * `tcp_len` - Length of TCP segment (header + payload)
181///
182/// # Returns
183///
184/// 12-byte pseudo-header.
185pub fn ipv4_pseudo_header(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_len: u16) -> [u8; 12] {
186 let mut header = [0u8; 12];
187
188 // Source IP (4 bytes)
189 header[0..4].copy_from_slice(&src_ip.octets());
190
191 // Destination IP (4 bytes)
192 header[4..8].copy_from_slice(&dst_ip.octets());
193
194 // Zero (1 byte)
195 header[8] = 0;
196
197 // Protocol (1 byte)
198 header[9] = TCP_PROTOCOL;
199
200 // TCP Length (2 bytes)
201 header[10..12].copy_from_slice(&tcp_len.to_be_bytes());
202
203 header
204}
205
206/// Build the IPv6 pseudo-header bytes for TCP.
207///
208/// # Arguments
209///
210/// * `src_ip` - Source IPv6 address
211/// * `dst_ip` - Destination IPv6 address
212/// * `tcp_len` - Length of TCP segment (header + payload)
213///
214/// # Returns
215///
216/// 40-byte pseudo-header.
217pub fn ipv6_pseudo_header(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_len: u32) -> [u8; 40] {
218 let mut header = [0u8; 40];
219
220 // Source IP (16 bytes)
221 header[0..16].copy_from_slice(&src_ip.octets());
222
223 // Destination IP (16 bytes)
224 header[16..32].copy_from_slice(&dst_ip.octets());
225
226 // Upper-Layer Packet Length (4 bytes)
227 header[32..36].copy_from_slice(&tcp_len.to_be_bytes());
228
229 // Zero (3 bytes) + Next Header (1 byte)
230 header[36] = 0;
231 header[37] = 0;
232 header[38] = 0;
233 header[39] = TCP_PROTOCOL;
234
235 header
236}
237
238/// Compute checksum for partial data (for segmented computation).
239///
240/// Useful when computing checksum across multiple buffers.
241pub fn tcp_partial_checksum(src_ip: &[u8; 4], dst_ip: &[u8; 4], tcp_len: u16) -> u32 {
242 pseudo_header_checksum(src_ip, dst_ip, TCP_PROTOCOL, tcp_len)
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_tcp_checksum_ipv4() {
251 let src_ip = Ipv4Addr::new(192, 168, 1, 1);
252 let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
253
254 // Minimal TCP header with zeroed checksum
255 let tcp_header = [
256 0x00, 0x50, // Source port: 80
257 0x1F, 0x90, // Dest port: 8080
258 0x00, 0x00, 0x00, 0x01, // Seq number
259 0x00, 0x00, 0x00, 0x00, // Ack number
260 0x50, 0x02, // Data offset + flags (SYN)
261 0xFF, 0xFF, // Window
262 0x00, 0x00, // Checksum (zeroed)
263 0x00, 0x00, // Urgent pointer
264 ];
265
266 let checksum = tcp_checksum_ipv4(src_ip, dst_ip, &tcp_header);
267 assert_ne!(checksum, 0);
268
269 // Insert checksum and verify
270 let mut tcp_with_checksum = tcp_header;
271 tcp_with_checksum[16] = (checksum >> 8) as u8;
272 tcp_with_checksum[17] = (checksum & 0xFF) as u8;
273
274 assert!(verify_tcp_checksum(src_ip, dst_ip, &tcp_with_checksum));
275 }
276
277 #[test]
278 fn test_tcp_checksum_ipv6() {
279 let src_ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1);
280 let dst_ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2);
281
282 // Minimal TCP header with zeroed checksum
283 let tcp_header = [
284 0x00, 0x50, // Source port: 80
285 0x1F, 0x90, // Dest port: 8080
286 0x00, 0x00, 0x00, 0x01, // Seq number
287 0x00, 0x00, 0x00, 0x00, // Ack number
288 0x50, 0x02, // Data offset + flags (SYN)
289 0xFF, 0xFF, // Window
290 0x00, 0x00, // Checksum (zeroed)
291 0x00, 0x00, // Urgent pointer
292 ];
293
294 let checksum = tcp_checksum_ipv6(src_ip, dst_ip, &tcp_header);
295 assert_ne!(checksum, 0);
296
297 // Insert checksum and verify
298 let mut tcp_with_checksum = tcp_header;
299 tcp_with_checksum[16] = (checksum >> 8) as u8;
300 tcp_with_checksum[17] = (checksum & 0xFF) as u8;
301
302 assert!(verify_tcp_checksum_ipv6(src_ip, dst_ip, &tcp_with_checksum));
303 }
304
305 #[test]
306 fn test_tcp_checksum_generic() {
307 let src_ipv4 = [192, 168, 1, 1];
308 let dst_ipv4 = [192, 168, 1, 2];
309
310 let tcp_header = [
311 0x00, 0x50, // Source port: 80
312 0x1F, 0x90, // Dest port: 8080
313 0x00, 0x00, 0x00, 0x01, // Seq number
314 0x00, 0x00, 0x00, 0x00, // Ack number
315 0x50, 0x02, // Data offset + flags (SYN)
316 0xFF, 0xFF, // Window
317 0x00, 0x00, // Checksum (zeroed)
318 0x00, 0x00, // Urgent pointer
319 ];
320
321 let checksum = tcp_checksum(&src_ipv4, &dst_ipv4, &tcp_header);
322 assert!(checksum.is_some());
323
324 let checksum_direct = tcp_checksum_ipv4(
325 Ipv4Addr::from(src_ipv4),
326 Ipv4Addr::from(dst_ipv4),
327 &tcp_header,
328 );
329 assert_eq!(checksum.unwrap(), checksum_direct);
330 }
331
332 #[test]
333 fn test_pseudo_header() {
334 let src_ip = Ipv4Addr::new(192, 168, 1, 1);
335 let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
336 let tcp_len = 20u16;
337
338 let header = ipv4_pseudo_header(src_ip, dst_ip, tcp_len);
339
340 assert_eq!(&header[0..4], &[192, 168, 1, 1]);
341 assert_eq!(&header[4..8], &[192, 168, 1, 2]);
342 assert_eq!(header[8], 0);
343 assert_eq!(header[9], TCP_PROTOCOL);
344 assert_eq!(&header[10..12], &[0, 20]);
345 }
346
347 #[test]
348 fn test_invalid_checksum() {
349 let src_ip = Ipv4Addr::new(192, 168, 1, 1);
350 let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
351
352 // TCP header with bad checksum
353 let tcp_header = [
354 0x00, 0x50, 0x1F, 0x90, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02,
355 0xFF, 0xFF, 0xFF, 0xFF, // Bad checksum
356 0x00, 0x00,
357 ];
358
359 assert!(!verify_tcp_checksum(src_ip, dst_ip, &tcp_header));
360 }
361
362 #[test]
363 fn test_checksum_with_payload() {
364 let src_ip = Ipv4Addr::new(10, 0, 0, 1);
365 let dst_ip = Ipv4Addr::new(10, 0, 0, 2);
366
367 // TCP header + "Hello" payload
368 let mut tcp_segment = vec![
369 0x00, 0x50, // Source port: 80
370 0x00, 0x51, // Dest port: 81
371 0x00, 0x00, 0x00, 0x01, // Seq number
372 0x00, 0x00, 0x00, 0x01, // Ack number
373 0x50, 0x18, // Data offset + flags (PSH+ACK)
374 0xFF, 0xFF, // Window
375 0x00, 0x00, // Checksum (zeroed)
376 0x00, 0x00, // Urgent pointer
377 ];
378 tcp_segment.extend_from_slice(b"Hello");
379
380 let checksum = tcp_checksum_ipv4(src_ip, dst_ip, &tcp_segment);
381 tcp_segment[16] = (checksum >> 8) as u8;
382 tcp_segment[17] = (checksum & 0xFF) as u8;
383
384 assert!(verify_tcp_checksum(src_ip, dst_ip, &tcp_segment));
385 }
386}