Skip to main content

stackforge_core/flow/
key.rs

1use std::net::IpAddr;
2
3use crate::Packet;
4use crate::layer::LayerKind;
5use crate::layer::ipv6::Ipv6Layer;
6
7use super::error::FlowError;
8
9/// Z-Wave conversation key based on home ID and node pair.
10///
11/// Uses canonical ordering: the smaller node ID is always `node_a`.
12/// This ensures that both directions of a Z-Wave conversation hash
13/// to the same key.
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct ZWaveKey {
16    /// Z-Wave network home ID.
17    pub home_id: u32,
18    /// The smaller node ID.
19    pub node_a: u8,
20    /// The larger node ID.
21    pub node_b: u8,
22}
23
24impl ZWaveKey {
25    /// Create a new canonical Z-Wave key with deterministic node ordering.
26    ///
27    /// Returns the key and the direction of the original packet relative
28    /// to the canonical ordering.
29    #[must_use]
30    pub fn new(home_id: u32, src_node: u8, dst_node: u8) -> (Self, FlowDirection) {
31        if src_node <= dst_node {
32            (
33                Self {
34                    home_id,
35                    node_a: src_node,
36                    node_b: dst_node,
37                },
38                FlowDirection::Forward,
39            )
40        } else {
41            (
42                Self {
43                    home_id,
44                    node_a: dst_node,
45                    node_b: src_node,
46                },
47                FlowDirection::Reverse,
48            )
49        }
50    }
51}
52
53impl std::fmt::Display for ZWaveKey {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        write!(
56            f,
57            "ZWave[{:#010X}] node {} <-> node {}",
58            self.home_id, self.node_a, self.node_b
59        )
60    }
61}
62
63/// Extract a Z-Wave key and direction from a parsed packet.
64///
65/// Reads the Z-Wave layer for home ID, source, and destination node IDs.
66pub fn extract_zwave_key(packet: &Packet) -> Result<(ZWaveKey, FlowDirection), FlowError> {
67    if !packet.is_parsed() {
68        return Err(FlowError::PacketNotParsed);
69    }
70
71    let buf = packet.as_bytes();
72
73    let zwave = packet.zwave().ok_or(FlowError::NoTransportLayer)?;
74
75    let home_id = zwave
76        .home_id(buf)
77        .map_err(|e| FlowError::PacketError(e.into()))?;
78    let src = zwave
79        .src(buf)
80        .map_err(|e| FlowError::PacketError(e.into()))?;
81    let dst = zwave
82        .dst(buf)
83        .map_err(|e| FlowError::PacketError(e.into()))?;
84
85    Ok(ZWaveKey::new(home_id, src, dst))
86}
87
88/// Transport layer protocol identifier for flow keys.
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
90pub enum TransportProtocol {
91    Tcp,
92    Udp,
93    Icmp,
94    Icmpv6,
95    Other(u8),
96}
97
98impl TransportProtocol {
99    /// Create from IP protocol number.
100    #[must_use]
101    pub fn from_ip_protocol(proto: u8) -> Self {
102        match proto {
103            6 => Self::Tcp,
104            17 => Self::Udp,
105            1 => Self::Icmp,
106            58 => Self::Icmpv6,
107            other => Self::Other(other),
108        }
109    }
110
111    /// Human-readable name.
112    #[must_use]
113    pub fn name(&self) -> &'static str {
114        match self {
115            Self::Tcp => "TCP",
116            Self::Udp => "UDP",
117            Self::Icmp => "ICMP",
118            Self::Icmpv6 => "ICMPv6",
119            Self::Other(_) => "Other",
120        }
121    }
122}
123
124impl std::fmt::Display for TransportProtocol {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        match self {
127            Self::Other(n) => write!(f, "Other({n})"),
128            _ => f.write_str(self.name()),
129        }
130    }
131}
132
133/// Direction of a packet relative to the conversation's canonical key.
134///
135/// Forward means the packet's source matches `addr_a` (the smaller address).
136/// Reverse means the packet's source matches `addr_b` (the larger address).
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum FlowDirection {
139    Forward,
140    Reverse,
141}
142
143/// Bidirectional canonical conversation key.
144///
145/// Uses Wireshark-style canonical ordering: the smaller IP address is always
146/// `addr_a` with its corresponding port as `port_a`. This ensures that both
147/// directions of a conversation hash to the same key.
148#[derive(Debug, Clone, PartialEq, Eq, Hash)]
149pub struct CanonicalKey {
150    /// The smaller IP address (or first if equal, then by port).
151    pub addr_a: IpAddr,
152    /// The larger IP address.
153    pub addr_b: IpAddr,
154    /// Port corresponding to `addr_a`.
155    pub port_a: u16,
156    /// Port corresponding to `addr_b`.
157    pub port_b: u16,
158    /// Transport protocol.
159    pub protocol: TransportProtocol,
160    /// Optional VLAN ID for deinterlacing.
161    pub vlan_id: Option<u16>,
162}
163
164/// Helper to get byte representation of an IP address for comparison.
165fn ip_to_bytes(ip: &IpAddr) -> Vec<u8> {
166    match ip {
167        IpAddr::V4(v4) => v4.octets().to_vec(),
168        IpAddr::V6(v6) => v6.octets().to_vec(),
169    }
170}
171
172impl CanonicalKey {
173    /// Create a new canonical key with deterministic ordering.
174    ///
175    /// Returns the key and the direction of the original packet relative
176    /// to the canonical ordering.
177    #[must_use]
178    pub fn new(
179        src_ip: IpAddr,
180        dst_ip: IpAddr,
181        src_port: u16,
182        dst_port: u16,
183        protocol: TransportProtocol,
184        vlan_id: Option<u16>,
185    ) -> (Self, FlowDirection) {
186        let src_bytes = ip_to_bytes(&src_ip);
187        let dst_bytes = ip_to_bytes(&dst_ip);
188
189        let (addr_a, port_a, addr_b, port_b, direction) = match src_bytes.cmp(&dst_bytes) {
190            std::cmp::Ordering::Less => {
191                (src_ip, src_port, dst_ip, dst_port, FlowDirection::Forward)
192            },
193            std::cmp::Ordering::Greater => {
194                (dst_ip, dst_port, src_ip, src_port, FlowDirection::Reverse)
195            },
196            std::cmp::Ordering::Equal => {
197                // IPs are equal, sort by port
198                if src_port <= dst_port {
199                    (src_ip, src_port, dst_ip, dst_port, FlowDirection::Forward)
200                } else {
201                    (dst_ip, dst_port, src_ip, src_port, FlowDirection::Reverse)
202                }
203            },
204        };
205
206        (
207            Self {
208                addr_a,
209                addr_b,
210                port_a,
211                port_b,
212                protocol,
213                vlan_id,
214            },
215            direction,
216        )
217    }
218}
219
220impl std::fmt::Display for CanonicalKey {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        write!(
223            f,
224            "{}:{} <-> {}:{} [{}]",
225            self.addr_a, self.port_a, self.addr_b, self.port_b, self.protocol
226        )
227    }
228}
229
230/// Extract a canonical key and direction from a parsed packet.
231///
232/// Reads the IP layer (IPv4 or IPv6) for addresses and protocol number,
233/// and the transport layer (TCP or UDP) for ports. For ICMP and other
234/// protocols without ports, ports are set to 0.
235pub fn extract_key(packet: &Packet) -> Result<(CanonicalKey, FlowDirection), FlowError> {
236    if !packet.is_parsed() {
237        return Err(FlowError::PacketNotParsed);
238    }
239
240    let buf = packet.as_bytes();
241
242    // Extract IP addresses and protocol
243    let (src_ip, dst_ip, proto) = if let Some(ipv4) = packet.ipv4() {
244        let src = ipv4
245            .src(buf)
246            .map_err(|e| FlowError::PacketError(e.into()))?;
247        let dst = ipv4
248            .dst(buf)
249            .map_err(|e| FlowError::PacketError(e.into()))?;
250        let protocol = ipv4
251            .protocol(buf)
252            .map_err(|e| FlowError::PacketError(e.into()))?;
253        (IpAddr::V4(src), IpAddr::V4(dst), protocol)
254    } else if let Some(idx) = packet.get_layer(LayerKind::Ipv6) {
255        let ipv6 = Ipv6Layer { index: *idx };
256        let src = ipv6
257            .src(buf)
258            .map_err(|e| FlowError::PacketError(e.into()))?;
259        let dst = ipv6
260            .dst(buf)
261            .map_err(|e| FlowError::PacketError(e.into()))?;
262        let next_header = ipv6
263            .next_header(buf)
264            .map_err(|e| FlowError::PacketError(e.into()))?;
265        (IpAddr::V6(src), IpAddr::V6(dst), next_header)
266    } else {
267        return Err(FlowError::NoIpLayer);
268    };
269
270    let transport = TransportProtocol::from_ip_protocol(proto);
271
272    // Extract ports from transport layer
273    let (src_port, dst_port) = match transport {
274        TransportProtocol::Tcp => {
275            let tcp = packet.tcp().ok_or(FlowError::NoTransportLayer)?;
276            let sport = tcp
277                .src_port(buf)
278                .map_err(|e| FlowError::PacketError(e.into()))?;
279            let dport = tcp
280                .dst_port(buf)
281                .map_err(|e| FlowError::PacketError(e.into()))?;
282            (sport, dport)
283        },
284        TransportProtocol::Udp => {
285            let udp = packet.udp().ok_or(FlowError::NoTransportLayer)?;
286            let sport = udp
287                .src_port(buf)
288                .map_err(|e| FlowError::PacketError(e.into()))?;
289            let dport = udp
290                .dst_port(buf)
291                .map_err(|e| FlowError::PacketError(e.into()))?;
292            (sport, dport)
293        },
294        TransportProtocol::Icmp => {
295            // For ICMP, use identifier (for echo/timestamp types) for both ports
296            // (symmetric), or type+code as port substitute for other types.
297            // Using identifier symmetrically ensures request and reply have
298            // the same canonical key regardless of direction.
299            if let Some(icmp_layer) = packet.get_layer(LayerKind::Icmp) {
300                if buf.len() >= icmp_layer.start + 8 {
301                    let icmp_type = buf[icmp_layer.start];
302                    let is_echo = icmp_type == 0 || icmp_type == 8;
303                    if is_echo {
304                        let id = u16::from_be_bytes([
305                            buf[icmp_layer.start + 4],
306                            buf[icmp_layer.start + 5],
307                        ]);
308                        (id, id) // Use identifier symmetrically for both ports
309                    } else {
310                        let code = buf[icmp_layer.start + 1];
311                        (icmp_type as u16, code as u16)
312                    }
313                } else {
314                    (0u16, 0u16)
315                }
316            } else {
317                (0u16, 0u16)
318            }
319        },
320        TransportProtocol::Icmpv6 => {
321            // For ICMPv6, use identifier (for echo/timestamp types) for both ports
322            // (symmetric), or type+code as port substitute for other types.
323            // Using identifier symmetrically ensures request and reply have
324            // the same canonical key regardless of direction.
325            if let Some(icmpv6_layer) = packet.get_layer(LayerKind::Icmpv6) {
326                if buf.len() >= icmpv6_layer.start + 8 {
327                    let icmpv6_type = buf[icmpv6_layer.start];
328                    let is_echo = icmpv6_type == 128 || icmpv6_type == 129;
329                    if is_echo {
330                        let id = u16::from_be_bytes([
331                            buf[icmpv6_layer.start + 4],
332                            buf[icmpv6_layer.start + 5],
333                        ]);
334                        (id, id) // Use identifier symmetrically for both ports
335                    } else {
336                        let code = buf[icmpv6_layer.start + 1];
337                        (icmpv6_type as u16, code as u16)
338                    }
339                } else {
340                    (0u16, 0u16)
341                }
342            } else {
343                (0u16, 0u16)
344            }
345        },
346        // Other protocols have no ports
347        _ => (0u16, 0u16),
348    };
349
350    // Check for VLAN tag
351    let vlan_id = if packet.get_layer(LayerKind::Dot1Q).is_some() {
352        // TODO: Extract actual VLAN ID from Dot1Q layer if needed
353        None
354    } else {
355        None
356    };
357
358    Ok(CanonicalKey::new(
359        src_ip, dst_ip, src_port, dst_port, transport, vlan_id,
360    ))
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use std::net::{Ipv4Addr, Ipv6Addr};
367
368    #[test]
369    fn test_canonical_key_forward() {
370        let (key, dir) = CanonicalKey::new(
371            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
372            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
373            12345,
374            80,
375            TransportProtocol::Tcp,
376            None,
377        );
378        assert_eq!(dir, FlowDirection::Forward);
379        assert_eq!(key.addr_a, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
380        assert_eq!(key.addr_b, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
381        assert_eq!(key.port_a, 12345);
382        assert_eq!(key.port_b, 80);
383    }
384
385    #[test]
386    fn test_canonical_key_reverse() {
387        let (key, dir) = CanonicalKey::new(
388            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
389            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
390            80,
391            12345,
392            TransportProtocol::Tcp,
393            None,
394        );
395        assert_eq!(dir, FlowDirection::Reverse);
396        assert_eq!(key.addr_a, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
397        assert_eq!(key.addr_b, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
398        assert_eq!(key.port_a, 12345);
399        assert_eq!(key.port_b, 80);
400    }
401
402    #[test]
403    fn test_canonical_key_bidirectional_match() {
404        let (key_fwd, _) = CanonicalKey::new(
405            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
406            IpAddr::V4(Ipv4Addr::new(81, 209, 179, 69)),
407            50272,
408            80,
409            TransportProtocol::Tcp,
410            None,
411        );
412        let (key_rev, _) = CanonicalKey::new(
413            IpAddr::V4(Ipv4Addr::new(81, 209, 179, 69)),
414            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
415            80,
416            50272,
417            TransportProtocol::Tcp,
418            None,
419        );
420        assert_eq!(key_fwd, key_rev);
421    }
422
423    #[test]
424    fn test_canonical_key_equal_ips_sort_by_port() {
425        let (key, dir) = CanonicalKey::new(
426            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
427            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
428            8080,
429            80,
430            TransportProtocol::Tcp,
431            None,
432        );
433        assert_eq!(dir, FlowDirection::Reverse);
434        assert_eq!(key.port_a, 80);
435        assert_eq!(key.port_b, 8080);
436    }
437
438    #[test]
439    fn test_canonical_key_ipv6() {
440        let src = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1));
441        let dst = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2));
442        let (key_fwd, _) = CanonicalKey::new(src, dst, 1234, 80, TransportProtocol::Tcp, None);
443        let (key_rev, _) = CanonicalKey::new(dst, src, 80, 1234, TransportProtocol::Tcp, None);
444        assert_eq!(key_fwd, key_rev);
445    }
446
447    #[test]
448    fn test_canonical_key_different_protocols() {
449        let (key_tcp, _) = CanonicalKey::new(
450            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
451            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
452            1234,
453            80,
454            TransportProtocol::Tcp,
455            None,
456        );
457        let (key_udp, _) = CanonicalKey::new(
458            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
459            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
460            1234,
461            80,
462            TransportProtocol::Udp,
463            None,
464        );
465        assert_ne!(key_tcp, key_udp);
466    }
467
468    #[test]
469    fn test_transport_protocol_from_ip() {
470        assert_eq!(
471            TransportProtocol::from_ip_protocol(6),
472            TransportProtocol::Tcp
473        );
474        assert_eq!(
475            TransportProtocol::from_ip_protocol(17),
476            TransportProtocol::Udp
477        );
478        assert_eq!(
479            TransportProtocol::from_ip_protocol(1),
480            TransportProtocol::Icmp
481        );
482        assert_eq!(
483            TransportProtocol::from_ip_protocol(58),
484            TransportProtocol::Icmpv6
485        );
486        assert_eq!(
487            TransportProtocol::from_ip_protocol(47),
488            TransportProtocol::Other(47)
489        );
490    }
491
492    #[test]
493    fn test_canonical_key_display() {
494        let (key, _) = CanonicalKey::new(
495            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
496            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
497            1234,
498            80,
499            TransportProtocol::Tcp,
500            None,
501        );
502        let s = key.to_string();
503        assert!(s.contains("10.0.0.1:1234"));
504        assert!(s.contains("10.0.0.2:80"));
505        assert!(s.contains("TCP"));
506    }
507}