use std::{
collections::HashMap,
net::{IpAddr, SocketAddr},
str::FromStr,
};
use super::cerr;
pub fn interfaces() -> std::io::Result<HashMap<InterfaceName, InterfaceData>> {
let mut elements = HashMap::default();
for data in InterfaceIterator::new()? {
let current: &mut InterfaceData = elements.entry(data.name).or_default();
current.socket_addrs.extend(data.socket_addr);
assert!(!(current.mac.is_some() && data.mac.is_some()));
current.mac = current.mac.or(data.mac);
}
Ok(elements)
}
#[derive(Default, Debug)]
pub struct InterfaceData {
socket_addrs: Vec<SocketAddr>,
mac: Option<[u8; 6]>,
}
impl InterfaceData {
pub fn has_ip_addr(&self, address: IpAddr) -> bool {
self.socket_addrs
.iter()
.any(|socket_addr| socket_addr.ip() == address)
}
pub fn mac(&self) -> Option<[u8; 6]> {
self.mac
}
}
struct InterfaceIterator {
base: *mut libc::ifaddrs,
next: *mut libc::ifaddrs,
}
impl InterfaceIterator {
pub fn new() -> std::io::Result<Self> {
let mut addrs: *mut libc::ifaddrs = std::ptr::null_mut();
unsafe {
cerr(libc::getifaddrs(&mut addrs))?;
assert!(!addrs.is_null());
Ok(Self {
base: addrs,
next: addrs,
})
}
}
}
impl Drop for InterfaceIterator {
fn drop(&mut self) {
unsafe { libc::freeifaddrs(self.base) };
}
}
struct InterfaceDataInternal {
name: InterfaceName,
mac: Option<[u8; 6]>,
socket_addr: Option<SocketAddr>,
}
impl Iterator for InterfaceIterator {
type Item = InterfaceDataInternal;
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
let ifaddr = unsafe { self.next.as_ref() }?;
self.next = ifaddr.ifa_next;
let ifname = unsafe { std::ffi::CStr::from_ptr(ifaddr.ifa_name) };
let name = match std::str::from_utf8(ifname.to_bytes()) {
Err(_) => unreachable!("interface names must be ascii"),
Ok(name) => InterfaceName::from_str(name).expect("name from os"),
};
let family = unsafe { (*ifaddr.ifa_addr).sa_family };
let mac = if family as i32 == libc::AF_PACKET {
let sockaddr_ll: libc::sockaddr_ll =
unsafe { std::ptr::read_unaligned(ifaddr.ifa_addr as *const _) };
Some([
sockaddr_ll.sll_addr[0],
sockaddr_ll.sll_addr[1],
sockaddr_ll.sll_addr[2],
sockaddr_ll.sll_addr[3],
sockaddr_ll.sll_addr[4],
sockaddr_ll.sll_addr[5],
])
} else {
None
};
let socket_addr = unsafe { sockaddr_to_socket_addr(ifaddr.ifa_addr) };
let data = InterfaceDataInternal {
name,
mac,
socket_addr,
};
Some(data)
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct InterfaceName {
bytes: [u8; libc::IFNAMSIZ],
}
impl InterfaceName {
#[cfg(test)]
pub const LOOPBACK: Self = Self {
bytes: *b"lo\0\0\0\0\0\0\0\0\0\0\0\0\0\0",
};
#[cfg(test)]
pub const INVALID: Self = Self {
bytes: *b"123412341234123\0",
};
pub fn as_str(&self) -> &str {
std::str::from_utf8(self.bytes.as_slice())
.unwrap_or_default()
.trim_end_matches('\0')
}
pub fn as_cstr(&self) -> &std::ffi::CStr {
let first_null = self.bytes.iter().position(|b| *b == 0).unwrap();
std::ffi::CStr::from_bytes_with_nul(&self.bytes[..=first_null]).unwrap()
}
pub fn to_ifr_name(self) -> [libc::c_char; libc::IFNAMSIZ] {
let mut it = self.bytes.iter().copied();
[0; libc::IFNAMSIZ].map(|_| it.next().unwrap_or(0) as libc::c_char)
}
pub fn from_socket_addr(local_addr: SocketAddr) -> std::io::Result<Option<Self>> {
let matches_inferface = |interface: &InterfaceDataInternal| match interface.socket_addr {
None => false,
Some(address) => address.ip() == local_addr.ip(),
};
match InterfaceIterator::new()?.find(matches_inferface) {
Some(interface) => Ok(Some(interface.name)),
None => Ok(None),
}
}
pub fn get_index(&self) -> Option<libc::c_uint> {
match unsafe { libc::if_nametoindex(self.as_cstr().as_ptr()) } {
0 => None,
n => Some(n),
}
}
}
impl std::fmt::Debug for InterfaceName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("InterfaceName")
.field(&self.as_str())
.finish()
}
}
impl std::fmt::Display for InterfaceName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.as_str().fmt(f)
}
}
impl std::str::FromStr for InterfaceName {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut bytes = [0; libc::IFNAMSIZ];
if s.len() >= bytes.len() {
return Err(());
}
if s.is_empty() {
return Err(());
}
let mut it = s.bytes();
bytes = bytes.map(|_| it.next().unwrap_or_default());
Ok(Self { bytes })
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for InterfaceName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
FromStr::from_str(&s).map_err(|_| serde::de::Error::custom("invalid interface name"))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LinuxNetworkMode {
Ipv4,
Ipv6,
}
impl LinuxNetworkMode {
pub fn unspecified_ip_addr(&self) -> IpAddr {
match self {
LinuxNetworkMode::Ipv4 => IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
LinuxNetworkMode::Ipv6 => IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
}
}
}
unsafe fn sockaddr_to_socket_addr(sockaddr: *const libc::sockaddr) -> Option<SocketAddr> {
match unsafe { (*sockaddr).sa_family as libc::c_int } {
libc::AF_INET => {
let inaddr: libc::sockaddr_in =
unsafe { std::ptr::read_unaligned(sockaddr as *const libc::sockaddr_in) };
let socketaddr = std::net::SocketAddrV4::new(
std::net::Ipv4Addr::from(inaddr.sin_addr.s_addr.to_ne_bytes()),
u16::from_be_bytes(inaddr.sin_port.to_ne_bytes()),
);
Some(std::net::SocketAddr::V4(socketaddr))
}
libc::AF_INET6 => {
let inaddr: libc::sockaddr_in6 =
unsafe { std::ptr::read_unaligned(sockaddr as *const libc::sockaddr_in6) };
let sin_addr = inaddr.sin6_addr.s6_addr;
let segment_bytes: [u8; 16] =
unsafe { std::ptr::read_unaligned(&sin_addr as *const _ as *const _) };
let socketaddr = std::net::SocketAddrV6::new(
std::net::Ipv6Addr::from(segment_bytes),
u16::from_be_bytes(inaddr.sin6_port.to_ne_bytes()),
inaddr.sin6_flowinfo, inaddr.sin6_scope_id,
);
Some(std::net::SocketAddr::V6(socketaddr))
}
_ => None,
}
}
pub fn sockaddr_storage_to_socket_addr(
sockaddr_storage: &libc::sockaddr_storage,
) -> Option<SocketAddr> {
unsafe { sockaddr_to_socket_addr(sockaddr_storage as *const _ as *const libc::sockaddr) }
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use super::*;
#[test]
fn interface_name_from_string() {
assert!(InterfaceName::from_str("").is_err());
assert!(InterfaceName::from_str("a string that is too long").is_err());
let input = "enp0s31f6";
assert_eq!(InterfaceName::from_str(input).unwrap().as_str(), input);
let ifr_name = (*b"enp0s31f6\0\0\0\0\0\0\0").map(|b| b as libc::c_char);
assert_eq!(
InterfaceName::from_str(input).unwrap().to_ifr_name(),
ifr_name
);
}
#[test]
fn test_mac_address_iterator() {
let v: Vec<_> = InterfaceIterator::new()
.unwrap()
.filter_map(|d| d.mac)
.collect();
assert!(!v.is_empty());
}
#[test]
fn test_interface_name_iterator() {
let v: Vec<_> = InterfaceIterator::new().unwrap().map(|d| d.name).collect();
assert!(v.contains(&InterfaceName::LOOPBACK));
}
#[test]
fn test_socket_addr_iterator() {
let v: Vec<_> = InterfaceIterator::new()
.unwrap()
.filter_map(|d| d.socket_addr)
.collect();
let localhost_0 = SocketAddr::from((Ipv4Addr::LOCALHOST, 0));
assert!(v.contains(&localhost_0));
}
#[test]
fn interface_index_ipv4() {
assert!(InterfaceName::LOOPBACK.get_index().is_some());
}
#[test]
fn interface_index_ipv6() {
assert!(InterfaceName::LOOPBACK.get_index().is_some());
}
#[test]
fn interface_index_invalid() {
assert!(InterfaceName::INVALID.get_index().is_none());
}
}