Skip to main content

stackforge_core/flow/
tcp_state.rs

1use crate::TcpLayer;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5use super::key::FlowDirection;
6use super::tcp_reassembly::TcpReassembler;
7
8/// TCP connection states per RFC 793.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpConnectionState {
11    Listen,
12    SynSent,
13    SynRcvd,
14    Established,
15    FinWait1,
16    FinWait2,
17    CloseWait,
18    Closing,
19    LastAck,
20    TimeWait,
21    Closed,
22}
23
24impl TcpConnectionState {
25    /// Human-readable state name.
26    pub fn name(&self) -> &'static str {
27        match self {
28            Self::Listen => "LISTEN",
29            Self::SynSent => "SYN_SENT",
30            Self::SynRcvd => "SYN_RCVD",
31            Self::Established => "ESTABLISHED",
32            Self::FinWait1 => "FIN_WAIT_1",
33            Self::FinWait2 => "FIN_WAIT_2",
34            Self::CloseWait => "CLOSE_WAIT",
35            Self::Closing => "CLOSING",
36            Self::LastAck => "LAST_ACK",
37            Self::TimeWait => "TIME_WAIT",
38            Self::Closed => "CLOSED",
39        }
40    }
41
42    /// Whether this is a terminal/closed state.
43    pub fn is_closed(&self) -> bool {
44        matches!(self, Self::Closed | Self::TimeWait)
45    }
46
47    /// Whether this is a half-open state (not yet established).
48    pub fn is_half_open(&self) -> bool {
49        matches!(self, Self::Listen | Self::SynSent | Self::SynRcvd)
50    }
51}
52
53impl std::fmt::Display for TcpConnectionState {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.write_str(self.name())
56    }
57}
58
59/// Per-endpoint sequence tracking state.
60#[derive(Debug, Clone)]
61pub struct TcpEndpointState {
62    /// Next expected sequence number from this endpoint.
63    pub next_expected_seq: u32,
64    /// Last acknowledged sequence number from this endpoint.
65    pub last_ack: u32,
66    /// Advertised receive window size.
67    pub window_size: u16,
68    /// Initial sequence number (set on SYN).
69    pub initial_seq: Option<u32>,
70}
71
72impl TcpEndpointState {
73    pub fn new() -> Self {
74        Self {
75            next_expected_seq: 0,
76            last_ack: 0,
77            window_size: 0,
78            initial_seq: None,
79        }
80    }
81}
82
83impl Default for TcpEndpointState {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89/// Complete TCP conversation state including connection tracking,
90/// per-endpoint sequence state, and stream reassembly.
91#[derive(Debug)]
92pub struct TcpConversationState {
93    /// Current connection state (RFC 793 state machine).
94    pub conn_state: TcpConnectionState,
95    /// Sequence tracking for the forward direction (addr_a → addr_b).
96    pub forward_endpoint: TcpEndpointState,
97    /// Sequence tracking for the reverse direction (addr_b → addr_a).
98    pub reverse_endpoint: TcpEndpointState,
99    /// Stream reassembly for forward direction.
100    pub reassembler_fwd: TcpReassembler,
101    /// Stream reassembly for reverse direction.
102    pub reassembler_rev: TcpReassembler,
103}
104
105impl TcpConversationState {
106    pub fn new() -> Self {
107        Self {
108            conn_state: TcpConnectionState::Listen,
109            forward_endpoint: TcpEndpointState::new(),
110            reverse_endpoint: TcpEndpointState::new(),
111            reassembler_fwd: TcpReassembler::new(),
112            reassembler_rev: TcpReassembler::new(),
113        }
114    }
115
116    /// Process a TCP packet, updating connection state and reassembly buffers.
117    ///
118    /// `direction` indicates whether this packet is Forward (addr_a → addr_b)
119    /// or Reverse (addr_b → addr_a) relative to the canonical key.
120    /// `tcp` is the TCP layer view, `buf` is the full packet buffer.
121    pub fn process_packet(
122        &mut self,
123        direction: FlowDirection,
124        tcp: &TcpLayer,
125        buf: &[u8],
126        config: &FlowConfig,
127    ) -> Result<(), FlowError> {
128        let flags = tcp
129            .flags(buf)
130            .map_err(|e| FlowError::PacketError(e.into()))?;
131        let seq = tcp.seq(buf).map_err(|e| FlowError::PacketError(e.into()))?;
132        let ack = tcp.ack(buf).map_err(|e| FlowError::PacketError(e.into()))?;
133        let window = tcp
134            .window(buf)
135            .map_err(|e| FlowError::PacketError(e.into()))?;
136
137        // Determine payload boundaries
138        let data_offset = tcp
139            .data_offset(buf)
140            .map_err(|e| FlowError::PacketError(e.into()))?;
141        let header_bytes = (data_offset as usize) * 4;
142        let tcp_start = tcp.index.start;
143        // TCP payload starts after the TCP header. Since the TCP layer's
144        // index.end marks the header boundary, the payload is everything
145        // from header end to the end of the packet buffer.
146        let payload_start = tcp_start + header_bytes;
147        let payload = if payload_start < buf.len() {
148            &buf[payload_start..buf.len()]
149        } else {
150            &[]
151        };
152
153        // Get mutable refs to endpoint and reassembler for this direction
154        let (sender, _receiver, reassembler) = match direction {
155            FlowDirection::Forward => (
156                &mut self.forward_endpoint,
157                &mut self.reverse_endpoint,
158                &mut self.reassembler_fwd,
159            ),
160            FlowDirection::Reverse => (
161                &mut self.reverse_endpoint,
162                &mut self.forward_endpoint,
163                &mut self.reassembler_rev,
164            ),
165        };
166
167        // Update endpoint state
168        sender.window_size = window;
169
170        // State machine transitions
171        if flags.rst {
172            self.conn_state = TcpConnectionState::Closed;
173            return Ok(());
174        }
175
176        match self.conn_state {
177            TcpConnectionState::Listen => {
178                if flags.syn && !flags.ack {
179                    // SYN from initiator
180                    sender.initial_seq = Some(seq);
181                    sender.next_expected_seq = seq.wrapping_add(1); // SYN consumes 1 seq
182                    self.conn_state = TcpConnectionState::SynSent;
183                }
184            },
185            TcpConnectionState::SynSent => {
186                if flags.syn && flags.ack {
187                    // SYN-ACK from responder
188                    sender.initial_seq = Some(seq);
189                    sender.next_expected_seq = seq.wrapping_add(1);
190                    sender.last_ack = ack;
191                    self.conn_state = TcpConnectionState::SynRcvd;
192                }
193            },
194            TcpConnectionState::SynRcvd => {
195                if flags.ack && !flags.syn {
196                    // Final ACK of 3-way handshake
197                    sender.last_ack = ack;
198                    self.conn_state = TcpConnectionState::Established;
199                    // Initialize reassemblers with ISN+1 (after SYN)
200                    if !self.reassembler_fwd.is_initialized() {
201                        if let Some(isn) = self.forward_endpoint.initial_seq {
202                            self.reassembler_fwd.initialize(isn.wrapping_add(1));
203                        }
204                    }
205                    if !self.reassembler_rev.is_initialized() {
206                        if let Some(isn) = self.reverse_endpoint.initial_seq {
207                            self.reassembler_rev.initialize(isn.wrapping_add(1));
208                        }
209                    }
210                }
211            },
212            TcpConnectionState::Established => {
213                sender.last_ack = ack;
214
215                // Process payload through reassembler
216                if !payload.is_empty() {
217                    // Ignore reassembly errors (buffer full, etc.) — they don't
218                    // affect connection state tracking
219                    let _ = reassembler.process_segment(seq, payload, config);
220                }
221
222                if flags.fin {
223                    sender.next_expected_seq =
224                        seq.wrapping_add(payload.len() as u32).wrapping_add(1); // FIN consumes 1 seq
225                    match direction {
226                        FlowDirection::Forward => {
227                            self.conn_state = TcpConnectionState::FinWait1;
228                        },
229                        FlowDirection::Reverse => {
230                            self.conn_state = TcpConnectionState::CloseWait;
231                        },
232                    }
233                } else {
234                    sender.next_expected_seq = seq.wrapping_add(payload.len() as u32);
235                }
236            },
237            TcpConnectionState::FinWait1 => {
238                if flags.fin && flags.ack {
239                    // Simultaneous close
240                    self.conn_state = TcpConnectionState::TimeWait;
241                } else if flags.ack {
242                    self.conn_state = TcpConnectionState::FinWait2;
243                } else if flags.fin {
244                    self.conn_state = TcpConnectionState::Closing;
245                }
246            },
247            TcpConnectionState::FinWait2 => {
248                if flags.fin {
249                    self.conn_state = TcpConnectionState::TimeWait;
250                }
251            },
252            TcpConnectionState::CloseWait => {
253                if flags.fin {
254                    self.conn_state = TcpConnectionState::LastAck;
255                }
256            },
257            TcpConnectionState::Closing => {
258                if flags.ack {
259                    self.conn_state = TcpConnectionState::TimeWait;
260                }
261            },
262            TcpConnectionState::LastAck => {
263                if flags.ack {
264                    self.conn_state = TcpConnectionState::Closed;
265                }
266            },
267            TcpConnectionState::TimeWait | TcpConnectionState::Closed => {
268                // Terminal states — no further transitions
269            },
270        }
271
272        Ok(())
273    }
274}
275
276impl Default for TcpConversationState {
277    fn default() -> Self {
278        Self::new()
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::layer::stack::{LayerStack, LayerStackEntry};
286    use crate::{EthernetBuilder, Ipv4Builder, TcpBuilder};
287
288    fn make_tcp_packet(
289        src_port: u16,
290        dst_port: u16,
291        seq: u32,
292        ack_num: u32,
293        flags: &str,
294        payload: &[u8],
295    ) -> crate::Packet {
296        let mut builder = TcpBuilder::new()
297            .src_port(src_port)
298            .dst_port(dst_port)
299            .seq(seq)
300            .ack_num(ack_num)
301            .window(65535);
302
303        for c in flags.chars() {
304            builder = match c {
305                'S' => builder.syn(),
306                'A' => builder.ack(),
307                'F' => builder.fin(),
308                'R' => builder.rst(),
309                'P' => builder.psh(),
310                _ => builder,
311            };
312        }
313
314        let stack = LayerStack::new()
315            .push(LayerStackEntry::Ethernet(
316                EthernetBuilder::new()
317                    .dst(crate::MacAddress::BROADCAST)
318                    .src(crate::MacAddress::new([0, 1, 2, 3, 4, 5])),
319            ))
320            .push(LayerStackEntry::Ipv4(
321                Ipv4Builder::new()
322                    .src(std::net::Ipv4Addr::new(10, 0, 0, 1))
323                    .dst(std::net::Ipv4Addr::new(10, 0, 0, 2)),
324            ))
325            .push(LayerStackEntry::Tcp(builder));
326
327        let stack = if !payload.is_empty() {
328            stack.push(LayerStackEntry::Raw(payload.to_vec()))
329        } else {
330            stack
331        };
332
333        stack.build_packet()
334    }
335
336    fn get_tcp_and_buf(pkt: &crate::Packet) -> (TcpLayer, &[u8]) {
337        let tcp = pkt.tcp().unwrap();
338        let buf = pkt.as_bytes();
339        (tcp, buf)
340    }
341
342    #[test]
343    fn test_three_way_handshake() {
344        let config = FlowConfig::default();
345        let mut state = TcpConversationState::new();
346
347        // SYN (client → server, forward)
348        let pkt = make_tcp_packet(12345, 80, 1000, 0, "S", &[]);
349        let (tcp, buf) = get_tcp_and_buf(&pkt);
350        state
351            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
352            .unwrap();
353        assert_eq!(state.conn_state, TcpConnectionState::SynSent);
354
355        // SYN-ACK (server → client, reverse)
356        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "SA", &[]);
357        let (tcp, buf) = get_tcp_and_buf(&pkt);
358        state
359            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
360            .unwrap();
361        assert_eq!(state.conn_state, TcpConnectionState::SynRcvd);
362
363        // ACK (client → server, forward)
364        let pkt = make_tcp_packet(12345, 80, 1001, 2001, "A", &[]);
365        let (tcp, buf) = get_tcp_and_buf(&pkt);
366        state
367            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
368            .unwrap();
369        assert_eq!(state.conn_state, TcpConnectionState::Established);
370    }
371
372    #[test]
373    fn test_rst_closes_connection() {
374        let config = FlowConfig::default();
375        let mut state = TcpConversationState::new();
376        state.conn_state = TcpConnectionState::Established;
377
378        let pkt = make_tcp_packet(12345, 80, 1000, 0, "R", &[]);
379        let (tcp, buf) = get_tcp_and_buf(&pkt);
380        state
381            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
382            .unwrap();
383        assert_eq!(state.conn_state, TcpConnectionState::Closed);
384    }
385
386    #[test]
387    fn test_fin_handshake() {
388        let config = FlowConfig::default();
389        let mut state = TcpConversationState::new();
390        state.conn_state = TcpConnectionState::Established;
391
392        // FIN from forward direction
393        let pkt = make_tcp_packet(12345, 80, 1000, 2000, "FA", &[]);
394        let (tcp, buf) = get_tcp_and_buf(&pkt);
395        state
396            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
397            .unwrap();
398        assert_eq!(state.conn_state, TcpConnectionState::FinWait1);
399
400        // ACK of FIN from reverse
401        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "A", &[]);
402        let (tcp, buf) = get_tcp_and_buf(&pkt);
403        state
404            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
405            .unwrap();
406        assert_eq!(state.conn_state, TcpConnectionState::FinWait2);
407
408        // FIN from reverse
409        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "FA", &[]);
410        let (tcp, buf) = get_tcp_and_buf(&pkt);
411        state
412            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
413            .unwrap();
414        assert_eq!(state.conn_state, TcpConnectionState::TimeWait);
415    }
416
417    #[test]
418    fn test_data_transfer_and_reassembly() {
419        let config = FlowConfig::default();
420        let mut state = TcpConversationState::new();
421        state.conn_state = TcpConnectionState::Established;
422
423        // Initialize forward reassembler
424        state.forward_endpoint.initial_seq = Some(999);
425        state.reassembler_fwd.initialize(1000);
426
427        // Data from forward direction
428        let pkt = make_tcp_packet(12345, 80, 1000, 2000, "A", b"GET /");
429        let (tcp, buf) = get_tcp_and_buf(&pkt);
430        state
431            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
432            .unwrap();
433
434        assert_eq!(state.reassembler_fwd.reassembled_data(), b"GET /");
435    }
436
437    #[test]
438    fn test_state_display() {
439        assert_eq!(TcpConnectionState::Established.name(), "ESTABLISHED");
440        assert_eq!(TcpConnectionState::SynSent.name(), "SYN_SENT");
441        assert!(TcpConnectionState::Closed.is_closed());
442        assert!(TcpConnectionState::TimeWait.is_closed());
443        assert!(TcpConnectionState::SynSent.is_half_open());
444        assert!(!TcpConnectionState::Established.is_half_open());
445    }
446}