Skip to main content

stackforge_core/layer/tcp/
checksum.rs

1//! TCP checksum calculation.
2//!
3//! TCP uses the Internet checksum (RFC 1071) computed over a pseudo-header
4//! followed by the TCP header and payload.
5//!
6//! # IPv4 Pseudo-header
7//!
8//! ```text
9//! +--------+--------+--------+--------+
10//! |           Source Address          |
11//! +--------+--------+--------+--------+
12//! |         Destination Address       |
13//! +--------+--------+--------+--------+
14//! |  zero  |  PTCL  |    TCP Length   |
15//! +--------+--------+--------+--------+
16//! ```
17//!
18//! # IPv6 Pseudo-header
19//!
20//! ```text
21//! +--------+--------+--------+--------+
22//! |                                   |
23//! +                                   +
24//! |           Source Address          |
25//! +                                   +
26//! |                                   |
27//! +                                   +
28//! |                                   |
29//! +--------+--------+--------+--------+
30//! |                                   |
31//! +                                   +
32//! |         Destination Address       |
33//! +                                   +
34//! |                                   |
35//! +                                   +
36//! |                                   |
37//! +--------+--------+--------+--------+
38//! |         Upper-Layer Packet Length |
39//! +--------+--------+--------+--------+
40//! |           zero          |Next Hdr|
41//! +--------+--------+--------+--------+
42//! ```
43
44use std::net::{Ipv4Addr, Ipv6Addr};
45
46use crate::layer::ipv4::checksum::{pseudo_header_checksum, transport_checksum};
47use crate::utils::{finalize_checksum, partial_checksum};
48
49/// TCP protocol number for pseudo-header.
50pub const TCP_PROTOCOL: u8 = 6;
51
52/// Compute TCP checksum with IPv4 pseudo-header.
53///
54/// # Arguments
55///
56/// * `src_ip` - Source IPv4 address
57/// * `dst_ip` - Destination IPv4 address
58/// * `tcp_data` - Complete TCP segment (header + payload)
59///
60/// # Returns
61///
62/// The computed checksum value.
63pub fn tcp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_data: &[u8]) -> u16 {
64    transport_checksum(&src_ip.octets(), &dst_ip.octets(), TCP_PROTOCOL, tcp_data)
65}
66
67/// Compute TCP checksum with IPv6 pseudo-header.
68///
69/// # Arguments
70///
71/// * `src_ip` - Source IPv6 address
72/// * `dst_ip` - Destination IPv6 address
73/// * `tcp_data` - Complete TCP segment (header + payload)
74///
75/// # Returns
76///
77/// The computed checksum value.
78pub fn tcp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_data: &[u8]) -> u16 {
79    let tcp_len = tcp_data.len() as u32;
80
81    // IPv6 pseudo-header checksum
82    let mut sum: u32 = 0;
83
84    // Source address (16 bytes)
85    let src_octets = src_ip.octets();
86    for chunk in src_octets.chunks(2) {
87        sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
88    }
89
90    // Destination address (16 bytes)
91    let dst_octets = dst_ip.octets();
92    for chunk in dst_octets.chunks(2) {
93        sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
94    }
95
96    // Upper-layer packet length (4 bytes)
97    sum += ((tcp_len >> 16) & 0xFFFF) as u32;
98    sum += (tcp_len & 0xFFFF) as u32;
99
100    // Zero + Next Header (4 bytes, but only last byte is non-zero)
101    sum += TCP_PROTOCOL as u32;
102
103    // Add TCP data
104    sum = partial_checksum(tcp_data, sum);
105
106    // Finalize
107    finalize_checksum(sum)
108}
109
110/// Compute TCP checksum (generic version).
111///
112/// Automatically handles IPv4 or IPv6 based on address type.
113/// Uses raw byte slices for maximum flexibility.
114///
115/// # Arguments
116///
117/// * `src_ip` - Source IP address bytes (4 for IPv4, 16 for IPv6)
118/// * `dst_ip` - Destination IP address bytes
119/// * `tcp_data` - Complete TCP segment (header + payload)
120///
121/// # Returns
122///
123/// The computed checksum value, or None if address lengths are invalid.
124pub fn tcp_checksum(src_ip: &[u8], dst_ip: &[u8], tcp_data: &[u8]) -> Option<u16> {
125    match (src_ip.len(), dst_ip.len()) {
126        (4, 4) => {
127            let src: [u8; 4] = src_ip.try_into().ok()?;
128            let dst: [u8; 4] = dst_ip.try_into().ok()?;
129            Some(transport_checksum(&src, &dst, TCP_PROTOCOL, tcp_data))
130        },
131        (16, 16) => {
132            let src: [u8; 16] = src_ip.try_into().ok()?;
133            let dst: [u8; 16] = dst_ip.try_into().ok()?;
134            Some(tcp_checksum_ipv6(
135                Ipv6Addr::from(src),
136                Ipv6Addr::from(dst),
137                tcp_data,
138            ))
139        },
140        _ => None,
141    }
142}
143
144/// Verify TCP checksum with IPv4 pseudo-header.
145///
146/// # Arguments
147///
148/// * `src_ip` - Source IPv4 address
149/// * `dst_ip` - Destination IPv4 address
150/// * `tcp_data` - Complete TCP segment (header + payload) with checksum
151///
152/// # Returns
153///
154/// `true` if the checksum is valid.
155pub fn verify_tcp_checksum(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_data: &[u8]) -> bool {
156    if tcp_data.len() < 20 {
157        return false;
158    }
159
160    let checksum = tcp_checksum_ipv4(src_ip, dst_ip, tcp_data);
161    checksum == 0 || checksum == 0xFFFF
162}
163
164/// Verify TCP checksum with IPv6 pseudo-header.
165pub fn verify_tcp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_data: &[u8]) -> bool {
166    if tcp_data.len() < 20 {
167        return false;
168    }
169
170    let checksum = tcp_checksum_ipv6(src_ip, dst_ip, tcp_data);
171    checksum == 0 || checksum == 0xFFFF
172}
173
174/// Build the IPv4 pseudo-header bytes for TCP.
175///
176/// # Arguments
177///
178/// * `src_ip` - Source IPv4 address
179/// * `dst_ip` - Destination IPv4 address
180/// * `tcp_len` - Length of TCP segment (header + payload)
181///
182/// # Returns
183///
184/// 12-byte pseudo-header.
185pub fn ipv4_pseudo_header(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_len: u16) -> [u8; 12] {
186    let mut header = [0u8; 12];
187
188    // Source IP (4 bytes)
189    header[0..4].copy_from_slice(&src_ip.octets());
190
191    // Destination IP (4 bytes)
192    header[4..8].copy_from_slice(&dst_ip.octets());
193
194    // Zero (1 byte)
195    header[8] = 0;
196
197    // Protocol (1 byte)
198    header[9] = TCP_PROTOCOL;
199
200    // TCP Length (2 bytes)
201    header[10..12].copy_from_slice(&tcp_len.to_be_bytes());
202
203    header
204}
205
206/// Build the IPv6 pseudo-header bytes for TCP.
207///
208/// # Arguments
209///
210/// * `src_ip` - Source IPv6 address
211/// * `dst_ip` - Destination IPv6 address
212/// * `tcp_len` - Length of TCP segment (header + payload)
213///
214/// # Returns
215///
216/// 40-byte pseudo-header.
217pub fn ipv6_pseudo_header(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_len: u32) -> [u8; 40] {
218    let mut header = [0u8; 40];
219
220    // Source IP (16 bytes)
221    header[0..16].copy_from_slice(&src_ip.octets());
222
223    // Destination IP (16 bytes)
224    header[16..32].copy_from_slice(&dst_ip.octets());
225
226    // Upper-Layer Packet Length (4 bytes)
227    header[32..36].copy_from_slice(&tcp_len.to_be_bytes());
228
229    // Zero (3 bytes) + Next Header (1 byte)
230    header[36] = 0;
231    header[37] = 0;
232    header[38] = 0;
233    header[39] = TCP_PROTOCOL;
234
235    header
236}
237
238/// Compute checksum for partial data (for segmented computation).
239///
240/// Useful when computing checksum across multiple buffers.
241pub fn tcp_partial_checksum(src_ip: &[u8; 4], dst_ip: &[u8; 4], tcp_len: u16) -> u32 {
242    pseudo_header_checksum(src_ip, dst_ip, TCP_PROTOCOL, tcp_len)
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_tcp_checksum_ipv4() {
251        let src_ip = Ipv4Addr::new(192, 168, 1, 1);
252        let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
253
254        // Minimal TCP header with zeroed checksum
255        let tcp_header = [
256            0x00, 0x50, // Source port: 80
257            0x1F, 0x90, // Dest port: 8080
258            0x00, 0x00, 0x00, 0x01, // Seq number
259            0x00, 0x00, 0x00, 0x00, // Ack number
260            0x50, 0x02, // Data offset + flags (SYN)
261            0xFF, 0xFF, // Window
262            0x00, 0x00, // Checksum (zeroed)
263            0x00, 0x00, // Urgent pointer
264        ];
265
266        let checksum = tcp_checksum_ipv4(src_ip, dst_ip, &tcp_header);
267        assert_ne!(checksum, 0);
268
269        // Insert checksum and verify
270        let mut tcp_with_checksum = tcp_header;
271        tcp_with_checksum[16] = (checksum >> 8) as u8;
272        tcp_with_checksum[17] = (checksum & 0xFF) as u8;
273
274        assert!(verify_tcp_checksum(src_ip, dst_ip, &tcp_with_checksum));
275    }
276
277    #[test]
278    fn test_tcp_checksum_ipv6() {
279        let src_ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1);
280        let dst_ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2);
281
282        // Minimal TCP header with zeroed checksum
283        let tcp_header = [
284            0x00, 0x50, // Source port: 80
285            0x1F, 0x90, // Dest port: 8080
286            0x00, 0x00, 0x00, 0x01, // Seq number
287            0x00, 0x00, 0x00, 0x00, // Ack number
288            0x50, 0x02, // Data offset + flags (SYN)
289            0xFF, 0xFF, // Window
290            0x00, 0x00, // Checksum (zeroed)
291            0x00, 0x00, // Urgent pointer
292        ];
293
294        let checksum = tcp_checksum_ipv6(src_ip, dst_ip, &tcp_header);
295        assert_ne!(checksum, 0);
296
297        // Insert checksum and verify
298        let mut tcp_with_checksum = tcp_header;
299        tcp_with_checksum[16] = (checksum >> 8) as u8;
300        tcp_with_checksum[17] = (checksum & 0xFF) as u8;
301
302        assert!(verify_tcp_checksum_ipv6(src_ip, dst_ip, &tcp_with_checksum));
303    }
304
305    #[test]
306    fn test_tcp_checksum_generic() {
307        let src_ipv4 = [192, 168, 1, 1];
308        let dst_ipv4 = [192, 168, 1, 2];
309
310        let tcp_header = [
311            0x00, 0x50, // Source port: 80
312            0x1F, 0x90, // Dest port: 8080
313            0x00, 0x00, 0x00, 0x01, // Seq number
314            0x00, 0x00, 0x00, 0x00, // Ack number
315            0x50, 0x02, // Data offset + flags (SYN)
316            0xFF, 0xFF, // Window
317            0x00, 0x00, // Checksum (zeroed)
318            0x00, 0x00, // Urgent pointer
319        ];
320
321        let checksum = tcp_checksum(&src_ipv4, &dst_ipv4, &tcp_header);
322        assert!(checksum.is_some());
323
324        let checksum_direct = tcp_checksum_ipv4(
325            Ipv4Addr::from(src_ipv4),
326            Ipv4Addr::from(dst_ipv4),
327            &tcp_header,
328        );
329        assert_eq!(checksum.unwrap(), checksum_direct);
330    }
331
332    #[test]
333    fn test_pseudo_header() {
334        let src_ip = Ipv4Addr::new(192, 168, 1, 1);
335        let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
336        let tcp_len = 20u16;
337
338        let header = ipv4_pseudo_header(src_ip, dst_ip, tcp_len);
339
340        assert_eq!(&header[0..4], &[192, 168, 1, 1]);
341        assert_eq!(&header[4..8], &[192, 168, 1, 2]);
342        assert_eq!(header[8], 0);
343        assert_eq!(header[9], TCP_PROTOCOL);
344        assert_eq!(&header[10..12], &[0, 20]);
345    }
346
347    #[test]
348    fn test_invalid_checksum() {
349        let src_ip = Ipv4Addr::new(192, 168, 1, 1);
350        let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
351
352        // TCP header with bad checksum
353        let tcp_header = [
354            0x00, 0x50, 0x1F, 0x90, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02,
355            0xFF, 0xFF, 0xFF, 0xFF, // Bad checksum
356            0x00, 0x00,
357        ];
358
359        assert!(!verify_tcp_checksum(src_ip, dst_ip, &tcp_header));
360    }
361
362    #[test]
363    fn test_checksum_with_payload() {
364        let src_ip = Ipv4Addr::new(10, 0, 0, 1);
365        let dst_ip = Ipv4Addr::new(10, 0, 0, 2);
366
367        // TCP header + "Hello" payload
368        let mut tcp_segment = vec![
369            0x00, 0x50, // Source port: 80
370            0x00, 0x51, // Dest port: 81
371            0x00, 0x00, 0x00, 0x01, // Seq number
372            0x00, 0x00, 0x00, 0x01, // Ack number
373            0x50, 0x18, // Data offset + flags (PSH+ACK)
374            0xFF, 0xFF, // Window
375            0x00, 0x00, // Checksum (zeroed)
376            0x00, 0x00, // Urgent pointer
377        ];
378        tcp_segment.extend_from_slice(b"Hello");
379
380        let checksum = tcp_checksum_ipv4(src_ip, dst_ip, &tcp_segment);
381        tcp_segment[16] = (checksum >> 8) as u8;
382        tcp_segment[17] = (checksum & 0xFF) as u8;
383
384        assert!(verify_tcp_checksum(src_ip, dst_ip, &tcp_segment));
385    }
386}