Skip to main content

windows_wfp/
events.rs

1//! WFP Network Event Subscription (Learning Mode)
2//!
3//! Subscribes to WFP network events to monitor blocked connections.
4//! Used for learning mode where blocked traffic is logged for auto-whitelisting.
5
6use crate::engine::WfpEngine;
7use crate::errors::{WfpError, WfpResult};
8use std::ffi::{c_void, OsString};
9use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
10use std::os::windows::ffi::OsStringExt;
11use std::path::PathBuf;
12use std::sync::mpsc;
13use std::time::{Duration, SystemTime};
14use windows::core::GUID;
15use windows::Win32::Foundation::{ERROR_SUCCESS, FILETIME, HANDLE};
16use windows::Win32::NetworkManagement::WindowsFilteringPlatform::{
17    FwpmNetEventSubscribe0, FwpmNetEventUnsubscribe0, FWPM_NET_EVENT1, FWPM_NET_EVENT_CALLBACK0,
18    FWPM_NET_EVENT_SUBSCRIPTION0,
19};
20
21/// Network event from WFP
22///
23/// Represents a network event captured by the Windows Filtering Platform.
24/// Used in learning mode to identify applications that were blocked.
25#[derive(Debug, Clone)]
26pub struct NetworkEvent {
27    /// When the event occurred
28    pub timestamp: SystemTime,
29
30    /// Type of event (Classify Drop, Classify Allow, etc.)
31    pub event_type: NetworkEventType,
32
33    /// Application path that triggered the event (if available)
34    pub app_path: Option<PathBuf>,
35
36    /// IP protocol (TCP=6, UDP=17, etc.)
37    pub protocol: u8,
38
39    /// Local IP address
40    pub local_addr: Option<IpAddr>,
41
42    /// Remote IP address
43    pub remote_addr: Option<IpAddr>,
44
45    /// Local port
46    pub local_port: u16,
47
48    /// Remote port
49    pub remote_port: u16,
50
51    /// Filter ID that triggered the event (for CLASSIFY_DROP)
52    pub filter_id: Option<u64>,
53
54    /// Layer ID where event occurred
55    pub layer_id: Option<u16>,
56}
57
58/// Type of network event
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60#[repr(u32)]
61pub enum NetworkEventType {
62    /// Connection was blocked by a filter
63    ClassifyDrop = 3,
64
65    /// Connection was allowed by a filter (Win8+)
66    ClassifyAllow = 6,
67
68    /// App container capability drop (Win8+)
69    CapabilityDrop = 7,
70
71    /// Other event type
72    Other(u32),
73}
74
75impl From<u32> for NetworkEventType {
76    fn from(value: u32) -> Self {
77        match value {
78            3 => NetworkEventType::ClassifyDrop,
79            6 => NetworkEventType::ClassifyAllow,
80            7 => NetworkEventType::CapabilityDrop,
81            other => NetworkEventType::Other(other),
82        }
83    }
84}
85
86impl std::fmt::Display for NetworkEventType {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        match self {
89            NetworkEventType::ClassifyDrop => write!(f, "ClassifyDrop"),
90            NetworkEventType::ClassifyAllow => write!(f, "ClassifyAllow"),
91            NetworkEventType::CapabilityDrop => write!(f, "CapabilityDrop"),
92            NetworkEventType::Other(n) => write!(f, "Other({})", n),
93        }
94    }
95}
96
97/// WFP Event Subscription Handle
98///
99/// RAII wrapper for WFP event subscription. Automatically unsubscribes on drop.
100/// Events are delivered via an mpsc channel for thread-safe processing.
101pub struct WfpEventSubscription {
102    engine: *const WfpEngine,
103    subscription_handle: HANDLE,
104    _callback: Box<FWPM_NET_EVENT_CALLBACK0>, // Keep callback alive
105    receiver: mpsc::Receiver<NetworkEvent>,
106    /// Raw pointer to the boxed `Sender<NetworkEvent>` used as the WFP callback context.
107    /// Must be freed after `FwpmNetEventUnsubscribe0` to avoid a memory leak.
108    sender_context: *mut c_void,
109}
110
111impl WfpEventSubscription {
112    /// Subscribe to WFP network events
113    ///
114    /// Creates a new event subscription that monitors network events.
115    /// Events are delivered via the returned receiver channel.
116    ///
117    /// # Errors
118    ///
119    /// Returns error if subscription fails (permissions, invalid engine, etc.)
120    pub fn new(engine: &WfpEngine) -> WfpResult<Self> {
121        let (sender, receiver) = mpsc::channel();
122
123        // Box the channel sender so it has a stable address for the context pointer
124        let sender_box = Box::new(sender);
125        let context = Box::into_raw(sender_box) as *mut c_void;
126
127        // Create the callback function
128        let callback: FWPM_NET_EVENT_CALLBACK0 = Some(event_callback);
129
130        // Create subscription for all events
131        let subscription = FWPM_NET_EVENT_SUBSCRIPTION0 {
132            enumTemplate: std::ptr::null_mut(), // Subscribe to all events
133            flags: 0,
134            sessionKey: GUID::zeroed(),
135        };
136
137        let mut subscription_handle = HANDLE::default();
138
139        unsafe {
140            let result = FwpmNetEventSubscribe0(
141                engine.handle(),
142                &subscription,
143                callback,
144                Some(context as *const c_void),
145                &mut subscription_handle,
146            );
147
148            if result != ERROR_SUCCESS.0 {
149                // Clean up the boxed sender on error
150                drop(Box::from_raw(context as *mut mpsc::Sender<NetworkEvent>));
151                return Err(WfpError::Other(format!(
152                    "Failed to subscribe to WFP events: error code {}",
153                    result
154                )));
155            }
156        }
157
158        Ok(Self {
159            engine: engine as *const WfpEngine,
160            subscription_handle,
161            _callback: Box::new(callback), // Keep callback alive to prevent GC
162            receiver,
163            sender_context: context,
164        })
165    }
166
167    /// Try to receive a network event (non-blocking)
168    pub fn try_recv(&self) -> Result<NetworkEvent, mpsc::TryRecvError> {
169        self.receiver.try_recv()
170    }
171
172    /// Receive a network event (blocking)
173    pub fn recv(&self) -> Result<NetworkEvent, mpsc::RecvError> {
174        self.receiver.recv()
175    }
176
177    /// Get an iterator over pending events
178    pub fn iter(&self) -> mpsc::Iter<'_, NetworkEvent> {
179        self.receiver.iter()
180    }
181}
182
183impl Drop for WfpEventSubscription {
184    fn drop(&mut self) {
185        if !self.subscription_handle.is_invalid() && !self.engine.is_null() {
186            unsafe {
187                // Unsubscribe first so no more callbacks can fire before we free the context
188                let _ = FwpmNetEventUnsubscribe0((*self.engine).handle(), self.subscription_handle);
189            }
190        }
191        // Free the boxed Sender that was passed as the WFP callback context.
192        // Safe to do now because WfpmNetEventUnsubscribe0 guarantees no further callbacks.
193        if !self.sender_context.is_null() {
194            unsafe {
195                drop(Box::from_raw(
196                    self.sender_context as *mut mpsc::Sender<NetworkEvent>,
197                ));
198            }
199        }
200    }
201}
202
203/// Native callback function invoked by WFP (runs on WFP worker thread)
204///
205/// # Safety
206///
207/// This function receives raw pointers from WFP and must carefully:
208/// - Validate the event pointer is not null
209/// - Parse the FWPM_NET_EVENT1 structure
210/// - Send the parsed event to the channel without blocking
211unsafe extern "system" fn event_callback(context: *mut c_void, event_ptr: *const FWPM_NET_EVENT1) {
212    // Validate pointers
213    if context.is_null() || event_ptr.is_null() {
214        return;
215    }
216
217    // Recover the sender from the context
218    let sender = &*(context as *const mpsc::Sender<NetworkEvent>);
219
220    // Parse the event
221    let event = &*event_ptr;
222    if let Some(network_event) = parse_network_event(event) {
223        // Send to channel (non-blocking - drops event if channel is full)
224        let _ = sender.send(network_event);
225    }
226}
227
228/// Parse FWPM_NET_EVENT1 into NetworkEvent
229///
230/// # Safety
231///
232/// The event pointer must be valid and point to a complete FWPM_NET_EVENT1 structure.
233unsafe fn parse_network_event(event: &FWPM_NET_EVENT1) -> Option<NetworkEvent> {
234    let header = &event.header;
235    let event_type = NetworkEventType::from(event.r#type.0 as u32);
236
237    // Parse timestamp (FILETIME to SystemTime)
238    let timestamp = filetime_to_systemtime(&header.timeStamp);
239
240    // Parse application path (wide string) - appId.data is *mut u8, need to cast
241    let app_path = if !header.appId.data.is_null() {
242        parse_wide_string(header.appId.data as *const u16).map(PathBuf::from)
243    } else {
244        None
245    };
246
247    // Parse IP addresses based on IP version (ipVersion: 0=V4, 1=V6)
248    let (local_addr, remote_addr) = if header.ipVersion.0 == 0 {
249        // IPv4
250        unsafe {
251            let local = parse_ipv4_union(&header.Anonymous1);
252            let remote = parse_ipv4_union_remote(&header.Anonymous2);
253            (local, remote)
254        }
255    } else if header.ipVersion.0 == 1 {
256        // IPv6
257        unsafe {
258            let local = parse_ipv6_union(&header.Anonymous1);
259            let remote = parse_ipv6_union_remote(&header.Anonymous2);
260            (local, remote)
261        }
262    } else {
263        (None, None)
264    };
265
266    // Parse filter ID and layer ID for CLASSIFY_DROP events
267    let (filter_id, layer_id) = if event_type == NetworkEventType::ClassifyDrop {
268        unsafe {
269            if !event.Anonymous.classifyDrop.is_null() {
270                let drop_info = &*event.Anonymous.classifyDrop;
271                (Some(drop_info.filterId), Some(drop_info.layerId))
272            } else {
273                (None, None)
274            }
275        }
276    } else {
277        (None, None)
278    };
279
280    Some(NetworkEvent {
281        timestamp,
282        event_type,
283        app_path,
284        protocol: header.ipProtocol,
285        local_addr,
286        remote_addr,
287        local_port: header.localPort,
288        remote_port: header.remotePort,
289        filter_id,
290        layer_id,
291    })
292}
293
294/// Convert FILETIME to SystemTime
295fn filetime_to_systemtime(ft: &FILETIME) -> SystemTime {
296    // FILETIME is 100-nanosecond intervals since January 1, 1601 (UTC)
297    // SystemTime is based on UNIX_EPOCH (January 1, 1970)
298
299    // Difference between Windows epoch (1601) and UNIX epoch (1970) in 100-ns intervals
300    const WINDOWS_TO_UNIX_EPOCH: u64 = 116444736000000000;
301
302    let intervals = ((ft.dwHighDateTime as u64) << 32) | (ft.dwLowDateTime as u64);
303
304    if intervals >= WINDOWS_TO_UNIX_EPOCH {
305        let unix_intervals = intervals - WINDOWS_TO_UNIX_EPOCH;
306        let secs = unix_intervals / 10_000_000;
307        let nanos = ((unix_intervals % 10_000_000) * 100) as u32;
308
309        SystemTime::UNIX_EPOCH + Duration::new(secs, nanos)
310    } else {
311        SystemTime::UNIX_EPOCH
312    }
313}
314
315/// Parse wide string (null-terminated UTF-16)
316unsafe fn parse_wide_string(ptr: *const u16) -> Option<OsString> {
317    if ptr.is_null() {
318        return None;
319    }
320
321    // Find the null terminator
322    let mut len = 0;
323    while *ptr.add(len) != 0 {
324        len += 1;
325    }
326
327    if len == 0 {
328        return None;
329    }
330
331    // Convert to OsString
332    let slice = std::slice::from_raw_parts(ptr, len);
333    Some(OsString::from_wide(slice))
334}
335
336/// Parse IPv4 address from union (reads first 4 bytes as u32) - HEADER1_0 version
337unsafe fn parse_ipv4_union(
338    addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_0,
339) -> Option<IpAddr> {
340    // Union contains localAddrV4 as u32
341    let addr_u32 = addr_union.localAddrV4;
342    let bytes = addr_u32.to_ne_bytes();
343    Some(IpAddr::V4(Ipv4Addr::from(bytes)))
344}
345
346/// Parse IPv6 address from union (reads 16-byte array) - HEADER1_0 version
347unsafe fn parse_ipv6_union(
348    addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_0,
349) -> Option<IpAddr> {
350    // Union contains localAddrV6 as byte[16]
351    let bytes = addr_union.localAddrV6.byteArray16;
352    Some(IpAddr::V6(Ipv6Addr::from(bytes)))
353}
354
355/// Parse IPv4 address from remote union (HEADER1_1 version)
356unsafe fn parse_ipv4_union_remote(
357    addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_1,
358) -> Option<IpAddr> {
359    let addr_u32 = addr_union.remoteAddrV4;
360    let bytes = addr_u32.to_ne_bytes();
361    Some(IpAddr::V4(Ipv4Addr::from(bytes)))
362}
363
364/// Parse IPv6 address from remote union (HEADER1_1 version)
365unsafe fn parse_ipv6_union_remote(
366    addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_1,
367) -> Option<IpAddr> {
368    let bytes = addr_union.remoteAddrV6.byteArray16;
369    Some(IpAddr::V6(Ipv6Addr::from(bytes)))
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn test_event_type_conversion() {
378        assert_eq!(NetworkEventType::from(3), NetworkEventType::ClassifyDrop);
379        assert_eq!(NetworkEventType::from(6), NetworkEventType::ClassifyAllow);
380        assert_eq!(NetworkEventType::from(7), NetworkEventType::CapabilityDrop);
381        assert_eq!(NetworkEventType::from(99), NetworkEventType::Other(99));
382    }
383
384    #[test]
385    fn test_event_type_boundaries() {
386        assert_eq!(NetworkEventType::from(0), NetworkEventType::Other(0));
387        assert_eq!(NetworkEventType::from(2), NetworkEventType::Other(2));
388        assert_eq!(NetworkEventType::from(4), NetworkEventType::Other(4));
389        assert_eq!(NetworkEventType::from(5), NetworkEventType::Other(5));
390        assert_eq!(NetworkEventType::from(8), NetworkEventType::Other(8));
391    }
392
393    #[test]
394    fn test_filetime_to_systemtime_unix_epoch() {
395        // FILETIME for Unix epoch (Jan 1, 1970) = 116444736000000000
396        let intervals: u64 = 116444736000000000;
397        let ft = FILETIME {
398            dwLowDateTime: intervals as u32,
399            dwHighDateTime: (intervals >> 32) as u32,
400        };
401        let result = filetime_to_systemtime(&ft);
402        assert_eq!(result, SystemTime::UNIX_EPOCH);
403    }
404
405    #[test]
406    fn test_filetime_to_systemtime_before_unix_epoch() {
407        let ft = FILETIME {
408            dwLowDateTime: 0,
409            dwHighDateTime: 0,
410        };
411        assert_eq!(filetime_to_systemtime(&ft), SystemTime::UNIX_EPOCH);
412    }
413
414    #[test]
415    fn test_filetime_to_systemtime_known_date() {
416        // Jan 1, 2000 = 125911584000000000 intervals from Windows epoch
417        let intervals: u64 = 125911584000000000;
418        let ft = FILETIME {
419            dwLowDateTime: intervals as u32,
420            dwHighDateTime: (intervals >> 32) as u32,
421        };
422        let result = filetime_to_systemtime(&ft);
423        let duration = result.duration_since(SystemTime::UNIX_EPOCH).unwrap();
424        assert_eq!(duration.as_secs(), 946684800); // Jan 1, 2000
425    }
426
427    #[test]
428    fn test_network_event_struct_creation() {
429        let event = NetworkEvent {
430            timestamp: SystemTime::UNIX_EPOCH,
431            event_type: NetworkEventType::ClassifyDrop,
432            app_path: Some(PathBuf::from(r"C:\test.exe")),
433            protocol: 6,
434            local_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
435            remote_addr: Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))),
436            local_port: 12345,
437            remote_port: 443,
438            filter_id: Some(42),
439            layer_id: Some(1),
440        };
441
442        assert_eq!(event.event_type, NetworkEventType::ClassifyDrop);
443        assert_eq!(event.protocol, 6);
444        assert_eq!(event.local_port, 12345);
445        assert_eq!(event.remote_port, 443);
446        assert!(event.app_path.is_some());
447    }
448
449    #[test]
450    fn test_network_event_clone() {
451        let event = NetworkEvent {
452            timestamp: SystemTime::UNIX_EPOCH,
453            event_type: NetworkEventType::ClassifyAllow,
454            app_path: None,
455            protocol: 17,
456            local_addr: None,
457            remote_addr: None,
458            local_port: 0,
459            remote_port: 0,
460            filter_id: None,
461            layer_id: None,
462        };
463
464        let cloned = event.clone();
465        assert_eq!(cloned.event_type, NetworkEventType::ClassifyAllow);
466        assert_eq!(cloned.protocol, 17);
467        assert!(cloned.app_path.is_none());
468    }
469
470    #[test]
471    fn test_network_event_type_display() {
472        assert_eq!(NetworkEventType::ClassifyDrop.to_string(), "ClassifyDrop");
473        assert_eq!(NetworkEventType::ClassifyAllow.to_string(), "ClassifyAllow");
474        assert_eq!(NetworkEventType::CapabilityDrop.to_string(), "CapabilityDrop");
475        assert_eq!(NetworkEventType::Other(42).to_string(), "Other(42)");
476    }
477}