tokio_seqpacket/ancillary/
writer.rs1use crate::borrow_fd::BorrowFd;
2
3use super::FD_SIZE;
4
5#[derive(Debug)]
34pub struct AncillaryMessageWriter<'a> {
35 pub(crate) buffer: &'a mut [u8],
36 pub(crate) length: usize,
37}
38
39pub struct AddControlMessageError(());
41
42impl<'a> AncillaryMessageWriter<'a> {
43 pub const BUFFER_ALIGN: usize = std::mem::align_of::<libc::cmsghdr>();
45
46 pub fn new(buffer: &'a mut [u8]) -> Self {
60 let buffer = align_buffer_mut(buffer, Self::BUFFER_ALIGN);
61 Self { buffer, length: 0 }
62 }
63
64 pub fn capacity(&self) -> usize {
66 self.buffer.len()
67 }
68
69 pub fn is_empty(&self) -> bool {
71 self.length == 0
72 }
73
74 pub fn len(&self) -> usize {
76 self.length
77 }
78
79 pub fn add_fds<T>(&mut self, fds: &[T]) -> Result<(), AddControlMessageError>
110 where
111 T: BorrowFd<'a>,
112 {
113 use std::os::fd::AsRawFd;
114
115 let byte_len = fds.len() * FD_SIZE;
116 let buffer = reserve_ancillary_data(self.buffer, &mut self.length, byte_len, libc::SOL_SOCKET, libc::SCM_RIGHTS)?;
117
118 for (i, fd) in fds.iter().enumerate() {
119 let bytes = fd.borrow_fd().as_raw_fd().to_ne_bytes();
120 buffer[i * FD_SIZE..][..FD_SIZE].copy_from_slice(&bytes)
121 }
122 Ok(())
123 }
124
125 #[cfg(any(target_os = "android", target_os = "linux", target_os = "netbsd"))]
133 pub fn add_ucreds(&mut self, credentials: &[crate::UCred]) -> Result<(), AddControlMessageError> {
134 use super::RawScmCreds;
135
136 const ELEM_SIZE: usize = std::mem::size_of::<RawScmCreds>();
137
138 let byte_len = credentials.len() * ELEM_SIZE;
139 let buffer = reserve_ancillary_data(self.buffer, &mut self.length, byte_len, libc::SOL_SOCKET, super::SCM_CREDENTIALS)?;
140
141 for (i, cred) in credentials.iter().enumerate() {
142 let raw = &cred.to_scm_creds();
143 unsafe {
147 std::ptr::copy_nonoverlapping(raw as *const _ as *const u8, buffer[i * ELEM_SIZE..].as_mut_ptr(), ELEM_SIZE);
148 }
149 }
150 Ok(())
151 }
152
153 #[cfg(all(doc, not(any(target_os = "android", target_os = "linux", target_os = "netbsd"))))]
161 pub fn add_ucreds(&mut self, credentials: &[crate::UCred]) -> Result<(), AddControlMessageError> {
162 panic!("fake function for doc generation")
163 }
164}
165
166impl std::error::Error for AddControlMessageError {}
167
168impl std::fmt::Debug for AddControlMessageError {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 f.debug_struct("AddDataError").finish()
171 }
172}
173
174impl std::fmt::Display for AddControlMessageError {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 f.write_str("Not enough space in ancillary buffer")
177 }
178}
179
180impl From<AddControlMessageError> for std::io::Error {
181 fn from(_value: AddControlMessageError) -> Self {
182 std::io::Error::from_raw_os_error(libc::ENOSPC)
183 }
184}
185
186fn reserve_ancillary_data<'a>(
187 buffer: &'a mut [u8],
188 length: &mut usize,
189 byte_len: usize,
190 cmsg_level: libc::c_int,
191 cmsg_type: libc::c_int,
192) -> Result<&'a mut [u8], AddControlMessageError> {
193 let byte_len = u32::try_from(byte_len)
194 .map_err(|_| AddControlMessageError(()))?;
195
196 unsafe {
197 let additional_space = libc::CMSG_SPACE(byte_len) as usize;
198 let new_length = length.checked_add(additional_space)
199 .ok_or(AddControlMessageError(()))?;
200 if new_length > buffer.len() {
201 return Err(AddControlMessageError(()));
202 }
203
204 buffer[*length..new_length].fill(0);
205
206 let mut msg: libc::msghdr = std::mem::zeroed();
207 msg.msg_control = buffer.as_mut_ptr().cast();
208 msg.msg_controllen = new_length as _;
209
210 let mut cmsg = libc::CMSG_FIRSTHDR(&msg);
211 let mut previous_cmsg = cmsg;
212 while !cmsg.is_null() {
213 previous_cmsg = cmsg;
214 cmsg = libc::CMSG_NXTHDR(&msg, cmsg);
215
216 if std::ptr::eq(cmsg, previous_cmsg) {
220 break;
221 }
222 }
223
224 if previous_cmsg.is_null() {
225 return Err(AddControlMessageError(()));
226 }
227
228 *length = new_length;
229 (*previous_cmsg).cmsg_level = cmsg_level;
230 (*previous_cmsg).cmsg_type = cmsg_type;
231 (*previous_cmsg).cmsg_len = libc::CMSG_LEN(byte_len) as _;
232
233 let data = libc::CMSG_DATA(previous_cmsg).cast();
234 Ok(std::slice::from_raw_parts_mut(data, additional_space))
235 }
236}
237
238fn align_buffer_mut(buffer: &mut [u8], align: usize) -> &mut [u8] {
240 let offset = buffer.as_ptr().align_offset(align);
241 if offset > buffer.len() {
242 &mut []
243 } else {
244 &mut buffer[offset..]
245 }
246}