wireguard_control/backends/
userspace.rs

1use crate::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerInfo};
2
3use std::{
4    fmt::Write as _,
5    fs,
6    io::{self, prelude::*, BufReader},
7    os::unix::net::UnixStream,
8    path::{Path, PathBuf},
9    process::{Command, Output},
10    time::{Duration, SystemTime},
11};
12
13static VAR_RUN_PATH: &str = "/var/run/wireguard";
14static RUN_PATH: &str = "/run/wireguard";
15
16fn get_base_folder() -> io::Result<PathBuf> {
17    if Path::new(VAR_RUN_PATH).exists() {
18        Ok(Path::new(VAR_RUN_PATH).to_path_buf())
19    } else if Path::new(RUN_PATH).exists() {
20        Ok(Path::new(RUN_PATH).to_path_buf())
21    } else {
22        Err(io::Error::new(
23            io::ErrorKind::NotFound,
24            "WireGuard socket directory not found.",
25        ))
26    }
27}
28
29fn get_namefile(name: &InterfaceName) -> io::Result<PathBuf> {
30    Ok(get_base_folder()?.join(format!("{}.name", name.as_str_lossy())))
31}
32
33fn get_socketfile(name: &InterfaceName) -> io::Result<PathBuf> {
34    if cfg!(target_os = "linux") {
35        Ok(get_base_folder()?.join(format!("{name}.sock")))
36    } else {
37        Ok(get_base_folder()?.join(format!("{}.sock", resolve_tun(name)?)))
38    }
39}
40
41fn open_socket(name: &InterfaceName) -> io::Result<UnixStream> {
42    UnixStream::connect(get_socketfile(name)?)
43}
44
45pub fn resolve_tun(name: &InterfaceName) -> io::Result<String> {
46    let namefile = get_namefile(name)?;
47    Ok(fs::read_to_string(namefile)
48        .map_err(|_| io::Error::new(io::ErrorKind::NotFound, "WireGuard name file can't be read"))?
49        .trim()
50        .to_string())
51}
52
53pub fn delete_interface(name: &InterfaceName) -> io::Result<()> {
54    fs::remove_file(get_socketfile(name)?).ok();
55    fs::remove_file(get_namefile(name)?).ok();
56
57    Ok(())
58}
59
60pub fn enumerate() -> Result<Vec<InterfaceName>, io::Error> {
61    use std::ffi::OsStr;
62
63    let mut interfaces = vec![];
64    for entry in fs::read_dir(get_base_folder()?)? {
65        let path = entry?.path();
66        if path.extension() == Some(OsStr::new("name")) {
67            let stem = path
68                .file_stem()
69                .and_then(|stem| stem.to_str())
70                .and_then(|name| name.parse::<InterfaceName>().ok())
71                .filter(|iface| open_socket(iface).is_ok());
72            if let Some(iface) = stem {
73                interfaces.push(iface);
74            }
75        }
76    }
77
78    Ok(interfaces)
79}
80
81struct ConfigParser {
82    device_info: Device,
83    current_peer: Option<PeerInfo>,
84}
85
86impl From<ConfigParser> for Device {
87    fn from(parser: ConfigParser) -> Self {
88        parser.device_info
89    }
90}
91
92impl ConfigParser {
93    /// Returns `None` if an invalid device name was provided.
94    fn new(name: &InterfaceName) -> Self {
95        let device_info = Device {
96            name: *name,
97            public_key: None,
98            private_key: None,
99            fwmark: None,
100            listen_port: None,
101            peers: vec![],
102            linked_name: resolve_tun(name).ok(),
103            backend: Backend::Userspace,
104        };
105
106        Self {
107            device_info,
108            current_peer: None,
109        }
110    }
111
112    fn add_line(&mut self, line: &str) -> Result<(), std::io::Error> {
113        use io::ErrorKind::InvalidData;
114
115        let split: Vec<&str> = line.splitn(2, '=').collect();
116        match &split[..] {
117            [key, value] => self.add_pair(key, value),
118            _ => Err(InvalidData.into()),
119        }
120    }
121
122    fn add_pair(&mut self, key: &str, value: &str) -> Result<(), std::io::Error> {
123        use io::ErrorKind::InvalidData;
124
125        match key {
126            "private_key" => {
127                self.device_info.private_key = Some(Key::from_hex(value).map_err(|_| InvalidData)?);
128                self.device_info.public_key = self
129                    .device_info
130                    .private_key
131                    .as_ref()
132                    .map(|k| k.get_public());
133            },
134            "listen_port" => {
135                self.device_info.listen_port = Some(value.parse().map_err(|_| InvalidData)?)
136            },
137            "fwmark" => self.device_info.fwmark = Some(value.parse().map_err(|_| InvalidData)?),
138            "public_key" => {
139                let new_peer =
140                    PeerInfo::from_public_key(Key::from_hex(value).map_err(|_| InvalidData)?);
141
142                if let Some(finished_peer) = self.current_peer.replace(new_peer) {
143                    self.device_info.peers.push(finished_peer);
144                }
145            },
146            "preshared_key" => {
147                self.current_peer
148                    .as_mut()
149                    .ok_or(InvalidData)?
150                    .config
151                    .preshared_key = Some(Key::from_hex(value).map_err(|_| InvalidData)?);
152            },
153            "tx_bytes" => {
154                self.current_peer
155                    .as_mut()
156                    .ok_or(InvalidData)?
157                    .stats
158                    .tx_bytes = value.parse().map_err(|_| InvalidData)?
159            },
160            "rx_bytes" => {
161                self.current_peer
162                    .as_mut()
163                    .ok_or(InvalidData)?
164                    .stats
165                    .rx_bytes = value.parse().map_err(|_| InvalidData)?
166            },
167            "last_handshake_time_sec" => {
168                let handshake_seconds: u64 = value.parse().map_err(|_| InvalidData)?;
169
170                if handshake_seconds > 0 {
171                    self.current_peer
172                        .as_mut()
173                        .ok_or(InvalidData)?
174                        .stats
175                        .last_handshake_time =
176                        Some(SystemTime::UNIX_EPOCH + Duration::from_secs(handshake_seconds));
177                }
178            },
179            "allowed_ip" => {
180                self.current_peer
181                    .as_mut()
182                    .ok_or(InvalidData)?
183                    .config
184                    .allowed_ips
185                    .push(value.parse().map_err(|_| InvalidData)?);
186            },
187            "persistent_keepalive_interval" => {
188                self.current_peer
189                    .as_mut()
190                    .ok_or(InvalidData)?
191                    .config
192                    .persistent_keepalive_interval = Some(value.parse().map_err(|_| InvalidData)?);
193            },
194            "endpoint" => {
195                self.current_peer
196                    .as_mut()
197                    .ok_or(InvalidData)?
198                    .config
199                    .endpoint = Some(value.parse().map_err(|_| InvalidData)?);
200            },
201            "errno" => {
202                // "errno" indicates an end of the stream, along with the error return code.
203                if value != "0" {
204                    return Err(std::io::Error::from_raw_os_error(
205                        value
206                            .parse()
207                            .expect("Unable to parse userspace wg errno return code"),
208                    ));
209                }
210
211                if let Some(finished_peer) = self.current_peer.take() {
212                    self.device_info.peers.push(finished_peer);
213                }
214            },
215            "protocol_version" | "last_handshake_time_nsec" => {},
216            _ => println!("got unsupported info: {key}={value}"),
217        }
218
219        Ok(())
220    }
221}
222
223pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> {
224    let mut sock = open_socket(name)?;
225    sock.write_all(b"get=1\n\n")?;
226    let mut reader = BufReader::new(sock);
227    let mut buf = String::new();
228
229    let mut parser = ConfigParser::new(name);
230
231    loop {
232        match reader.read_line(&mut buf)? {
233            0 | 1 if buf == "\n" => break,
234            _ => {
235                parser.add_line(buf.trim_end())?;
236                buf.clear();
237            },
238        };
239    }
240
241    Ok(parser.into())
242}
243
244/// Following the rough logic of wg-quick(8), use the wireguard-go userspace
245/// implementation by default, but allow for an environment variable to choose
246/// a different implementation.
247///
248/// wgctrl-rs will look for WG_USERSPACE_IMPLEMENTATION first, but will also
249/// respect the WG_QUICK_USERSPACE_IMPLEMENTATION choice if the former isn't
250/// available.
251fn get_userspace_implementation() -> String {
252    std::env::var("WG_USERSPACE_IMPLEMENTATION")
253        .or_else(|_| std::env::var("WG_QUICK_USERSPACE_IMPLEMENTATION"))
254        .unwrap_or_else(|_| "wireguard-go".to_string())
255}
256
257fn start_userspace_wireguard(iface: &InterfaceName) -> io::Result<Output> {
258    let userspace_implementation = get_userspace_implementation();
259    let mut command = Command::new(&userspace_implementation);
260
261    let output_res = if cfg!(target_os = "linux") {
262        command.args(&[iface.to_string()]).output()
263    } else {
264        command
265            .env("WG_TUN_NAME_FILE", format!("{VAR_RUN_PATH}/{iface}.name"))
266            .args(["utun"])
267            .output()
268    };
269
270    match output_res {
271        Ok(output) => {
272            if output.status.success() {
273                Ok(output)
274            } else {
275                Err(io::ErrorKind::AddrNotAvailable.into())
276            }
277        },
278        Err(e) if e.kind() == io::ErrorKind::NotFound => Err(io::Error::new(
279            io::ErrorKind::NotFound,
280            format!("Cannot find \"{userspace_implementation}\". Specify a custom path with WG_USERSPACE_IMPLEMENTATION."),
281        )),
282        Err(e) => Err(e),
283    }
284}
285
286pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> {
287    // If we can't open a configuration socket to an existing interface, try starting it.
288    let mut sock = match open_socket(iface) {
289        Err(_) => {
290            fs::create_dir_all(VAR_RUN_PATH)?;
291            // Clear out any old namefiles if they didn't lead to a connected socket.
292            let _ = fs::remove_file(get_namefile(iface)?);
293            start_userspace_wireguard(iface)?;
294            std::thread::sleep(Duration::from_millis(100));
295            open_socket(iface)
296                .map_err(|e| io::Error::new(e.kind(), format!("failed to open socket ({e})")))?
297        },
298        Ok(sock) => sock,
299    };
300
301    let mut request = String::from("set=1\n");
302
303    if let Some(ref k) = builder.private_key {
304        writeln!(request, "private_key={}", hex::encode(k.as_bytes())).ok();
305    }
306
307    if let Some(f) = builder.fwmark {
308        writeln!(request, "fwmark={f}").ok();
309    }
310
311    if let Some(f) = builder.listen_port {
312        writeln!(request, "listen_port={f}").ok();
313    }
314
315    if builder.replace_peers {
316        writeln!(request, "replace_peers=true").ok();
317    }
318
319    for peer in &builder.peers {
320        writeln!(
321            request,
322            "public_key={}",
323            hex::encode(peer.public_key.as_bytes())
324        )
325        .ok();
326
327        if peer.replace_allowed_ips {
328            writeln!(request, "replace_allowed_ips=true").ok();
329        }
330
331        if peer.remove_me {
332            writeln!(request, "remove=true").ok();
333        }
334
335        if let Some(ref k) = peer.preshared_key {
336            writeln!(request, "preshared_key={}", hex::encode(k.as_bytes())).ok();
337        }
338
339        if let Some(endpoint) = peer.endpoint {
340            writeln!(request, "endpoint={endpoint}").ok();
341        }
342
343        if let Some(keepalive_interval) = peer.persistent_keepalive_interval {
344            writeln!(
345                request,
346                "persistent_keepalive_interval={keepalive_interval}"
347            )
348            .ok();
349        }
350
351        for allowed_ip in &peer.allowed_ips {
352            writeln!(
353                request,
354                "allowed_ip={}/{}",
355                allowed_ip.address, allowed_ip.cidr
356            )
357            .ok();
358        }
359    }
360
361    request.push('\n');
362
363    sock.write_all(request.as_bytes())?;
364
365    let mut reader = BufReader::new(sock);
366    let mut line = String::new();
367
368    reader.read_line(&mut line)?;
369    let split: Vec<&str> = line.trim_end().splitn(2, '=').collect();
370    match &split[..] {
371        ["errno", "0"] => Ok(()),
372        ["errno", val] => {
373            println!("ERROR {val}");
374            Err(io::ErrorKind::InvalidInput.into())
375        },
376        _ => Err(io::ErrorKind::Other.into()),
377    }
378}