zbus/connection/socket/
mod.rs1#[cfg(feature = "p2p")]
2pub mod channel;
3#[cfg(feature = "p2p")]
4pub use channel::Channel;
5
6mod split;
7pub use split::{BoxedSplit, Split};
8
9#[cfg(unix)]
10pub(crate) mod command;
11#[cfg(unix)]
12pub(crate) use command::Command;
13mod tcp;
14mod unix;
15mod vsock;
16
17#[cfg(not(feature = "tokio"))]
18use async_io::Async;
19#[cfg(not(feature = "tokio"))]
20use std::sync::Arc;
21use std::{io, mem};
22use tracing::trace;
23
24use crate::{
25 Message,
26 conn::AuthMechanism,
27 fdo::ConnectionCredentials,
28 message::{
29 PrimaryHeader,
30 header::{MAX_MESSAGE_SIZE, MIN_MESSAGE_SIZE},
31 },
32 padding_for_8_bytes,
33};
34#[cfg(unix)]
35use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
36use zvariant::{
37 Endian,
38 serialized::{self, Context},
39};
40
41#[cfg(unix)]
42type RecvmsgResult = io::Result<(usize, Vec<OwnedFd>)>;
43
44#[cfg(not(unix))]
45type RecvmsgResult = io::Result<usize>;
46
47pub trait Socket {
61 type ReadHalf: ReadHalf;
62 type WriteHalf: WriteHalf;
63
64 fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf>
66 where
67 Self: Sized;
68}
69
70#[async_trait::async_trait]
74pub trait ReadHalf: std::fmt::Debug + Send + Sync + 'static {
75 async fn receive_message(
94 &mut self,
95 seq: u64,
96 already_received_bytes: &mut Vec<u8>,
97 #[cfg(unix)] already_received_fds: &mut Vec<OwnedFd>,
98 ) -> crate::Result<Message> {
99 #[cfg(unix)]
100 let mut fds = vec![];
101 let mut bytes = if already_received_bytes.len() < MIN_MESSAGE_SIZE {
102 let mut bytes = vec![];
103 if !already_received_bytes.is_empty() {
104 mem::swap(already_received_bytes, &mut bytes);
105 }
106 let mut pos = bytes.len();
107 bytes.resize(MIN_MESSAGE_SIZE, 0);
108 while pos < MIN_MESSAGE_SIZE {
115 let res = self.recvmsg(&mut bytes[pos..]).await?;
116 let len = {
117 #[cfg(unix)]
118 {
119 fds.extend(res.1);
120 res.0
121 }
122 #[cfg(not(unix))]
123 {
124 res
125 }
126 };
127 pos += len;
128 if len == 0 {
129 return Err(std::io::Error::new(
130 std::io::ErrorKind::UnexpectedEof,
131 "failed to receive message",
132 )
133 .into());
134 }
135 }
136
137 bytes
138 } else {
139 already_received_bytes.drain(..MIN_MESSAGE_SIZE).collect()
140 };
141
142 let (primary_header, fields_len) = PrimaryHeader::read(&bytes)?;
143 let header_len = MIN_MESSAGE_SIZE + fields_len as usize;
144 let body_padding = padding_for_8_bytes(header_len);
145 let body_len = primary_header.body_len() as usize;
146 let total_len = header_len + body_padding + body_len;
147 if total_len > MAX_MESSAGE_SIZE {
148 return Err(crate::Error::ExcessData);
149 }
150
151 if !already_received_bytes.is_empty() {
154 let pending = total_len - bytes.len();
156 let to_take = std::cmp::min(pending, already_received_bytes.len());
157 bytes.extend(already_received_bytes.drain(..to_take));
158 }
159 let mut pos = bytes.len();
160 bytes.resize(total_len, 0);
161
162 while pos < total_len {
164 let res = self.recvmsg(&mut bytes[pos..]).await?;
165 let read = {
166 #[cfg(unix)]
167 {
168 fds.extend(res.1);
169 res.0
170 }
171 #[cfg(not(unix))]
172 {
173 res
174 }
175 };
176 pos += read;
177 if read == 0 {
178 return Err(crate::Error::InputOutput(
179 std::io::Error::new(
180 std::io::ErrorKind::UnexpectedEof,
181 "failed to receive message",
182 )
183 .into(),
184 ));
185 }
186 }
187
188 let endian = Endian::from(primary_header.endian_sig());
190
191 #[cfg(unix)]
192 if !already_received_fds.is_empty() {
193 use crate::message::header::PRIMARY_HEADER_SIZE;
194
195 let ctxt = Context::new_dbus(endian, PRIMARY_HEADER_SIZE);
196 let encoded_fields =
197 serialized::Data::new(&bytes[PRIMARY_HEADER_SIZE..header_len], ctxt);
198 let fields: crate::message::Fields<'_> = encoded_fields.deserialize()?.0;
199 let num_required_fds = match fields.unix_fds {
200 Some(num_fds) => num_fds as usize,
201 _ => 0,
202 };
203 let num_pending = num_required_fds
204 .checked_sub(fds.len())
205 .ok_or_else(|| crate::Error::ExcessData)?;
206 if num_pending == 0 {
208 return Err(crate::Error::MissingParameter("Missing file descriptors"));
209 }
210 let mut already_received: Vec<_> = already_received_fds.drain(..num_pending).collect();
212 mem::swap(&mut already_received, &mut fds);
213 fds.extend(already_received);
214 }
215
216 let ctxt = Context::new_dbus(endian, 0);
217 #[cfg(unix)]
218 let bytes = serialized::Data::new_fds(bytes, ctxt, fds);
219 #[cfg(not(unix))]
220 let bytes = serialized::Data::new(bytes, ctxt);
221 Message::from_raw_parts(bytes, seq)
222 }
223
224 async fn recvmsg(&mut self, _buf: &mut [u8]) -> RecvmsgResult {
232 unimplemented!("`ReadHalf` implementers must either override `read_message` or `recvmsg`");
233 }
234
235 fn can_pass_unix_fd(&self) -> bool {
239 false
240 }
241
242 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
244 Ok(ConnectionCredentials::default())
245 }
246
247 fn auth_mechanism(&self) -> AuthMechanism {
251 AuthMechanism::External
252 }
253}
254
255#[async_trait::async_trait]
259pub trait WriteHalf: std::fmt::Debug + Send + Sync + 'static {
260 async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
267 let data = msg.data();
268 let serial = msg.primary_header().serial_num();
269
270 trace!("Sending message: {:?}", msg);
271 let mut pos = 0;
272 while pos < data.len() {
273 #[cfg(unix)]
274 let fds = if pos == 0 {
275 data.fds().iter().map(|f| f.as_fd()).collect()
276 } else {
277 vec![]
278 };
279 pos += self
280 .sendmsg(
281 &data[pos..],
282 #[cfg(unix)]
283 &fds,
284 )
285 .await?;
286 }
287 trace!("Sent message with serial: {}", serial);
288
289 Ok(())
290 }
291
292 async fn sendmsg(
307 &mut self,
308 _buffer: &[u8],
309 #[cfg(unix)] _fds: &[BorrowedFd<'_>],
310 ) -> io::Result<usize> {
311 unimplemented!("`WriteHalf` implementers must either override `send_message` or `sendmsg`");
312 }
313
314 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
319 async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
320 Ok(None)
321 }
322
323 async fn close(&mut self) -> io::Result<()>;
327
328 fn can_pass_unix_fd(&self) -> bool {
332 false
333 }
334
335 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
337 Ok(ConnectionCredentials::default())
338 }
339}
340
341#[async_trait::async_trait]
342impl ReadHalf for Box<dyn ReadHalf> {
343 fn can_pass_unix_fd(&self) -> bool {
344 (**self).can_pass_unix_fd()
345 }
346
347 async fn receive_message(
348 &mut self,
349 seq: u64,
350 already_received_bytes: &mut Vec<u8>,
351 #[cfg(unix)] already_received_fds: &mut Vec<OwnedFd>,
352 ) -> crate::Result<Message> {
353 (**self)
354 .receive_message(
355 seq,
356 already_received_bytes,
357 #[cfg(unix)]
358 already_received_fds,
359 )
360 .await
361 }
362
363 async fn recvmsg(&mut self, buf: &mut [u8]) -> RecvmsgResult {
364 (**self).recvmsg(buf).await
365 }
366
367 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
368 (**self).peer_credentials().await
369 }
370
371 fn auth_mechanism(&self) -> AuthMechanism {
372 (**self).auth_mechanism()
373 }
374}
375
376#[async_trait::async_trait]
377impl WriteHalf for Box<dyn WriteHalf> {
378 async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
379 (**self).send_message(msg).await
380 }
381
382 async fn sendmsg(
383 &mut self,
384 buffer: &[u8],
385 #[cfg(unix)] fds: &[BorrowedFd<'_>],
386 ) -> io::Result<usize> {
387 (**self)
388 .sendmsg(
389 buffer,
390 #[cfg(unix)]
391 fds,
392 )
393 .await
394 }
395
396 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
397 async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
398 (**self).send_zero_byte().await
399 }
400
401 async fn close(&mut self) -> io::Result<()> {
402 (**self).close().await
403 }
404
405 fn can_pass_unix_fd(&self) -> bool {
406 (**self).can_pass_unix_fd()
407 }
408
409 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
410 (**self).peer_credentials().await
411 }
412}
413
414#[cfg(not(feature = "tokio"))]
415impl<T> Socket for Async<T>
416where
417 T: std::fmt::Debug + Send + Sync,
418 Arc<Async<T>>: ReadHalf + WriteHalf,
419{
420 type ReadHalf = Arc<Async<T>>;
421 type WriteHalf = Arc<Async<T>>;
422
423 fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf> {
424 let arc = Arc::new(self);
425
426 Split {
427 read: arc.clone(),
428 write: arc,
429 }
430 }
431}