1use crate::ip::IpNextLevelProtocol;
4use crate::Packet;
5use crate::PrimitiveValues;
6
7use alloc::{vec, vec::Vec};
8
9use xenet_macro::packet;
10use xenet_macro_helper::types::*;
11
12use crate::util::{self, Octets};
13use std::net::Ipv4Addr;
14use std::net::Ipv6Addr;
15
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18
19pub const TCP_HEADER_LEN: usize = MutableTcpPacket::minimum_packet_size();
21pub const TCP_MIN_DATA_OFFSET: u8 = 5;
23pub const TCP_OPTION_MAX_LEN: usize = 40;
25pub const TCP_HEADER_MAX_LEN: usize = TCP_HEADER_LEN + TCP_OPTION_MAX_LEN;
27
28#[derive(Clone, Debug, PartialEq, Eq)]
30#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
31pub struct TcpOptionHeader {
32    pub kind: TcpOptionKind,
33    pub length: Option<u8>,
34    pub data: Vec<u8>,
35}
36
37impl TcpOptionHeader {
38    pub fn get_timestamp(&self) -> (u32, u32) {
40        if self.kind == TcpOptionKind::TIMESTAMPS && self.data.len() >= 8 {
41            let mut my: [u8; 4] = [0; 4];
42            my.copy_from_slice(&self.data[0..4]);
43            let mut their: [u8; 4] = [0; 4];
44            their.copy_from_slice(&self.data[4..8]);
45            (u32::from_be_bytes(my), u32::from_be_bytes(their))
46        } else {
47            return (0, 0);
48        }
49    }
50    pub fn get_mss(&self) -> u16 {
52        if self.kind == TcpOptionKind::MSS && self.data.len() >= 2 {
53            let mut mss: [u8; 2] = [0; 2];
54            mss.copy_from_slice(&self.data[0..2]);
55            u16::from_be_bytes(mss)
56        } else {
57            0
58        }
59    }
60    pub fn get_wscale(&self) -> u8 {
62        if self.kind == TcpOptionKind::WSCALE && self.data.len() > 0 {
63            self.data[0]
64        } else {
65            0
66        }
67    }
68}
69
70#[derive(Clone, Debug, PartialEq, Eq)]
72#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
73pub struct TcpHeader {
74    pub source: u16be,
75    pub destination: u16be,
76    pub sequence: u32be,
77    pub acknowledgement: u32be,
78    pub data_offset: u4,
79    pub reserved: u4,
80    pub flags: u8,
81    pub window: u16be,
82    pub checksum: u16be,
83    pub urgent_ptr: u16be,
84    pub options: Vec<TcpOptionHeader>,
85}
86
87impl TcpHeader {
88    pub fn from_bytes(packet: &[u8]) -> Result<TcpHeader, String> {
90        if packet.len() < TCP_HEADER_LEN {
91            return Err("Packet is too small for TCP header".to_string());
92        }
93        match TcpPacket::new(packet) {
94            Some(tcp_packet) => Ok(TcpHeader {
95                source: tcp_packet.get_source(),
96                destination: tcp_packet.get_destination(),
97                sequence: tcp_packet.get_sequence(),
98                acknowledgement: tcp_packet.get_acknowledgement(),
99                data_offset: tcp_packet.get_data_offset(),
100                reserved: tcp_packet.get_reserved(),
101                flags: tcp_packet.get_flags(),
102                window: tcp_packet.get_window(),
103                checksum: tcp_packet.get_checksum(),
104                urgent_ptr: tcp_packet.get_urgent_ptr(),
105                options: tcp_packet
106                    .get_options_iter()
107                    .map(|opt| TcpOptionHeader {
108                        kind: opt.get_kind(),
109                        length: opt.get_length_raw().first().cloned(),
110                        data: opt.payload().to_vec(),
111                    })
112                    .collect(),
113            }),
114            None => Err("Failed to parse TCP packet".to_string()),
115        }
116    }
117    pub(crate) fn from_packet(tcp_packet: &TcpPacket) -> TcpHeader {
119        TcpHeader {
120            source: tcp_packet.get_source(),
121            destination: tcp_packet.get_destination(),
122            sequence: tcp_packet.get_sequence(),
123            acknowledgement: tcp_packet.get_acknowledgement(),
124            data_offset: tcp_packet.get_data_offset(),
125            reserved: tcp_packet.get_reserved(),
126            flags: tcp_packet.get_flags(),
127            window: tcp_packet.get_window(),
128            checksum: tcp_packet.get_checksum(),
129            urgent_ptr: tcp_packet.get_urgent_ptr(),
130            options: tcp_packet
131                .get_options_iter()
132                .map(|opt| TcpOptionHeader {
133                    kind: opt.get_kind(),
134                    length: opt.get_length_raw().first().cloned(),
135                    data: opt.payload().to_vec(),
136                })
137                .collect(),
138        }
139    }
140}
141
142#[allow(non_snake_case)]
145#[allow(non_upper_case_globals)]
146pub mod TcpFlags {
147    pub const CWR: u8 = 0b10000000;
151    pub const ECE: u8 = 0b01000000;
157    pub const URG: u8 = 0b00100000;
159    pub const ACK: u8 = 0b00010000;
162    pub const PSH: u8 = 0b00001000;
164    pub const RST: u8 = 0b00000100;
166    pub const SYN: u8 = 0b00000010;
169    pub const FIN: u8 = 0b00000001;
171}
172
173#[packet]
175pub struct Tcp {
176    pub source: u16be,
177    pub destination: u16be,
178    pub sequence: u32be,
179    pub acknowledgement: u32be,
180    pub data_offset: u4,
181    pub reserved: u4,
182    pub flags: u8,
183    pub window: u16be,
184    pub checksum: u16be,
185    pub urgent_ptr: u16be,
186    #[length_fn = "tcp_options_length"]
187    pub options: Vec<TcpOption>,
188    #[payload]
189    pub payload: Vec<u8>,
190}
191
192#[allow(non_camel_case_types)]
195#[repr(u8)]
196#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
197#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
198pub enum TcpOptionKind {
199    EOL = 0,
200    NOP = 1,
201    MSS = 2,
202    WSCALE = 3,
203    SACK_PERMITTED = 4,
204    SACK = 5,
205    TIMESTAMPS = 8,
206}
207
208impl TcpOptionKind {
209    pub fn new(n: u8) -> TcpOptionKind {
211        match n {
212            0 => TcpOptionKind::EOL,
213            1 => TcpOptionKind::NOP,
214            2 => TcpOptionKind::MSS,
215            3 => TcpOptionKind::WSCALE,
216            4 => TcpOptionKind::SACK_PERMITTED,
217            5 => TcpOptionKind::SACK,
218            8 => TcpOptionKind::TIMESTAMPS,
219            _ => panic!("Unknown TCP option kind: {}", n),
220        }
221    }
222    pub fn name(&self) -> String {
224        match *self {
225            TcpOptionKind::EOL => String::from("EOL"),
226            TcpOptionKind::NOP => String::from("NOP"),
227            TcpOptionKind::MSS => String::from("MSS"),
228            TcpOptionKind::WSCALE => String::from("WSCALE"),
229            TcpOptionKind::SACK_PERMITTED => String::from("SACK_PERMITTED"),
230            TcpOptionKind::SACK => String::from("SACK"),
231            TcpOptionKind::TIMESTAMPS => String::from("TIMESTAMPS"),
232        }
233    }
234    pub fn size(&self) -> usize {
236        match *self {
237            TcpOptionKind::EOL => 1,
238            TcpOptionKind::NOP => 1,
239            TcpOptionKind::MSS => 4,
240            TcpOptionKind::WSCALE => 3,
241            TcpOptionKind::SACK_PERMITTED => 2,
242            TcpOptionKind::SACK => 10,
243            TcpOptionKind::TIMESTAMPS => 10,
244        }
245    }
246}
247
248impl PrimitiveValues for TcpOptionKind {
249    type T = (u8,);
250    fn to_primitive_values(&self) -> (u8,) {
251        (*self as u8,)
252    }
253}
254
255#[packet]
257pub struct TcpOption {
258    #[construct_with(u8)]
259    kind: TcpOptionKind,
260    #[length_fn = "tcp_option_length"]
261    length: Vec<u8>,
264    #[length_fn = "tcp_option_payload_length"]
265    #[payload]
266    data: Vec<u8>,
267}
268
269impl TcpOption {
270    pub fn nop() -> Self {
272        TcpOption {
273            kind: TcpOptionKind::NOP,
274            length: vec![],
275            data: vec![],
276        }
277    }
278
279    pub fn timestamp(my: u32, their: u32) -> Self {
283        let mut data = vec![];
284        data.extend_from_slice(&my.octets()[..]);
285        data.extend_from_slice(&their.octets()[..]);
286
287        TcpOption {
288            kind: TcpOptionKind::TIMESTAMPS,
289            length: vec![10],
290            data: data,
291        }
292    }
293
294    pub fn mss(val: u16) -> Self {
297        let mut data = vec![];
298        data.extend_from_slice(&val.octets()[..]);
299
300        TcpOption {
301            kind: TcpOptionKind::MSS,
302            length: vec![4],
303            data: data,
304        }
305    }
306
307    pub fn wscale(val: u8) -> Self {
310        TcpOption {
311            kind: TcpOptionKind::WSCALE,
312            length: vec![3],
313            data: vec![val],
314        }
315    }
316
317    pub fn sack_perm() -> Self {
321        TcpOption {
322            kind: TcpOptionKind::SACK_PERMITTED,
323            length: vec![2],
324            data: vec![],
325        }
326    }
327
328    pub fn selective_ack(acks: &[u32]) -> Self {
333        let mut data = vec![];
334        for ack in acks {
335            data.extend_from_slice(&ack.octets()[..]);
336        }
337        TcpOption {
338            kind: TcpOptionKind::SACK,
339            length: vec![1 + 1 + data.len() as u8],
340            data: data,
341        }
342    }
343    pub fn kind(&self) -> TcpOptionKind {
345        self.kind
346    }
347    pub fn length(&self) -> u8 {
349        if self.length.is_empty() {
350            0
351        } else {
352            self.length[0]
353        }
354    }
355    pub fn get_timestamp(&self) -> (u32, u32) {
357        if self.kind == TcpOptionKind::TIMESTAMPS && self.data.len() >= 8 {
358            let mut my: [u8; 4] = [0; 4];
359            my.copy_from_slice(&self.data[0..4]);
360            let mut their: [u8; 4] = [0; 4];
361            their.copy_from_slice(&self.data[4..8]);
362            (u32::from_be_bytes(my), u32::from_be_bytes(their))
363        } else {
364            return (0, 0);
365        }
366    }
367    pub fn get_mss(&self) -> u16 {
369        if self.kind == TcpOptionKind::MSS && self.data.len() >= 2 {
370            let mut mss: [u8; 2] = [0; 2];
371            mss.copy_from_slice(&self.data[0..2]);
372            u16::from_be_bytes(mss)
373        } else {
374            0
375        }
376    }
377    pub fn get_wscale(&self) -> u8 {
379        if self.kind == TcpOptionKind::WSCALE && self.data.len() > 0 {
380            self.data[0]
381        } else {
382            0
383        }
384    }
385}
386
387#[inline]
391fn tcp_option_length(option: &TcpOptionPacket) -> usize {
392    match option.get_kind() {
393        TcpOptionKind::EOL => 0,
394        TcpOptionKind::NOP => 0,
395        _ => 1,
396    }
397}
398
399fn tcp_option_payload_length(ipv4_option: &TcpOptionPacket) -> usize {
400    match ipv4_option.get_length_raw().first() {
401        Some(len) if *len >= 2 => *len as usize - 2,
402        _ => 0,
403    }
404}
405
406#[inline]
407fn tcp_options_length(tcp: &TcpPacket) -> usize {
408    let data_offset = tcp.get_data_offset();
409
410    if data_offset > 5 {
411        data_offset as usize * 4 - 20
412    } else {
413        0
414    }
415}
416
417pub fn ipv4_checksum(packet: &TcpPacket, source: &Ipv4Addr, destination: &Ipv4Addr) -> u16 {
419    ipv4_checksum_adv(packet, &[], source, destination)
420}
421
422pub fn ipv4_checksum_adv(
430    packet: &TcpPacket,
431    extra_data: &[u8],
432    source: &Ipv4Addr,
433    destination: &Ipv4Addr,
434) -> u16 {
435    util::ipv4_checksum(
436        packet.packet(),
437        8,
438        extra_data,
439        source,
440        destination,
441        IpNextLevelProtocol::Tcp,
442    )
443}
444
445pub fn ipv6_checksum(packet: &TcpPacket, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16 {
447    ipv6_checksum_adv(packet, &[], source, destination)
448}
449
450pub fn ipv6_checksum_adv(
458    packet: &TcpPacket,
459    extra_data: &[u8],
460    source: &Ipv6Addr,
461    destination: &Ipv6Addr,
462) -> u16 {
463    util::ipv6_checksum(
464        packet.packet(),
465        8,
466        extra_data,
467        source,
468        destination,
469        IpNextLevelProtocol::Tcp,
470    )
471}
472
473#[test]
474fn tcp_header_ipv4_test() {
475    use crate::ip::IpNextLevelProtocol;
476    use crate::ipv4::MutableIpv4Packet;
477
478    const IPV4_HEADER_LEN: usize = 20;
479    const TCP_HEADER_LEN: usize = 32;
480    const TEST_DATA_LEN: usize = 4;
481
482    let mut packet = [0u8; IPV4_HEADER_LEN + TCP_HEADER_LEN + TEST_DATA_LEN];
483    let ipv4_source = Ipv4Addr::new(192, 168, 2, 1);
484    let ipv4_destination = Ipv4Addr::new(192, 168, 111, 51);
485    {
486        let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap();
487        ip_header.set_next_level_protocol(IpNextLevelProtocol::Tcp);
488        ip_header.set_source(ipv4_source);
489        ip_header.set_destination(ipv4_destination);
490    }
491
492    packet[IPV4_HEADER_LEN + TCP_HEADER_LEN] = 't' as u8;
494    packet[IPV4_HEADER_LEN + TCP_HEADER_LEN + 1] = 'e' as u8;
495    packet[IPV4_HEADER_LEN + TCP_HEADER_LEN + 2] = 's' as u8;
496    packet[IPV4_HEADER_LEN + TCP_HEADER_LEN + 3] = 't' as u8;
497
498    {
499        let mut tcp_header = MutableTcpPacket::new(&mut packet[IPV4_HEADER_LEN..]).unwrap();
500        tcp_header.set_source(49511);
501        assert_eq!(tcp_header.get_source(), 49511);
502
503        tcp_header.set_destination(9000);
504        assert_eq!(tcp_header.get_destination(), 9000);
505
506        tcp_header.set_sequence(0x9037d2b8);
507        assert_eq!(tcp_header.get_sequence(), 0x9037d2b8);
508
509        tcp_header.set_acknowledgement(0x944bb276);
510        assert_eq!(tcp_header.get_acknowledgement(), 0x944bb276);
511
512        tcp_header.set_flags(TcpFlags::PSH | TcpFlags::ACK);
513        assert_eq!(tcp_header.get_flags(), TcpFlags::PSH | TcpFlags::ACK);
514
515        tcp_header.set_window(4015);
516        assert_eq!(tcp_header.get_window(), 4015);
517
518        tcp_header.set_data_offset(8);
519        assert_eq!(tcp_header.get_data_offset(), 8);
520
521        let ts = TcpOption::timestamp(743951781, 44056978);
522        tcp_header.set_options(&vec![TcpOption::nop(), TcpOption::nop(), ts]);
523
524        let checksum = ipv4_checksum(&tcp_header.to_immutable(), &ipv4_source, &ipv4_destination);
525        tcp_header.set_checksum(checksum);
526        assert_eq!(tcp_header.get_checksum(), 0xc031);
527    }
528    let ref_packet = [
529        0xc1, 0x67, 0x23, 0x28, 0x90, 0x37, 0xd2, 0xb8, 0x94, 0x4b, 0xb2, 0x76, 0x80, 0x18, 0x0f, 0xaf, 0xc0, 0x31, 0x00, 0x00, 0x01, 0x01, 0x08, 0x0a, 0x2c, 0x57, 0xcd, 0xa5, 0x02, 0xa0, 0x41, 0x92, 0x74, 0x65, 0x73, 0x74, ];
540    assert_eq!(&ref_packet[..], &packet[20..]);
541}
542
543#[test]
544fn tcp_test_options_invalid_offset() {
545    let mut buf = [0; 20]; {
547        if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
548            tcp.set_data_offset(10); }
550    }
551
552    if let Some(tcp) = TcpPacket::new(&buf[..]) {
553        let _options = tcp.get_options_iter(); }
555}
556
557#[test]
558fn tcp_test_options_vec_invalid_offset() {
559    let mut buf = [0; 20]; {
561        if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
562            tcp.set_data_offset(10); }
564    }
565
566    if let Some(tcp) = TcpPacket::new(&buf[..]) {
567        let _options = tcp.get_options(); }
569}
570
571#[test]
572fn tcp_test_options_slice_invalid_offset() {
573    let mut buf = [0; 20]; {
575        if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
576            tcp.set_data_offset(10); }
578    }
579
580    if let Some(tcp) = TcpPacket::new(&buf[..]) {
581        let _options = tcp.get_options_raw(); }
583}
584
585#[test]
586fn tcp_test_option_invalid_len() {
587    use std::println;
588    let mut buf = [0; 24];
589    {
590        if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
591            tcp.set_data_offset(6);
592        }
593        buf[20] = 2; buf[21] = 8; }
596
597    if let Some(tcp) = TcpPacket::new(&buf[..]) {
598        let options = tcp.get_options_iter();
599        for opt in options {
600            println!("{:?}", opt);
601        }
602    }
603}
604
605#[test]
606fn tcp_test_payload_slice_invalid_offset() {
607    let mut buf = [0; 20];
608    {
609        if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
610            tcp.set_data_offset(10); }
612    }
613
614    if let Some(tcp) = TcpPacket::new(&buf[..]) {
615        assert_eq!(tcp.payload().len(), 0);
616    }
617}