Skip to main content

windows_wfp/
events.rs

1//! WFP Network Event Subscription (Learning Mode)
2//!
3//! Subscribes to WFP network events to monitor blocked connections.
4//! Used for learning mode where blocked traffic is logged for auto-whitelisting.
5
6use crate::engine::WfpEngine;
7use crate::errors::{WfpError, WfpResult};
8use std::ffi::{c_void, OsString};
9use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
10use std::os::windows::ffi::OsStringExt;
11use std::path::PathBuf;
12use std::sync::mpsc;
13use std::time::{Duration, SystemTime};
14use windows::core::GUID;
15use windows::Win32::Foundation::{ERROR_SUCCESS, FILETIME, HANDLE};
16use windows::Win32::NetworkManagement::WindowsFilteringPlatform::{
17    FwpmNetEventSubscribe0, FwpmNetEventUnsubscribe0, FWPM_NET_EVENT1, FWPM_NET_EVENT_CALLBACK0,
18    FWPM_NET_EVENT_SUBSCRIPTION0,
19};
20
21/// Network event from WFP
22///
23/// Represents a network event captured by the Windows Filtering Platform.
24/// Used in learning mode to identify applications that were blocked.
25#[derive(Debug, Clone)]
26pub struct NetworkEvent {
27    /// When the event occurred
28    pub timestamp: SystemTime,
29
30    /// Type of event (Classify Drop, Classify Allow, etc.)
31    pub event_type: NetworkEventType,
32
33    /// Application path that triggered the event (if available)
34    pub app_path: Option<PathBuf>,
35
36    /// IP protocol (TCP=6, UDP=17, etc.)
37    pub protocol: u8,
38
39    /// Local IP address
40    pub local_addr: Option<IpAddr>,
41
42    /// Remote IP address
43    pub remote_addr: Option<IpAddr>,
44
45    /// Local port
46    pub local_port: u16,
47
48    /// Remote port
49    pub remote_port: u16,
50
51    /// Filter ID that triggered the event (for CLASSIFY_DROP)
52    pub filter_id: Option<u64>,
53
54    /// Layer ID where event occurred
55    pub layer_id: Option<u16>,
56}
57
58/// Type of network event
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60#[repr(u32)]
61pub enum NetworkEventType {
62    /// Connection was blocked by a filter
63    ClassifyDrop = 3,
64
65    /// Connection was allowed by a filter (Win8+)
66    ClassifyAllow = 6,
67
68    /// App container capability drop (Win8+)
69    CapabilityDrop = 7,
70
71    /// Other event type
72    Other(u32),
73}
74
75impl From<u32> for NetworkEventType {
76    fn from(value: u32) -> Self {
77        match value {
78            3 => NetworkEventType::ClassifyDrop,
79            6 => NetworkEventType::ClassifyAllow,
80            7 => NetworkEventType::CapabilityDrop,
81            other => NetworkEventType::Other(other),
82        }
83    }
84}
85
86/// WFP Event Subscription Handle
87///
88/// RAII wrapper for WFP event subscription. Automatically unsubscribes on drop.
89/// Events are delivered via an mpsc channel for thread-safe processing.
90pub struct WfpEventSubscription {
91    engine: *const WfpEngine,
92    subscription_handle: HANDLE,
93    _callback: Box<FWPM_NET_EVENT_CALLBACK0>, // Keep callback alive
94    receiver: mpsc::Receiver<NetworkEvent>,
95}
96
97impl WfpEventSubscription {
98    /// Subscribe to WFP network events
99    ///
100    /// Creates a new event subscription that monitors network events.
101    /// Events are delivered via the returned receiver channel.
102    ///
103    /// # Errors
104    ///
105    /// Returns error if subscription fails (permissions, invalid engine, etc.)
106    pub fn new(engine: &WfpEngine) -> WfpResult<Self> {
107        let (sender, receiver) = mpsc::channel();
108
109        // Box the channel sender so it has a stable address for the context pointer
110        let sender_box = Box::new(sender);
111        let context = Box::into_raw(sender_box) as *const c_void;
112
113        // Create the callback function
114        let callback: FWPM_NET_EVENT_CALLBACK0 = Some(event_callback);
115
116        // Create subscription for all events
117        let subscription = FWPM_NET_EVENT_SUBSCRIPTION0 {
118            enumTemplate: std::ptr::null_mut(), // Subscribe to all events
119            flags: 0,
120            sessionKey: GUID::zeroed(),
121        };
122
123        let mut subscription_handle = HANDLE::default();
124
125        unsafe {
126            let result = FwpmNetEventSubscribe0(
127                engine.handle(),
128                &subscription,
129                callback,
130                Some(context),
131                &mut subscription_handle,
132            );
133
134            if result != ERROR_SUCCESS.0 {
135                // Clean up the boxed sender on error
136                let _ = Box::from_raw(context as *mut mpsc::Sender<NetworkEvent>);
137                return Err(WfpError::Other(format!(
138                    "Failed to subscribe to WFP events: error code {}",
139                    result
140                )));
141            }
142        }
143
144        Ok(Self {
145            engine: engine as *const WfpEngine,
146            subscription_handle,
147            _callback: Box::new(callback), // Keep callback alive to prevent GC
148            receiver,
149        })
150    }
151
152    /// Try to receive a network event (non-blocking)
153    pub fn try_recv(&self) -> Result<NetworkEvent, mpsc::TryRecvError> {
154        self.receiver.try_recv()
155    }
156
157    /// Receive a network event (blocking)
158    pub fn recv(&self) -> Result<NetworkEvent, mpsc::RecvError> {
159        self.receiver.recv()
160    }
161
162    /// Get an iterator over pending events
163    pub fn iter(&self) -> mpsc::Iter<'_, NetworkEvent> {
164        self.receiver.iter()
165    }
166}
167
168impl Drop for WfpEventSubscription {
169    fn drop(&mut self) {
170        if !self.subscription_handle.is_invalid() && !self.engine.is_null() {
171            unsafe {
172                let _ = FwpmNetEventUnsubscribe0((*self.engine).handle(), self.subscription_handle);
173            }
174        }
175    }
176}
177
178/// Native callback function invoked by WFP (runs on WFP worker thread)
179///
180/// # Safety
181///
182/// This function receives raw pointers from WFP and must carefully:
183/// - Validate the event pointer is not null
184/// - Parse the FWPM_NET_EVENT1 structure
185/// - Send the parsed event to the channel without blocking
186unsafe extern "system" fn event_callback(context: *mut c_void, event_ptr: *const FWPM_NET_EVENT1) {
187    // Validate pointers
188    if context.is_null() || event_ptr.is_null() {
189        return;
190    }
191
192    // Recover the sender from the context
193    let sender = &*(context as *const mpsc::Sender<NetworkEvent>);
194
195    // Parse the event
196    let event = &*event_ptr;
197    if let Some(network_event) = parse_network_event(event) {
198        // Send to channel (non-blocking - drops event if channel is full)
199        let _ = sender.send(network_event);
200    }
201}
202
203/// Parse FWPM_NET_EVENT1 into NetworkEvent
204///
205/// # Safety
206///
207/// The event pointer must be valid and point to a complete FWPM_NET_EVENT1 structure.
208unsafe fn parse_network_event(event: &FWPM_NET_EVENT1) -> Option<NetworkEvent> {
209    let header = &event.header;
210    let event_type = NetworkEventType::from(event.r#type.0 as u32);
211
212    // Parse timestamp (FILETIME to SystemTime)
213    let timestamp = filetime_to_systemtime(&header.timeStamp);
214
215    // Parse application path (wide string) - appId.data is *mut u8, need to cast
216    let app_path = if !header.appId.data.is_null() {
217        parse_wide_string(header.appId.data as *const u16).map(PathBuf::from)
218    } else {
219        None
220    };
221
222    // Parse IP addresses based on IP version (ipVersion: 0=V4, 1=V6)
223    let (local_addr, remote_addr) = if header.ipVersion.0 == 0 {
224        // IPv4
225        unsafe {
226            let local = parse_ipv4_union(&header.Anonymous1);
227            let remote = parse_ipv4_union_remote(&header.Anonymous2);
228            (local, remote)
229        }
230    } else if header.ipVersion.0 == 1 {
231        // IPv6
232        unsafe {
233            let local = parse_ipv6_union(&header.Anonymous1);
234            let remote = parse_ipv6_union_remote(&header.Anonymous2);
235            (local, remote)
236        }
237    } else {
238        (None, None)
239    };
240
241    // Parse filter ID and layer ID for CLASSIFY_DROP events
242    let (filter_id, layer_id) = if event_type == NetworkEventType::ClassifyDrop {
243        unsafe {
244            if !event.Anonymous.classifyDrop.is_null() {
245                let drop_info = &*event.Anonymous.classifyDrop;
246                (Some(drop_info.filterId), Some(drop_info.layerId))
247            } else {
248                (None, None)
249            }
250        }
251    } else {
252        (None, None)
253    };
254
255    Some(NetworkEvent {
256        timestamp,
257        event_type,
258        app_path,
259        protocol: header.ipProtocol,
260        local_addr,
261        remote_addr,
262        local_port: header.localPort,
263        remote_port: header.remotePort,
264        filter_id,
265        layer_id,
266    })
267}
268
269/// Convert FILETIME to SystemTime
270fn filetime_to_systemtime(ft: &FILETIME) -> SystemTime {
271    // FILETIME is 100-nanosecond intervals since January 1, 1601 (UTC)
272    // SystemTime is based on UNIX_EPOCH (January 1, 1970)
273
274    // Difference between Windows epoch (1601) and UNIX epoch (1970) in 100-ns intervals
275    const WINDOWS_TO_UNIX_EPOCH: u64 = 116444736000000000;
276
277    let intervals = ((ft.dwHighDateTime as u64) << 32) | (ft.dwLowDateTime as u64);
278
279    if intervals >= WINDOWS_TO_UNIX_EPOCH {
280        let unix_intervals = intervals - WINDOWS_TO_UNIX_EPOCH;
281        let secs = unix_intervals / 10_000_000;
282        let nanos = ((unix_intervals % 10_000_000) * 100) as u32;
283
284        SystemTime::UNIX_EPOCH + Duration::new(secs, nanos)
285    } else {
286        SystemTime::UNIX_EPOCH
287    }
288}
289
290/// Parse wide string (null-terminated UTF-16)
291unsafe fn parse_wide_string(ptr: *const u16) -> Option<OsString> {
292    if ptr.is_null() {
293        return None;
294    }
295
296    // Find the null terminator
297    let mut len = 0;
298    while *ptr.add(len) != 0 {
299        len += 1;
300    }
301
302    if len == 0 {
303        return None;
304    }
305
306    // Convert to OsString
307    let slice = std::slice::from_raw_parts(ptr, len);
308    Some(OsString::from_wide(slice))
309}
310
311/// Parse IPv4 address from union (reads first 4 bytes as u32) - HEADER1_0 version
312unsafe fn parse_ipv4_union(
313    addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_0,
314) -> Option<IpAddr> {
315    // Union contains localAddrV4 as u32
316    let addr_u32 = addr_union.localAddrV4;
317    let bytes = addr_u32.to_ne_bytes();
318    Some(IpAddr::V4(Ipv4Addr::from(bytes)))
319}
320
321/// Parse IPv6 address from union (reads 16-byte array) - HEADER1_0 version
322unsafe fn parse_ipv6_union(
323    addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_0,
324) -> Option<IpAddr> {
325    // Union contains localAddrV6 as byte[16]
326    let bytes = addr_union.localAddrV6.byteArray16;
327    Some(IpAddr::V6(Ipv6Addr::from(bytes)))
328}
329
330/// Parse IPv4 address from remote union (HEADER1_1 version)
331unsafe fn parse_ipv4_union_remote(
332    addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_1,
333) -> Option<IpAddr> {
334    let addr_u32 = addr_union.remoteAddrV4;
335    let bytes = addr_u32.to_ne_bytes();
336    Some(IpAddr::V4(Ipv4Addr::from(bytes)))
337}
338
339/// Parse IPv6 address from remote union (HEADER1_1 version)
340unsafe fn parse_ipv6_union_remote(
341    addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_1,
342) -> Option<IpAddr> {
343    let bytes = addr_union.remoteAddrV6.byteArray16;
344    Some(IpAddr::V6(Ipv6Addr::from(bytes)))
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_event_type_conversion() {
353        assert_eq!(NetworkEventType::from(3), NetworkEventType::ClassifyDrop);
354        assert_eq!(NetworkEventType::from(6), NetworkEventType::ClassifyAllow);
355        assert_eq!(NetworkEventType::from(7), NetworkEventType::CapabilityDrop);
356        assert_eq!(NetworkEventType::from(99), NetworkEventType::Other(99));
357    }
358
359    #[test]
360    fn test_event_type_boundaries() {
361        assert_eq!(NetworkEventType::from(0), NetworkEventType::Other(0));
362        assert_eq!(NetworkEventType::from(2), NetworkEventType::Other(2));
363        assert_eq!(NetworkEventType::from(4), NetworkEventType::Other(4));
364        assert_eq!(NetworkEventType::from(5), NetworkEventType::Other(5));
365        assert_eq!(NetworkEventType::from(8), NetworkEventType::Other(8));
366    }
367
368    #[test]
369    fn test_filetime_to_systemtime_unix_epoch() {
370        // FILETIME for Unix epoch (Jan 1, 1970) = 116444736000000000
371        let intervals: u64 = 116444736000000000;
372        let ft = FILETIME {
373            dwLowDateTime: intervals as u32,
374            dwHighDateTime: (intervals >> 32) as u32,
375        };
376        let result = filetime_to_systemtime(&ft);
377        assert_eq!(result, SystemTime::UNIX_EPOCH);
378    }
379
380    #[test]
381    fn test_filetime_to_systemtime_before_unix_epoch() {
382        let ft = FILETIME {
383            dwLowDateTime: 0,
384            dwHighDateTime: 0,
385        };
386        assert_eq!(filetime_to_systemtime(&ft), SystemTime::UNIX_EPOCH);
387    }
388
389    #[test]
390    fn test_filetime_to_systemtime_known_date() {
391        // Jan 1, 2000 = 125911584000000000 intervals from Windows epoch
392        let intervals: u64 = 125911584000000000;
393        let ft = FILETIME {
394            dwLowDateTime: intervals as u32,
395            dwHighDateTime: (intervals >> 32) as u32,
396        };
397        let result = filetime_to_systemtime(&ft);
398        let duration = result.duration_since(SystemTime::UNIX_EPOCH).unwrap();
399        assert_eq!(duration.as_secs(), 946684800); // Jan 1, 2000
400    }
401
402    #[test]
403    fn test_network_event_struct_creation() {
404        let event = NetworkEvent {
405            timestamp: SystemTime::UNIX_EPOCH,
406            event_type: NetworkEventType::ClassifyDrop,
407            app_path: Some(PathBuf::from(r"C:\test.exe")),
408            protocol: 6,
409            local_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
410            remote_addr: Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))),
411            local_port: 12345,
412            remote_port: 443,
413            filter_id: Some(42),
414            layer_id: Some(1),
415        };
416
417        assert_eq!(event.event_type, NetworkEventType::ClassifyDrop);
418        assert_eq!(event.protocol, 6);
419        assert_eq!(event.local_port, 12345);
420        assert_eq!(event.remote_port, 443);
421        assert!(event.app_path.is_some());
422    }
423
424    #[test]
425    fn test_network_event_clone() {
426        let event = NetworkEvent {
427            timestamp: SystemTime::UNIX_EPOCH,
428            event_type: NetworkEventType::ClassifyAllow,
429            app_path: None,
430            protocol: 17,
431            local_addr: None,
432            remote_addr: None,
433            local_port: 0,
434            remote_port: 0,
435            filter_id: None,
436            layer_id: None,
437        };
438
439        let cloned = event.clone();
440        assert_eq!(cloned.event_type, NetworkEventType::ClassifyAllow);
441        assert_eq!(cloned.protocol, 17);
442        assert!(cloned.app_path.is_none());
443    }
444}