Skip to main content

stackforge_core/layer/tcp/
checksum.rs

1//! TCP checksum calculation.
2//!
3//! TCP uses the Internet checksum (RFC 1071) computed over a pseudo-header
4//! followed by the TCP header and payload.
5//!
6//! # IPv4 Pseudo-header
7//!
8//! ```text
9//! +--------+--------+--------+--------+
10//! |           Source Address          |
11//! +--------+--------+--------+--------+
12//! |         Destination Address       |
13//! +--------+--------+--------+--------+
14//! |  zero  |  PTCL  |    TCP Length   |
15//! +--------+--------+--------+--------+
16//! ```
17//!
18//! # IPv6 Pseudo-header
19//!
20//! ```text
21//! +--------+--------+--------+--------+
22//! |                                   |
23//! +                                   +
24//! |           Source Address          |
25//! +                                   +
26//! |                                   |
27//! +                                   +
28//! |                                   |
29//! +--------+--------+--------+--------+
30//! |                                   |
31//! +                                   +
32//! |         Destination Address       |
33//! +                                   +
34//! |                                   |
35//! +                                   +
36//! |                                   |
37//! +--------+--------+--------+--------+
38//! |         Upper-Layer Packet Length |
39//! +--------+--------+--------+--------+
40//! |           zero          |Next Hdr|
41//! +--------+--------+--------+--------+
42//! ```
43
44use std::net::{Ipv4Addr, Ipv6Addr};
45
46use crate::layer::ipv4::checksum::{pseudo_header_checksum, transport_checksum};
47use crate::utils::{finalize_checksum, partial_checksum};
48
49/// TCP protocol number for pseudo-header.
50pub const TCP_PROTOCOL: u8 = 6;
51
52/// Compute TCP checksum with IPv4 pseudo-header.
53///
54/// # Arguments
55///
56/// * `src_ip` - Source IPv4 address
57/// * `dst_ip` - Destination IPv4 address
58/// * `tcp_data` - Complete TCP segment (header + payload)
59///
60/// # Returns
61///
62/// The computed checksum value.
63#[must_use]
64pub fn tcp_checksum_ipv4(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_data: &[u8]) -> u16 {
65    transport_checksum(&src_ip.octets(), &dst_ip.octets(), TCP_PROTOCOL, tcp_data)
66}
67
68/// Compute TCP checksum with IPv6 pseudo-header.
69///
70/// # Arguments
71///
72/// * `src_ip` - Source IPv6 address
73/// * `dst_ip` - Destination IPv6 address
74/// * `tcp_data` - Complete TCP segment (header + payload)
75///
76/// # Returns
77///
78/// The computed checksum value.
79#[must_use]
80pub fn tcp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_data: &[u8]) -> u16 {
81    let tcp_len = tcp_data.len() as u32;
82
83    // IPv6 pseudo-header checksum
84    let mut sum: u32 = 0;
85
86    // Source address (16 bytes)
87    let src_octets = src_ip.octets();
88    for chunk in src_octets.chunks(2) {
89        sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
90    }
91
92    // Destination address (16 bytes)
93    let dst_octets = dst_ip.octets();
94    for chunk in dst_octets.chunks(2) {
95        sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]]));
96    }
97
98    // Upper-layer packet length (4 bytes)
99    sum += (tcp_len >> 16) & 0xFFFF;
100    sum += tcp_len & 0xFFFF;
101
102    // Zero + Next Header (4 bytes, but only last byte is non-zero)
103    sum += u32::from(TCP_PROTOCOL);
104
105    // Add TCP data
106    sum = partial_checksum(tcp_data, sum);
107
108    // Finalize
109    finalize_checksum(sum)
110}
111
112/// Compute TCP checksum (generic version).
113///
114/// Automatically handles IPv4 or IPv6 based on address type.
115/// Uses raw byte slices for maximum flexibility.
116///
117/// # Arguments
118///
119/// * `src_ip` - Source IP address bytes (4 for IPv4, 16 for IPv6)
120/// * `dst_ip` - Destination IP address bytes
121/// * `tcp_data` - Complete TCP segment (header + payload)
122///
123/// # Returns
124///
125/// The computed checksum value, or None if address lengths are invalid.
126#[must_use]
127pub fn tcp_checksum(src_ip: &[u8], dst_ip: &[u8], tcp_data: &[u8]) -> Option<u16> {
128    match (src_ip.len(), dst_ip.len()) {
129        (4, 4) => {
130            let src: [u8; 4] = src_ip.try_into().ok()?;
131            let dst: [u8; 4] = dst_ip.try_into().ok()?;
132            Some(transport_checksum(&src, &dst, TCP_PROTOCOL, tcp_data))
133        },
134        (16, 16) => {
135            let src: [u8; 16] = src_ip.try_into().ok()?;
136            let dst: [u8; 16] = dst_ip.try_into().ok()?;
137            Some(tcp_checksum_ipv6(
138                Ipv6Addr::from(src),
139                Ipv6Addr::from(dst),
140                tcp_data,
141            ))
142        },
143        _ => None,
144    }
145}
146
147/// Verify TCP checksum with IPv4 pseudo-header.
148///
149/// # Arguments
150///
151/// * `src_ip` - Source IPv4 address
152/// * `dst_ip` - Destination IPv4 address
153/// * `tcp_data` - Complete TCP segment (header + payload) with checksum
154///
155/// # Returns
156///
157/// `true` if the checksum is valid.
158#[must_use]
159pub fn verify_tcp_checksum(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_data: &[u8]) -> bool {
160    if tcp_data.len() < 20 {
161        return false;
162    }
163
164    let checksum = tcp_checksum_ipv4(src_ip, dst_ip, tcp_data);
165    checksum == 0 || checksum == 0xFFFF
166}
167
168/// Verify TCP checksum with IPv6 pseudo-header.
169#[must_use]
170pub fn verify_tcp_checksum_ipv6(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_data: &[u8]) -> bool {
171    if tcp_data.len() < 20 {
172        return false;
173    }
174
175    let checksum = tcp_checksum_ipv6(src_ip, dst_ip, tcp_data);
176    checksum == 0 || checksum == 0xFFFF
177}
178
179/// Build the IPv4 pseudo-header bytes for TCP.
180///
181/// # Arguments
182///
183/// * `src_ip` - Source IPv4 address
184/// * `dst_ip` - Destination IPv4 address
185/// * `tcp_len` - Length of TCP segment (header + payload)
186///
187/// # Returns
188///
189/// 12-byte pseudo-header.
190#[must_use]
191pub fn ipv4_pseudo_header(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, tcp_len: u16) -> [u8; 12] {
192    let mut header = [0u8; 12];
193
194    // Source IP (4 bytes)
195    header[0..4].copy_from_slice(&src_ip.octets());
196
197    // Destination IP (4 bytes)
198    header[4..8].copy_from_slice(&dst_ip.octets());
199
200    // Zero (1 byte)
201    header[8] = 0;
202
203    // Protocol (1 byte)
204    header[9] = TCP_PROTOCOL;
205
206    // TCP Length (2 bytes)
207    header[10..12].copy_from_slice(&tcp_len.to_be_bytes());
208
209    header
210}
211
212/// Build the IPv6 pseudo-header bytes for TCP.
213///
214/// # Arguments
215///
216/// * `src_ip` - Source IPv6 address
217/// * `dst_ip` - Destination IPv6 address
218/// * `tcp_len` - Length of TCP segment (header + payload)
219///
220/// # Returns
221///
222/// 40-byte pseudo-header.
223#[must_use]
224pub fn ipv6_pseudo_header(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, tcp_len: u32) -> [u8; 40] {
225    let mut header = [0u8; 40];
226
227    // Source IP (16 bytes)
228    header[0..16].copy_from_slice(&src_ip.octets());
229
230    // Destination IP (16 bytes)
231    header[16..32].copy_from_slice(&dst_ip.octets());
232
233    // Upper-Layer Packet Length (4 bytes)
234    header[32..36].copy_from_slice(&tcp_len.to_be_bytes());
235
236    // Zero (3 bytes) + Next Header (1 byte)
237    header[36] = 0;
238    header[37] = 0;
239    header[38] = 0;
240    header[39] = TCP_PROTOCOL;
241
242    header
243}
244
245/// Compute checksum for partial data (for segmented computation).
246///
247/// Useful when computing checksum across multiple buffers.
248#[must_use]
249pub fn tcp_partial_checksum(src_ip: &[u8; 4], dst_ip: &[u8; 4], tcp_len: u16) -> u32 {
250    pseudo_header_checksum(src_ip, dst_ip, TCP_PROTOCOL, tcp_len)
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_tcp_checksum_ipv4() {
259        let src_ip = Ipv4Addr::new(192, 168, 1, 1);
260        let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
261
262        // Minimal TCP header with zeroed checksum
263        let tcp_header = [
264            0x00, 0x50, // Source port: 80
265            0x1F, 0x90, // Dest port: 8080
266            0x00, 0x00, 0x00, 0x01, // Seq number
267            0x00, 0x00, 0x00, 0x00, // Ack number
268            0x50, 0x02, // Data offset + flags (SYN)
269            0xFF, 0xFF, // Window
270            0x00, 0x00, // Checksum (zeroed)
271            0x00, 0x00, // Urgent pointer
272        ];
273
274        let checksum = tcp_checksum_ipv4(src_ip, dst_ip, &tcp_header);
275        assert_ne!(checksum, 0);
276
277        // Insert checksum and verify
278        let mut tcp_with_checksum = tcp_header;
279        tcp_with_checksum[16] = (checksum >> 8) as u8;
280        tcp_with_checksum[17] = (checksum & 0xFF) as u8;
281
282        assert!(verify_tcp_checksum(src_ip, dst_ip, &tcp_with_checksum));
283    }
284
285    #[test]
286    fn test_tcp_checksum_ipv6() {
287        let src_ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1);
288        let dst_ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2);
289
290        // Minimal TCP header with zeroed checksum
291        let tcp_header = [
292            0x00, 0x50, // Source port: 80
293            0x1F, 0x90, // Dest port: 8080
294            0x00, 0x00, 0x00, 0x01, // Seq number
295            0x00, 0x00, 0x00, 0x00, // Ack number
296            0x50, 0x02, // Data offset + flags (SYN)
297            0xFF, 0xFF, // Window
298            0x00, 0x00, // Checksum (zeroed)
299            0x00, 0x00, // Urgent pointer
300        ];
301
302        let checksum = tcp_checksum_ipv6(src_ip, dst_ip, &tcp_header);
303        assert_ne!(checksum, 0);
304
305        // Insert checksum and verify
306        let mut tcp_with_checksum = tcp_header;
307        tcp_with_checksum[16] = (checksum >> 8) as u8;
308        tcp_with_checksum[17] = (checksum & 0xFF) as u8;
309
310        assert!(verify_tcp_checksum_ipv6(src_ip, dst_ip, &tcp_with_checksum));
311    }
312
313    #[test]
314    fn test_tcp_checksum_generic() {
315        let src_ipv4 = [192, 168, 1, 1];
316        let dst_ipv4 = [192, 168, 1, 2];
317
318        let tcp_header = [
319            0x00, 0x50, // Source port: 80
320            0x1F, 0x90, // Dest port: 8080
321            0x00, 0x00, 0x00, 0x01, // Seq number
322            0x00, 0x00, 0x00, 0x00, // Ack number
323            0x50, 0x02, // Data offset + flags (SYN)
324            0xFF, 0xFF, // Window
325            0x00, 0x00, // Checksum (zeroed)
326            0x00, 0x00, // Urgent pointer
327        ];
328
329        let checksum = tcp_checksum(&src_ipv4, &dst_ipv4, &tcp_header);
330        assert!(checksum.is_some());
331
332        let checksum_direct = tcp_checksum_ipv4(
333            Ipv4Addr::from(src_ipv4),
334            Ipv4Addr::from(dst_ipv4),
335            &tcp_header,
336        );
337        assert_eq!(checksum.unwrap(), checksum_direct);
338    }
339
340    #[test]
341    fn test_pseudo_header() {
342        let src_ip = Ipv4Addr::new(192, 168, 1, 1);
343        let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
344        let tcp_len = 20u16;
345
346        let header = ipv4_pseudo_header(src_ip, dst_ip, tcp_len);
347
348        assert_eq!(&header[0..4], &[192, 168, 1, 1]);
349        assert_eq!(&header[4..8], &[192, 168, 1, 2]);
350        assert_eq!(header[8], 0);
351        assert_eq!(header[9], TCP_PROTOCOL);
352        assert_eq!(&header[10..12], &[0, 20]);
353    }
354
355    #[test]
356    fn test_invalid_checksum() {
357        let src_ip = Ipv4Addr::new(192, 168, 1, 1);
358        let dst_ip = Ipv4Addr::new(192, 168, 1, 2);
359
360        // TCP header with bad checksum
361        let tcp_header = [
362            0x00, 0x50, 0x1F, 0x90, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02,
363            0xFF, 0xFF, 0xFF, 0xFF, // Bad checksum
364            0x00, 0x00,
365        ];
366
367        assert!(!verify_tcp_checksum(src_ip, dst_ip, &tcp_header));
368    }
369
370    #[test]
371    fn test_checksum_with_payload() {
372        let src_ip = Ipv4Addr::new(10, 0, 0, 1);
373        let dst_ip = Ipv4Addr::new(10, 0, 0, 2);
374
375        // TCP header + "Hello" payload
376        let mut tcp_segment = vec![
377            0x00, 0x50, // Source port: 80
378            0x00, 0x51, // Dest port: 81
379            0x00, 0x00, 0x00, 0x01, // Seq number
380            0x00, 0x00, 0x00, 0x01, // Ack number
381            0x50, 0x18, // Data offset + flags (PSH+ACK)
382            0xFF, 0xFF, // Window
383            0x00, 0x00, // Checksum (zeroed)
384            0x00, 0x00, // Urgent pointer
385        ];
386        tcp_segment.extend_from_slice(b"Hello");
387
388        let checksum = tcp_checksum_ipv4(src_ip, dst_ip, &tcp_segment);
389        tcp_segment[16] = (checksum >> 8) as u8;
390        tcp_segment[17] = (checksum & 0xFF) as u8;
391
392        assert!(verify_tcp_checksum(src_ip, dst_ip, &tcp_segment));
393    }
394}