1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use crate::get;
use crate::linux::attr::WgDeviceAttribute;
use crate::linux::cmd::WgCmd;
use crate::linux::consts::{WG_GENL_NAME, WG_GENL_VERSION};
use crate::linux::err::{ConnectError, GetDeviceError, SetDeviceError};
use crate::linux::set;
use crate::linux::set::create_set_device_messages;
use crate::linux::socket::parse::*;
use crate::linux::socket::NlWgMsgType;
use crate::linux::DeviceInterface;
use libc::IFNAMSIZ;
use neli::consts::{NlFamily, NlmF, Nlmsg};
use neli::genl::Genlmsghdr;
use neli::nl::Nlmsghdr;
use neli::nlattr::Nlattr;
use neli::socket::NlSocket;
use neli::Nl;
use neli::StreamWriteBuffer;

pub struct WgSocket {
    sock: NlSocket,
    family_id: NlWgMsgType,
}

impl WgSocket {
    pub fn connect() -> Result<Self, ConnectError> {
        let family_id = {
            NlSocket::new(NlFamily::Generic, true)?
                .resolve_genl_family(WG_GENL_NAME)
                .map_err(ConnectError::ResolveFamilyError)?
        };

        let track_seq = true;
        let mut wgsock = NlSocket::new(NlFamily::Generic, track_seq)?;

        // Autoselect a PID
        let pid = None;
        let groups = None;
        wgsock.bind(pid, groups)?;

        Ok(Self {
            sock: wgsock,
            family_id,
        })
    }

    pub fn get_device(
        &mut self,
        interface: DeviceInterface,
    ) -> Result<get::Device, GetDeviceError> {
        let mut mem = StreamWriteBuffer::new_growable(None);
        let attr = match interface {
            DeviceInterface::Name(name) => {
                Some(name.len())
                    .filter(|&len| 0 < len && len < IFNAMSIZ)
                    .ok_or(GetDeviceError::InvalidInterfaceName)?;
                name.as_ref().serialize(&mut mem)?;
                Nlattr::new(None, WgDeviceAttribute::Ifname, mem.as_ref())?
            }
            DeviceInterface::Index(index) => {
                index.serialize(&mut mem)?;
                Nlattr::new(None, WgDeviceAttribute::Ifindex, mem.as_ref())?
            }
        };
        let genlhdr = {
            let cmd = WgCmd::GetDevice;
            let version = WG_GENL_VERSION;
            let attrs = vec![attr];
            Genlmsghdr::new(cmd, version, attrs)?
        };
        let nlhdr = {
            let size = None;
            let nl_type = self.family_id;
            let flags = vec![NlmF::Request, NlmF::Ack, NlmF::Dump];
            let seq = None;
            let pid = None;
            let payload = genlhdr;
            Nlmsghdr::new(size, nl_type, flags, seq, pid, payload)
        };

        self.sock.send_nl(nlhdr)?;

        let mut iter = self
            .sock
            .iter::<Nlmsg, Genlmsghdr<WgCmd, WgDeviceAttribute>>();

        let mut device = None;
        while let Some(Ok(response)) = iter.next() {
            match response.nl_type {
                Nlmsg::Error => return Err(GetDeviceError::AccessError),
                Nlmsg::Done => break,
                _ => (),
            };

            let handle = response.nl_payload.get_attr_handle();
            device = Some(match device {
                Some(device) => extend_device(device, handle)?,
                None => parse_device(handle)?,
            });
        }

        device.ok_or(GetDeviceError::AccessError)
    }

    /// This assumes that the device interface has already been created. Otherwise an error will
    /// be returned. You can create a new device interface with
    /// [`RouteSocket::add_device`](./struct.RouteSocket.html#add_device.v).
    ///
    /// The peers in this device won't be reachable at their allowed IPs until they're added to the
    /// newly created device interface through a Netlink Route message. This library doesn't have
    /// built-in way to do that right now. Here's how it would be done with the `ip` command:
    ///
    ///
    /// ```sh
    ///  sudo ip -4 route add 127.3.1.1/32 dev wgtest0
    /// ```
    pub fn set_device(&mut self, device: set::Device) -> Result<(), SetDeviceError> {
        for nl_message in create_set_device_messages(device, self.family_id)? {
            self.sock.send_nl(nl_message)?;
            self.sock.recv_ack()?;
        }

        Ok(())
    }
}