1use std::fmt::{Debug, Formatter};
2use std::io::{Error, ErrorKind, IoSliceMut};
3use std::mem::MaybeUninit;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
5use std::os::unix::io::{AsRawFd, RawFd};
6use std::{io, mem, ptr};
7
8use socket2::{Domain, Protocol, SockAddr, Socket, Type};
9
10use crate::PktInfo;
11
12unsafe fn setsockopt<T>(
13 socket: libc::c_int,
14 level: libc::c_int,
15 name: libc::c_int,
16 value: T,
17) -> io::Result<()>
18where
19 T: Copy,
20{
21 let value = &value as *const T as *const libc::c_void;
22 if libc::setsockopt(
23 socket,
24 level,
25 name,
26 value,
27 mem::size_of::<T>() as libc::socklen_t,
28 ) == 0
29 {
30 Ok(())
31 } else {
32 Err(Error::last_os_error())
33 }
34}
35
36pub struct PktInfoUdpSocket {
38 socket: Socket,
39 domain: Domain,
40}
41
42impl Debug for PktInfoUdpSocket {
43 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44 self.socket.fmt(f)
45 }
46}
47
48impl AsRawFd for PktInfoUdpSocket {
49 fn as_raw_fd(&self) -> RawFd {
50 self.socket.as_raw_fd()
51 }
52}
53
54impl PktInfoUdpSocket {
55 pub fn new(domain: Domain) -> io::Result<PktInfoUdpSocket> {
56 let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
57
58 match domain {
59 Domain::IPV4 => unsafe {
60 setsockopt(socket.as_raw_fd(), libc::IPPROTO_IP, libc::IP_PKTINFO, 1)?;
61 },
62 Domain::IPV6 => unsafe {
63 setsockopt(
64 socket.as_raw_fd(),
65 libc::IPPROTO_IPV6,
66 libc::IPV6_RECVPKTINFO,
67 1,
68 )?;
69 },
70 _ => return Err(Error::from(ErrorKind::Unsupported)),
71 }
72
73 Ok(PktInfoUdpSocket { socket, domain })
74 }
75
76 pub fn domain(&self) -> Domain {
77 self.domain
78 }
79 pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
80 self.socket.set_reuse_address(reuse)
81 }
82
83 pub fn set_reuse_port(&self, reuse: bool) -> io::Result<()> {
84 self.socket.set_reuse_port(reuse)
85 }
86
87 pub fn join_multicast_v4(&self, addr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
88 self.socket.join_multicast_v4(addr, interface)
89 }
90
91 pub fn set_multicast_if_v4(&self, interface: &Ipv4Addr) -> io::Result<()> {
92 self.socket.set_multicast_if_v4(interface)
93 }
94
95 pub fn set_multicast_loop_v4(&self, loop_v4: bool) -> io::Result<()> {
96 self.socket.set_multicast_loop_v4(loop_v4)
97 }
98
99 pub fn join_multicast_v6(&self, addr: &Ipv6Addr, interface: u32) -> io::Result<()> {
100 self.socket.join_multicast_v6(addr, interface)
101 }
102
103 pub fn set_multicast_if_v6(&self, interface: u32) -> io::Result<()> {
104 self.socket.set_multicast_if_v6(interface)
105 }
106
107 pub fn set_multicast_loop_v6(&self, loop_v6: bool) -> io::Result<()> {
108 self.socket.set_multicast_loop_v6(loop_v6)
109 }
110
111 pub fn set_nonblocking(&self, reuse: bool) -> io::Result<()> {
112 self.socket.set_nonblocking(reuse)
113 }
114
115 pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
116 self.socket.bind(addr)
117 }
118
119 pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
120 self.socket.send(buf)
121 }
122
123 pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result<usize> {
124 self.socket.send_to(buf, addr)
125 }
126
127 pub fn recv(&self, buf: &mut [u8]) -> io::Result<(usize, PktInfo)> {
128 let mut addr_src: MaybeUninit<libc::sockaddr_storage> = MaybeUninit::uninit();
129 let mut msg_iov = IoSliceMut::new(buf);
130 let mut cmsg = {
131 let space = unsafe {
132 libc::CMSG_SPACE(mem::size_of::<libc::in_pktinfo>() as libc::c_uint) as usize
133 };
134 Vec::<u8>::with_capacity(space)
135 };
136
137 let mut mhdr = unsafe {
138 let mut mhdr = MaybeUninit::<libc::msghdr>::zeroed();
139 let p = mhdr.as_mut_ptr();
140 (*p).msg_name = addr_src.as_mut_ptr() as *mut libc::c_void;
141 (*p).msg_namelen = mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
142 (*p).msg_iov = &mut msg_iov as *mut IoSliceMut as *mut libc::iovec;
143 (*p).msg_iovlen = 1;
144 (*p).msg_control = cmsg.as_mut_ptr() as *mut libc::c_void;
145 (*p).msg_controllen = cmsg.capacity() as _;
146 (*p).msg_flags = 0;
147 mhdr.assume_init()
148 };
149
150 let bytes_recv =
151 unsafe { libc::recvmsg(self.socket.as_raw_fd(), &mut mhdr as *mut libc::msghdr, 0) };
152 if bytes_recv <= 0 {
153 return Err(Error::last_os_error());
154 }
155
156 let addr_src = unsafe {
157 SockAddr::new(
158 addr_src.assume_init(),
159 mem::size_of::<libc::sockaddr_storage>() as _,
160 )
161 }
162 .as_socket()
163 .unwrap();
164
165 let mut header = if mhdr.msg_controllen > 0 {
166 debug_assert!(!mhdr.msg_control.is_null());
167 debug_assert!(cmsg.capacity() >= mhdr.msg_controllen as usize);
168
169 Some(unsafe {
170 libc::CMSG_FIRSTHDR(&mhdr as *const libc::msghdr)
171 .as_ref()
172 .unwrap()
173 })
174 } else {
175 None
176 };
177
178 let mut info: Option<PktInfo> = None;
179 while info.is_none() && header.is_some() {
180 let h = header.unwrap();
181 let p = unsafe { libc::CMSG_DATA(h) };
182
183 match (h.cmsg_level, h.cmsg_type) {
184 (libc::IPPROTO_IP, libc::IP_PKTINFO) => {
185 let pktinfo = unsafe { ptr::read_unaligned(p as *const libc::in_pktinfo) };
186 info = Some(PktInfo {
187 if_index: pktinfo.ipi_ifindex as _,
188 addr_src,
189 addr_dst: IpAddr::V4(Ipv4Addr::from(u32::from_be(pktinfo.ipi_addr.s_addr))),
190 })
191 }
192 (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
193 let pktinfo = unsafe { ptr::read_unaligned(p as *const libc::in6_pktinfo) };
194
195 info = Some(PktInfo {
196 if_index: pktinfo.ipi6_ifindex as _,
197 addr_src,
198 addr_dst: IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)),
199 })
200 }
201 _ => {
202 header = unsafe {
203 let p = libc::CMSG_NXTHDR(&mhdr as *const _, h as *const _);
204 p.as_ref()
205 };
206 }
207 }
208 }
209
210 match info {
211 None => Err(Error::new(
212 ErrorKind::NotFound,
213 "Failed to read PKTINFO from socket",
214 )),
215 Some(info) => Ok((bytes_recv as _, info)),
216 }
217 }
218}