wireguard_control/backends/
kernel.rs

1use crate::{
2    device::AllowedIp, Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfig,
3    PeerConfigBuilder, PeerInfo, PeerStats,
4};
5use netlink_packet_core::{
6    NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_DUMP, NLM_F_EXCL, NLM_F_REQUEST,
7};
8use netlink_packet_generic::GenlMessage;
9use netlink_packet_route::{
10    link::{self, InfoKind, LinkInfo, LinkMessage},
11    RouteNetlinkMessage,
12};
13use netlink_packet_utils::traits::Emitable;
14use netlink_packet_wireguard::{
15    self,
16    constants::{
17        AF_INET, AF_INET6, WGDEVICE_F_REPLACE_PEERS, WGPEER_F_REMOVE_ME,
18        WGPEER_F_REPLACE_ALLOWEDIPS,
19    },
20    nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs},
21    Wireguard, WireguardCmd,
22};
23use netlink_request::{max_genl_payload_length, netlink_request_genl, netlink_request_rtnl};
24
25use std::{convert::TryFrom, io};
26
27macro_rules! get_nla_value {
28    ($nlas:expr, $e:ident, $v:ident) => {
29        $nlas.iter().find_map(|attr| match attr {
30            $e::$v(value) => Some(value),
31            _ => None,
32        })
33    };
34}
35
36impl TryFrom<WgAllowedIp> for AllowedIp {
37    type Error = io::Error;
38
39    fn try_from(attrs: WgAllowedIp) -> Result<Self, Self::Error> {
40        let address = *get_nla_value!(attrs, WgAllowedIpAttrs, IpAddr)
41            .ok_or_else(|| io::ErrorKind::NotFound)?;
42        let cidr = *get_nla_value!(attrs, WgAllowedIpAttrs, Cidr)
43            .ok_or_else(|| io::ErrorKind::NotFound)?;
44        Ok(AllowedIp { address, cidr })
45    }
46}
47
48impl AllowedIp {
49    fn to_nla(&self) -> WgAllowedIp {
50        WgAllowedIp(vec![
51            WgAllowedIpAttrs::Family(if self.address.is_ipv4() {
52                AF_INET
53            } else {
54                AF_INET6
55            }),
56            WgAllowedIpAttrs::IpAddr(self.address),
57            WgAllowedIpAttrs::Cidr(self.cidr),
58        ])
59    }
60}
61
62impl PeerConfigBuilder {
63    fn to_nla(&self) -> WgPeer {
64        let mut attrs = vec![WgPeerAttrs::PublicKey(self.public_key.0)];
65        let mut flags = 0u32;
66        if let Some(endpoint) = self.endpoint {
67            attrs.push(WgPeerAttrs::Endpoint(endpoint));
68        }
69        if let Some(ref key) = self.preshared_key {
70            attrs.push(WgPeerAttrs::PresharedKey(key.0));
71        }
72        if let Some(i) = self.persistent_keepalive_interval {
73            attrs.push(WgPeerAttrs::PersistentKeepalive(i));
74        }
75        let allowed_ips: Vec<_> = self.allowed_ips.iter().map(AllowedIp::to_nla).collect();
76        attrs.push(WgPeerAttrs::AllowedIps(allowed_ips));
77        if self.remove_me {
78            flags |= WGPEER_F_REMOVE_ME;
79        }
80        if self.replace_allowed_ips {
81            flags |= WGPEER_F_REPLACE_ALLOWEDIPS;
82        }
83        if flags != 0 {
84            attrs.push(WgPeerAttrs::Flags(flags));
85        }
86        WgPeer(attrs)
87    }
88}
89
90impl TryFrom<WgPeer> for PeerInfo {
91    type Error = io::Error;
92
93    fn try_from(attrs: WgPeer) -> Result<Self, Self::Error> {
94        let public_key = get_nla_value!(attrs, WgPeerAttrs, PublicKey)
95            .map(|key| Key(*key))
96            .ok_or(io::ErrorKind::NotFound)?;
97        let preshared_key = get_nla_value!(attrs, WgPeerAttrs, PresharedKey).map(|key| Key(*key));
98        let endpoint = get_nla_value!(attrs, WgPeerAttrs, Endpoint).cloned();
99        let persistent_keepalive_interval =
100            get_nla_value!(attrs, WgPeerAttrs, PersistentKeepalive).cloned();
101        let allowed_ips = get_nla_value!(attrs, WgPeerAttrs, AllowedIps)
102            .cloned()
103            .unwrap_or_default()
104            .into_iter()
105            .map(AllowedIp::try_from)
106            .collect::<Result<Vec<_>, _>>()?;
107        let last_handshake_time = get_nla_value!(attrs, WgPeerAttrs, LastHandshake).cloned();
108        let rx_bytes = get_nla_value!(attrs, WgPeerAttrs, RxBytes)
109            .cloned()
110            .unwrap_or_default();
111        let tx_bytes = get_nla_value!(attrs, WgPeerAttrs, TxBytes)
112            .cloned()
113            .unwrap_or_default();
114        Ok(PeerInfo {
115            config: PeerConfig {
116                public_key,
117                preshared_key,
118                endpoint,
119                persistent_keepalive_interval,
120                allowed_ips,
121            },
122            stats: PeerStats {
123                last_handshake_time,
124                rx_bytes,
125                tx_bytes,
126            },
127        })
128    }
129}
130
131impl<'a> TryFrom<&'a [WgDeviceAttrs]> for Device {
132    type Error = io::Error;
133
134    fn try_from(nlas: &'a [WgDeviceAttrs]) -> Result<Self, Self::Error> {
135        let name = get_nla_value!(nlas, WgDeviceAttrs, IfName)
136            .ok_or_else(|| io::ErrorKind::NotFound)?
137            .parse()?;
138        let public_key = get_nla_value!(nlas, WgDeviceAttrs, PublicKey).map(|key| Key(*key));
139        let private_key = get_nla_value!(nlas, WgDeviceAttrs, PrivateKey).map(|key| Key(*key));
140        let listen_port = get_nla_value!(nlas, WgDeviceAttrs, ListenPort).cloned();
141        let fwmark = get_nla_value!(nlas, WgDeviceAttrs, Fwmark).cloned();
142        let peers = nlas
143            .iter()
144            .filter_map(|nla| match nla {
145                WgDeviceAttrs::Peers(peers) => Some(peers.clone()),
146                _ => None,
147            })
148            .flatten()
149            .map(PeerInfo::try_from)
150            .collect::<Result<Vec<_>, _>>()?;
151        Ok(Device {
152            name,
153            public_key,
154            private_key,
155            listen_port,
156            fwmark,
157            peers,
158            linked_name: None,
159            backend: Backend::Kernel,
160        })
161    }
162}
163
164pub fn enumerate() -> Result<Vec<InterfaceName>, io::Error> {
165    let link_responses = netlink_request_rtnl(
166        RouteNetlinkMessage::GetLink(LinkMessage::default()),
167        Some(NLM_F_DUMP | NLM_F_REQUEST),
168    )?;
169    let links = link_responses
170        .into_iter()
171        // Filter out non-link messages
172        .filter_map(|response| match response {
173            NetlinkMessage {
174                payload: NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewLink(link)),
175                ..
176            } => Some(link),
177            _ => None,
178        })
179        .filter(|link| {
180            for nla in link.attributes.iter() {
181                if let link::LinkAttribute::LinkInfo(infos) = nla {
182                    return infos.iter().any(|info| info == &LinkInfo::Kind(InfoKind::Wireguard))
183                }
184            }
185            false
186        })
187        .filter_map(|link| link.attributes.iter().find_map(|nla| match nla {
188            link::LinkAttribute::IfName(name) => Some(name.clone()),
189            _ => None,
190        }))
191        .filter_map(|name| name.parse().ok())
192        .collect::<Vec<_>>();
193
194    Ok(links)
195}
196
197fn add_del(iface: &InterfaceName, add: bool) -> io::Result<()> {
198    let mut message = LinkMessage::default();
199    message.attributes.push(link::LinkAttribute::IfName(
200        iface.as_str_lossy().to_string(),
201    ));
202    message
203        .attributes
204        .push(link::LinkAttribute::LinkInfo(vec![LinkInfo::Kind(
205            link::InfoKind::Wireguard,
206        )]));
207    let extra_flags = if add { NLM_F_CREATE | NLM_F_EXCL } else { 0 };
208    let rtnl_message = if add {
209        RouteNetlinkMessage::NewLink(message)
210    } else {
211        RouteNetlinkMessage::DelLink(message)
212    };
213    match netlink_request_rtnl(rtnl_message, Some(NLM_F_REQUEST | NLM_F_ACK | extra_flags)) {
214        Err(e) if e.kind() != io::ErrorKind::AlreadyExists => Err(e),
215        _ => Ok(()),
216    }
217}
218
219pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> {
220    add_del(iface, true)?;
221    let mut payload = ApplyPayload::new(iface);
222    if let Some(Key(k)) = builder.private_key {
223        payload.push(WgDeviceAttrs::PrivateKey(k))?;
224    }
225    if let Some(f) = builder.fwmark {
226        payload.push(WgDeviceAttrs::Fwmark(f))?;
227    }
228    if let Some(f) = builder.listen_port {
229        payload.push(WgDeviceAttrs::ListenPort(f))?;
230    }
231    if builder.replace_peers {
232        payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))?;
233    }
234
235    builder
236        .peers
237        .iter()
238        .map(|peer| payload.push_peer(peer.to_nla()))
239        .collect::<Result<Vec<_>, _>>()?;
240
241    for message in payload.finish() {
242        netlink_request_genl(message, Some(NLM_F_REQUEST | NLM_F_ACK))?;
243    }
244    Ok(())
245}
246
247struct ApplyPayload {
248    iface: String,
249    nlas: Vec<WgDeviceAttrs>,
250    current_buffer_len: usize,
251    queue: Vec<GenlMessage<Wireguard>>,
252}
253
254impl ApplyPayload {
255    fn new(iface: &InterfaceName) -> Self {
256        let iface_str = iface.as_str_lossy().to_string();
257        let nlas = vec![WgDeviceAttrs::IfName(iface_str.clone())];
258        let current_buffer_len = nlas.as_slice().buffer_len();
259        Self {
260            iface: iface_str,
261            nlas,
262            queue: vec![],
263            current_buffer_len,
264        }
265    }
266
267    fn flush_nlas(&mut self) {
268        // // cleanup: clear out any empty peer lists.
269        self.nlas
270            .retain(|nla| !matches!(nla, WgDeviceAttrs::Peers(peers) if peers.is_empty()));
271
272        let name = WgDeviceAttrs::IfName(self.iface.clone());
273        let template = vec![name];
274
275        if !self.nlas.is_empty() && self.nlas != template {
276            self.current_buffer_len = template.as_slice().buffer_len();
277            let message = GenlMessage::from_payload(Wireguard {
278                cmd: WireguardCmd::SetDevice,
279                nlas: std::mem::replace(&mut self.nlas, template),
280            });
281            self.queue.push(message);
282        }
283    }
284
285    /// Push a device attribute which will be optimally packed into 1 or more netlink messages
286    pub fn push(&mut self, nla: WgDeviceAttrs) -> io::Result<()> {
287        let max_payload_len = max_genl_payload_length();
288
289        let nla_buffer_len = nla.buffer_len();
290        if (self.current_buffer_len + nla_buffer_len) > max_payload_len {
291            self.flush_nlas();
292        }
293
294        // If the NLA *still* doesn't fit...
295        if (self.current_buffer_len + nla_buffer_len) > max_payload_len {
296            return Err(io::Error::new(
297                io::ErrorKind::InvalidInput,
298                format!("encoded NLA ({nla_buffer_len} bytes) is too large: {nla:?}"),
299            ));
300        }
301        self.nlas.push(nla);
302        self.current_buffer_len += nla_buffer_len;
303        Ok(())
304    }
305
306    /// A helper function to assist in breaking up large peer lists across multiple netlink messages
307    pub fn push_peer(&mut self, peer: WgPeer) -> io::Result<()> {
308        const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]);
309        let max_payload_len = max_genl_payload_length();
310        let mut needs_peer_nla = !self
311            .nlas
312            .iter()
313            .any(|nla| matches!(nla, WgDeviceAttrs::Peers(_)));
314        let peer_buffer_len = peer.buffer_len();
315        let mut additional_buffer_len = peer_buffer_len;
316        if needs_peer_nla {
317            additional_buffer_len += EMPTY_PEERS.buffer_len();
318        }
319        if (self.current_buffer_len + additional_buffer_len) > max_payload_len {
320            self.flush_nlas();
321            needs_peer_nla = true;
322        }
323
324        if needs_peer_nla {
325            self.push(EMPTY_PEERS)?;
326        }
327
328        // If the peer *still* doesn't fit...
329        if (self.current_buffer_len + peer_buffer_len) > max_payload_len {
330            return Err(io::Error::new(
331                io::ErrorKind::InvalidInput,
332                format!("encoded peer ({peer_buffer_len} bytes) is too large: {peer:?}"),
333            ));
334        }
335
336        let peers_nla = self
337            .nlas
338            .iter_mut()
339            .find_map(|nla| match nla {
340                WgDeviceAttrs::Peers(peers) => Some(peers),
341                _ => None,
342            })
343            .expect("WgDeviceAttrs::Peers missing from NLAs when it should exist.");
344
345        peers_nla.push(peer);
346        self.current_buffer_len += peer_buffer_len;
347
348        Ok(())
349    }
350
351    pub fn finish(mut self) -> Vec<GenlMessage<Wireguard>> {
352        self.flush_nlas();
353        self.queue
354    }
355}
356
357pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> {
358    let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard {
359        cmd: WireguardCmd::GetDevice,
360        nlas: vec![WgDeviceAttrs::IfName(name.as_str_lossy().to_string())],
361    });
362    let responses = netlink_request_genl(genlmsg, Some(NLM_F_REQUEST | NLM_F_DUMP | NLM_F_ACK))?;
363    log::debug!(
364        "get_by_name: got {} response message(s) from netlink request",
365        responses.len()
366    );
367
368    let nlas = responses.into_iter().try_fold(vec![], |mut nlas, nlmsg| {
369        let mut message = match nlmsg {
370            NetlinkMessage {
371                payload: NetlinkPayload::InnerMessage(message),
372                ..
373            } => message,
374            _ => {
375                return Err(io::Error::new(
376                    io::ErrorKind::InvalidData,
377                    format!("unexpected netlink payload: {nlmsg:?}"),
378                ))
379            },
380        };
381        nlas.append(&mut message.payload.nlas);
382        Ok(nlas)
383    })?;
384    let device = Device::try_from(&nlas[..])?;
385    log::debug!(
386        "get_by_name: parsed wireguard device {} with {} peer(s)",
387        device.name,
388        device.peers.len(),
389    );
390    Ok(device)
391}
392
393pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
394    add_del(iface, false)
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use netlink_packet_wireguard::nlas::WgAllowedIp;
401    use netlink_request::max_netlink_buffer_length;
402    use std::str::FromStr;
403
404    #[test]
405    fn test_simple_payload() {
406        let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
407        payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])).unwrap();
408        payload.push(WgDeviceAttrs::Fwmark(111)).unwrap();
409        payload.push(WgDeviceAttrs::ListenPort(12345)).unwrap();
410        payload
411            .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))
412            .unwrap();
413        payload
414            .push_peer(WgPeer(vec![
415                WgPeerAttrs::PublicKey([2u8; 32]),
416                WgPeerAttrs::PersistentKeepalive(25),
417                WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
418                WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
419                WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
420                    WgAllowedIpAttrs::Family(AF_INET),
421                    WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
422                    WgAllowedIpAttrs::Cidr(24),
423                ])]),
424            ]))
425            .unwrap();
426        assert_eq!(payload.finish().len(), 1);
427    }
428
429    #[test]
430    fn test_massive_payload() {
431        let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
432        payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])).unwrap();
433        payload.push(WgDeviceAttrs::Fwmark(111)).unwrap();
434        payload.push(WgDeviceAttrs::ListenPort(12345)).unwrap();
435        payload
436            .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))
437            .unwrap();
438
439        for i in 0..10_000 {
440            payload
441                .push_peer(WgPeer(vec![
442                    WgPeerAttrs::PublicKey([2u8; 32]),
443                    WgPeerAttrs::PersistentKeepalive(25),
444                    WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
445                    WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
446                    WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
447                        WgAllowedIpAttrs::Family(AF_INET),
448                        WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
449                        WgAllowedIpAttrs::Cidr(24),
450                    ])]),
451                    WgPeerAttrs::Unspec(vec![1u8; (i % 256) as usize]),
452                ]))
453                .unwrap();
454        }
455
456        let messages = payload.finish();
457        println!("generated {} messages", messages.len());
458        assert!(messages.len() > 1);
459        let max_buffer_len = max_netlink_buffer_length();
460        for message in messages {
461            assert!(NetlinkMessage::from(message).buffer_len() <= max_buffer_len);
462        }
463    }
464}