Skip to main content

stackforge_core/flow/
key.rs

1use std::net::IpAddr;
2
3use crate::Packet;
4use crate::layer::LayerKind;
5use crate::layer::ipv6::Ipv6Layer;
6
7use super::error::FlowError;
8
9/// Transport layer protocol identifier for flow keys.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum TransportProtocol {
12    Tcp,
13    Udp,
14    Icmp,
15    Icmpv6,
16    Other(u8),
17}
18
19impl TransportProtocol {
20    /// Create from IP protocol number.
21    pub fn from_ip_protocol(proto: u8) -> Self {
22        match proto {
23            6 => Self::Tcp,
24            17 => Self::Udp,
25            1 => Self::Icmp,
26            58 => Self::Icmpv6,
27            other => Self::Other(other),
28        }
29    }
30
31    /// Human-readable name.
32    pub fn name(&self) -> &'static str {
33        match self {
34            Self::Tcp => "TCP",
35            Self::Udp => "UDP",
36            Self::Icmp => "ICMP",
37            Self::Icmpv6 => "ICMPv6",
38            Self::Other(_) => "Other",
39        }
40    }
41}
42
43impl std::fmt::Display for TransportProtocol {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            Self::Other(n) => write!(f, "Other({n})"),
47            _ => f.write_str(self.name()),
48        }
49    }
50}
51
52/// Direction of a packet relative to the conversation's canonical key.
53///
54/// Forward means the packet's source matches `addr_a` (the smaller address).
55/// Reverse means the packet's source matches `addr_b` (the larger address).
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub enum FlowDirection {
58    Forward,
59    Reverse,
60}
61
62/// Bidirectional canonical conversation key.
63///
64/// Uses Wireshark-style canonical ordering: the smaller IP address is always
65/// `addr_a` with its corresponding port as `port_a`. This ensures that both
66/// directions of a conversation hash to the same key.
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub struct CanonicalKey {
69    /// The smaller IP address (or first if equal, then by port).
70    pub addr_a: IpAddr,
71    /// The larger IP address.
72    pub addr_b: IpAddr,
73    /// Port corresponding to `addr_a`.
74    pub port_a: u16,
75    /// Port corresponding to `addr_b`.
76    pub port_b: u16,
77    /// Transport protocol.
78    pub protocol: TransportProtocol,
79    /// Optional VLAN ID for deinterlacing.
80    pub vlan_id: Option<u16>,
81}
82
83/// Helper to get byte representation of an IP address for comparison.
84fn ip_to_bytes(ip: &IpAddr) -> Vec<u8> {
85    match ip {
86        IpAddr::V4(v4) => v4.octets().to_vec(),
87        IpAddr::V6(v6) => v6.octets().to_vec(),
88    }
89}
90
91impl CanonicalKey {
92    /// Create a new canonical key with deterministic ordering.
93    ///
94    /// Returns the key and the direction of the original packet relative
95    /// to the canonical ordering.
96    pub fn new(
97        src_ip: IpAddr,
98        dst_ip: IpAddr,
99        src_port: u16,
100        dst_port: u16,
101        protocol: TransportProtocol,
102        vlan_id: Option<u16>,
103    ) -> (Self, FlowDirection) {
104        let src_bytes = ip_to_bytes(&src_ip);
105        let dst_bytes = ip_to_bytes(&dst_ip);
106
107        let (addr_a, port_a, addr_b, port_b, direction) = match src_bytes.cmp(&dst_bytes) {
108            std::cmp::Ordering::Less => {
109                (src_ip, src_port, dst_ip, dst_port, FlowDirection::Forward)
110            },
111            std::cmp::Ordering::Greater => {
112                (dst_ip, dst_port, src_ip, src_port, FlowDirection::Reverse)
113            },
114            std::cmp::Ordering::Equal => {
115                // IPs are equal, sort by port
116                if src_port <= dst_port {
117                    (src_ip, src_port, dst_ip, dst_port, FlowDirection::Forward)
118                } else {
119                    (dst_ip, dst_port, src_ip, src_port, FlowDirection::Reverse)
120                }
121            },
122        };
123
124        (
125            Self {
126                addr_a,
127                addr_b,
128                port_a,
129                port_b,
130                protocol,
131                vlan_id,
132            },
133            direction,
134        )
135    }
136}
137
138impl std::fmt::Display for CanonicalKey {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        write!(
141            f,
142            "{}:{} <-> {}:{} [{}]",
143            self.addr_a, self.port_a, self.addr_b, self.port_b, self.protocol
144        )
145    }
146}
147
148/// Extract a canonical key and direction from a parsed packet.
149///
150/// Reads the IP layer (IPv4 or IPv6) for addresses and protocol number,
151/// and the transport layer (TCP or UDP) for ports. For ICMP and other
152/// protocols without ports, ports are set to 0.
153pub fn extract_key(packet: &Packet) -> Result<(CanonicalKey, FlowDirection), FlowError> {
154    if !packet.is_parsed() {
155        return Err(FlowError::PacketNotParsed);
156    }
157
158    let buf = packet.as_bytes();
159
160    // Extract IP addresses and protocol
161    let (src_ip, dst_ip, proto) = if let Some(ipv4) = packet.ipv4() {
162        let src = ipv4
163            .src(buf)
164            .map_err(|e| FlowError::PacketError(e.into()))?;
165        let dst = ipv4
166            .dst(buf)
167            .map_err(|e| FlowError::PacketError(e.into()))?;
168        let protocol = ipv4
169            .protocol(buf)
170            .map_err(|e| FlowError::PacketError(e.into()))?;
171        (IpAddr::V4(src), IpAddr::V4(dst), protocol)
172    } else if let Some(idx) = packet.get_layer(LayerKind::Ipv6) {
173        let ipv6 = Ipv6Layer { index: *idx };
174        let src = ipv6
175            .src(buf)
176            .map_err(|e| FlowError::PacketError(e.into()))?;
177        let dst = ipv6
178            .dst(buf)
179            .map_err(|e| FlowError::PacketError(e.into()))?;
180        let next_header = ipv6
181            .next_header(buf)
182            .map_err(|e| FlowError::PacketError(e.into()))?;
183        (IpAddr::V6(src), IpAddr::V6(dst), next_header)
184    } else {
185        return Err(FlowError::NoIpLayer);
186    };
187
188    let transport = TransportProtocol::from_ip_protocol(proto);
189
190    // Extract ports from transport layer
191    let (src_port, dst_port) = match transport {
192        TransportProtocol::Tcp => {
193            let tcp = packet.tcp().ok_or(FlowError::NoTransportLayer)?;
194            let sport = tcp
195                .src_port(buf)
196                .map_err(|e| FlowError::PacketError(e.into()))?;
197            let dport = tcp
198                .dst_port(buf)
199                .map_err(|e| FlowError::PacketError(e.into()))?;
200            (sport, dport)
201        },
202        TransportProtocol::Udp => {
203            let udp = packet.udp().ok_or(FlowError::NoTransportLayer)?;
204            let sport = udp
205                .src_port(buf)
206                .map_err(|e| FlowError::PacketError(e.into()))?;
207            let dport = udp
208                .dst_port(buf)
209                .map_err(|e| FlowError::PacketError(e.into()))?;
210            (sport, dport)
211        },
212        // ICMP and other protocols have no ports
213        _ => (0u16, 0u16),
214    };
215
216    // Check for VLAN tag
217    let vlan_id = if packet.get_layer(LayerKind::Dot1Q).is_some() {
218        // TODO: Extract actual VLAN ID from Dot1Q layer if needed
219        None
220    } else {
221        None
222    };
223
224    Ok(CanonicalKey::new(
225        src_ip, dst_ip, src_port, dst_port, transport, vlan_id,
226    ))
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use std::net::{Ipv4Addr, Ipv6Addr};
233
234    #[test]
235    fn test_canonical_key_forward() {
236        let (key, dir) = CanonicalKey::new(
237            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
238            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
239            12345,
240            80,
241            TransportProtocol::Tcp,
242            None,
243        );
244        assert_eq!(dir, FlowDirection::Forward);
245        assert_eq!(key.addr_a, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
246        assert_eq!(key.addr_b, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
247        assert_eq!(key.port_a, 12345);
248        assert_eq!(key.port_b, 80);
249    }
250
251    #[test]
252    fn test_canonical_key_reverse() {
253        let (key, dir) = CanonicalKey::new(
254            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
255            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
256            80,
257            12345,
258            TransportProtocol::Tcp,
259            None,
260        );
261        assert_eq!(dir, FlowDirection::Reverse);
262        assert_eq!(key.addr_a, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
263        assert_eq!(key.addr_b, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
264        assert_eq!(key.port_a, 12345);
265        assert_eq!(key.port_b, 80);
266    }
267
268    #[test]
269    fn test_canonical_key_bidirectional_match() {
270        let (key_fwd, _) = CanonicalKey::new(
271            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
272            IpAddr::V4(Ipv4Addr::new(81, 209, 179, 69)),
273            50272,
274            80,
275            TransportProtocol::Tcp,
276            None,
277        );
278        let (key_rev, _) = CanonicalKey::new(
279            IpAddr::V4(Ipv4Addr::new(81, 209, 179, 69)),
280            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
281            80,
282            50272,
283            TransportProtocol::Tcp,
284            None,
285        );
286        assert_eq!(key_fwd, key_rev);
287    }
288
289    #[test]
290    fn test_canonical_key_equal_ips_sort_by_port() {
291        let (key, dir) = CanonicalKey::new(
292            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
293            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
294            8080,
295            80,
296            TransportProtocol::Tcp,
297            None,
298        );
299        assert_eq!(dir, FlowDirection::Reverse);
300        assert_eq!(key.port_a, 80);
301        assert_eq!(key.port_b, 8080);
302    }
303
304    #[test]
305    fn test_canonical_key_ipv6() {
306        let src = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1));
307        let dst = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2));
308        let (key_fwd, _) = CanonicalKey::new(src, dst, 1234, 80, TransportProtocol::Tcp, None);
309        let (key_rev, _) = CanonicalKey::new(dst, src, 80, 1234, TransportProtocol::Tcp, None);
310        assert_eq!(key_fwd, key_rev);
311    }
312
313    #[test]
314    fn test_canonical_key_different_protocols() {
315        let (key_tcp, _) = CanonicalKey::new(
316            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
317            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
318            1234,
319            80,
320            TransportProtocol::Tcp,
321            None,
322        );
323        let (key_udp, _) = CanonicalKey::new(
324            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
325            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
326            1234,
327            80,
328            TransportProtocol::Udp,
329            None,
330        );
331        assert_ne!(key_tcp, key_udp);
332    }
333
334    #[test]
335    fn test_transport_protocol_from_ip() {
336        assert_eq!(
337            TransportProtocol::from_ip_protocol(6),
338            TransportProtocol::Tcp
339        );
340        assert_eq!(
341            TransportProtocol::from_ip_protocol(17),
342            TransportProtocol::Udp
343        );
344        assert_eq!(
345            TransportProtocol::from_ip_protocol(1),
346            TransportProtocol::Icmp
347        );
348        assert_eq!(
349            TransportProtocol::from_ip_protocol(58),
350            TransportProtocol::Icmpv6
351        );
352        assert_eq!(
353            TransportProtocol::from_ip_protocol(47),
354            TransportProtocol::Other(47)
355        );
356    }
357
358    #[test]
359    fn test_canonical_key_display() {
360        let (key, _) = CanonicalKey::new(
361            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
362            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
363            1234,
364            80,
365            TransportProtocol::Tcp,
366            None,
367        );
368        let s = key.to_string();
369        assert!(s.contains("10.0.0.1:1234"));
370        assert!(s.contains("10.0.0.2:80"));
371        assert!(s.contains("TCP"));
372    }
373}