wireguard_netstack/
config.rs

1//! WireGuard configuration file parser.
2//!
3//! Parses standard WireGuard configuration files (.conf) using serde.
4
5use base64::{engine::general_purpose::STANDARD, Engine};
6use serde::Deserialize;
7use std::net::Ipv4Addr;
8use std::path::Path;
9
10use crate::dns::{DohResolver, DohServerConfig};
11use crate::error::{Error, Result};
12use crate::wireguard::WireGuardConfig;
13
14/// Raw WireGuard configuration as parsed from the INI file.
15#[derive(Debug, Deserialize)]
16#[serde(rename_all = "PascalCase")]
17struct RawWgConfig {
18    interface: InterfaceSection,
19    peer: PeerSection,
20}
21
22#[derive(Debug, Deserialize)]
23#[serde(rename_all = "PascalCase")]
24struct InterfaceSection {
25    private_key: String,
26    address: String,
27    #[serde(rename = "DNS")]
28    dns: Option<String>,
29}
30
31#[derive(Debug, Deserialize)]
32#[serde(rename_all = "PascalCase")]
33struct PeerSection {
34    public_key: String,
35    endpoint: String,
36    #[serde(default)]
37    preshared_key: Option<String>,
38    #[serde(default)]
39    persistent_keepalive: Option<u16>,
40    #[allow(dead_code)]
41    allowed_ips: Option<String>,
42}
43
44/// Parsed WireGuard configuration file.
45#[derive(Debug, Clone)]
46pub struct WgConfigFile {
47    /// Private key (base64 encoded in file).
48    pub private_key: [u8; 32],
49    /// Interface address (tunnel IP).
50    pub address: Ipv4Addr,
51    /// DNS server (optional).
52    pub dns: Option<Ipv4Addr>,
53    /// Peer public key.
54    pub peer_public_key: [u8; 32],
55    /// Peer endpoint hostname or IP.
56    pub endpoint_host: String,
57    /// Peer endpoint port.
58    pub endpoint_port: u16,
59    /// Preshared key (optional).
60    pub preshared_key: Option<[u8; 32]>,
61    /// Persistent keepalive interval in seconds (optional).
62    pub persistent_keepalive: Option<u16>,
63}
64
65impl WgConfigFile {
66    /// Parse a WireGuard configuration file from the given path.
67    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
68        let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
69            Error::ConfigParse(format!("Failed to read config file {:?}: {}", path.as_ref(), e))
70        })?;
71        Self::parse(&content)
72    }
73
74    /// Parse a WireGuard configuration from a string.
75    pub fn parse(content: &str) -> Result<Self> {
76        let raw: RawWgConfig =
77            serde_ini::from_str(content).map_err(|e| Error::ConfigParse(e.to_string()))?;
78
79        // Decode private key
80        let private_key = decode_key(&raw.interface.private_key)?;
81
82        // Parse address (strip CIDR notation if present)
83        let ip_str = raw
84            .interface
85            .address
86            .split('/')
87            .next()
88            .unwrap_or(&raw.interface.address);
89        let address: Ipv4Addr = ip_str
90            .parse()
91            .map_err(|_| Error::InvalidAddress(raw.interface.address.clone()))?;
92
93        // Parse DNS (take first if comma-separated)
94        let dns = raw
95            .interface
96            .dns
97            .as_ref()
98            .and_then(|d| d.split(',').next())
99            .map(|s| s.trim().parse())
100            .transpose()
101            .map_err(|_| Error::InvalidAddress("Invalid DNS address".into()))?;
102
103        // Decode peer public key
104        let peer_public_key = decode_key(&raw.peer.public_key)?;
105
106        // Parse endpoint
107        let (endpoint_host, endpoint_port) = parse_endpoint(&raw.peer.endpoint)?;
108
109        // Decode preshared key if present
110        let preshared_key = raw
111            .peer
112            .preshared_key
113            .as_ref()
114            .map(|k| decode_key(k))
115            .transpose()?;
116
117        Ok(Self {
118            private_key,
119            address,
120            dns,
121            peer_public_key,
122            endpoint_host,
123            endpoint_port,
124            preshared_key,
125            persistent_keepalive: raw.peer.persistent_keepalive,
126        })
127    }
128
129    /// Convert to WireGuardConfig, resolving the endpoint hostname via DoH if needed.
130    /// Uses the default Cloudflare DNS for resolution.
131    pub async fn into_wireguard_config(self) -> Result<WireGuardConfig> {
132        self.into_wireguard_config_with_dns(DohServerConfig::default()).await
133    }
134
135    /// Convert to WireGuardConfig, resolving the endpoint hostname via DoH with custom DNS.
136    ///
137    /// # Arguments
138    ///
139    /// * `dns_config` - The DNS server configuration to use for resolving the endpoint hostname.
140    ///
141    /// # Example
142    ///
143    /// ```no_run
144    /// use wireguard_netstack::{WgConfigFile, DohServerConfig};
145    ///
146    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
147    /// let config = WgConfigFile::from_file("wg.conf")?
148    ///     .into_wireguard_config_with_dns(DohServerConfig::google())
149    ///     .await?;
150    /// # Ok(())
151    /// # }
152    /// ```
153    pub async fn into_wireguard_config_with_dns(self, dns_config: DohServerConfig) -> Result<WireGuardConfig> {
154        // First try to parse as IP:port, otherwise resolve via DoH
155        let peer_endpoint =
156            match format!("{}:{}", self.endpoint_host, self.endpoint_port).parse() {
157                Ok(addr) => addr,
158                Err(_) => {
159                    // Resolve using DNS-over-HTTPS (direct mode, before tunnel is up)
160                    log::info!(
161                        "Resolving WireGuard endpoint '{}' via DoH ({})...",
162                        self.endpoint_host,
163                        dns_config.hostname
164                    );
165                    let doh_resolver = DohResolver::new_direct_with_config(dns_config);
166                    doh_resolver
167                        .resolve_addr(&self.endpoint_host, self.endpoint_port)
168                        .await?
169                }
170            };
171
172        log::info!("WireGuard endpoint resolved to: {}", peer_endpoint);
173
174        Ok(WireGuardConfig {
175            private_key: self.private_key,
176            peer_public_key: self.peer_public_key,
177            peer_endpoint,
178            tunnel_ip: self.address,
179            preshared_key: self.preshared_key,
180            keepalive_seconds: self.persistent_keepalive.or(Some(25)), // Default to 25s if not specified
181        })
182    }
183}
184
185/// Decode a base64-encoded 32-byte key.
186fn decode_key(b64: &str) -> Result<[u8; 32]> {
187    let bytes = STANDARD
188        .decode(b64)
189        .map_err(|_| Error::InvalidKey(b64.to_string()))?;
190    bytes
191        .try_into()
192        .map_err(|v: Vec<u8>| Error::InvalidKey(format!("Key must be 32 bytes, got {} bytes", v.len())))
193}
194
195/// Parse an endpoint string (host:port).
196fn parse_endpoint(endpoint: &str) -> Result<(String, u16)> {
197    // Handle IPv6 addresses in brackets: [::1]:51820
198    if endpoint.starts_with('[') {
199        if let Some(bracket_end) = endpoint.find(']') {
200            let host = endpoint[1..bracket_end].to_string();
201            let port_str = endpoint[bracket_end + 1..].trim_start_matches(':');
202            let port: u16 = port_str
203                .parse()
204                .map_err(|_| Error::InvalidEndpoint(endpoint.to_string()))?;
205            return Ok((host, port));
206        }
207    }
208
209    // Handle hostname:port or IPv4:port
210    if let Some((host, port_str)) = endpoint.rsplit_once(':') {
211        let port: u16 = port_str
212            .parse()
213            .map_err(|_| Error::InvalidEndpoint(endpoint.to_string()))?;
214        Ok((host.to_string(), port))
215    } else {
216        Err(Error::InvalidEndpoint(format!(
217            "Invalid endpoint format (expected host:port): {}",
218            endpoint
219        )))
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_parse_config() {
229        let config_str = r#"
230[Interface]
231PrivateKey = eC3sErLXd5A7z3FTJnrb55uuxlazlDM40HQmWZrb6Vc=
232Address = 192.168.3.4/32
233DNS = 192.168.3.1
234
235[Peer]
236PublicKey = EISEG38ycR6D7nK0m+mnacAM9HfXzdqcO1mO5LNs6jU=
237AllowedIPs = 0.0.0.0/0
238Endpoint = direct.casarizzotti.com:51820
239"#;
240
241        let config = WgConfigFile::parse(config_str).unwrap();
242        assert_eq!(config.address, "192.168.3.4".parse::<Ipv4Addr>().unwrap());
243        assert_eq!(config.dns, Some("192.168.3.1".parse().unwrap()));
244        assert_eq!(config.endpoint_host, "direct.casarizzotti.com");
245        assert_eq!(config.endpoint_port, 51820);
246    }
247
248    #[test]
249    fn test_parse_endpoint() {
250        // IPv4
251        let (host, port) = parse_endpoint("1.2.3.4:51820").unwrap();
252        assert_eq!(host, "1.2.3.4");
253        assert_eq!(port, 51820);
254
255        // Hostname
256        let (host, port) = parse_endpoint("example.com:51820").unwrap();
257        assert_eq!(host, "example.com");
258        assert_eq!(port, 51820);
259
260        // IPv6
261        let (host, port) = parse_endpoint("[::1]:51820").unwrap();
262        assert_eq!(host, "::1");
263        assert_eq!(port, 51820);
264    }
265}