wireguard_netstack/
config.rs

1//! WireGuard configuration file parser.
2//!
3//! Parses standard WireGuard configuration files (.conf) using serde.
4
5use base64::{engine::general_purpose::STANDARD, Engine};
6use serde::Deserialize;
7use std::net::Ipv4Addr;
8use std::path::Path;
9
10use crate::dns::{DohResolver, DohServerConfig};
11use crate::error::{Error, Result};
12use crate::wireguard::WireGuardConfig;
13
14/// Raw WireGuard configuration as parsed from the INI file.
15#[derive(Debug, Deserialize)]
16#[serde(rename_all = "PascalCase")]
17struct RawWgConfig {
18    interface: InterfaceSection,
19    peer: PeerSection,
20}
21
22#[derive(Debug, Deserialize)]
23#[serde(rename_all = "PascalCase")]
24struct InterfaceSection {
25    private_key: String,
26    address: String,
27    #[serde(rename = "DNS")]
28    dns: Option<String>,
29    #[serde(rename = "MTU")]
30    mtu: Option<u16>,
31}
32
33#[derive(Debug, Deserialize)]
34#[serde(rename_all = "PascalCase")]
35struct PeerSection {
36    public_key: String,
37    endpoint: String,
38    #[serde(default)]
39    preshared_key: Option<String>,
40    #[serde(default)]
41    persistent_keepalive: Option<u16>,
42    #[allow(dead_code)]
43    allowed_ips: Option<String>,
44}
45
46/// Parsed WireGuard configuration file.
47#[derive(Debug, Clone)]
48pub struct WgConfigFile {
49    /// Private key (base64 encoded in file).
50    pub private_key: [u8; 32],
51    /// Interface address (tunnel IP).
52    pub address: Ipv4Addr,
53    /// DNS server (optional).
54    pub dns: Option<Ipv4Addr>,
55    /// MTU for the tunnel interface (optional, defaults to 460).
56    pub mtu: Option<u16>,
57    /// Peer public key.
58    pub peer_public_key: [u8; 32],
59    /// Peer endpoint hostname or IP.
60    pub endpoint_host: String,
61    /// Peer endpoint port.
62    pub endpoint_port: u16,
63    /// Preshared key (optional).
64    pub preshared_key: Option<[u8; 32]>,
65    /// Persistent keepalive interval in seconds (optional).
66    pub persistent_keepalive: Option<u16>,
67}
68
69impl WgConfigFile {
70    /// Parse a WireGuard configuration file from the given path.
71    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
72        let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
73            Error::ConfigParse(format!("Failed to read config file {:?}: {}", path.as_ref(), e))
74        })?;
75        Self::parse(&content)
76    }
77
78    /// Parse a WireGuard configuration from a string.
79    pub fn parse(content: &str) -> Result<Self> {
80        let raw: RawWgConfig =
81            serde_ini::from_str(content).map_err(|e| Error::ConfigParse(e.to_string()))?;
82
83        // Decode private key
84        let private_key = decode_key(&raw.interface.private_key)?;
85
86        // Parse address (strip CIDR notation if present)
87        let ip_str = raw
88            .interface
89            .address
90            .split('/')
91            .next()
92            .unwrap_or(&raw.interface.address);
93        let address: Ipv4Addr = ip_str
94            .parse()
95            .map_err(|_| Error::InvalidAddress(raw.interface.address.clone()))?;
96
97        // Parse DNS (take first if comma-separated)
98        let dns = raw
99            .interface
100            .dns
101            .as_ref()
102            .and_then(|d| d.split(',').next())
103            .map(|s| s.trim().parse())
104            .transpose()
105            .map_err(|_| Error::InvalidAddress("Invalid DNS address".into()))?;
106
107        // Decode peer public key
108        let peer_public_key = decode_key(&raw.peer.public_key)?;
109
110        // Parse endpoint
111        let (endpoint_host, endpoint_port) = parse_endpoint(&raw.peer.endpoint)?;
112
113        // Decode preshared key if present
114        let preshared_key = raw
115            .peer
116            .preshared_key
117            .as_ref()
118            .map(|k| decode_key(k))
119            .transpose()?;
120
121        Ok(Self {
122            private_key,
123            address,
124            dns,
125            mtu: raw.interface.mtu,
126            peer_public_key,
127            endpoint_host,
128            endpoint_port,
129            preshared_key,
130            persistent_keepalive: raw.peer.persistent_keepalive,
131        })
132    }
133
134    /// Convert to WireGuardConfig, resolving the endpoint hostname via DoH if needed.
135    /// Uses the default Cloudflare DNS for resolution.
136    pub async fn into_wireguard_config(self) -> Result<WireGuardConfig> {
137        self.into_wireguard_config_with_dns(DohServerConfig::default()).await
138    }
139
140    /// Convert to WireGuardConfig, resolving the endpoint hostname via DoH with custom DNS.
141    ///
142    /// # Arguments
143    ///
144    /// * `dns_config` - The DNS server configuration to use for resolving the endpoint hostname.
145    ///
146    /// # Example
147    ///
148    /// ```no_run
149    /// use wireguard_netstack::{WgConfigFile, DohServerConfig};
150    ///
151    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
152    /// let config = WgConfigFile::from_file("wg.conf")?
153    ///     .into_wireguard_config_with_dns(DohServerConfig::google())
154    ///     .await?;
155    /// # Ok(())
156    /// # }
157    /// ```
158    pub async fn into_wireguard_config_with_dns(self, dns_config: DohServerConfig) -> Result<WireGuardConfig> {
159        // First try to parse as IP:port, otherwise resolve via DoH
160        let peer_endpoint =
161            match format!("{}:{}", self.endpoint_host, self.endpoint_port).parse() {
162                Ok(addr) => addr,
163                Err(_) => {
164                    // Resolve using DNS-over-HTTPS (direct mode, before tunnel is up)
165                    log::info!(
166                        "Resolving WireGuard endpoint '{}' via DoH ({})...",
167                        self.endpoint_host,
168                        dns_config.hostname
169                    );
170                    let doh_resolver = DohResolver::new_direct_with_config(dns_config);
171                    doh_resolver
172                        .resolve_addr(&self.endpoint_host, self.endpoint_port)
173                        .await?
174                }
175            };
176
177        log::info!("WireGuard endpoint resolved to: {}", peer_endpoint);
178
179        Ok(WireGuardConfig {
180            private_key: self.private_key,
181            peer_public_key: self.peer_public_key,
182            peer_endpoint,
183            tunnel_ip: self.address,
184            preshared_key: self.preshared_key,
185            keepalive_seconds: self.persistent_keepalive.or(Some(25)), // Default to 25s if not specified
186            mtu: self.mtu,
187        })
188    }
189}
190
191/// Decode a base64-encoded 32-byte key.
192fn decode_key(b64: &str) -> Result<[u8; 32]> {
193    let bytes = STANDARD
194        .decode(b64)
195        .map_err(|_| Error::InvalidKey(b64.to_string()))?;
196    bytes
197        .try_into()
198        .map_err(|v: Vec<u8>| Error::InvalidKey(format!("Key must be 32 bytes, got {} bytes", v.len())))
199}
200
201/// Parse an endpoint string (host:port).
202fn parse_endpoint(endpoint: &str) -> Result<(String, u16)> {
203    // Handle IPv6 addresses in brackets: [::1]:51820
204    if endpoint.starts_with('[') {
205        if let Some(bracket_end) = endpoint.find(']') {
206            let host = endpoint[1..bracket_end].to_string();
207            let port_str = endpoint[bracket_end + 1..].trim_start_matches(':');
208            let port: u16 = port_str
209                .parse()
210                .map_err(|_| Error::InvalidEndpoint(endpoint.to_string()))?;
211            return Ok((host, port));
212        }
213    }
214
215    // Handle hostname:port or IPv4:port
216    if let Some((host, port_str)) = endpoint.rsplit_once(':') {
217        let port: u16 = port_str
218            .parse()
219            .map_err(|_| Error::InvalidEndpoint(endpoint.to_string()))?;
220        Ok((host.to_string(), port))
221    } else {
222        Err(Error::InvalidEndpoint(format!(
223            "Invalid endpoint format (expected host:port): {}",
224            endpoint
225        )))
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_parse_config() {
235        let config_str = r#"
236[Interface]
237PrivateKey = eC3sErLXd5A7z3FTJnrb55uuxlazlDM40HQmWZrb6Vc=
238Address = 192.168.3.4/32
239DNS = 192.168.3.1
240
241[Peer]
242PublicKey = EISEG38ycR6D7nK0m+mnacAM9HfXzdqcO1mO5LNs6jU=
243AllowedIPs = 0.0.0.0/0
244Endpoint = direct.casarizzotti.com:51820
245"#;
246
247        let config = WgConfigFile::parse(config_str).unwrap();
248        assert_eq!(config.address, "192.168.3.4".parse::<Ipv4Addr>().unwrap());
249        assert_eq!(config.dns, Some("192.168.3.1".parse().unwrap()));
250        assert_eq!(config.endpoint_host, "direct.casarizzotti.com");
251        assert_eq!(config.endpoint_port, 51820);
252    }
253
254    #[test]
255    fn test_parse_endpoint() {
256        // IPv4
257        let (host, port) = parse_endpoint("1.2.3.4:51820").unwrap();
258        assert_eq!(host, "1.2.3.4");
259        assert_eq!(port, 51820);
260
261        // Hostname
262        let (host, port) = parse_endpoint("example.com:51820").unwrap();
263        assert_eq!(host, "example.com");
264        assert_eq!(port, 51820);
265
266        // IPv6
267        let (host, port) = parse_endpoint("[::1]:51820").unwrap();
268        assert_eq!(host, "::1");
269        assert_eq!(port, 51820);
270    }
271}