1use std::net::IpAddr;
2
3use crate::Packet;
4use crate::layer::LayerKind;
5use crate::layer::ipv6::Ipv6Layer;
6
7use super::error::FlowError;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct ZWaveKey {
16 pub home_id: u32,
18 pub node_a: u8,
20 pub node_b: u8,
22}
23
24impl ZWaveKey {
25 #[must_use]
30 pub fn new(home_id: u32, src_node: u8, dst_node: u8) -> (Self, FlowDirection) {
31 if src_node <= dst_node {
32 (
33 Self {
34 home_id,
35 node_a: src_node,
36 node_b: dst_node,
37 },
38 FlowDirection::Forward,
39 )
40 } else {
41 (
42 Self {
43 home_id,
44 node_a: dst_node,
45 node_b: src_node,
46 },
47 FlowDirection::Reverse,
48 )
49 }
50 }
51}
52
53impl std::fmt::Display for ZWaveKey {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(
56 f,
57 "ZWave[{:#010X}] node {} <-> node {}",
58 self.home_id, self.node_a, self.node_b
59 )
60 }
61}
62
63pub fn extract_zwave_key(packet: &Packet) -> Result<(ZWaveKey, FlowDirection), FlowError> {
67 if !packet.is_parsed() {
68 return Err(FlowError::PacketNotParsed);
69 }
70
71 let buf = packet.as_bytes();
72
73 let zwave = packet.zwave().ok_or(FlowError::NoTransportLayer)?;
74
75 let home_id = zwave
76 .home_id(buf)
77 .map_err(|e| FlowError::PacketError(e.into()))?;
78 let src = zwave
79 .src(buf)
80 .map_err(|e| FlowError::PacketError(e.into()))?;
81 let dst = zwave
82 .dst(buf)
83 .map_err(|e| FlowError::PacketError(e.into()))?;
84
85 Ok(ZWaveKey::new(home_id, src, dst))
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
90pub enum TransportProtocol {
91 Tcp,
92 Udp,
93 Icmp,
94 Icmpv6,
95 Other(u8),
96}
97
98impl TransportProtocol {
99 #[must_use]
101 pub fn from_ip_protocol(proto: u8) -> Self {
102 match proto {
103 6 => Self::Tcp,
104 17 => Self::Udp,
105 1 => Self::Icmp,
106 58 => Self::Icmpv6,
107 other => Self::Other(other),
108 }
109 }
110
111 #[must_use]
113 pub fn name(&self) -> &'static str {
114 match self {
115 Self::Tcp => "TCP",
116 Self::Udp => "UDP",
117 Self::Icmp => "ICMP",
118 Self::Icmpv6 => "ICMPv6",
119 Self::Other(_) => "Other",
120 }
121 }
122}
123
124impl std::fmt::Display for TransportProtocol {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 Self::Other(n) => write!(f, "Other({n})"),
128 _ => f.write_str(self.name()),
129 }
130 }
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum FlowDirection {
139 Forward,
140 Reverse,
141}
142
143#[derive(Debug, Clone, PartialEq, Eq, Hash)]
149pub struct CanonicalKey {
150 pub addr_a: IpAddr,
152 pub addr_b: IpAddr,
154 pub port_a: u16,
156 pub port_b: u16,
158 pub protocol: TransportProtocol,
160 pub vlan_id: Option<u16>,
162}
163
164fn ip_to_bytes(ip: &IpAddr) -> Vec<u8> {
166 match ip {
167 IpAddr::V4(v4) => v4.octets().to_vec(),
168 IpAddr::V6(v6) => v6.octets().to_vec(),
169 }
170}
171
172impl CanonicalKey {
173 #[must_use]
178 pub fn new(
179 src_ip: IpAddr,
180 dst_ip: IpAddr,
181 src_port: u16,
182 dst_port: u16,
183 protocol: TransportProtocol,
184 vlan_id: Option<u16>,
185 ) -> (Self, FlowDirection) {
186 let src_bytes = ip_to_bytes(&src_ip);
187 let dst_bytes = ip_to_bytes(&dst_ip);
188
189 let (addr_a, port_a, addr_b, port_b, direction) = match src_bytes.cmp(&dst_bytes) {
190 std::cmp::Ordering::Less => {
191 (src_ip, src_port, dst_ip, dst_port, FlowDirection::Forward)
192 },
193 std::cmp::Ordering::Greater => {
194 (dst_ip, dst_port, src_ip, src_port, FlowDirection::Reverse)
195 },
196 std::cmp::Ordering::Equal => {
197 if src_port <= dst_port {
199 (src_ip, src_port, dst_ip, dst_port, FlowDirection::Forward)
200 } else {
201 (dst_ip, dst_port, src_ip, src_port, FlowDirection::Reverse)
202 }
203 },
204 };
205
206 (
207 Self {
208 addr_a,
209 addr_b,
210 port_a,
211 port_b,
212 protocol,
213 vlan_id,
214 },
215 direction,
216 )
217 }
218}
219
220impl std::fmt::Display for CanonicalKey {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 write!(
223 f,
224 "{}:{} <-> {}:{} [{}]",
225 self.addr_a, self.port_a, self.addr_b, self.port_b, self.protocol
226 )
227 }
228}
229
230pub fn extract_key(packet: &Packet) -> Result<(CanonicalKey, FlowDirection), FlowError> {
236 if !packet.is_parsed() {
237 return Err(FlowError::PacketNotParsed);
238 }
239
240 let buf = packet.as_bytes();
241
242 let (src_ip, dst_ip, proto) = if let Some(ipv4) = packet.ipv4() {
244 let src = ipv4
245 .src(buf)
246 .map_err(|e| FlowError::PacketError(e.into()))?;
247 let dst = ipv4
248 .dst(buf)
249 .map_err(|e| FlowError::PacketError(e.into()))?;
250 let protocol = ipv4
251 .protocol(buf)
252 .map_err(|e| FlowError::PacketError(e.into()))?;
253 (IpAddr::V4(src), IpAddr::V4(dst), protocol)
254 } else if let Some(idx) = packet.get_layer(LayerKind::Ipv6) {
255 let ipv6 = Ipv6Layer { index: *idx };
256 let src = ipv6
257 .src(buf)
258 .map_err(|e| FlowError::PacketError(e.into()))?;
259 let dst = ipv6
260 .dst(buf)
261 .map_err(|e| FlowError::PacketError(e.into()))?;
262 let next_header = ipv6
263 .next_header(buf)
264 .map_err(|e| FlowError::PacketError(e.into()))?;
265 (IpAddr::V6(src), IpAddr::V6(dst), next_header)
266 } else {
267 return Err(FlowError::NoIpLayer);
268 };
269
270 let transport = TransportProtocol::from_ip_protocol(proto);
271
272 let (src_port, dst_port) = match transport {
274 TransportProtocol::Tcp => {
275 let tcp = packet.tcp().ok_or(FlowError::NoTransportLayer)?;
276 let sport = tcp
277 .src_port(buf)
278 .map_err(|e| FlowError::PacketError(e.into()))?;
279 let dport = tcp
280 .dst_port(buf)
281 .map_err(|e| FlowError::PacketError(e.into()))?;
282 (sport, dport)
283 },
284 TransportProtocol::Udp => {
285 let udp = packet.udp().ok_or(FlowError::NoTransportLayer)?;
286 let sport = udp
287 .src_port(buf)
288 .map_err(|e| FlowError::PacketError(e.into()))?;
289 let dport = udp
290 .dst_port(buf)
291 .map_err(|e| FlowError::PacketError(e.into()))?;
292 (sport, dport)
293 },
294 TransportProtocol::Icmp => {
295 if let Some(icmp_layer) = packet.get_layer(LayerKind::Icmp) {
300 if buf.len() >= icmp_layer.start + 8 {
301 let icmp_type = buf[icmp_layer.start];
302 let is_echo = icmp_type == 0 || icmp_type == 8;
303 if is_echo {
304 let id = u16::from_be_bytes([
305 buf[icmp_layer.start + 4],
306 buf[icmp_layer.start + 5],
307 ]);
308 (id, id) } else {
310 let code = buf[icmp_layer.start + 1];
311 (icmp_type as u16, code as u16)
312 }
313 } else {
314 (0u16, 0u16)
315 }
316 } else {
317 (0u16, 0u16)
318 }
319 },
320 TransportProtocol::Icmpv6 => {
321 if let Some(icmpv6_layer) = packet.get_layer(LayerKind::Icmpv6) {
326 if buf.len() >= icmpv6_layer.start + 8 {
327 let icmpv6_type = buf[icmpv6_layer.start];
328 let is_echo = icmpv6_type == 128 || icmpv6_type == 129;
329 if is_echo {
330 let id = u16::from_be_bytes([
331 buf[icmpv6_layer.start + 4],
332 buf[icmpv6_layer.start + 5],
333 ]);
334 (id, id) } else {
336 let code = buf[icmpv6_layer.start + 1];
337 (icmpv6_type as u16, code as u16)
338 }
339 } else {
340 (0u16, 0u16)
341 }
342 } else {
343 (0u16, 0u16)
344 }
345 },
346 _ => (0u16, 0u16),
348 };
349
350 let vlan_id = if packet.get_layer(LayerKind::Dot1Q).is_some() {
352 None
354 } else {
355 None
356 };
357
358 Ok(CanonicalKey::new(
359 src_ip, dst_ip, src_port, dst_port, transport, vlan_id,
360 ))
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use std::net::{Ipv4Addr, Ipv6Addr};
367
368 #[test]
369 fn test_canonical_key_forward() {
370 let (key, dir) = CanonicalKey::new(
371 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
372 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
373 12345,
374 80,
375 TransportProtocol::Tcp,
376 None,
377 );
378 assert_eq!(dir, FlowDirection::Forward);
379 assert_eq!(key.addr_a, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
380 assert_eq!(key.addr_b, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
381 assert_eq!(key.port_a, 12345);
382 assert_eq!(key.port_b, 80);
383 }
384
385 #[test]
386 fn test_canonical_key_reverse() {
387 let (key, dir) = CanonicalKey::new(
388 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
389 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
390 80,
391 12345,
392 TransportProtocol::Tcp,
393 None,
394 );
395 assert_eq!(dir, FlowDirection::Reverse);
396 assert_eq!(key.addr_a, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
397 assert_eq!(key.addr_b, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
398 assert_eq!(key.port_a, 12345);
399 assert_eq!(key.port_b, 80);
400 }
401
402 #[test]
403 fn test_canonical_key_bidirectional_match() {
404 let (key_fwd, _) = CanonicalKey::new(
405 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
406 IpAddr::V4(Ipv4Addr::new(81, 209, 179, 69)),
407 50272,
408 80,
409 TransportProtocol::Tcp,
410 None,
411 );
412 let (key_rev, _) = CanonicalKey::new(
413 IpAddr::V4(Ipv4Addr::new(81, 209, 179, 69)),
414 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
415 80,
416 50272,
417 TransportProtocol::Tcp,
418 None,
419 );
420 assert_eq!(key_fwd, key_rev);
421 }
422
423 #[test]
424 fn test_canonical_key_equal_ips_sort_by_port() {
425 let (key, dir) = CanonicalKey::new(
426 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
427 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
428 8080,
429 80,
430 TransportProtocol::Tcp,
431 None,
432 );
433 assert_eq!(dir, FlowDirection::Reverse);
434 assert_eq!(key.port_a, 80);
435 assert_eq!(key.port_b, 8080);
436 }
437
438 #[test]
439 fn test_canonical_key_ipv6() {
440 let src = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1));
441 let dst = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2));
442 let (key_fwd, _) = CanonicalKey::new(src, dst, 1234, 80, TransportProtocol::Tcp, None);
443 let (key_rev, _) = CanonicalKey::new(dst, src, 80, 1234, TransportProtocol::Tcp, None);
444 assert_eq!(key_fwd, key_rev);
445 }
446
447 #[test]
448 fn test_canonical_key_different_protocols() {
449 let (key_tcp, _) = CanonicalKey::new(
450 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
451 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
452 1234,
453 80,
454 TransportProtocol::Tcp,
455 None,
456 );
457 let (key_udp, _) = CanonicalKey::new(
458 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
459 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
460 1234,
461 80,
462 TransportProtocol::Udp,
463 None,
464 );
465 assert_ne!(key_tcp, key_udp);
466 }
467
468 #[test]
469 fn test_transport_protocol_from_ip() {
470 assert_eq!(
471 TransportProtocol::from_ip_protocol(6),
472 TransportProtocol::Tcp
473 );
474 assert_eq!(
475 TransportProtocol::from_ip_protocol(17),
476 TransportProtocol::Udp
477 );
478 assert_eq!(
479 TransportProtocol::from_ip_protocol(1),
480 TransportProtocol::Icmp
481 );
482 assert_eq!(
483 TransportProtocol::from_ip_protocol(58),
484 TransportProtocol::Icmpv6
485 );
486 assert_eq!(
487 TransportProtocol::from_ip_protocol(47),
488 TransportProtocol::Other(47)
489 );
490 }
491
492 #[test]
493 fn test_canonical_key_display() {
494 let (key, _) = CanonicalKey::new(
495 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
496 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
497 1234,
498 80,
499 TransportProtocol::Tcp,
500 None,
501 );
502 let s = key.to_string();
503 assert!(s.contains("10.0.0.1:1234"));
504 assert!(s.contains("10.0.0.2:80"));
505 assert!(s.contains("TCP"));
506 }
507}