Skip to main content

tokio_seqpacket/ancillary/
writer.rs

1use crate::borrow_fd::BorrowFd;
2
3use super::FD_SIZE;
4
5/// Writer to help you construct ancillary messages for Unix sockets.
6///
7/// The writer uses a pre-allocated buffer and will never (re)-allocate.
8///
9/// # Example
10/// ```no_run
11/// use tokio_seqpacket::UnixSeqpacket;
12/// use tokio_seqpacket::ancillary::AncillaryMessageWriter;
13/// use std::io::IoSlice;
14/// use std::os::fd::AsRawFd;
15///
16/// #[tokio::main]
17/// async fn main() -> std::io::Result<()> {
18///     let sock = UnixSeqpacket::connect("/tmp/sock").await?;
19///     let file = std::fs::File::create("/tmp/my-file")?;
20///
21///     let mut fds = [0; 8];
22///     let mut ancillary_buffer = [0; 128];
23///     let mut ancillary = AncillaryMessageWriter::new(&mut ancillary_buffer);
24///     ancillary.add_fds(&[&file])?;
25///
26///     let mut buf = [1; 8];
27///     let mut bufs = [IoSlice::new(&mut buf)];
28///     sock.send_vectored_with_ancillary(&mut bufs, &mut ancillary).await?;
29///
30///     Ok(())
31/// }
32/// ```
33#[derive(Debug)]
34pub struct AncillaryMessageWriter<'a> {
35	pub(crate) buffer: &'a mut [u8],
36	pub(crate) length: usize,
37}
38
39/// Failed to add a control message to a ancillary message buffer.
40pub struct AddControlMessageError(());
41
42impl<'a> AncillaryMessageWriter<'a> {
43	/// Alignment requirement for the control messages added to the buffer.
44	pub const BUFFER_ALIGN: usize = std::mem::align_of::<libc::cmsghdr>();
45
46	/// Create an ancillary data with the given buffer.
47	///
48	/// Some bytes at the start of the buffer may be left unused to enforce alignment to [`Self::BUFFER_ALIGN`].
49	/// You can use [`Self::capacity()`] to check how much of the buffer can be used for control messages.
50	///
51	/// # Example
52	///
53	/// ```no_run
54	/// # #![allow(unused_mut)]
55	/// use tokio_seqpacket::ancillary::AncillaryMessageWriter;
56	/// let mut ancillary_buffer = [0; 128];
57	/// let mut ancillary = AncillaryMessageWriter::new(&mut ancillary_buffer);
58	/// ```
59	pub fn new(buffer: &'a mut [u8]) -> Self {
60		let buffer = align_buffer_mut(buffer, Self::BUFFER_ALIGN);
61		Self { buffer, length: 0 }
62	}
63
64	/// Returns the capacity of the buffer.
65	pub fn capacity(&self) -> usize {
66		self.buffer.len()
67	}
68
69	/// Returns `true` if the ancillary data is empty.
70	pub fn is_empty(&self) -> bool {
71		self.length == 0
72	}
73
74	/// Returns the number of used bytes.
75	pub fn len(&self) -> usize {
76		self.length
77	}
78
79	/// Add file descriptors to the ancillary data.
80	///
81	/// The function returns `Ok(())` if there was enough space in the buffer.
82	/// If there was not enough space then no file descriptors was appended.
83	///
84	/// This adds a single control message with level `SOL_SOCKET` and type `SCM_RIGHTS`.
85	///
86	/// # Example
87	///
88	/// ```no_run
89	/// use tokio_seqpacket::UnixSeqpacket;
90	/// use tokio_seqpacket::ancillary::AncillaryMessageWriter;
91	/// use std::os::unix::io::AsFd;
92	/// use std::io::IoSlice;
93	///
94	/// #[tokio::main]
95	/// async fn main() -> std::io::Result<()> {
96	///     let sock = UnixSeqpacket::connect("/tmp/sock").await?;
97	///     let file = std::fs::File::open("/my/file")?;
98	///
99	///     let mut ancillary_buffer = [0; 128];
100	///     let mut ancillary = AncillaryMessageWriter::new(&mut ancillary_buffer);
101	///     ancillary.add_fds(&[file.as_fd()]);
102	///
103	///     let buf = [1; 8];
104	///     let mut bufs = &mut [IoSlice::new(&buf)];
105	///     sock.send_vectored_with_ancillary(bufs, &mut ancillary).await?;
106	///     Ok(())
107	/// }
108	/// ```
109	pub fn add_fds<T>(&mut self, fds: &[T]) -> Result<(), AddControlMessageError>
110		where
111			T: BorrowFd<'a>,
112	{
113		use std::os::fd::AsRawFd;
114
115		let byte_len = fds.len() * FD_SIZE;
116		let buffer = reserve_ancillary_data(self.buffer, &mut self.length, byte_len, libc::SOL_SOCKET, libc::SCM_RIGHTS)?;
117
118		for (i, fd) in fds.iter().enumerate() {
119			let bytes = fd.borrow_fd().as_raw_fd().to_ne_bytes();
120			buffer[i * FD_SIZE..][..FD_SIZE].copy_from_slice(&bytes)
121		}
122		Ok(())
123	}
124
125	/// Add Unix credentials to the ancillary data.
126	///
127	/// The function returns `Ok(())` if there is enough space in the buffer.
128	/// If there is not enough space, then no credentials are appended.
129	///
130	/// This function adds a single control message with level `SOL_SOCKET` and type `SCM_CREDENTIALS` on most platforms.
131	/// On NetBSD the message has type `SCM_CREDS`.
132	#[cfg(any(target_os = "android", target_os = "linux", target_os = "netbsd"))]
133	pub fn add_ucreds(&mut self, credentials: &[crate::UCred]) -> Result<(), AddControlMessageError> {
134		use super::RawScmCreds;
135
136		const ELEM_SIZE: usize = std::mem::size_of::<RawScmCreds>();
137
138		let byte_len = credentials.len() * ELEM_SIZE;
139		let buffer = reserve_ancillary_data(self.buffer, &mut self.length, byte_len, libc::SOL_SOCKET, super::SCM_CREDENTIALS)?;
140
141		for (i, cred) in credentials.iter().enumerate() {
142			let raw = &cred.to_scm_creds();
143			// SAFETY: The pointers are guaranteed valid and non-overlapping,
144			// since they come from distinct &mut self and &[SocketCred] references.
145			// The buffer is guaranteed to be large enough by `reserve_ancillary_data`.
146			unsafe {
147				std::ptr::copy_nonoverlapping(raw as *const _ as *const u8, buffer[i * ELEM_SIZE..].as_mut_ptr(), ELEM_SIZE);
148			}
149		}
150		Ok(())
151	}
152
153	/// Add Unix credentials to the ancillary data.
154	///
155	/// The function returns `Ok(())` if there is enough space in the buffer.
156	/// If there is not enough space, then no credentials are appended.
157	///
158	/// This function adds a single control message with level `SOL_SOCKET` and type `SCM_CREDENTIALS` on most platforms.
159	/// On NetBSD the message has type `SCM_CREDS`.
160	#[cfg(all(doc, not(any(target_os = "android", target_os = "linux", target_os = "netbsd"))))]
161	pub fn add_ucreds(&mut self, credentials: &[crate::UCred]) -> Result<(), AddControlMessageError> {
162		panic!("fake function for doc generation")
163	}
164}
165
166impl std::error::Error for AddControlMessageError {}
167
168impl std::fmt::Debug for AddControlMessageError {
169	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170		f.debug_struct("AddDataError").finish()
171	}
172}
173
174impl std::fmt::Display for AddControlMessageError {
175	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176		f.write_str("Not enough space in ancillary buffer")
177	}
178}
179
180impl From<AddControlMessageError> for std::io::Error {
181	fn from(_value: AddControlMessageError) -> Self {
182		std::io::Error::from_raw_os_error(libc::ENOSPC)
183	}
184}
185
186fn reserve_ancillary_data<'a>(
187	buffer: &'a mut [u8],
188	length: &mut usize,
189	byte_len: usize,
190	cmsg_level: libc::c_int,
191	cmsg_type: libc::c_int,
192) -> Result<&'a mut [u8], AddControlMessageError> {
193	let byte_len = u32::try_from(byte_len)
194		.map_err(|_| AddControlMessageError(()))?;
195
196	unsafe {
197		let additional_space = libc::CMSG_SPACE(byte_len) as usize;
198		let new_length = length.checked_add(additional_space)
199			.ok_or(AddControlMessageError(()))?;
200		if new_length > buffer.len() {
201			return Err(AddControlMessageError(()));
202		}
203
204		buffer[*length..new_length].fill(0);
205
206		let mut msg: libc::msghdr = std::mem::zeroed();
207		msg.msg_control = buffer.as_mut_ptr().cast();
208		msg.msg_controllen = new_length as _;
209
210		let mut cmsg = libc::CMSG_FIRSTHDR(&msg);
211		let mut previous_cmsg = cmsg;
212		while !cmsg.is_null() {
213			previous_cmsg = cmsg;
214			cmsg = libc::CMSG_NXTHDR(&msg, cmsg);
215
216			// Most operating systems, but not Linux or emscripten, return the previous pointer
217			// when its length is zero. Therefore, check if the previous pointer is the same as
218			// the current one.
219			if std::ptr::eq(cmsg, previous_cmsg) {
220				break;
221			}
222		}
223
224		if previous_cmsg.is_null() {
225			return Err(AddControlMessageError(()));
226		}
227
228		*length = new_length;
229		(*previous_cmsg).cmsg_level = cmsg_level;
230		(*previous_cmsg).cmsg_type = cmsg_type;
231		(*previous_cmsg).cmsg_len = libc::CMSG_LEN(byte_len) as _;
232
233		let data = libc::CMSG_DATA(previous_cmsg).cast();
234		Ok(std::slice::from_raw_parts_mut(data, additional_space))
235	}
236}
237
238/// Align a buffer to the given alignment.
239fn align_buffer_mut(buffer: &mut [u8], align: usize) -> &mut [u8] {
240	let offset = buffer.as_ptr().align_offset(align);
241	if offset > buffer.len() {
242		&mut []
243	} else {
244		&mut buffer[offset..]
245	}
246}