Skip to main content

windows_wfp/
condition.rs

1//! WFP filter condition types
2//!
3//! Types used as filter conditions in WFP rules: IP addresses with masks,
4//! and IP protocol numbers.
5
6use std::fmt;
7use std::net::IpAddr;
8
9/// IP address with CIDR prefix length
10///
11/// Used for local and remote IP address conditions in WFP filters.
12///
13/// # Examples
14///
15/// ```
16/// use windows_wfp::IpAddrMask;
17/// use std::net::IpAddr;
18///
19/// // Match a single host
20/// let host = IpAddrMask::new("192.168.1.1".parse().unwrap(), 32);
21///
22/// // Match a /24 subnet
23/// let subnet = IpAddrMask::from_cidr("192.168.1.0/24").unwrap();
24///
25/// // Match an IPv6 address
26/// let ipv6 = IpAddrMask::new("::1".parse().unwrap(), 128);
27/// ```
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct IpAddrMask {
30    /// IP address (IPv4 or IPv6)
31    pub addr: IpAddr,
32    /// CIDR prefix length (0-32 for IPv4, 0-128 for IPv6)
33    pub prefix_len: u8,
34}
35
36impl IpAddrMask {
37    /// Create a new IP address with mask
38    pub fn new(addr: IpAddr, prefix_len: u8) -> Self {
39        Self { addr, prefix_len }
40    }
41
42    /// Parse from CIDR notation (e.g., "192.168.1.0/24" or "::1/128")
43    ///
44    /// # Errors
45    ///
46    /// Returns an error string if the format is invalid.
47    pub fn from_cidr(s: &str) -> Result<Self, String> {
48        let parts: Vec<&str> = s.split('/').collect();
49        if parts.len() != 2 {
50            return Err(format!("Invalid CIDR notation: {}", s));
51        }
52
53        let addr: IpAddr = parts[0]
54            .parse()
55            .map_err(|e| format!("Invalid IP address: {}", e))?;
56
57        let prefix_len: u8 = parts[1]
58            .parse()
59            .map_err(|e| format!("Invalid prefix length: {}", e))?;
60
61        let max_prefix = match addr {
62            IpAddr::V4(_) => 32,
63            IpAddr::V6(_) => 128,
64        };
65
66        if prefix_len > max_prefix {
67            return Err(format!(
68                "Prefix length {} exceeds maximum {} for {:?}",
69                prefix_len, max_prefix, addr
70            ));
71        }
72
73        Ok(Self { addr, prefix_len })
74    }
75
76    /// Returns true if this is an IPv6 address
77    pub fn is_ipv6(&self) -> bool {
78        matches!(self.addr, IpAddr::V6(_))
79    }
80}
81
82/// IP protocol numbers (IANA assigned)
83///
84/// Standard protocol numbers used in WFP filter conditions.
85/// Values match the IANA protocol number assignments.
86///
87/// # Examples
88///
89/// ```
90/// use windows_wfp::Protocol;
91///
92/// let tcp = Protocol::Tcp;
93/// assert_eq!(tcp as u8, 6);
94///
95/// let udp = Protocol::Udp;
96/// assert_eq!(udp as u8, 17);
97/// ```
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
99#[repr(u8)]
100pub enum Protocol {
101    /// IPv6 Hop-by-Hop Option (protocol 0)
102    Hopopt = 0,
103    /// Internet Control Message Protocol v4 (protocol 1)
104    Icmp = 1,
105    /// Internet Group Management Protocol (protocol 2)
106    Igmp = 2,
107    /// Transmission Control Protocol (protocol 6)
108    Tcp = 6,
109    /// User Datagram Protocol (protocol 17)
110    Udp = 17,
111    /// Generic Routing Encapsulation (protocol 47)
112    Gre = 47,
113    /// Encapsulating Security Payload / IPsec (protocol 50)
114    Esp = 50,
115    /// Authentication Header / IPsec (protocol 51)
116    Ah = 51,
117    /// Internet Control Message Protocol v6 (protocol 58)
118    Icmpv6 = 58,
119}
120
121impl Protocol {
122    /// Get the IANA protocol number
123    pub fn as_u8(self) -> u8 {
124        self as u8
125    }
126}
127
128impl fmt::Display for Protocol {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        match self {
131            Protocol::Hopopt => write!(f, "HOPOPT"),
132            Protocol::Icmp => write!(f, "ICMP"),
133            Protocol::Igmp => write!(f, "IGMP"),
134            Protocol::Tcp => write!(f, "TCP"),
135            Protocol::Udp => write!(f, "UDP"),
136            Protocol::Gre => write!(f, "GRE"),
137            Protocol::Esp => write!(f, "ESP"),
138            Protocol::Ah => write!(f, "AH"),
139            Protocol::Icmpv6 => write!(f, "ICMPv6"),
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_ip_addr_mask_new() {
150        let mask = IpAddrMask::new("10.0.0.1".parse().unwrap(), 24);
151        assert_eq!(mask.prefix_len, 24);
152        assert!(!mask.is_ipv6());
153    }
154
155    #[test]
156    fn test_ip_addr_mask_from_cidr_v4() {
157        let mask = IpAddrMask::from_cidr("192.168.1.0/24").unwrap();
158        assert_eq!(mask.addr, "192.168.1.0".parse::<IpAddr>().unwrap());
159        assert_eq!(mask.prefix_len, 24);
160    }
161
162    #[test]
163    fn test_ip_addr_mask_from_cidr_v6() {
164        let mask = IpAddrMask::from_cidr("fe80::1/64").unwrap();
165        assert!(mask.is_ipv6());
166        assert_eq!(mask.prefix_len, 64);
167    }
168
169    #[test]
170    fn test_ip_addr_mask_invalid_cidr() {
171        assert!(IpAddrMask::from_cidr("not-an-ip/24").is_err());
172        assert!(IpAddrMask::from_cidr("192.168.1.0").is_err());
173        assert!(IpAddrMask::from_cidr("192.168.1.0/33").is_err());
174        assert!(IpAddrMask::from_cidr("::1/129").is_err());
175    }
176
177    #[test]
178    fn test_protocol_values() {
179        assert_eq!(Protocol::Hopopt.as_u8(), 0);
180        assert_eq!(Protocol::Icmp.as_u8(), 1);
181        assert_eq!(Protocol::Igmp.as_u8(), 2);
182        assert_eq!(Protocol::Tcp.as_u8(), 6);
183        assert_eq!(Protocol::Udp.as_u8(), 17);
184        assert_eq!(Protocol::Gre.as_u8(), 47);
185        assert_eq!(Protocol::Esp.as_u8(), 50);
186        assert_eq!(Protocol::Ah.as_u8(), 51);
187        assert_eq!(Protocol::Icmpv6.as_u8(), 58);
188    }
189
190    #[test]
191    fn test_ip_addr_mask_boundary_prefixes_v4() {
192        let zero = IpAddrMask::from_cidr("0.0.0.0/0").unwrap();
193        assert_eq!(zero.prefix_len, 0);
194
195        let host = IpAddrMask::from_cidr("10.0.0.1/32").unwrap();
196        assert_eq!(host.prefix_len, 32);
197    }
198
199    #[test]
200    fn test_ip_addr_mask_boundary_prefixes_v6() {
201        let zero = IpAddrMask::from_cidr("::/0").unwrap();
202        assert_eq!(zero.prefix_len, 0);
203        assert!(zero.is_ipv6());
204
205        let host = IpAddrMask::from_cidr("::1/128").unwrap();
206        assert_eq!(host.prefix_len, 128);
207    }
208
209    #[test]
210    fn test_ip_addr_mask_equality() {
211        let a = IpAddrMask::new("10.0.0.1".parse().unwrap(), 24);
212        let b = IpAddrMask::new("10.0.0.1".parse().unwrap(), 24);
213        assert_eq!(a, b);
214
215        let c = IpAddrMask::new("10.0.0.1".parse().unwrap(), 16);
216        assert_ne!(a, c);
217    }
218
219    #[test]
220    fn test_ip_addr_mask_multiple_slashes() {
221        assert!(IpAddrMask::from_cidr("10.0.0.1/24/8").is_err());
222    }
223
224    #[test]
225    fn test_ip_addr_mask_empty_string() {
226        assert!(IpAddrMask::from_cidr("").is_err());
227    }
228
229    #[test]
230    fn test_ip_addr_mask_v4_is_not_ipv6() {
231        let mask = IpAddrMask::new("192.168.0.1".parse().unwrap(), 24);
232        assert!(!mask.is_ipv6());
233    }
234
235    #[test]
236    fn test_ip_addr_mask_v6_is_ipv6() {
237        let mask = IpAddrMask::new("::1".parse().unwrap(), 128);
238        assert!(mask.is_ipv6());
239    }
240
241    #[test]
242    fn test_protocol_copy() {
243        let p = Protocol::Tcp;
244        let p2 = p; // Copy
245        assert_eq!(p, p2);
246    }
247
248    #[test]
249    fn test_protocol_display() {
250        assert_eq!(Protocol::Tcp.to_string(), "TCP");
251        assert_eq!(Protocol::Udp.to_string(), "UDP");
252        assert_eq!(Protocol::Icmp.to_string(), "ICMP");
253        assert_eq!(Protocol::Icmpv6.to_string(), "ICMPv6");
254        assert_eq!(Protocol::Gre.to_string(), "GRE");
255        assert_eq!(Protocol::Esp.to_string(), "ESP");
256        assert_eq!(Protocol::Ah.to_string(), "AH");
257        assert_eq!(Protocol::Igmp.to_string(), "IGMP");
258        assert_eq!(Protocol::Hopopt.to_string(), "HOPOPT");
259    }
260}