wireguard_uapi/linux/socket/
wg_socket.rs1use crate::get;
2use crate::linux::attr::WgDeviceAttribute;
3use crate::linux::cmd::WgCmd;
4use crate::linux::consts::NLA_NETWORK_ORDER;
5use crate::linux::consts::{WG_GENL_NAME, WG_GENL_VERSION};
6use crate::linux::err::{ConnectError, GetDeviceError, SetDeviceError};
7use crate::linux::set;
8use crate::linux::set::create_set_device_messages;
9use crate::linux::socket::parse::*;
10use crate::linux::socket::NlWgMsgType;
11use crate::linux::DeviceInterface;
12use libc::IFNAMSIZ;
13use neli::{
14 consts::{
15 nl::{NlmF, NlmFFlags, Nlmsg},
16 socket::NlFamily,
17 },
18 genl::{Genlmsghdr, Nlattr},
19 nl::{NlPayload, Nlmsghdr},
20 socket::NlSocketHandle,
21 types::GenlBuffer,
22};
23use std::convert::TryFrom;
24
25pub struct WgSocket {
26 sock: NlSocketHandle,
27 family_id: NlWgMsgType,
28}
29
30impl WgSocket {
31 pub fn connect() -> Result<Self, ConnectError> {
32 let family_id = {
33 NlSocketHandle::new(NlFamily::Generic)?
34 .resolve_genl_family(WG_GENL_NAME)
35 .map_err(ConnectError::ResolveFamilyError)?
36 };
37
38 let pid = None;
40 let groups = &[];
41 let wgsock = NlSocketHandle::connect(NlFamily::Generic, pid, groups)?;
42
43 Ok(Self {
44 sock: wgsock,
45 family_id,
46 })
47 }
48
49 pub fn get_device(
50 &mut self,
51 interface: DeviceInterface,
52 ) -> Result<get::Device, GetDeviceError> {
53 let attr = match interface {
54 DeviceInterface::Name(name) => {
55 Some(name.len())
56 .filter(|&len| 0 < len && len < IFNAMSIZ)
57 .ok_or(GetDeviceError::InvalidInterfaceName)?;
58 Nlattr::new(
59 false,
60 NLA_NETWORK_ORDER,
61 WgDeviceAttribute::Ifname,
62 name.as_ref(),
63 )?
64 }
65 DeviceInterface::Index(index) => {
66 Nlattr::new(false, NLA_NETWORK_ORDER, WgDeviceAttribute::Ifindex, index)?
67 }
68 };
69 let genlhdr = {
70 let cmd = WgCmd::GetDevice;
71 let version = WG_GENL_VERSION;
72 let mut attrs = GenlBuffer::new();
73
74 attrs.push(attr);
75 Genlmsghdr::new(cmd, version, attrs)
76 };
77 let nlhdr = {
78 let size = None;
79 let nl_type = self.family_id;
80 let flags = NlmFFlags::new(&[NlmF::Request, NlmF::Ack, NlmF::Dump]);
81 let seq = None;
82 let pid = None;
83 let payload = NlPayload::Payload(genlhdr);
84 Nlmsghdr::new(size, nl_type, flags, seq, pid, payload)
85 };
86
87 self.sock.send(nlhdr)?;
88
89 let mut iter = self
90 .sock
91 .iter::<Nlmsg, Genlmsghdr<WgCmd, WgDeviceAttribute>>(false);
92
93 let mut device = None;
94 while let Some(Ok(response)) = iter.next() {
95 match response.nl_type {
96 Nlmsg::Error => return Err(GetDeviceError::AccessError),
97 Nlmsg::Done => break,
98 _ => (),
99 };
100
101 let handle = response.get_payload()?.get_attr_handle();
102 device = Some(match device {
103 Some(device) => extend_device(device, handle)?,
104 None => get::Device::try_from(handle)?,
105 });
106 }
107
108 device.ok_or(GetDeviceError::AccessError)
109 }
110
111 pub fn set_device(&mut self, device: set::Device) -> Result<(), SetDeviceError> {
124 for nl_message in create_set_device_messages(device, self.family_id)? {
125 self.sock.send(nl_message)?;
126 self.sock.recv()?;
127 }
128
129 Ok(())
130 }
131}