socket_pktinfo/
unix.rs

1use std::fmt::{Debug, Formatter};
2use std::io::{Error, ErrorKind, IoSliceMut};
3use std::mem::MaybeUninit;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
5use std::os::unix::io::{AsRawFd, RawFd};
6use std::{io, mem, ptr};
7
8use socket2::{Domain, Protocol, SockAddr, Socket, Type};
9
10use crate::PktInfo;
11
12unsafe fn setsockopt<T>(
13    socket: libc::c_int,
14    level: libc::c_int,
15    name: libc::c_int,
16    value: T,
17) -> io::Result<()>
18where
19    T: Copy,
20{
21    let value = &value as *const T as *const libc::c_void;
22    if libc::setsockopt(
23        socket,
24        level,
25        name,
26        value,
27        mem::size_of::<T>() as libc::socklen_t,
28    ) == 0
29    {
30        Ok(())
31    } else {
32        Err(Error::last_os_error())
33    }
34}
35
36//
37pub struct PktInfoUdpSocket {
38    socket: Socket,
39    domain: Domain,
40}
41
42impl Debug for PktInfoUdpSocket {
43    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44        self.socket.fmt(f)
45    }
46}
47
48impl AsRawFd for PktInfoUdpSocket {
49    fn as_raw_fd(&self) -> RawFd {
50        self.socket.as_raw_fd()
51    }
52}
53
54impl PktInfoUdpSocket {
55    pub fn new(domain: Domain) -> io::Result<PktInfoUdpSocket> {
56        let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
57
58        match domain {
59            Domain::IPV4 => unsafe {
60                setsockopt(socket.as_raw_fd(), libc::IPPROTO_IP, libc::IP_PKTINFO, 1)?;
61            },
62            Domain::IPV6 => unsafe {
63                setsockopt(
64                    socket.as_raw_fd(),
65                    libc::IPPROTO_IPV6,
66                    libc::IPV6_RECVPKTINFO,
67                    1,
68                )?;
69            },
70            _ => return Err(Error::from(ErrorKind::Unsupported)),
71        }
72
73        Ok(PktInfoUdpSocket { socket, domain })
74    }
75
76    pub fn domain(&self) -> Domain {
77        self.domain
78    }
79    pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
80        self.socket.set_reuse_address(reuse)
81    }
82
83    pub fn set_reuse_port(&self, reuse: bool) -> io::Result<()> {
84        self.socket.set_reuse_port(reuse)
85    }
86
87    pub fn join_multicast_v4(&self, addr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
88        self.socket.join_multicast_v4(addr, interface)
89    }
90
91    pub fn set_multicast_if_v4(&self, interface: &Ipv4Addr) -> io::Result<()> {
92        self.socket.set_multicast_if_v4(interface)
93    }
94
95    pub fn set_multicast_loop_v4(&self, loop_v4: bool) -> io::Result<()> {
96        self.socket.set_multicast_loop_v4(loop_v4)
97    }
98
99    pub fn join_multicast_v6(&self, addr: &Ipv6Addr, interface: u32) -> io::Result<()> {
100        self.socket.join_multicast_v6(addr, interface)
101    }
102
103    pub fn set_multicast_if_v6(&self, interface: u32) -> io::Result<()> {
104        self.socket.set_multicast_if_v6(interface)
105    }
106
107    pub fn set_multicast_loop_v6(&self, loop_v6: bool) -> io::Result<()> {
108        self.socket.set_multicast_loop_v6(loop_v6)
109    }
110
111    pub fn set_nonblocking(&self, reuse: bool) -> io::Result<()> {
112        self.socket.set_nonblocking(reuse)
113    }
114
115    pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
116        self.socket.bind(addr)
117    }
118
119    pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
120        self.socket.send(buf)
121    }
122
123    pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result<usize> {
124        self.socket.send_to(buf, addr)
125    }
126
127    pub fn recv(&self, buf: &mut [u8]) -> io::Result<(usize, PktInfo)> {
128        let mut addr_src: MaybeUninit<libc::sockaddr_storage> = MaybeUninit::uninit();
129        let mut msg_iov = IoSliceMut::new(buf);
130        let mut cmsg = {
131            let space = unsafe {
132                libc::CMSG_SPACE(mem::size_of::<libc::in_pktinfo>() as libc::c_uint) as usize
133            };
134            Vec::<u8>::with_capacity(space)
135        };
136
137        let mut mhdr = unsafe {
138            let mut mhdr = MaybeUninit::<libc::msghdr>::zeroed();
139            let p = mhdr.as_mut_ptr();
140            (*p).msg_name = addr_src.as_mut_ptr() as *mut libc::c_void;
141            (*p).msg_namelen = mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
142            (*p).msg_iov = &mut msg_iov as *mut IoSliceMut as *mut libc::iovec;
143            (*p).msg_iovlen = 1;
144            (*p).msg_control = cmsg.as_mut_ptr() as *mut libc::c_void;
145            (*p).msg_controllen = cmsg.capacity() as _;
146            (*p).msg_flags = 0;
147            mhdr.assume_init()
148        };
149
150        let bytes_recv =
151            unsafe { libc::recvmsg(self.socket.as_raw_fd(), &mut mhdr as *mut libc::msghdr, 0) };
152        if bytes_recv <= 0 {
153            return Err(Error::last_os_error());
154        }
155
156        let addr_src = unsafe {
157            SockAddr::new(
158                addr_src.assume_init(),
159                mem::size_of::<libc::sockaddr_storage>() as _,
160            )
161        }
162        .as_socket()
163        .unwrap();
164
165        let mut header = if mhdr.msg_controllen > 0 {
166            debug_assert!(!mhdr.msg_control.is_null());
167            debug_assert!(cmsg.capacity() >= mhdr.msg_controllen as usize);
168
169            Some(unsafe {
170                libc::CMSG_FIRSTHDR(&mhdr as *const libc::msghdr)
171                    .as_ref()
172                    .unwrap()
173            })
174        } else {
175            None
176        };
177
178        let mut info: Option<PktInfo> = None;
179        while info.is_none() && header.is_some() {
180            let h = header.unwrap();
181            let p = unsafe { libc::CMSG_DATA(h) };
182
183            match (h.cmsg_level, h.cmsg_type) {
184                (libc::IPPROTO_IP, libc::IP_PKTINFO) => {
185                    let pktinfo = unsafe { ptr::read_unaligned(p as *const libc::in_pktinfo) };
186                    info = Some(PktInfo {
187                        if_index: pktinfo.ipi_ifindex as _,
188                        addr_src,
189                        addr_dst: IpAddr::V4(Ipv4Addr::from(u32::from_be(pktinfo.ipi_addr.s_addr))),
190                    })
191                }
192                (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
193                    let pktinfo = unsafe { ptr::read_unaligned(p as *const libc::in6_pktinfo) };
194
195                    info = Some(PktInfo {
196                        if_index: pktinfo.ipi6_ifindex as _,
197                        addr_src,
198                        addr_dst: IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)),
199                    })
200                }
201                _ => {
202                    header = unsafe {
203                        let p = libc::CMSG_NXTHDR(&mhdr as *const _, h as *const _);
204                        p.as_ref()
205                    };
206                }
207            }
208        }
209
210        match info {
211            None => Err(Error::new(
212                ErrorKind::NotFound,
213                "Failed to read PKTINFO from socket",
214            )),
215            Some(info) => Ok((bytes_recv as _, info)),
216        }
217    }
218}