1use crate::TcpLayer;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5use super::key::FlowDirection;
6use super::tcp_reassembly::TcpReassembler;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpConnectionState {
11 Listen,
12 SynSent,
13 SynRcvd,
14 Established,
15 FinWait1,
16 FinWait2,
17 CloseWait,
18 Closing,
19 LastAck,
20 TimeWait,
21 Closed,
22}
23
24impl TcpConnectionState {
25 #[must_use]
27 pub fn name(&self) -> &'static str {
28 match self {
29 Self::Listen => "LISTEN",
30 Self::SynSent => "SYN_SENT",
31 Self::SynRcvd => "SYN_RCVD",
32 Self::Established => "ESTABLISHED",
33 Self::FinWait1 => "FIN_WAIT_1",
34 Self::FinWait2 => "FIN_WAIT_2",
35 Self::CloseWait => "CLOSE_WAIT",
36 Self::Closing => "CLOSING",
37 Self::LastAck => "LAST_ACK",
38 Self::TimeWait => "TIME_WAIT",
39 Self::Closed => "CLOSED",
40 }
41 }
42
43 #[must_use]
45 pub fn is_closed(&self) -> bool {
46 matches!(self, Self::Closed | Self::TimeWait)
47 }
48
49 #[must_use]
51 pub fn is_half_open(&self) -> bool {
52 matches!(self, Self::Listen | Self::SynSent | Self::SynRcvd)
53 }
54}
55
56impl std::fmt::Display for TcpConnectionState {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.write_str(self.name())
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct TcpEndpointState {
65 pub next_expected_seq: u32,
67 pub last_ack: u32,
69 pub window_size: u16,
71 pub initial_seq: Option<u32>,
73}
74
75impl TcpEndpointState {
76 #[must_use]
77 pub fn new() -> Self {
78 Self {
79 next_expected_seq: 0,
80 last_ack: 0,
81 window_size: 0,
82 initial_seq: None,
83 }
84 }
85}
86
87impl Default for TcpEndpointState {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93#[derive(Debug)]
96pub struct TcpConversationState {
97 pub conn_state: TcpConnectionState,
99 pub forward_endpoint: TcpEndpointState,
101 pub reverse_endpoint: TcpEndpointState,
103 pub reassembler_fwd: TcpReassembler,
105 pub reassembler_rev: TcpReassembler,
107 pub dropped_segments_fwd: u64,
109 pub dropped_segments_rev: u64,
111}
112
113impl TcpConversationState {
114 #[must_use]
115 pub fn new() -> Self {
116 Self {
117 conn_state: TcpConnectionState::Listen,
118 forward_endpoint: TcpEndpointState::new(),
119 reverse_endpoint: TcpEndpointState::new(),
120 reassembler_fwd: TcpReassembler::new(),
121 reassembler_rev: TcpReassembler::new(),
122 dropped_segments_fwd: 0,
123 dropped_segments_rev: 0,
124 }
125 }
126
127 #[must_use]
129 pub fn total_dropped_segments(&self) -> u64 {
130 self.dropped_segments_fwd + self.dropped_segments_rev
131 }
132
133 pub fn process_packet(
139 &mut self,
140 direction: FlowDirection,
141 tcp: &TcpLayer,
142 buf: &[u8],
143 config: &FlowConfig,
144 ) -> Result<(), FlowError> {
145 let flags = tcp
146 .flags(buf)
147 .map_err(|e| FlowError::PacketError(e.into()))?;
148 let seq = tcp.seq(buf).map_err(|e| FlowError::PacketError(e.into()))?;
149 let ack = tcp.ack(buf).map_err(|e| FlowError::PacketError(e.into()))?;
150 let window = tcp
151 .window(buf)
152 .map_err(|e| FlowError::PacketError(e.into()))?;
153
154 let data_offset = tcp
156 .data_offset(buf)
157 .map_err(|e| FlowError::PacketError(e.into()))?;
158 let header_bytes = (data_offset as usize) * 4;
159 let tcp_start = tcp.index.start;
160 let payload_start = tcp_start + header_bytes;
164 let payload = if payload_start < buf.len() {
165 &buf[payload_start..buf.len()]
166 } else {
167 &[]
168 };
169
170 let (sender, _receiver, reassembler) = match direction {
172 FlowDirection::Forward => (
173 &mut self.forward_endpoint,
174 &mut self.reverse_endpoint,
175 &mut self.reassembler_fwd,
176 ),
177 FlowDirection::Reverse => (
178 &mut self.reverse_endpoint,
179 &mut self.forward_endpoint,
180 &mut self.reassembler_rev,
181 ),
182 };
183
184 sender.window_size = window;
186
187 if flags.rst {
189 self.conn_state = TcpConnectionState::Closed;
190 return Ok(());
191 }
192
193 match self.conn_state {
194 TcpConnectionState::Listen => {
195 if flags.syn && !flags.ack {
196 sender.initial_seq = Some(seq);
198 sender.next_expected_seq = seq.wrapping_add(1); self.conn_state = TcpConnectionState::SynSent;
200 }
201 },
202 TcpConnectionState::SynSent => {
203 if flags.syn && flags.ack {
204 sender.initial_seq = Some(seq);
206 sender.next_expected_seq = seq.wrapping_add(1);
207 sender.last_ack = ack;
208 self.conn_state = TcpConnectionState::SynRcvd;
209 }
210 },
211 TcpConnectionState::SynRcvd => {
212 if flags.ack && !flags.syn {
213 sender.last_ack = ack;
215 self.conn_state = TcpConnectionState::Established;
216 if !self.reassembler_fwd.is_initialized()
218 && let Some(isn) = self.forward_endpoint.initial_seq
219 {
220 self.reassembler_fwd.initialize(isn.wrapping_add(1));
221 }
222 if !self.reassembler_rev.is_initialized()
223 && let Some(isn) = self.reverse_endpoint.initial_seq
224 {
225 self.reassembler_rev.initialize(isn.wrapping_add(1));
226 }
227 }
228 },
229 TcpConnectionState::Established => {
230 sender.last_ack = ack;
231
232 if !payload.is_empty() {
234 if let Err(e) = reassembler.process_segment(seq, payload, config) {
235 match direction {
237 FlowDirection::Forward => self.dropped_segments_fwd += 1,
238 FlowDirection::Reverse => self.dropped_segments_rev += 1,
239 }
240 let total = self.dropped_segments_fwd + self.dropped_segments_rev;
242 if total == 1 || total.is_power_of_two() {
243 eprintln!(
244 "[!] stackforge: TCP reassembly dropped segment ({e}), {total} total drops for this flow"
245 );
246 }
247 }
248 }
249
250 if flags.fin {
251 sender.next_expected_seq =
252 seq.wrapping_add(payload.len() as u32).wrapping_add(1); match direction {
254 FlowDirection::Forward => {
255 self.conn_state = TcpConnectionState::FinWait1;
256 },
257 FlowDirection::Reverse => {
258 self.conn_state = TcpConnectionState::CloseWait;
259 },
260 }
261 } else {
262 sender.next_expected_seq = seq.wrapping_add(payload.len() as u32);
263 }
264 },
265 TcpConnectionState::FinWait1 => {
266 if flags.fin && flags.ack {
267 self.conn_state = TcpConnectionState::TimeWait;
269 } else if flags.ack {
270 self.conn_state = TcpConnectionState::FinWait2;
271 } else if flags.fin {
272 self.conn_state = TcpConnectionState::Closing;
273 }
274 },
275 TcpConnectionState::FinWait2 => {
276 if flags.fin {
277 self.conn_state = TcpConnectionState::TimeWait;
278 }
279 },
280 TcpConnectionState::CloseWait => {
281 if flags.fin {
282 self.conn_state = TcpConnectionState::LastAck;
283 }
284 },
285 TcpConnectionState::Closing => {
286 if flags.ack {
287 self.conn_state = TcpConnectionState::TimeWait;
288 }
289 },
290 TcpConnectionState::LastAck => {
291 if flags.ack {
292 self.conn_state = TcpConnectionState::Closed;
293 }
294 },
295 TcpConnectionState::TimeWait | TcpConnectionState::Closed => {
296 },
298 }
299
300 Ok(())
301 }
302}
303
304impl Default for TcpConversationState {
305 fn default() -> Self {
306 Self::new()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::layer::stack::{LayerStack, LayerStackEntry};
314 use crate::{EthernetBuilder, Ipv4Builder, TcpBuilder};
315
316 fn make_tcp_packet(
317 src_port: u16,
318 dst_port: u16,
319 seq: u32,
320 ack_num: u32,
321 flags: &str,
322 payload: &[u8],
323 ) -> crate::Packet {
324 let mut builder = TcpBuilder::new()
325 .src_port(src_port)
326 .dst_port(dst_port)
327 .seq(seq)
328 .ack_num(ack_num)
329 .window(65535);
330
331 for c in flags.chars() {
332 builder = match c {
333 'S' => builder.syn(),
334 'A' => builder.ack(),
335 'F' => builder.fin(),
336 'R' => builder.rst(),
337 'P' => builder.psh(),
338 _ => builder,
339 };
340 }
341
342 let stack = LayerStack::new()
343 .push(LayerStackEntry::Ethernet(
344 EthernetBuilder::new()
345 .dst(crate::MacAddress::BROADCAST)
346 .src(crate::MacAddress::new([0, 1, 2, 3, 4, 5])),
347 ))
348 .push(LayerStackEntry::Ipv4(
349 Ipv4Builder::new()
350 .src(std::net::Ipv4Addr::new(10, 0, 0, 1))
351 .dst(std::net::Ipv4Addr::new(10, 0, 0, 2)),
352 ))
353 .push(LayerStackEntry::Tcp(builder));
354
355 let stack = if !payload.is_empty() {
356 stack.push(LayerStackEntry::Raw(payload.to_vec()))
357 } else {
358 stack
359 };
360
361 stack.build_packet()
362 }
363
364 fn get_tcp_and_buf(pkt: &crate::Packet) -> (TcpLayer, &[u8]) {
365 let tcp = pkt.tcp().unwrap();
366 let buf = pkt.as_bytes();
367 (tcp, buf)
368 }
369
370 #[test]
371 fn test_three_way_handshake() {
372 let config = FlowConfig::default();
373 let mut state = TcpConversationState::new();
374
375 let pkt = make_tcp_packet(12345, 80, 1000, 0, "S", &[]);
377 let (tcp, buf) = get_tcp_and_buf(&pkt);
378 state
379 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
380 .unwrap();
381 assert_eq!(state.conn_state, TcpConnectionState::SynSent);
382
383 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "SA", &[]);
385 let (tcp, buf) = get_tcp_and_buf(&pkt);
386 state
387 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
388 .unwrap();
389 assert_eq!(state.conn_state, TcpConnectionState::SynRcvd);
390
391 let pkt = make_tcp_packet(12345, 80, 1001, 2001, "A", &[]);
393 let (tcp, buf) = get_tcp_and_buf(&pkt);
394 state
395 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
396 .unwrap();
397 assert_eq!(state.conn_state, TcpConnectionState::Established);
398 }
399
400 #[test]
401 fn test_rst_closes_connection() {
402 let config = FlowConfig::default();
403 let mut state = TcpConversationState::new();
404 state.conn_state = TcpConnectionState::Established;
405
406 let pkt = make_tcp_packet(12345, 80, 1000, 0, "R", &[]);
407 let (tcp, buf) = get_tcp_and_buf(&pkt);
408 state
409 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
410 .unwrap();
411 assert_eq!(state.conn_state, TcpConnectionState::Closed);
412 }
413
414 #[test]
415 fn test_fin_handshake() {
416 let config = FlowConfig::default();
417 let mut state = TcpConversationState::new();
418 state.conn_state = TcpConnectionState::Established;
419
420 let pkt = make_tcp_packet(12345, 80, 1000, 2000, "FA", &[]);
422 let (tcp, buf) = get_tcp_and_buf(&pkt);
423 state
424 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
425 .unwrap();
426 assert_eq!(state.conn_state, TcpConnectionState::FinWait1);
427
428 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "A", &[]);
430 let (tcp, buf) = get_tcp_and_buf(&pkt);
431 state
432 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
433 .unwrap();
434 assert_eq!(state.conn_state, TcpConnectionState::FinWait2);
435
436 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "FA", &[]);
438 let (tcp, buf) = get_tcp_and_buf(&pkt);
439 state
440 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
441 .unwrap();
442 assert_eq!(state.conn_state, TcpConnectionState::TimeWait);
443 }
444
445 #[test]
446 fn test_data_transfer_and_reassembly() {
447 let config = FlowConfig::default();
448 let mut state = TcpConversationState::new();
449 state.conn_state = TcpConnectionState::Established;
450
451 state.forward_endpoint.initial_seq = Some(999);
453 state.reassembler_fwd.initialize(1000);
454
455 let pkt = make_tcp_packet(12345, 80, 1000, 2000, "A", b"GET /");
457 let (tcp, buf) = get_tcp_and_buf(&pkt);
458 state
459 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
460 .unwrap();
461
462 assert_eq!(state.reassembler_fwd.reassembled_data(), b"GET /");
463 }
464
465 #[test]
466 fn test_state_display() {
467 assert_eq!(TcpConnectionState::Established.name(), "ESTABLISHED");
468 assert_eq!(TcpConnectionState::SynSent.name(), "SYN_SENT");
469 assert!(TcpConnectionState::Closed.is_closed());
470 assert!(TcpConnectionState::TimeWait.is_closed());
471 assert!(TcpConnectionState::SynSent.is_half_open());
472 assert!(!TcpConnectionState::Established.is_half_open());
473 }
474}