1use crate::ip::IpNextLevelProtocol;
4use crate::Packet;
5use crate::PrimitiveValues;
6
7use alloc::{vec, vec::Vec};
8
9use xenet_macro::packet;
10use xenet_macro_helper::types::*;
11
12use crate::util::{self, Octets};
13use std::net::Ipv4Addr;
14use std::net::Ipv6Addr;
15
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18
19pub const TCP_HEADER_LEN: usize = MutableTcpPacket::minimum_packet_size();
21pub const TCP_MIN_DATA_OFFSET: u8 = 5;
23pub const TCP_OPTION_MAX_LEN: usize = 40;
25pub const TCP_HEADER_MAX_LEN: usize = TCP_HEADER_LEN + TCP_OPTION_MAX_LEN;
27
28#[derive(Clone, Debug, PartialEq, Eq)]
30#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
31pub struct TcpOptionHeader {
32 pub kind: TcpOptionKind,
33 pub length: Option<u8>,
34 pub data: Vec<u8>,
35}
36
37impl TcpOptionHeader {
38 pub fn get_timestamp(&self) -> (u32, u32) {
40 if self.kind == TcpOptionKind::TIMESTAMPS && self.data.len() >= 8 {
41 let mut my: [u8; 4] = [0; 4];
42 my.copy_from_slice(&self.data[0..4]);
43 let mut their: [u8; 4] = [0; 4];
44 their.copy_from_slice(&self.data[4..8]);
45 (u32::from_be_bytes(my), u32::from_be_bytes(their))
46 } else {
47 return (0, 0);
48 }
49 }
50 pub fn get_mss(&self) -> u16 {
52 if self.kind == TcpOptionKind::MSS && self.data.len() >= 2 {
53 let mut mss: [u8; 2] = [0; 2];
54 mss.copy_from_slice(&self.data[0..2]);
55 u16::from_be_bytes(mss)
56 } else {
57 0
58 }
59 }
60 pub fn get_wscale(&self) -> u8 {
62 if self.kind == TcpOptionKind::WSCALE && self.data.len() > 0 {
63 self.data[0]
64 } else {
65 0
66 }
67 }
68}
69
70#[derive(Clone, Debug, PartialEq, Eq)]
72#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
73pub struct TcpHeader {
74 pub source: u16be,
75 pub destination: u16be,
76 pub sequence: u32be,
77 pub acknowledgement: u32be,
78 pub data_offset: u4,
79 pub reserved: u4,
80 pub flags: u8,
81 pub window: u16be,
82 pub checksum: u16be,
83 pub urgent_ptr: u16be,
84 pub options: Vec<TcpOptionHeader>,
85}
86
87impl TcpHeader {
88 pub fn from_bytes(packet: &[u8]) -> Result<TcpHeader, String> {
90 if packet.len() < TCP_HEADER_LEN {
91 return Err("Packet is too small for TCP header".to_string());
92 }
93 match TcpPacket::new(packet) {
94 Some(tcp_packet) => Ok(TcpHeader {
95 source: tcp_packet.get_source(),
96 destination: tcp_packet.get_destination(),
97 sequence: tcp_packet.get_sequence(),
98 acknowledgement: tcp_packet.get_acknowledgement(),
99 data_offset: tcp_packet.get_data_offset(),
100 reserved: tcp_packet.get_reserved(),
101 flags: tcp_packet.get_flags(),
102 window: tcp_packet.get_window(),
103 checksum: tcp_packet.get_checksum(),
104 urgent_ptr: tcp_packet.get_urgent_ptr(),
105 options: tcp_packet
106 .get_options_iter()
107 .map(|opt| TcpOptionHeader {
108 kind: opt.get_kind(),
109 length: opt.get_length_raw().first().cloned(),
110 data: opt.payload().to_vec(),
111 })
112 .collect(),
113 }),
114 None => Err("Failed to parse TCP packet".to_string()),
115 }
116 }
117 pub(crate) fn from_packet(tcp_packet: &TcpPacket) -> TcpHeader {
119 TcpHeader {
120 source: tcp_packet.get_source(),
121 destination: tcp_packet.get_destination(),
122 sequence: tcp_packet.get_sequence(),
123 acknowledgement: tcp_packet.get_acknowledgement(),
124 data_offset: tcp_packet.get_data_offset(),
125 reserved: tcp_packet.get_reserved(),
126 flags: tcp_packet.get_flags(),
127 window: tcp_packet.get_window(),
128 checksum: tcp_packet.get_checksum(),
129 urgent_ptr: tcp_packet.get_urgent_ptr(),
130 options: tcp_packet
131 .get_options_iter()
132 .map(|opt| TcpOptionHeader {
133 kind: opt.get_kind(),
134 length: opt.get_length_raw().first().cloned(),
135 data: opt.payload().to_vec(),
136 })
137 .collect(),
138 }
139 }
140}
141
142#[allow(non_snake_case)]
145#[allow(non_upper_case_globals)]
146pub mod TcpFlags {
147 pub const CWR: u8 = 0b10000000;
151 pub const ECE: u8 = 0b01000000;
157 pub const URG: u8 = 0b00100000;
159 pub const ACK: u8 = 0b00010000;
162 pub const PSH: u8 = 0b00001000;
164 pub const RST: u8 = 0b00000100;
166 pub const SYN: u8 = 0b00000010;
169 pub const FIN: u8 = 0b00000001;
171}
172
173#[packet]
175pub struct Tcp {
176 pub source: u16be,
177 pub destination: u16be,
178 pub sequence: u32be,
179 pub acknowledgement: u32be,
180 pub data_offset: u4,
181 pub reserved: u4,
182 pub flags: u8,
183 pub window: u16be,
184 pub checksum: u16be,
185 pub urgent_ptr: u16be,
186 #[length_fn = "tcp_options_length"]
187 pub options: Vec<TcpOption>,
188 #[payload]
189 pub payload: Vec<u8>,
190}
191
192#[allow(non_camel_case_types)]
195#[repr(u8)]
196#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
197#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
198pub enum TcpOptionKind {
199 EOL = 0,
200 NOP = 1,
201 MSS = 2,
202 WSCALE = 3,
203 SACK_PERMITTED = 4,
204 SACK = 5,
205 TIMESTAMPS = 8,
206}
207
208impl TcpOptionKind {
209 pub fn new(n: u8) -> TcpOptionKind {
211 match n {
212 0 => TcpOptionKind::EOL,
213 1 => TcpOptionKind::NOP,
214 2 => TcpOptionKind::MSS,
215 3 => TcpOptionKind::WSCALE,
216 4 => TcpOptionKind::SACK_PERMITTED,
217 5 => TcpOptionKind::SACK,
218 8 => TcpOptionKind::TIMESTAMPS,
219 _ => panic!("Unknown TCP option kind: {}", n),
220 }
221 }
222 pub fn name(&self) -> String {
224 match *self {
225 TcpOptionKind::EOL => String::from("EOL"),
226 TcpOptionKind::NOP => String::from("NOP"),
227 TcpOptionKind::MSS => String::from("MSS"),
228 TcpOptionKind::WSCALE => String::from("WSCALE"),
229 TcpOptionKind::SACK_PERMITTED => String::from("SACK_PERMITTED"),
230 TcpOptionKind::SACK => String::from("SACK"),
231 TcpOptionKind::TIMESTAMPS => String::from("TIMESTAMPS"),
232 }
233 }
234 pub fn size(&self) -> usize {
236 match *self {
237 TcpOptionKind::EOL => 1,
238 TcpOptionKind::NOP => 1,
239 TcpOptionKind::MSS => 4,
240 TcpOptionKind::WSCALE => 3,
241 TcpOptionKind::SACK_PERMITTED => 2,
242 TcpOptionKind::SACK => 10,
243 TcpOptionKind::TIMESTAMPS => 10,
244 }
245 }
246}
247
248impl PrimitiveValues for TcpOptionKind {
249 type T = (u8,);
250 fn to_primitive_values(&self) -> (u8,) {
251 (*self as u8,)
252 }
253}
254
255#[packet]
257pub struct TcpOption {
258 #[construct_with(u8)]
259 kind: TcpOptionKind,
260 #[length_fn = "tcp_option_length"]
261 length: Vec<u8>,
264 #[length_fn = "tcp_option_payload_length"]
265 #[payload]
266 data: Vec<u8>,
267}
268
269impl TcpOption {
270 pub fn nop() -> Self {
272 TcpOption {
273 kind: TcpOptionKind::NOP,
274 length: vec![],
275 data: vec![],
276 }
277 }
278
279 pub fn timestamp(my: u32, their: u32) -> Self {
283 let mut data = vec![];
284 data.extend_from_slice(&my.octets()[..]);
285 data.extend_from_slice(&their.octets()[..]);
286
287 TcpOption {
288 kind: TcpOptionKind::TIMESTAMPS,
289 length: vec![10],
290 data: data,
291 }
292 }
293
294 pub fn mss(val: u16) -> Self {
297 let mut data = vec![];
298 data.extend_from_slice(&val.octets()[..]);
299
300 TcpOption {
301 kind: TcpOptionKind::MSS,
302 length: vec![4],
303 data: data,
304 }
305 }
306
307 pub fn wscale(val: u8) -> Self {
310 TcpOption {
311 kind: TcpOptionKind::WSCALE,
312 length: vec![3],
313 data: vec![val],
314 }
315 }
316
317 pub fn sack_perm() -> Self {
321 TcpOption {
322 kind: TcpOptionKind::SACK_PERMITTED,
323 length: vec![2],
324 data: vec![],
325 }
326 }
327
328 pub fn selective_ack(acks: &[u32]) -> Self {
333 let mut data = vec![];
334 for ack in acks {
335 data.extend_from_slice(&ack.octets()[..]);
336 }
337 TcpOption {
338 kind: TcpOptionKind::SACK,
339 length: vec![1 + 1 + data.len() as u8],
340 data: data,
341 }
342 }
343 pub fn kind(&self) -> TcpOptionKind {
345 self.kind
346 }
347 pub fn length(&self) -> u8 {
349 if self.length.is_empty() {
350 0
351 } else {
352 self.length[0]
353 }
354 }
355 pub fn get_timestamp(&self) -> (u32, u32) {
357 if self.kind == TcpOptionKind::TIMESTAMPS && self.data.len() >= 8 {
358 let mut my: [u8; 4] = [0; 4];
359 my.copy_from_slice(&self.data[0..4]);
360 let mut their: [u8; 4] = [0; 4];
361 their.copy_from_slice(&self.data[4..8]);
362 (u32::from_be_bytes(my), u32::from_be_bytes(their))
363 } else {
364 return (0, 0);
365 }
366 }
367 pub fn get_mss(&self) -> u16 {
369 if self.kind == TcpOptionKind::MSS && self.data.len() >= 2 {
370 let mut mss: [u8; 2] = [0; 2];
371 mss.copy_from_slice(&self.data[0..2]);
372 u16::from_be_bytes(mss)
373 } else {
374 0
375 }
376 }
377 pub fn get_wscale(&self) -> u8 {
379 if self.kind == TcpOptionKind::WSCALE && self.data.len() > 0 {
380 self.data[0]
381 } else {
382 0
383 }
384 }
385}
386
387#[inline]
391fn tcp_option_length(option: &TcpOptionPacket) -> usize {
392 match option.get_kind() {
393 TcpOptionKind::EOL => 0,
394 TcpOptionKind::NOP => 0,
395 _ => 1,
396 }
397}
398
399fn tcp_option_payload_length(ipv4_option: &TcpOptionPacket) -> usize {
400 match ipv4_option.get_length_raw().first() {
401 Some(len) if *len >= 2 => *len as usize - 2,
402 _ => 0,
403 }
404}
405
406#[inline]
407fn tcp_options_length(tcp: &TcpPacket) -> usize {
408 let data_offset = tcp.get_data_offset();
409
410 if data_offset > 5 {
411 data_offset as usize * 4 - 20
412 } else {
413 0
414 }
415}
416
417pub fn ipv4_checksum(packet: &TcpPacket, source: &Ipv4Addr, destination: &Ipv4Addr) -> u16 {
419 ipv4_checksum_adv(packet, &[], source, destination)
420}
421
422pub fn ipv4_checksum_adv(
430 packet: &TcpPacket,
431 extra_data: &[u8],
432 source: &Ipv4Addr,
433 destination: &Ipv4Addr,
434) -> u16 {
435 util::ipv4_checksum(
436 packet.packet(),
437 8,
438 extra_data,
439 source,
440 destination,
441 IpNextLevelProtocol::Tcp,
442 )
443}
444
445pub fn ipv6_checksum(packet: &TcpPacket, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16 {
447 ipv6_checksum_adv(packet, &[], source, destination)
448}
449
450pub fn ipv6_checksum_adv(
458 packet: &TcpPacket,
459 extra_data: &[u8],
460 source: &Ipv6Addr,
461 destination: &Ipv6Addr,
462) -> u16 {
463 util::ipv6_checksum(
464 packet.packet(),
465 8,
466 extra_data,
467 source,
468 destination,
469 IpNextLevelProtocol::Tcp,
470 )
471}
472
473#[test]
474fn tcp_header_ipv4_test() {
475 use crate::ip::IpNextLevelProtocol;
476 use crate::ipv4::MutableIpv4Packet;
477
478 const IPV4_HEADER_LEN: usize = 20;
479 const TCP_HEADER_LEN: usize = 32;
480 const TEST_DATA_LEN: usize = 4;
481
482 let mut packet = [0u8; IPV4_HEADER_LEN + TCP_HEADER_LEN + TEST_DATA_LEN];
483 let ipv4_source = Ipv4Addr::new(192, 168, 2, 1);
484 let ipv4_destination = Ipv4Addr::new(192, 168, 111, 51);
485 {
486 let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap();
487 ip_header.set_next_level_protocol(IpNextLevelProtocol::Tcp);
488 ip_header.set_source(ipv4_source);
489 ip_header.set_destination(ipv4_destination);
490 }
491
492 packet[IPV4_HEADER_LEN + TCP_HEADER_LEN] = 't' as u8;
494 packet[IPV4_HEADER_LEN + TCP_HEADER_LEN + 1] = 'e' as u8;
495 packet[IPV4_HEADER_LEN + TCP_HEADER_LEN + 2] = 's' as u8;
496 packet[IPV4_HEADER_LEN + TCP_HEADER_LEN + 3] = 't' as u8;
497
498 {
499 let mut tcp_header = MutableTcpPacket::new(&mut packet[IPV4_HEADER_LEN..]).unwrap();
500 tcp_header.set_source(49511);
501 assert_eq!(tcp_header.get_source(), 49511);
502
503 tcp_header.set_destination(9000);
504 assert_eq!(tcp_header.get_destination(), 9000);
505
506 tcp_header.set_sequence(0x9037d2b8);
507 assert_eq!(tcp_header.get_sequence(), 0x9037d2b8);
508
509 tcp_header.set_acknowledgement(0x944bb276);
510 assert_eq!(tcp_header.get_acknowledgement(), 0x944bb276);
511
512 tcp_header.set_flags(TcpFlags::PSH | TcpFlags::ACK);
513 assert_eq!(tcp_header.get_flags(), TcpFlags::PSH | TcpFlags::ACK);
514
515 tcp_header.set_window(4015);
516 assert_eq!(tcp_header.get_window(), 4015);
517
518 tcp_header.set_data_offset(8);
519 assert_eq!(tcp_header.get_data_offset(), 8);
520
521 let ts = TcpOption::timestamp(743951781, 44056978);
522 tcp_header.set_options(&vec![TcpOption::nop(), TcpOption::nop(), ts]);
523
524 let checksum = ipv4_checksum(&tcp_header.to_immutable(), &ipv4_source, &ipv4_destination);
525 tcp_header.set_checksum(checksum);
526 assert_eq!(tcp_header.get_checksum(), 0xc031);
527 }
528 let ref_packet = [
529 0xc1, 0x67, 0x23, 0x28, 0x90, 0x37, 0xd2, 0xb8, 0x94, 0x4b, 0xb2, 0x76, 0x80, 0x18, 0x0f, 0xaf, 0xc0, 0x31, 0x00, 0x00, 0x01, 0x01, 0x08, 0x0a, 0x2c, 0x57, 0xcd, 0xa5, 0x02, 0xa0, 0x41, 0x92, 0x74, 0x65, 0x73, 0x74, ];
540 assert_eq!(&ref_packet[..], &packet[20..]);
541}
542
543#[test]
544fn tcp_test_options_invalid_offset() {
545 let mut buf = [0; 20]; {
547 if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
548 tcp.set_data_offset(10); }
550 }
551
552 if let Some(tcp) = TcpPacket::new(&buf[..]) {
553 let _options = tcp.get_options_iter(); }
555}
556
557#[test]
558fn tcp_test_options_vec_invalid_offset() {
559 let mut buf = [0; 20]; {
561 if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
562 tcp.set_data_offset(10); }
564 }
565
566 if let Some(tcp) = TcpPacket::new(&buf[..]) {
567 let _options = tcp.get_options(); }
569}
570
571#[test]
572fn tcp_test_options_slice_invalid_offset() {
573 let mut buf = [0; 20]; {
575 if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
576 tcp.set_data_offset(10); }
578 }
579
580 if let Some(tcp) = TcpPacket::new(&buf[..]) {
581 let _options = tcp.get_options_raw(); }
583}
584
585#[test]
586fn tcp_test_option_invalid_len() {
587 use std::println;
588 let mut buf = [0; 24];
589 {
590 if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
591 tcp.set_data_offset(6);
592 }
593 buf[20] = 2; buf[21] = 8; }
596
597 if let Some(tcp) = TcpPacket::new(&buf[..]) {
598 let options = tcp.get_options_iter();
599 for opt in options {
600 println!("{:?}", opt);
601 }
602 }
603}
604
605#[test]
606fn tcp_test_payload_slice_invalid_offset() {
607 let mut buf = [0; 20];
608 {
609 if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) {
610 tcp.set_data_offset(10); }
612 }
613
614 if let Some(tcp) = TcpPacket::new(&buf[..]) {
615 assert_eq!(tcp.payload().len(), 0);
616 }
617}