1use crate::engine::WfpEngine;
7use crate::errors::{WfpError, WfpResult};
8use std::ffi::{c_void, OsString};
9use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
10use std::os::windows::ffi::OsStringExt;
11use std::path::PathBuf;
12use std::sync::mpsc;
13use std::time::{Duration, SystemTime};
14use windows::core::GUID;
15use windows::Win32::Foundation::{ERROR_SUCCESS, FILETIME, HANDLE};
16use windows::Win32::NetworkManagement::WindowsFilteringPlatform::{
17 FwpmNetEventSubscribe0, FwpmNetEventUnsubscribe0, FWPM_NET_EVENT1, FWPM_NET_EVENT_CALLBACK0,
18 FWPM_NET_EVENT_SUBSCRIPTION0,
19};
20
21#[derive(Debug, Clone)]
26pub struct NetworkEvent {
27 pub timestamp: SystemTime,
29
30 pub event_type: NetworkEventType,
32
33 pub app_path: Option<PathBuf>,
35
36 pub protocol: u8,
38
39 pub local_addr: Option<IpAddr>,
41
42 pub remote_addr: Option<IpAddr>,
44
45 pub local_port: u16,
47
48 pub remote_port: u16,
50
51 pub filter_id: Option<u64>,
53
54 pub layer_id: Option<u16>,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60#[repr(u32)]
61pub enum NetworkEventType {
62 ClassifyDrop = 3,
64
65 ClassifyAllow = 6,
67
68 CapabilityDrop = 7,
70
71 Other(u32),
73}
74
75impl From<u32> for NetworkEventType {
76 fn from(value: u32) -> Self {
77 match value {
78 3 => NetworkEventType::ClassifyDrop,
79 6 => NetworkEventType::ClassifyAllow,
80 7 => NetworkEventType::CapabilityDrop,
81 other => NetworkEventType::Other(other),
82 }
83 }
84}
85
86pub struct WfpEventSubscription {
91 engine: *const WfpEngine,
92 subscription_handle: HANDLE,
93 _callback: Box<FWPM_NET_EVENT_CALLBACK0>, receiver: mpsc::Receiver<NetworkEvent>,
95}
96
97impl WfpEventSubscription {
98 pub fn new(engine: &WfpEngine) -> WfpResult<Self> {
107 let (sender, receiver) = mpsc::channel();
108
109 let sender_box = Box::new(sender);
111 let context = Box::into_raw(sender_box) as *const c_void;
112
113 let callback: FWPM_NET_EVENT_CALLBACK0 = Some(event_callback);
115
116 let subscription = FWPM_NET_EVENT_SUBSCRIPTION0 {
118 enumTemplate: std::ptr::null_mut(), flags: 0,
120 sessionKey: GUID::zeroed(),
121 };
122
123 let mut subscription_handle = HANDLE::default();
124
125 unsafe {
126 let result = FwpmNetEventSubscribe0(
127 engine.handle(),
128 &subscription,
129 callback,
130 Some(context),
131 &mut subscription_handle,
132 );
133
134 if result != ERROR_SUCCESS.0 {
135 let _ = Box::from_raw(context as *mut mpsc::Sender<NetworkEvent>);
137 return Err(WfpError::Other(format!(
138 "Failed to subscribe to WFP events: error code {}",
139 result
140 )));
141 }
142 }
143
144 Ok(Self {
145 engine: engine as *const WfpEngine,
146 subscription_handle,
147 _callback: Box::new(callback), receiver,
149 })
150 }
151
152 pub fn try_recv(&self) -> Result<NetworkEvent, mpsc::TryRecvError> {
154 self.receiver.try_recv()
155 }
156
157 pub fn recv(&self) -> Result<NetworkEvent, mpsc::RecvError> {
159 self.receiver.recv()
160 }
161
162 pub fn iter(&self) -> mpsc::Iter<'_, NetworkEvent> {
164 self.receiver.iter()
165 }
166}
167
168impl Drop for WfpEventSubscription {
169 fn drop(&mut self) {
170 if !self.subscription_handle.is_invalid() && !self.engine.is_null() {
171 unsafe {
172 let _ = FwpmNetEventUnsubscribe0((*self.engine).handle(), self.subscription_handle);
173 }
174 }
175 }
176}
177
178unsafe extern "system" fn event_callback(context: *mut c_void, event_ptr: *const FWPM_NET_EVENT1) {
187 if context.is_null() || event_ptr.is_null() {
189 return;
190 }
191
192 let sender = &*(context as *const mpsc::Sender<NetworkEvent>);
194
195 let event = &*event_ptr;
197 if let Some(network_event) = parse_network_event(event) {
198 let _ = sender.send(network_event);
200 }
201}
202
203unsafe fn parse_network_event(event: &FWPM_NET_EVENT1) -> Option<NetworkEvent> {
209 let header = &event.header;
210 let event_type = NetworkEventType::from(event.r#type.0 as u32);
211
212 let timestamp = filetime_to_systemtime(&header.timeStamp);
214
215 let app_path = if !header.appId.data.is_null() {
217 parse_wide_string(header.appId.data as *const u16).map(PathBuf::from)
218 } else {
219 None
220 };
221
222 let (local_addr, remote_addr) = if header.ipVersion.0 == 0 {
224 unsafe {
226 let local = parse_ipv4_union(&header.Anonymous1);
227 let remote = parse_ipv4_union_remote(&header.Anonymous2);
228 (local, remote)
229 }
230 } else if header.ipVersion.0 == 1 {
231 unsafe {
233 let local = parse_ipv6_union(&header.Anonymous1);
234 let remote = parse_ipv6_union_remote(&header.Anonymous2);
235 (local, remote)
236 }
237 } else {
238 (None, None)
239 };
240
241 let (filter_id, layer_id) = if event_type == NetworkEventType::ClassifyDrop {
243 unsafe {
244 if !event.Anonymous.classifyDrop.is_null() {
245 let drop_info = &*event.Anonymous.classifyDrop;
246 (Some(drop_info.filterId), Some(drop_info.layerId))
247 } else {
248 (None, None)
249 }
250 }
251 } else {
252 (None, None)
253 };
254
255 Some(NetworkEvent {
256 timestamp,
257 event_type,
258 app_path,
259 protocol: header.ipProtocol,
260 local_addr,
261 remote_addr,
262 local_port: header.localPort,
263 remote_port: header.remotePort,
264 filter_id,
265 layer_id,
266 })
267}
268
269fn filetime_to_systemtime(ft: &FILETIME) -> SystemTime {
271 const WINDOWS_TO_UNIX_EPOCH: u64 = 116444736000000000;
276
277 let intervals = ((ft.dwHighDateTime as u64) << 32) | (ft.dwLowDateTime as u64);
278
279 if intervals >= WINDOWS_TO_UNIX_EPOCH {
280 let unix_intervals = intervals - WINDOWS_TO_UNIX_EPOCH;
281 let secs = unix_intervals / 10_000_000;
282 let nanos = ((unix_intervals % 10_000_000) * 100) as u32;
283
284 SystemTime::UNIX_EPOCH + Duration::new(secs, nanos)
285 } else {
286 SystemTime::UNIX_EPOCH
287 }
288}
289
290unsafe fn parse_wide_string(ptr: *const u16) -> Option<OsString> {
292 if ptr.is_null() {
293 return None;
294 }
295
296 let mut len = 0;
298 while *ptr.add(len) != 0 {
299 len += 1;
300 }
301
302 if len == 0 {
303 return None;
304 }
305
306 let slice = std::slice::from_raw_parts(ptr, len);
308 Some(OsString::from_wide(slice))
309}
310
311unsafe fn parse_ipv4_union(
313 addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_0,
314) -> Option<IpAddr> {
315 let addr_u32 = addr_union.localAddrV4;
317 let bytes = addr_u32.to_ne_bytes();
318 Some(IpAddr::V4(Ipv4Addr::from(bytes)))
319}
320
321unsafe fn parse_ipv6_union(
323 addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_0,
324) -> Option<IpAddr> {
325 let bytes = addr_union.localAddrV6.byteArray16;
327 Some(IpAddr::V6(Ipv6Addr::from(bytes)))
328}
329
330unsafe fn parse_ipv4_union_remote(
332 addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_1,
333) -> Option<IpAddr> {
334 let addr_u32 = addr_union.remoteAddrV4;
335 let bytes = addr_u32.to_ne_bytes();
336 Some(IpAddr::V4(Ipv4Addr::from(bytes)))
337}
338
339unsafe fn parse_ipv6_union_remote(
341 addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_1,
342) -> Option<IpAddr> {
343 let bytes = addr_union.remoteAddrV6.byteArray16;
344 Some(IpAddr::V6(Ipv6Addr::from(bytes)))
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_event_type_conversion() {
353 assert_eq!(NetworkEventType::from(3), NetworkEventType::ClassifyDrop);
354 assert_eq!(NetworkEventType::from(6), NetworkEventType::ClassifyAllow);
355 assert_eq!(NetworkEventType::from(7), NetworkEventType::CapabilityDrop);
356 assert_eq!(NetworkEventType::from(99), NetworkEventType::Other(99));
357 }
358
359 #[test]
360 fn test_event_type_boundaries() {
361 assert_eq!(NetworkEventType::from(0), NetworkEventType::Other(0));
362 assert_eq!(NetworkEventType::from(2), NetworkEventType::Other(2));
363 assert_eq!(NetworkEventType::from(4), NetworkEventType::Other(4));
364 assert_eq!(NetworkEventType::from(5), NetworkEventType::Other(5));
365 assert_eq!(NetworkEventType::from(8), NetworkEventType::Other(8));
366 }
367
368 #[test]
369 fn test_filetime_to_systemtime_unix_epoch() {
370 let intervals: u64 = 116444736000000000;
372 let ft = FILETIME {
373 dwLowDateTime: intervals as u32,
374 dwHighDateTime: (intervals >> 32) as u32,
375 };
376 let result = filetime_to_systemtime(&ft);
377 assert_eq!(result, SystemTime::UNIX_EPOCH);
378 }
379
380 #[test]
381 fn test_filetime_to_systemtime_before_unix_epoch() {
382 let ft = FILETIME {
383 dwLowDateTime: 0,
384 dwHighDateTime: 0,
385 };
386 assert_eq!(filetime_to_systemtime(&ft), SystemTime::UNIX_EPOCH);
387 }
388
389 #[test]
390 fn test_filetime_to_systemtime_known_date() {
391 let intervals: u64 = 125911584000000000;
393 let ft = FILETIME {
394 dwLowDateTime: intervals as u32,
395 dwHighDateTime: (intervals >> 32) as u32,
396 };
397 let result = filetime_to_systemtime(&ft);
398 let duration = result.duration_since(SystemTime::UNIX_EPOCH).unwrap();
399 assert_eq!(duration.as_secs(), 946684800); }
401
402 #[test]
403 fn test_network_event_struct_creation() {
404 let event = NetworkEvent {
405 timestamp: SystemTime::UNIX_EPOCH,
406 event_type: NetworkEventType::ClassifyDrop,
407 app_path: Some(PathBuf::from(r"C:\test.exe")),
408 protocol: 6,
409 local_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
410 remote_addr: Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))),
411 local_port: 12345,
412 remote_port: 443,
413 filter_id: Some(42),
414 layer_id: Some(1),
415 };
416
417 assert_eq!(event.event_type, NetworkEventType::ClassifyDrop);
418 assert_eq!(event.protocol, 6);
419 assert_eq!(event.local_port, 12345);
420 assert_eq!(event.remote_port, 443);
421 assert!(event.app_path.is_some());
422 }
423
424 #[test]
425 fn test_network_event_clone() {
426 let event = NetworkEvent {
427 timestamp: SystemTime::UNIX_EPOCH,
428 event_type: NetworkEventType::ClassifyAllow,
429 app_path: None,
430 protocol: 17,
431 local_addr: None,
432 remote_addr: None,
433 local_port: 0,
434 remote_port: 0,
435 filter_id: None,
436 layer_id: None,
437 };
438
439 let cloned = event.clone();
440 assert_eq!(cloned.event_type, NetworkEventType::ClassifyAllow);
441 assert_eq!(cloned.protocol, 17);
442 assert!(cloned.app_path.is_none());
443 }
444}