Skip to main content

stackforge_core/flow/
tcp_state.rs

1use crate::TcpLayer;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5use super::key::FlowDirection;
6use super::tcp_reassembly::TcpReassembler;
7
8/// TCP connection states per RFC 793.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpConnectionState {
11    Listen,
12    SynSent,
13    SynRcvd,
14    Established,
15    FinWait1,
16    FinWait2,
17    CloseWait,
18    Closing,
19    LastAck,
20    TimeWait,
21    Closed,
22}
23
24impl TcpConnectionState {
25    /// Human-readable state name.
26    #[must_use]
27    pub fn name(&self) -> &'static str {
28        match self {
29            Self::Listen => "LISTEN",
30            Self::SynSent => "SYN_SENT",
31            Self::SynRcvd => "SYN_RCVD",
32            Self::Established => "ESTABLISHED",
33            Self::FinWait1 => "FIN_WAIT_1",
34            Self::FinWait2 => "FIN_WAIT_2",
35            Self::CloseWait => "CLOSE_WAIT",
36            Self::Closing => "CLOSING",
37            Self::LastAck => "LAST_ACK",
38            Self::TimeWait => "TIME_WAIT",
39            Self::Closed => "CLOSED",
40        }
41    }
42
43    /// Whether this is a terminal/closed state.
44    #[must_use]
45    pub fn is_closed(&self) -> bool {
46        matches!(self, Self::Closed | Self::TimeWait)
47    }
48
49    /// Whether this is a half-open state (not yet established).
50    #[must_use]
51    pub fn is_half_open(&self) -> bool {
52        matches!(self, Self::Listen | Self::SynSent | Self::SynRcvd)
53    }
54}
55
56impl std::fmt::Display for TcpConnectionState {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.write_str(self.name())
59    }
60}
61
62/// Per-endpoint sequence tracking state.
63#[derive(Debug, Clone)]
64pub struct TcpEndpointState {
65    /// Next expected sequence number from this endpoint.
66    pub next_expected_seq: u32,
67    /// Last acknowledged sequence number from this endpoint.
68    pub last_ack: u32,
69    /// Advertised receive window size.
70    pub window_size: u16,
71    /// Initial sequence number (set on SYN).
72    pub initial_seq: Option<u32>,
73}
74
75impl TcpEndpointState {
76    #[must_use]
77    pub fn new() -> Self {
78        Self {
79            next_expected_seq: 0,
80            last_ack: 0,
81            window_size: 0,
82            initial_seq: None,
83        }
84    }
85}
86
87impl Default for TcpEndpointState {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93/// Complete TCP conversation state including connection tracking,
94/// per-endpoint sequence state, and stream reassembly.
95#[derive(Debug)]
96pub struct TcpConversationState {
97    /// Current connection state (RFC 793 state machine).
98    pub conn_state: TcpConnectionState,
99    /// Sequence tracking for the forward direction (`addr_a` → `addr_b`).
100    pub forward_endpoint: TcpEndpointState,
101    /// Sequence tracking for the reverse direction (`addr_b` → `addr_a`).
102    pub reverse_endpoint: TcpEndpointState,
103    /// Stream reassembly for forward direction.
104    pub reassembler_fwd: TcpReassembler,
105    /// Stream reassembly for reverse direction.
106    pub reassembler_rev: TcpReassembler,
107    /// Number of forward segments dropped due to buffer/fragment limits.
108    pub dropped_segments_fwd: u64,
109    /// Number of reverse segments dropped due to buffer/fragment limits.
110    pub dropped_segments_rev: u64,
111}
112
113impl TcpConversationState {
114    #[must_use]
115    pub fn new() -> Self {
116        Self {
117            conn_state: TcpConnectionState::Listen,
118            forward_endpoint: TcpEndpointState::new(),
119            reverse_endpoint: TcpEndpointState::new(),
120            reassembler_fwd: TcpReassembler::new(),
121            reassembler_rev: TcpReassembler::new(),
122            dropped_segments_fwd: 0,
123            dropped_segments_rev: 0,
124        }
125    }
126
127    /// Total dropped segments across both directions.
128    #[must_use]
129    pub fn total_dropped_segments(&self) -> u64 {
130        self.dropped_segments_fwd + self.dropped_segments_rev
131    }
132
133    /// Process a TCP packet, updating connection state and reassembly buffers.
134    ///
135    /// `direction` indicates whether this packet is Forward (`addr_a` → `addr_b`)
136    /// or Reverse (`addr_b` → `addr_a`) relative to the canonical key.
137    /// `tcp` is the TCP layer view, `buf` is the full packet buffer.
138    pub fn process_packet(
139        &mut self,
140        direction: FlowDirection,
141        tcp: &TcpLayer,
142        buf: &[u8],
143        config: &FlowConfig,
144    ) -> Result<(), FlowError> {
145        let flags = tcp
146            .flags(buf)
147            .map_err(|e| FlowError::PacketError(e.into()))?;
148        let seq = tcp.seq(buf).map_err(|e| FlowError::PacketError(e.into()))?;
149        let ack = tcp.ack(buf).map_err(|e| FlowError::PacketError(e.into()))?;
150        let window = tcp
151            .window(buf)
152            .map_err(|e| FlowError::PacketError(e.into()))?;
153
154        // Determine payload boundaries
155        let data_offset = tcp
156            .data_offset(buf)
157            .map_err(|e| FlowError::PacketError(e.into()))?;
158        let header_bytes = (data_offset as usize) * 4;
159        let tcp_start = tcp.index.start;
160        // TCP payload starts after the TCP header. Since the TCP layer's
161        // index.end marks the header boundary, the payload is everything
162        // from header end to the end of the packet buffer.
163        let payload_start = tcp_start + header_bytes;
164        let payload = if payload_start < buf.len() {
165            &buf[payload_start..buf.len()]
166        } else {
167            &[]
168        };
169
170        // Get mutable refs to endpoint and reassembler for this direction
171        let (sender, _receiver, reassembler) = match direction {
172            FlowDirection::Forward => (
173                &mut self.forward_endpoint,
174                &mut self.reverse_endpoint,
175                &mut self.reassembler_fwd,
176            ),
177            FlowDirection::Reverse => (
178                &mut self.reverse_endpoint,
179                &mut self.forward_endpoint,
180                &mut self.reassembler_rev,
181            ),
182        };
183
184        // Update endpoint state
185        sender.window_size = window;
186
187        // State machine transitions
188        if flags.rst {
189            self.conn_state = TcpConnectionState::Closed;
190            return Ok(());
191        }
192
193        match self.conn_state {
194            TcpConnectionState::Listen => {
195                if flags.syn && !flags.ack {
196                    // SYN from initiator
197                    sender.initial_seq = Some(seq);
198                    sender.next_expected_seq = seq.wrapping_add(1); // SYN consumes 1 seq
199                    self.conn_state = TcpConnectionState::SynSent;
200                }
201            },
202            TcpConnectionState::SynSent => {
203                if flags.syn && flags.ack {
204                    // SYN-ACK from responder
205                    sender.initial_seq = Some(seq);
206                    sender.next_expected_seq = seq.wrapping_add(1);
207                    sender.last_ack = ack;
208                    self.conn_state = TcpConnectionState::SynRcvd;
209                }
210            },
211            TcpConnectionState::SynRcvd => {
212                if flags.ack && !flags.syn {
213                    // Final ACK of 3-way handshake
214                    sender.last_ack = ack;
215                    self.conn_state = TcpConnectionState::Established;
216                    // Initialize reassemblers with ISN+1 (after SYN)
217                    if !self.reassembler_fwd.is_initialized()
218                        && let Some(isn) = self.forward_endpoint.initial_seq
219                    {
220                        self.reassembler_fwd.initialize(isn.wrapping_add(1));
221                    }
222                    if !self.reassembler_rev.is_initialized()
223                        && let Some(isn) = self.reverse_endpoint.initial_seq
224                    {
225                        self.reassembler_rev.initialize(isn.wrapping_add(1));
226                    }
227                }
228            },
229            TcpConnectionState::Established => {
230                sender.last_ack = ack;
231
232                // Process payload through reassembler
233                if !payload.is_empty() {
234                    if let Err(e) = reassembler.process_segment(seq, payload, config) {
235                        // Buffer full or too many fragments — track the drop
236                        match direction {
237                            FlowDirection::Forward => self.dropped_segments_fwd += 1,
238                            FlowDirection::Reverse => self.dropped_segments_rev += 1,
239                        }
240                        // Log once at thresholds to avoid flooding stderr
241                        let total = self.dropped_segments_fwd + self.dropped_segments_rev;
242                        if total == 1 || total.is_power_of_two() {
243                            eprintln!(
244                                "[!] stackforge: TCP reassembly dropped segment ({e}), {total} total drops for this flow"
245                            );
246                        }
247                    }
248                }
249
250                if flags.fin {
251                    sender.next_expected_seq =
252                        seq.wrapping_add(payload.len() as u32).wrapping_add(1); // FIN consumes 1 seq
253                    match direction {
254                        FlowDirection::Forward => {
255                            self.conn_state = TcpConnectionState::FinWait1;
256                        },
257                        FlowDirection::Reverse => {
258                            self.conn_state = TcpConnectionState::CloseWait;
259                        },
260                    }
261                } else {
262                    sender.next_expected_seq = seq.wrapping_add(payload.len() as u32);
263                }
264            },
265            TcpConnectionState::FinWait1 => {
266                if flags.fin && flags.ack {
267                    // Simultaneous close
268                    self.conn_state = TcpConnectionState::TimeWait;
269                } else if flags.ack {
270                    self.conn_state = TcpConnectionState::FinWait2;
271                } else if flags.fin {
272                    self.conn_state = TcpConnectionState::Closing;
273                }
274            },
275            TcpConnectionState::FinWait2 => {
276                if flags.fin {
277                    self.conn_state = TcpConnectionState::TimeWait;
278                }
279            },
280            TcpConnectionState::CloseWait => {
281                if flags.fin {
282                    self.conn_state = TcpConnectionState::LastAck;
283                }
284            },
285            TcpConnectionState::Closing => {
286                if flags.ack {
287                    self.conn_state = TcpConnectionState::TimeWait;
288                }
289            },
290            TcpConnectionState::LastAck => {
291                if flags.ack {
292                    self.conn_state = TcpConnectionState::Closed;
293                }
294            },
295            TcpConnectionState::TimeWait | TcpConnectionState::Closed => {
296                // Terminal states — no further transitions
297            },
298        }
299
300        Ok(())
301    }
302}
303
304impl Default for TcpConversationState {
305    fn default() -> Self {
306        Self::new()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::layer::stack::{LayerStack, LayerStackEntry};
314    use crate::{EthernetBuilder, Ipv4Builder, TcpBuilder};
315
316    fn make_tcp_packet(
317        src_port: u16,
318        dst_port: u16,
319        seq: u32,
320        ack_num: u32,
321        flags: &str,
322        payload: &[u8],
323    ) -> crate::Packet {
324        let mut builder = TcpBuilder::new()
325            .src_port(src_port)
326            .dst_port(dst_port)
327            .seq(seq)
328            .ack_num(ack_num)
329            .window(65535);
330
331        for c in flags.chars() {
332            builder = match c {
333                'S' => builder.syn(),
334                'A' => builder.ack(),
335                'F' => builder.fin(),
336                'R' => builder.rst(),
337                'P' => builder.psh(),
338                _ => builder,
339            };
340        }
341
342        let stack = LayerStack::new()
343            .push(LayerStackEntry::Ethernet(
344                EthernetBuilder::new()
345                    .dst(crate::MacAddress::BROADCAST)
346                    .src(crate::MacAddress::new([0, 1, 2, 3, 4, 5])),
347            ))
348            .push(LayerStackEntry::Ipv4(
349                Ipv4Builder::new()
350                    .src(std::net::Ipv4Addr::new(10, 0, 0, 1))
351                    .dst(std::net::Ipv4Addr::new(10, 0, 0, 2)),
352            ))
353            .push(LayerStackEntry::Tcp(builder));
354
355        let stack = if !payload.is_empty() {
356            stack.push(LayerStackEntry::Raw(payload.to_vec()))
357        } else {
358            stack
359        };
360
361        stack.build_packet()
362    }
363
364    fn get_tcp_and_buf(pkt: &crate::Packet) -> (TcpLayer, &[u8]) {
365        let tcp = pkt.tcp().unwrap();
366        let buf = pkt.as_bytes();
367        (tcp, buf)
368    }
369
370    #[test]
371    fn test_three_way_handshake() {
372        let config = FlowConfig::default();
373        let mut state = TcpConversationState::new();
374
375        // SYN (client → server, forward)
376        let pkt = make_tcp_packet(12345, 80, 1000, 0, "S", &[]);
377        let (tcp, buf) = get_tcp_and_buf(&pkt);
378        state
379            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
380            .unwrap();
381        assert_eq!(state.conn_state, TcpConnectionState::SynSent);
382
383        // SYN-ACK (server → client, reverse)
384        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "SA", &[]);
385        let (tcp, buf) = get_tcp_and_buf(&pkt);
386        state
387            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
388            .unwrap();
389        assert_eq!(state.conn_state, TcpConnectionState::SynRcvd);
390
391        // ACK (client → server, forward)
392        let pkt = make_tcp_packet(12345, 80, 1001, 2001, "A", &[]);
393        let (tcp, buf) = get_tcp_and_buf(&pkt);
394        state
395            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
396            .unwrap();
397        assert_eq!(state.conn_state, TcpConnectionState::Established);
398    }
399
400    #[test]
401    fn test_rst_closes_connection() {
402        let config = FlowConfig::default();
403        let mut state = TcpConversationState::new();
404        state.conn_state = TcpConnectionState::Established;
405
406        let pkt = make_tcp_packet(12345, 80, 1000, 0, "R", &[]);
407        let (tcp, buf) = get_tcp_and_buf(&pkt);
408        state
409            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
410            .unwrap();
411        assert_eq!(state.conn_state, TcpConnectionState::Closed);
412    }
413
414    #[test]
415    fn test_fin_handshake() {
416        let config = FlowConfig::default();
417        let mut state = TcpConversationState::new();
418        state.conn_state = TcpConnectionState::Established;
419
420        // FIN from forward direction
421        let pkt = make_tcp_packet(12345, 80, 1000, 2000, "FA", &[]);
422        let (tcp, buf) = get_tcp_and_buf(&pkt);
423        state
424            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
425            .unwrap();
426        assert_eq!(state.conn_state, TcpConnectionState::FinWait1);
427
428        // ACK of FIN from reverse
429        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "A", &[]);
430        let (tcp, buf) = get_tcp_and_buf(&pkt);
431        state
432            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
433            .unwrap();
434        assert_eq!(state.conn_state, TcpConnectionState::FinWait2);
435
436        // FIN from reverse
437        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "FA", &[]);
438        let (tcp, buf) = get_tcp_and_buf(&pkt);
439        state
440            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
441            .unwrap();
442        assert_eq!(state.conn_state, TcpConnectionState::TimeWait);
443    }
444
445    #[test]
446    fn test_data_transfer_and_reassembly() {
447        let config = FlowConfig::default();
448        let mut state = TcpConversationState::new();
449        state.conn_state = TcpConnectionState::Established;
450
451        // Initialize forward reassembler
452        state.forward_endpoint.initial_seq = Some(999);
453        state.reassembler_fwd.initialize(1000);
454
455        // Data from forward direction
456        let pkt = make_tcp_packet(12345, 80, 1000, 2000, "A", b"GET /");
457        let (tcp, buf) = get_tcp_and_buf(&pkt);
458        state
459            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
460            .unwrap();
461
462        assert_eq!(state.reassembler_fwd.reassembled_data(), b"GET /");
463    }
464
465    #[test]
466    fn test_state_display() {
467        assert_eq!(TcpConnectionState::Established.name(), "ESTABLISHED");
468        assert_eq!(TcpConnectionState::SynSent.name(), "SYN_SENT");
469        assert!(TcpConnectionState::Closed.is_closed());
470        assert!(TcpConnectionState::TimeWait.is_closed());
471        assert!(TcpConnectionState::SynSent.is_half_open());
472        assert!(!TcpConnectionState::Established.is_half_open());
473    }
474}