1use crate::{
2 device::AllowedIp, Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfig,
3 PeerConfigBuilder, PeerInfo, PeerStats,
4};
5use netlink_packet_core::{
6 NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_DUMP, NLM_F_EXCL, NLM_F_REQUEST,
7};
8use netlink_packet_generic::GenlMessage;
9use netlink_packet_route::{
10 link::{self, InfoKind, LinkInfo, LinkMessage},
11 RouteNetlinkMessage,
12};
13use netlink_packet_utils::traits::Emitable;
14use netlink_packet_wireguard::{
15 self,
16 constants::{
17 AF_INET, AF_INET6, WGDEVICE_F_REPLACE_PEERS, WGPEER_F_REMOVE_ME,
18 WGPEER_F_REPLACE_ALLOWEDIPS,
19 },
20 nlas::{WgAllowedIp, WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs},
21 Wireguard, WireguardCmd,
22};
23use netlink_request::{max_genl_payload_length, netlink_request_genl, netlink_request_rtnl};
24
25use std::{convert::TryFrom, io};
26
27macro_rules! get_nla_value {
28 ($nlas:expr, $e:ident, $v:ident) => {
29 $nlas.iter().find_map(|attr| match attr {
30 $e::$v(value) => Some(value),
31 _ => None,
32 })
33 };
34}
35
36impl TryFrom<WgAllowedIp> for AllowedIp {
37 type Error = io::Error;
38
39 fn try_from(attrs: WgAllowedIp) -> Result<Self, Self::Error> {
40 let address = *get_nla_value!(attrs, WgAllowedIpAttrs, IpAddr)
41 .ok_or_else(|| io::ErrorKind::NotFound)?;
42 let cidr = *get_nla_value!(attrs, WgAllowedIpAttrs, Cidr)
43 .ok_or_else(|| io::ErrorKind::NotFound)?;
44 Ok(AllowedIp { address, cidr })
45 }
46}
47
48impl AllowedIp {
49 fn to_nla(&self) -> WgAllowedIp {
50 WgAllowedIp(vec![
51 WgAllowedIpAttrs::Family(if self.address.is_ipv4() {
52 AF_INET
53 } else {
54 AF_INET6
55 }),
56 WgAllowedIpAttrs::IpAddr(self.address),
57 WgAllowedIpAttrs::Cidr(self.cidr),
58 ])
59 }
60}
61
62impl PeerConfigBuilder {
63 fn to_nla(&self) -> WgPeer {
64 let mut attrs = vec![WgPeerAttrs::PublicKey(self.public_key.0)];
65 let mut flags = 0u32;
66 if let Some(endpoint) = self.endpoint {
67 attrs.push(WgPeerAttrs::Endpoint(endpoint));
68 }
69 if let Some(ref key) = self.preshared_key {
70 attrs.push(WgPeerAttrs::PresharedKey(key.0));
71 }
72 if let Some(i) = self.persistent_keepalive_interval {
73 attrs.push(WgPeerAttrs::PersistentKeepalive(i));
74 }
75 let allowed_ips: Vec<_> = self.allowed_ips.iter().map(AllowedIp::to_nla).collect();
76 attrs.push(WgPeerAttrs::AllowedIps(allowed_ips));
77 if self.remove_me {
78 flags |= WGPEER_F_REMOVE_ME;
79 }
80 if self.replace_allowed_ips {
81 flags |= WGPEER_F_REPLACE_ALLOWEDIPS;
82 }
83 if flags != 0 {
84 attrs.push(WgPeerAttrs::Flags(flags));
85 }
86 WgPeer(attrs)
87 }
88}
89
90impl TryFrom<WgPeer> for PeerInfo {
91 type Error = io::Error;
92
93 fn try_from(attrs: WgPeer) -> Result<Self, Self::Error> {
94 let public_key = get_nla_value!(attrs, WgPeerAttrs, PublicKey)
95 .map(|key| Key(*key))
96 .ok_or(io::ErrorKind::NotFound)?;
97 let preshared_key = get_nla_value!(attrs, WgPeerAttrs, PresharedKey).map(|key| Key(*key));
98 let endpoint = get_nla_value!(attrs, WgPeerAttrs, Endpoint).cloned();
99 let persistent_keepalive_interval =
100 get_nla_value!(attrs, WgPeerAttrs, PersistentKeepalive).cloned();
101 let allowed_ips = get_nla_value!(attrs, WgPeerAttrs, AllowedIps)
102 .cloned()
103 .unwrap_or_default()
104 .into_iter()
105 .map(AllowedIp::try_from)
106 .collect::<Result<Vec<_>, _>>()?;
107 let last_handshake_time = get_nla_value!(attrs, WgPeerAttrs, LastHandshake).cloned();
108 let rx_bytes = get_nla_value!(attrs, WgPeerAttrs, RxBytes)
109 .cloned()
110 .unwrap_or_default();
111 let tx_bytes = get_nla_value!(attrs, WgPeerAttrs, TxBytes)
112 .cloned()
113 .unwrap_or_default();
114 Ok(PeerInfo {
115 config: PeerConfig {
116 public_key,
117 preshared_key,
118 endpoint,
119 persistent_keepalive_interval,
120 allowed_ips,
121 },
122 stats: PeerStats {
123 last_handshake_time,
124 rx_bytes,
125 tx_bytes,
126 },
127 })
128 }
129}
130
131impl<'a> TryFrom<&'a [WgDeviceAttrs]> for Device {
132 type Error = io::Error;
133
134 fn try_from(nlas: &'a [WgDeviceAttrs]) -> Result<Self, Self::Error> {
135 let name = get_nla_value!(nlas, WgDeviceAttrs, IfName)
136 .ok_or_else(|| io::ErrorKind::NotFound)?
137 .parse()?;
138 let public_key = get_nla_value!(nlas, WgDeviceAttrs, PublicKey).map(|key| Key(*key));
139 let private_key = get_nla_value!(nlas, WgDeviceAttrs, PrivateKey).map(|key| Key(*key));
140 let listen_port = get_nla_value!(nlas, WgDeviceAttrs, ListenPort).cloned();
141 let fwmark = get_nla_value!(nlas, WgDeviceAttrs, Fwmark).cloned();
142 let peers = nlas
143 .iter()
144 .filter_map(|nla| match nla {
145 WgDeviceAttrs::Peers(peers) => Some(peers.clone()),
146 _ => None,
147 })
148 .flatten()
149 .map(PeerInfo::try_from)
150 .collect::<Result<Vec<_>, _>>()?;
151 Ok(Device {
152 name,
153 public_key,
154 private_key,
155 listen_port,
156 fwmark,
157 peers,
158 linked_name: None,
159 backend: Backend::Kernel,
160 })
161 }
162}
163
164pub fn enumerate() -> Result<Vec<InterfaceName>, io::Error> {
165 let link_responses = netlink_request_rtnl(
166 RouteNetlinkMessage::GetLink(LinkMessage::default()),
167 Some(NLM_F_DUMP | NLM_F_REQUEST),
168 )?;
169 let links = link_responses
170 .into_iter()
171 .filter_map(|response| match response {
173 NetlinkMessage {
174 payload: NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewLink(link)),
175 ..
176 } => Some(link),
177 _ => None,
178 })
179 .filter(|link| {
180 for nla in link.attributes.iter() {
181 if let link::LinkAttribute::LinkInfo(infos) = nla {
182 return infos.iter().any(|info| info == &LinkInfo::Kind(InfoKind::Wireguard))
183 }
184 }
185 false
186 })
187 .filter_map(|link| link.attributes.iter().find_map(|nla| match nla {
188 link::LinkAttribute::IfName(name) => Some(name.clone()),
189 _ => None,
190 }))
191 .filter_map(|name| name.parse().ok())
192 .collect::<Vec<_>>();
193
194 Ok(links)
195}
196
197fn add_del(iface: &InterfaceName, add: bool) -> io::Result<()> {
198 let mut message = LinkMessage::default();
199 message.attributes.push(link::LinkAttribute::IfName(
200 iface.as_str_lossy().to_string(),
201 ));
202 message
203 .attributes
204 .push(link::LinkAttribute::LinkInfo(vec![LinkInfo::Kind(
205 link::InfoKind::Wireguard,
206 )]));
207 let extra_flags = if add { NLM_F_CREATE | NLM_F_EXCL } else { 0 };
208 let rtnl_message = if add {
209 RouteNetlinkMessage::NewLink(message)
210 } else {
211 RouteNetlinkMessage::DelLink(message)
212 };
213 match netlink_request_rtnl(rtnl_message, Some(NLM_F_REQUEST | NLM_F_ACK | extra_flags)) {
214 Err(e) if e.kind() != io::ErrorKind::AlreadyExists => Err(e),
215 _ => Ok(()),
216 }
217}
218
219pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> {
220 add_del(iface, true)?;
221 let mut payload = ApplyPayload::new(iface);
222 if let Some(Key(k)) = builder.private_key {
223 payload.push(WgDeviceAttrs::PrivateKey(k))?;
224 }
225 if let Some(f) = builder.fwmark {
226 payload.push(WgDeviceAttrs::Fwmark(f))?;
227 }
228 if let Some(f) = builder.listen_port {
229 payload.push(WgDeviceAttrs::ListenPort(f))?;
230 }
231 if builder.replace_peers {
232 payload.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))?;
233 }
234
235 builder
236 .peers
237 .iter()
238 .map(|peer| payload.push_peer(peer.to_nla()))
239 .collect::<Result<Vec<_>, _>>()?;
240
241 for message in payload.finish() {
242 netlink_request_genl(message, Some(NLM_F_REQUEST | NLM_F_ACK))?;
243 }
244 Ok(())
245}
246
247struct ApplyPayload {
248 iface: String,
249 nlas: Vec<WgDeviceAttrs>,
250 current_buffer_len: usize,
251 queue: Vec<GenlMessage<Wireguard>>,
252}
253
254impl ApplyPayload {
255 fn new(iface: &InterfaceName) -> Self {
256 let iface_str = iface.as_str_lossy().to_string();
257 let nlas = vec![WgDeviceAttrs::IfName(iface_str.clone())];
258 let current_buffer_len = nlas.as_slice().buffer_len();
259 Self {
260 iface: iface_str,
261 nlas,
262 queue: vec![],
263 current_buffer_len,
264 }
265 }
266
267 fn flush_nlas(&mut self) {
268 self.nlas
270 .retain(|nla| !matches!(nla, WgDeviceAttrs::Peers(peers) if peers.is_empty()));
271
272 let name = WgDeviceAttrs::IfName(self.iface.clone());
273 let template = vec![name];
274
275 if !self.nlas.is_empty() && self.nlas != template {
276 self.current_buffer_len = template.as_slice().buffer_len();
277 let message = GenlMessage::from_payload(Wireguard {
278 cmd: WireguardCmd::SetDevice,
279 nlas: std::mem::replace(&mut self.nlas, template),
280 });
281 self.queue.push(message);
282 }
283 }
284
285 pub fn push(&mut self, nla: WgDeviceAttrs) -> io::Result<()> {
287 let max_payload_len = max_genl_payload_length();
288
289 let nla_buffer_len = nla.buffer_len();
290 if (self.current_buffer_len + nla_buffer_len) > max_payload_len {
291 self.flush_nlas();
292 }
293
294 if (self.current_buffer_len + nla_buffer_len) > max_payload_len {
296 return Err(io::Error::new(
297 io::ErrorKind::InvalidInput,
298 format!("encoded NLA ({nla_buffer_len} bytes) is too large: {nla:?}"),
299 ));
300 }
301 self.nlas.push(nla);
302 self.current_buffer_len += nla_buffer_len;
303 Ok(())
304 }
305
306 pub fn push_peer(&mut self, peer: WgPeer) -> io::Result<()> {
308 const EMPTY_PEERS: WgDeviceAttrs = WgDeviceAttrs::Peers(vec![]);
309 let max_payload_len = max_genl_payload_length();
310 let mut needs_peer_nla = !self
311 .nlas
312 .iter()
313 .any(|nla| matches!(nla, WgDeviceAttrs::Peers(_)));
314 let peer_buffer_len = peer.buffer_len();
315 let mut additional_buffer_len = peer_buffer_len;
316 if needs_peer_nla {
317 additional_buffer_len += EMPTY_PEERS.buffer_len();
318 }
319 if (self.current_buffer_len + additional_buffer_len) > max_payload_len {
320 self.flush_nlas();
321 needs_peer_nla = true;
322 }
323
324 if needs_peer_nla {
325 self.push(EMPTY_PEERS)?;
326 }
327
328 if (self.current_buffer_len + peer_buffer_len) > max_payload_len {
330 return Err(io::Error::new(
331 io::ErrorKind::InvalidInput,
332 format!("encoded peer ({peer_buffer_len} bytes) is too large: {peer:?}"),
333 ));
334 }
335
336 let peers_nla = self
337 .nlas
338 .iter_mut()
339 .find_map(|nla| match nla {
340 WgDeviceAttrs::Peers(peers) => Some(peers),
341 _ => None,
342 })
343 .expect("WgDeviceAttrs::Peers missing from NLAs when it should exist.");
344
345 peers_nla.push(peer);
346 self.current_buffer_len += peer_buffer_len;
347
348 Ok(())
349 }
350
351 pub fn finish(mut self) -> Vec<GenlMessage<Wireguard>> {
352 self.flush_nlas();
353 self.queue
354 }
355}
356
357pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> {
358 let genlmsg: GenlMessage<Wireguard> = GenlMessage::from_payload(Wireguard {
359 cmd: WireguardCmd::GetDevice,
360 nlas: vec![WgDeviceAttrs::IfName(name.as_str_lossy().to_string())],
361 });
362 let responses = netlink_request_genl(genlmsg, Some(NLM_F_REQUEST | NLM_F_DUMP | NLM_F_ACK))?;
363 log::debug!(
364 "get_by_name: got {} response message(s) from netlink request",
365 responses.len()
366 );
367
368 let nlas = responses.into_iter().try_fold(vec![], |mut nlas, nlmsg| {
369 let mut message = match nlmsg {
370 NetlinkMessage {
371 payload: NetlinkPayload::InnerMessage(message),
372 ..
373 } => message,
374 _ => {
375 return Err(io::Error::new(
376 io::ErrorKind::InvalidData,
377 format!("unexpected netlink payload: {nlmsg:?}"),
378 ))
379 },
380 };
381 nlas.append(&mut message.payload.nlas);
382 Ok(nlas)
383 })?;
384 let device = Device::try_from(&nlas[..])?;
385 log::debug!(
386 "get_by_name: parsed wireguard device {} with {} peer(s)",
387 device.name,
388 device.peers.len(),
389 );
390 Ok(device)
391}
392
393pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
394 add_del(iface, false)
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use netlink_packet_wireguard::nlas::WgAllowedIp;
401 use netlink_request::max_netlink_buffer_length;
402 use std::str::FromStr;
403
404 #[test]
405 fn test_simple_payload() {
406 let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
407 payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])).unwrap();
408 payload.push(WgDeviceAttrs::Fwmark(111)).unwrap();
409 payload.push(WgDeviceAttrs::ListenPort(12345)).unwrap();
410 payload
411 .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))
412 .unwrap();
413 payload
414 .push_peer(WgPeer(vec![
415 WgPeerAttrs::PublicKey([2u8; 32]),
416 WgPeerAttrs::PersistentKeepalive(25),
417 WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
418 WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
419 WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
420 WgAllowedIpAttrs::Family(AF_INET),
421 WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
422 WgAllowedIpAttrs::Cidr(24),
423 ])]),
424 ]))
425 .unwrap();
426 assert_eq!(payload.finish().len(), 1);
427 }
428
429 #[test]
430 fn test_massive_payload() {
431 let mut payload = ApplyPayload::new(&InterfaceName::from_str("wg0").unwrap());
432 payload.push(WgDeviceAttrs::PrivateKey([1u8; 32])).unwrap();
433 payload.push(WgDeviceAttrs::Fwmark(111)).unwrap();
434 payload.push(WgDeviceAttrs::ListenPort(12345)).unwrap();
435 payload
436 .push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS))
437 .unwrap();
438
439 for i in 0..10_000 {
440 payload
441 .push_peer(WgPeer(vec![
442 WgPeerAttrs::PublicKey([2u8; 32]),
443 WgPeerAttrs::PersistentKeepalive(25),
444 WgPeerAttrs::Endpoint("1.1.1.1:51820".parse().unwrap()),
445 WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
446 WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
447 WgAllowedIpAttrs::Family(AF_INET),
448 WgAllowedIpAttrs::IpAddr([10, 1, 1, 1].into()),
449 WgAllowedIpAttrs::Cidr(24),
450 ])]),
451 WgPeerAttrs::Unspec(vec![1u8; (i % 256) as usize]),
452 ]))
453 .unwrap();
454 }
455
456 let messages = payload.finish();
457 println!("generated {} messages", messages.len());
458 assert!(messages.len() > 1);
459 let max_buffer_len = max_netlink_buffer_length();
460 for message in messages {
461 assert!(NetlinkMessage::from(message).buffer_len() <= max_buffer_len);
462 }
463 }
464}