Skip to main content

tokio_seqpacket/ancillary/
reader.rs

1use std::os::fd::{OwnedFd, BorrowedFd};
2
3use super::FD_SIZE;
4
5/// Reader to parse received ancillary messages from a Unix socket.
6///
7/// # Example
8/// ```no_run
9/// use tokio_seqpacket::UnixSeqpacket;
10/// use tokio_seqpacket::ancillary::{AncillaryMessageReader, AncillaryMessage};
11/// use std::io::IoSliceMut;
12/// use std::os::fd::AsRawFd;
13///
14/// #[tokio::main]
15/// async fn main() -> std::io::Result<()> {
16///     let sock = UnixSeqpacket::connect("/tmp/sock").await?;
17///
18///     let mut fds = [0; 8];
19///     let mut ancillary_buffer = [0; 128];
20///
21///     let mut buf = [1; 8];
22///     let mut bufs = [IoSliceMut::new(&mut buf)];
23///     let (_read, ancillary) = sock.recv_vectored_with_ancillary(&mut bufs, &mut ancillary_buffer).await?;
24///
25///     for message in ancillary.messages() {
26///         if let AncillaryMessage::FileDescriptors(fds) = message {
27///             for fd in fds {
28///                 println!("received file descriptor: {}", fd.as_raw_fd());
29///             }
30///         }
31///     }
32///     Ok(())
33/// }
34/// ```
35#[derive(Debug)]
36pub struct AncillaryMessageReader<'a> {
37	pub(crate) buffer: &'a mut [u8],
38	pub(crate) truncated: bool,
39}
40
41/// Iterator over ancillary messages from a [`AncillaryMessageReader`].
42#[derive(Copy, Clone)]
43pub struct AncillaryMessages<'a> {
44	buffer: &'a [u8],
45	current: Option<&'a libc::cmsghdr>,
46}
47
48/// Owning iterator over ancillary messages from a [`AncillaryMessageReader`].
49pub struct IntoAncillaryMessages<'a> {
50	buffer: &'a mut [u8],
51	current: Option<&'a libc::cmsghdr>,
52}
53
54/// This enum represent one control message of variable type.
55pub enum AncillaryMessage<'a> {
56	/// Ancillary message holding file descriptors.
57	FileDescriptors(FileDescriptors<'a>),
58
59	/// Ancillary message holding unix credentials.
60	#[cfg(any(doc, target_os = "android", target_os = "linux", target_os = "netbsd",))]
61	Credentials(UnixCredentials<'a>),
62
63	/// Ancillary message uninterpreted data.
64	Other(UnknownMessage<'a>)
65}
66
67/// This enum represent one control message of variable type.
68///
69/// Where applicable, it has taken ownership of the objects in the control message.
70pub enum OwnedAncillaryMessage<'a> {
71	/// Ancillary message holding file descriptors.
72	FileDescriptors(OwnedFileDescriptors<'a>),
73
74	/// Ancillary message holding unix credentials.
75	#[cfg(any(doc, target_os = "android", target_os = "linux", target_os = "netbsd",))]
76	Credentials(UnixCredentials<'a>),
77
78	/// Ancillary message uninterpreted data.
79	Other(UnknownMessage<'a>)
80}
81
82/// A control message containing borrowed file descriptors.
83#[derive(Copy, Clone)]
84pub struct FileDescriptors<'a> {
85	/// The message data.
86	data: &'a [u8],
87}
88
89/// A control message containing owned file descriptors.
90pub struct OwnedFileDescriptors<'a> {
91	/// The message data.
92	data: &'a mut [u8],
93}
94
95/// A control message containing unix credentials for a process.
96#[derive(Copy, Clone)]
97#[cfg(any(doc, target_os = "linux", target_os = "android", target_os = "netbsd"))]
98pub struct UnixCredentials<'a> {
99	/// The message data.
100	data: &'a [u8],
101}
102
103/// An unrecognized control message.
104#[derive(Copy, Clone)]
105pub struct UnknownMessage<'a> {
106	/// The `cmsg_level` field of the ancillary data.
107	cmsg_level: i32,
108
109	/// The `cmsg_type` field of the ancillary data.
110	cmsg_type: i32,
111
112	/// The message data.
113	data: &'a [u8],
114}
115
116impl<'a> AncillaryMessageReader<'a> {
117	/// Create an ancillary data with the given buffer.
118	///
119	/// # Safety
120	/// The memory buffer must contain valid ancillary messages received from the kernel for a Unix socket.
121	///
122	/// The created reader assumes ownership of objects (such as file descriptors) within the message.
123	/// Because of this, you may only create one ancillary message reader for any ancillary message received from the kernel.
124	/// You must also ensure that no other object assumes ownership of the objects within the message.
125	pub unsafe fn new(buffer: &'a mut [u8], truncated: bool) -> Self {
126		Self { buffer, truncated }
127	}
128
129	/// Returns the number of used bytes.
130	pub fn len(&self) -> usize {
131		self.buffer.len()
132	}
133
134	/// Returns `true` if the ancillary data is empty.
135	pub fn is_empty(&self) -> bool {
136		self.buffer.is_empty()
137	}
138
139	/// Is `true` if during a recv operation the ancillary message was truncated.
140	///
141	/// # Example
142	///
143	/// ```no_run
144	/// use tokio_seqpacket::UnixSeqpacket;
145	/// use tokio_seqpacket::ancillary::AncillaryMessageReader;
146	/// use std::io::IoSliceMut;
147	///
148	/// #[tokio::main]
149	/// async fn main() -> std::io::Result<()> {
150	///     let sock = UnixSeqpacket::connect("/tmp/sock").await?;
151	///
152	///     let mut ancillary_buffer = [0; 128];
153	///
154	///     let mut buf = [1; 8];
155	///     let mut bufs = &mut [IoSliceMut::new(&mut buf)];
156	///     let (_read, ancillary) = sock.recv_vectored_with_ancillary(bufs, &mut ancillary_buffer).await?;
157	///
158	///     println!("Is truncated: {}", ancillary.is_truncated());
159	///     Ok(())
160	/// }
161	/// ```
162	pub fn is_truncated(&self) -> bool {
163		self.truncated
164	}
165
166	/// Returns the iterator of the control messages.
167	pub fn messages(&self) -> AncillaryMessages<'_> {
168		AncillaryMessages { buffer: self.buffer, current: None }
169	}
170
171	/// Consume the ancillary message to take ownership of the contained objects (such as file descriptors).
172	pub fn into_messages(mut self) -> IntoAncillaryMessages<'a> {
173		let buffer = std::mem::take(&mut self.buffer);
174		IntoAncillaryMessages { buffer, current: None }
175	}
176}
177
178impl Drop for AncillaryMessageReader<'_> {
179	fn drop(&mut self) {
180		if !self.is_empty() {
181			drop(IntoAncillaryMessages { buffer: self.buffer, current: None })
182		}
183	}
184}
185
186impl<'a> Iterator for AncillaryMessages<'a> {
187	type Item = AncillaryMessage<'a>;
188
189	fn next(&mut self) -> Option<Self::Item> {
190		if self.buffer.is_empty() {
191			return None;
192		}
193		unsafe {
194			let mut msg: libc::msghdr = std::mem::zeroed();
195			msg.msg_control = self.buffer.as_ptr() as *mut _;
196			msg.msg_controllen = self.buffer.len() as _;
197
198			let cmsg = if let Some(current) = self.current {
199				libc::CMSG_NXTHDR(&msg, current)
200			} else {
201				libc::CMSG_FIRSTHDR(&msg)
202			};
203
204			let cmsg = cmsg.as_ref()?;
205
206			// Most operating systems, but not Linux or emscripten, return the previous pointer
207			// when its length is zero. Therefore, check if the previous pointer is the same as
208			// the current one.
209			if let Some(current) = self.current {
210				if std::ptr::eq(current, cmsg) {
211					return None;
212				}
213			}
214
215			self.current = Some(cmsg);
216			let ancillary_result = AncillaryMessage::try_from_cmsghdr(cmsg);
217			Some(ancillary_result)
218		}
219	}
220}
221
222impl<'a> AncillaryMessage<'a> {
223	#[allow(clippy::unnecessary_cast)]
224	fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Self {
225		unsafe {
226			let cmsg_len_zero = libc::CMSG_LEN(0) as usize;
227			let data_len = cmsg.cmsg_len as usize - cmsg_len_zero;
228			let data = libc::CMSG_DATA(cmsg).cast();
229			let data = std::slice::from_raw_parts(data, data_len);
230
231			match (cmsg.cmsg_level, cmsg.cmsg_type) {
232				(libc::SOL_SOCKET, libc::SCM_RIGHTS) => Self::FileDescriptors(FileDescriptors { data }),
233				#[cfg(any(target_os = "android", target_os = "linux", target_os = "netbsd"))]
234				(libc::SOL_SOCKET, super::SCM_CREDENTIALS) => Self::Credentials(UnixCredentials { data }),
235				(cmsg_level, cmsg_type) => Self::Other(UnknownMessage { cmsg_level, cmsg_type, data }),
236			}
237		}
238	}
239}
240
241impl<'a> Iterator for IntoAncillaryMessages<'a> {
242	type Item = OwnedAncillaryMessage<'a>;
243
244	fn next(&mut self) -> Option<Self::Item> {
245		if self.buffer.is_empty() {
246			return None;
247		}
248		unsafe {
249			let mut msg: libc::msghdr = std::mem::zeroed();
250			msg.msg_control = self.buffer.as_ptr() as *mut _;
251			msg.msg_controllen = self.buffer.len() as _;
252
253			let cmsg = if let Some(current) = self.current {
254				libc::CMSG_NXTHDR(&msg, current)
255			} else {
256				libc::CMSG_FIRSTHDR(&msg)
257			};
258
259			let cmsg = cmsg.as_ref()?;
260
261			// Most operating systems, but not Linux or emscripten, return the previous pointer
262			// when its length is zero. Therefore, check if the previous pointer is the same as
263			// the current one.
264			if let Some(current) = self.current {
265				if std::ptr::eq(current, cmsg) {
266					return None;
267				}
268			}
269
270			self.current = Some(cmsg);
271			let ancillary_result = OwnedAncillaryMessage::try_from_cmsghdr(cmsg);
272			Some(ancillary_result)
273		}
274	}
275}
276
277impl Drop for IntoAncillaryMessages<'_> {
278	fn drop(&mut self) {
279		for message in self {
280			drop(message)
281		}
282	}
283}
284
285impl<'a> OwnedAncillaryMessage<'a> {
286	#[allow(clippy::unnecessary_cast)]
287	fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Self {
288		unsafe {
289			let cmsg_len_zero = libc::CMSG_LEN(0) as usize;
290			let data_len = cmsg.cmsg_len as usize - cmsg_len_zero;
291			let data = libc::CMSG_DATA(cmsg).cast();
292			let data = std::slice::from_raw_parts_mut(data, data_len);
293
294			match (cmsg.cmsg_level, cmsg.cmsg_type) {
295				(libc::SOL_SOCKET, libc::SCM_RIGHTS) => Self::FileDescriptors(OwnedFileDescriptors { data }),
296				#[cfg(any(target_os = "android", target_os = "linux", target_os = "netbsd"))]
297				(libc::SOL_SOCKET, super::SCM_CREDENTIALS) => Self::Credentials(UnixCredentials { data }),
298				(cmsg_level, cmsg_type) => Self::Other(UnknownMessage { cmsg_level, cmsg_type, data }),
299			}
300		}
301	}
302}
303
304impl<'a> FileDescriptors<'a> {
305	/// Get the number of file descriptors in the message.
306	pub fn len(&self) -> usize {
307		self.data.len() / FD_SIZE
308	}
309
310	/// Check if the message is empty (contains no file descriptors).
311	pub fn is_empty(&self) -> bool {
312		self.len() == 0
313	}
314
315	/// Get a borrowed file descriptor from the message.
316	///
317	/// Returns `None` if the index is out of bounds.
318	pub fn get(&self, index: usize) -> Option<BorrowedFd<'a>> {
319		if index >= self.len() {
320			None
321		} else {
322			// SAFETY: The memory is valid, and the kernel guaranteed it is a file descriptor.
323			// Additionally, the returned lifetime is linked to the `AncillaryMessageReader` which owns the file descriptor.
324			unsafe {
325				Some(std::ptr::read_unaligned(self.data[index * FD_SIZE..].as_ptr().cast()))
326			}
327		}
328	}
329}
330
331impl<'a> Iterator for FileDescriptors<'a> {
332	type Item = BorrowedFd<'a>;
333
334	fn next(&mut self) -> Option<Self::Item> {
335		let fd = self.get(0)?;
336		self.data = &self.data[FD_SIZE..];
337		Some(fd)
338	}
339
340	fn size_hint(&self) -> (usize, Option<usize>) {
341		(self.len(), Some(self.len()))
342	}
343}
344
345impl<'a> std::iter::ExactSizeIterator for FileDescriptors<'a> {
346	fn len(&self) -> usize {
347		self.len()
348	}
349}
350
351impl<'a> OwnedFileDescriptors<'a> {
352	/// Get the number of file descriptors in the message.
353	pub fn len(&self) -> usize {
354		self.data.len() / FD_SIZE
355	}
356
357	/// Check if the message is empty (contains no file descriptors).
358	pub fn is_empty(&self) -> bool {
359		self.len() == 0
360	}
361
362	/// Advance the iterator.
363	fn advance(&mut self) {
364		let data = std::mem::take(&mut self.data);
365		self.data = &mut data[FD_SIZE..];
366	}
367}
368
369impl<'a> Iterator for OwnedFileDescriptors<'a> {
370	type Item = OwnedFd;
371
372	fn next(&mut self) -> Option<Self::Item> {
373		if Self::is_empty(self) {
374			None
375		} else {
376			// SAFETY: The memory is valid, and the kernel guaranteed it is a file descriptor.
377			// Additionally, the returned lifetime is linked to the `AncillaryMessageReader` which owns the file descriptor.
378			// And we overwrite the original value with -1 before returning the owned fd to ensure we don't try to own it multiple times.
379			unsafe {
380				use std::os::fd::{FromRawFd, RawFd};
381				let raw_fd: RawFd = std::ptr::read_unaligned(self.data.as_mut_ptr().cast());
382				self.advance();
383				Some(OwnedFd::from_raw_fd(raw_fd))
384			}
385		}
386	}
387
388	fn size_hint(&self) -> (usize, Option<usize>) {
389		(self.len(), Some(self.len()))
390	}
391}
392
393impl Drop for OwnedFileDescriptors<'_> {
394	fn drop(&mut self) {
395		for fd in self {
396			drop(fd)
397		}
398	}
399}
400
401impl<'a> std::iter::ExactSizeIterator for OwnedFileDescriptors<'a> {
402	fn len(&self) -> usize {
403		self.len()
404	}
405}
406
407#[cfg(any(target_os = "linux", target_os = "android", target_os = "netbsd"))]
408mod unix_creds_impl {
409	use super::UnixCredentials;
410	use super::super::RawScmCreds;
411	use crate::UCred;
412
413	impl UnixCredentials<'_> {
414		/// Get the number of credentials in the message.
415		pub fn len(&self) -> usize {
416			self.data.len() / std::mem::size_of::<RawScmCreds>()
417		}
418
419		/// Check if the message is empty (contains no credentials).
420		pub fn is_empty(&self) -> bool {
421			self.len() == 0
422		}
423
424		/// Get the credentials at a specific index.
425		pub fn get(&self, index: usize) -> Option<UCred> {
426			if index >= self.len() {
427				None
428			} else {
429				// SAFETY: The memory is valid, and the kernel guaranteed it is a credentials struct.
430				// It probably also guarantees alignment, but just in case not, use read_unaligned.
431				let raw: RawScmCreds = unsafe {
432					std::ptr::read_unaligned(self.data.as_ptr().cast::<RawScmCreds>().add(index))
433				};
434				Some(UCred::from_scm_creds(raw))
435			}
436		}
437	}
438
439	impl Iterator for UnixCredentials<'_> {
440		type Item = UCred;
441
442		fn next(&mut self) -> Option<Self::Item> {
443			let fd = self.get(0)?;
444			self.data = &self.data[std::mem::size_of::<RawScmCreds>()..];
445			Some(fd)
446		}
447
448		fn size_hint(&self) -> (usize, Option<usize>) {
449			(self.len(), Some(self.len()))
450		}
451	}
452
453	impl<'a> std::iter::ExactSizeIterator for UnixCredentials<'a> {
454		fn len(&self) -> usize {
455			self.len()
456		}
457	}
458}
459
460impl<'a> UnknownMessage<'a> {
461	/// Get the cmsg_level of the message.
462	pub fn cmsg_level(&self) -> i32 {
463		self.cmsg_level
464	}
465
466	/// Get the cmsg_type of the message.
467	pub fn cmsg_type(&self) -> i32 {
468		self.cmsg_type
469	}
470
471	/// Get the data of the message.
472	pub fn data(&self) -> &'a [u8] {
473		self.data
474	}
475}