1use std::net::IpAddr;
2
3use crate::Packet;
4use crate::layer::LayerKind;
5use crate::layer::ipv6::Ipv6Layer;
6
7use super::error::FlowError;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum TransportProtocol {
12 Tcp,
13 Udp,
14 Icmp,
15 Icmpv6,
16 Other(u8),
17}
18
19impl TransportProtocol {
20 pub fn from_ip_protocol(proto: u8) -> Self {
22 match proto {
23 6 => Self::Tcp,
24 17 => Self::Udp,
25 1 => Self::Icmp,
26 58 => Self::Icmpv6,
27 other => Self::Other(other),
28 }
29 }
30
31 pub fn name(&self) -> &'static str {
33 match self {
34 Self::Tcp => "TCP",
35 Self::Udp => "UDP",
36 Self::Icmp => "ICMP",
37 Self::Icmpv6 => "ICMPv6",
38 Self::Other(_) => "Other",
39 }
40 }
41}
42
43impl std::fmt::Display for TransportProtocol {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 Self::Other(n) => write!(f, "Other({n})"),
47 _ => f.write_str(self.name()),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub enum FlowDirection {
58 Forward,
59 Reverse,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub struct CanonicalKey {
69 pub addr_a: IpAddr,
71 pub addr_b: IpAddr,
73 pub port_a: u16,
75 pub port_b: u16,
77 pub protocol: TransportProtocol,
79 pub vlan_id: Option<u16>,
81}
82
83fn ip_to_bytes(ip: &IpAddr) -> Vec<u8> {
85 match ip {
86 IpAddr::V4(v4) => v4.octets().to_vec(),
87 IpAddr::V6(v6) => v6.octets().to_vec(),
88 }
89}
90
91impl CanonicalKey {
92 pub fn new(
97 src_ip: IpAddr,
98 dst_ip: IpAddr,
99 src_port: u16,
100 dst_port: u16,
101 protocol: TransportProtocol,
102 vlan_id: Option<u16>,
103 ) -> (Self, FlowDirection) {
104 let src_bytes = ip_to_bytes(&src_ip);
105 let dst_bytes = ip_to_bytes(&dst_ip);
106
107 let (addr_a, port_a, addr_b, port_b, direction) = match src_bytes.cmp(&dst_bytes) {
108 std::cmp::Ordering::Less => {
109 (src_ip, src_port, dst_ip, dst_port, FlowDirection::Forward)
110 },
111 std::cmp::Ordering::Greater => {
112 (dst_ip, dst_port, src_ip, src_port, FlowDirection::Reverse)
113 },
114 std::cmp::Ordering::Equal => {
115 if src_port <= dst_port {
117 (src_ip, src_port, dst_ip, dst_port, FlowDirection::Forward)
118 } else {
119 (dst_ip, dst_port, src_ip, src_port, FlowDirection::Reverse)
120 }
121 },
122 };
123
124 (
125 Self {
126 addr_a,
127 addr_b,
128 port_a,
129 port_b,
130 protocol,
131 vlan_id,
132 },
133 direction,
134 )
135 }
136}
137
138impl std::fmt::Display for CanonicalKey {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 write!(
141 f,
142 "{}:{} <-> {}:{} [{}]",
143 self.addr_a, self.port_a, self.addr_b, self.port_b, self.protocol
144 )
145 }
146}
147
148pub fn extract_key(packet: &Packet) -> Result<(CanonicalKey, FlowDirection), FlowError> {
154 if !packet.is_parsed() {
155 return Err(FlowError::PacketNotParsed);
156 }
157
158 let buf = packet.as_bytes();
159
160 let (src_ip, dst_ip, proto) = if let Some(ipv4) = packet.ipv4() {
162 let src = ipv4
163 .src(buf)
164 .map_err(|e| FlowError::PacketError(e.into()))?;
165 let dst = ipv4
166 .dst(buf)
167 .map_err(|e| FlowError::PacketError(e.into()))?;
168 let protocol = ipv4
169 .protocol(buf)
170 .map_err(|e| FlowError::PacketError(e.into()))?;
171 (IpAddr::V4(src), IpAddr::V4(dst), protocol)
172 } else if let Some(idx) = packet.get_layer(LayerKind::Ipv6) {
173 let ipv6 = Ipv6Layer { index: *idx };
174 let src = ipv6
175 .src(buf)
176 .map_err(|e| FlowError::PacketError(e.into()))?;
177 let dst = ipv6
178 .dst(buf)
179 .map_err(|e| FlowError::PacketError(e.into()))?;
180 let next_header = ipv6
181 .next_header(buf)
182 .map_err(|e| FlowError::PacketError(e.into()))?;
183 (IpAddr::V6(src), IpAddr::V6(dst), next_header)
184 } else {
185 return Err(FlowError::NoIpLayer);
186 };
187
188 let transport = TransportProtocol::from_ip_protocol(proto);
189
190 let (src_port, dst_port) = match transport {
192 TransportProtocol::Tcp => {
193 let tcp = packet.tcp().ok_or(FlowError::NoTransportLayer)?;
194 let sport = tcp
195 .src_port(buf)
196 .map_err(|e| FlowError::PacketError(e.into()))?;
197 let dport = tcp
198 .dst_port(buf)
199 .map_err(|e| FlowError::PacketError(e.into()))?;
200 (sport, dport)
201 },
202 TransportProtocol::Udp => {
203 let udp = packet.udp().ok_or(FlowError::NoTransportLayer)?;
204 let sport = udp
205 .src_port(buf)
206 .map_err(|e| FlowError::PacketError(e.into()))?;
207 let dport = udp
208 .dst_port(buf)
209 .map_err(|e| FlowError::PacketError(e.into()))?;
210 (sport, dport)
211 },
212 _ => (0u16, 0u16),
214 };
215
216 let vlan_id = if packet.get_layer(LayerKind::Dot1Q).is_some() {
218 None
220 } else {
221 None
222 };
223
224 Ok(CanonicalKey::new(
225 src_ip, dst_ip, src_port, dst_port, transport, vlan_id,
226 ))
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use std::net::{Ipv4Addr, Ipv6Addr};
233
234 #[test]
235 fn test_canonical_key_forward() {
236 let (key, dir) = CanonicalKey::new(
237 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
238 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
239 12345,
240 80,
241 TransportProtocol::Tcp,
242 None,
243 );
244 assert_eq!(dir, FlowDirection::Forward);
245 assert_eq!(key.addr_a, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
246 assert_eq!(key.addr_b, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
247 assert_eq!(key.port_a, 12345);
248 assert_eq!(key.port_b, 80);
249 }
250
251 #[test]
252 fn test_canonical_key_reverse() {
253 let (key, dir) = CanonicalKey::new(
254 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
255 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
256 80,
257 12345,
258 TransportProtocol::Tcp,
259 None,
260 );
261 assert_eq!(dir, FlowDirection::Reverse);
262 assert_eq!(key.addr_a, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
263 assert_eq!(key.addr_b, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
264 assert_eq!(key.port_a, 12345);
265 assert_eq!(key.port_b, 80);
266 }
267
268 #[test]
269 fn test_canonical_key_bidirectional_match() {
270 let (key_fwd, _) = CanonicalKey::new(
271 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
272 IpAddr::V4(Ipv4Addr::new(81, 209, 179, 69)),
273 50272,
274 80,
275 TransportProtocol::Tcp,
276 None,
277 );
278 let (key_rev, _) = CanonicalKey::new(
279 IpAddr::V4(Ipv4Addr::new(81, 209, 179, 69)),
280 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
281 80,
282 50272,
283 TransportProtocol::Tcp,
284 None,
285 );
286 assert_eq!(key_fwd, key_rev);
287 }
288
289 #[test]
290 fn test_canonical_key_equal_ips_sort_by_port() {
291 let (key, dir) = CanonicalKey::new(
292 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
293 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
294 8080,
295 80,
296 TransportProtocol::Tcp,
297 None,
298 );
299 assert_eq!(dir, FlowDirection::Reverse);
300 assert_eq!(key.port_a, 80);
301 assert_eq!(key.port_b, 8080);
302 }
303
304 #[test]
305 fn test_canonical_key_ipv6() {
306 let src = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1));
307 let dst = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2));
308 let (key_fwd, _) = CanonicalKey::new(src, dst, 1234, 80, TransportProtocol::Tcp, None);
309 let (key_rev, _) = CanonicalKey::new(dst, src, 80, 1234, TransportProtocol::Tcp, None);
310 assert_eq!(key_fwd, key_rev);
311 }
312
313 #[test]
314 fn test_canonical_key_different_protocols() {
315 let (key_tcp, _) = CanonicalKey::new(
316 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
317 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
318 1234,
319 80,
320 TransportProtocol::Tcp,
321 None,
322 );
323 let (key_udp, _) = CanonicalKey::new(
324 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
325 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
326 1234,
327 80,
328 TransportProtocol::Udp,
329 None,
330 );
331 assert_ne!(key_tcp, key_udp);
332 }
333
334 #[test]
335 fn test_transport_protocol_from_ip() {
336 assert_eq!(
337 TransportProtocol::from_ip_protocol(6),
338 TransportProtocol::Tcp
339 );
340 assert_eq!(
341 TransportProtocol::from_ip_protocol(17),
342 TransportProtocol::Udp
343 );
344 assert_eq!(
345 TransportProtocol::from_ip_protocol(1),
346 TransportProtocol::Icmp
347 );
348 assert_eq!(
349 TransportProtocol::from_ip_protocol(58),
350 TransportProtocol::Icmpv6
351 );
352 assert_eq!(
353 TransportProtocol::from_ip_protocol(47),
354 TransportProtocol::Other(47)
355 );
356 }
357
358 #[test]
359 fn test_canonical_key_display() {
360 let (key, _) = CanonicalKey::new(
361 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
362 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
363 1234,
364 80,
365 TransportProtocol::Tcp,
366 None,
367 );
368 let s = key.to_string();
369 assert!(s.contains("10.0.0.1:1234"));
370 assert!(s.contains("10.0.0.2:80"));
371 assert!(s.contains("TCP"));
372 }
373}