zlink_core/connection/read_connection.rs
1//! Contains connection related API.
2
3use core::{fmt::Debug, str::from_utf8_unchecked};
4
5use crate::{Result, varlink_service};
6
7use super::{
8 BUFFER_SIZE, Call, MAX_BUFFER_SIZE,
9 reply::{self, Reply},
10 socket::ReadHalf,
11};
12#[cfg(feature = "std")]
13use alloc::collections::VecDeque;
14use alloc::vec::Vec;
15use serde::Deserialize;
16use serde_json::Deserializer;
17
18#[cfg(feature = "std")]
19use std::os::fd::OwnedFd;
20
21// Type alias for receive methods - std returns FDs, no_std doesn't
22#[cfg(feature = "std")]
23type RecvResult<T> = (T, Vec<OwnedFd>);
24#[cfg(not(feature = "std"))]
25type RecvResult<T> = T;
26
27/// A connection that can only be used for reading.
28///
29/// # Cancel safety
30///
31/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
32/// documentation.
33#[derive(Debug)]
34pub struct ReadConnection<Read: ReadHalf> {
35 socket: Read,
36 read_pos: usize,
37 msg_pos: usize,
38 buffer: Vec<u8>,
39 id: usize,
40 #[cfg(feature = "std")]
41 pending_fds: VecDeque<Vec<OwnedFd>>,
42 // Number of `recvmsg` calls that returned FDs. Used by `Connection` to drain
43 // `WriteConnection::held_fds` on macOS. See that field's comment for details.
44 #[cfg(all(feature = "std", target_os = "macos"))]
45 pub(super) fd_recvs: usize,
46}
47
48impl<Read: ReadHalf> ReadConnection<Read> {
49 /// Create a new connection.
50 pub(super) fn new(socket: Read, id: usize) -> Self {
51 Self {
52 socket,
53 read_pos: 0,
54 msg_pos: 0,
55 id,
56 buffer: alloc::vec![0; BUFFER_SIZE],
57 #[cfg(feature = "std")]
58 pending_fds: VecDeque::new(),
59 #[cfg(all(feature = "std", target_os = "macos"))]
60 fd_recvs: 0,
61 }
62 }
63
64 /// The unique identifier of the connection.
65 #[inline]
66 pub fn id(&self) -> usize {
67 self.id
68 }
69
70 /// Receives a method call reply.
71 ///
72 /// The generic parameters needs some explanation:
73 ///
74 /// * `ReplyParams` is the type of the successful reply. This should be a type that can
75 /// deserialize itself from the `parameters` field of the reply.
76 /// * `ReplyError` is the type of the error reply. This should be a type that can deserialize
77 /// itself from the whole reply object itself and must fail when there is no `error` field in
78 /// the object. This can be easily achieved using the `zlink::ReplyError` derive:
79 ///
80 /// ```rust
81 /// use zlink_core::ReplyError;
82 ///
83 /// #[derive(Debug, ReplyError)]
84 /// #[zlink(
85 /// interface = "org.example.ftl",
86 /// // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
87 /// crate = "zlink_core",
88 /// )]
89 /// enum MyError {
90 /// Alpha { param1: u32, param2: String },
91 /// Bravo,
92 /// Charlie { param1: String },
93 /// }
94 /// ```
95 ///
96 /// Returns the reply and any file descriptors received (std only).
97 pub async fn receive_reply<'r, ReplyParams, ReplyError>(
98 &'r mut self,
99 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
100 where
101 ReplyParams: Deserialize<'r> + Debug,
102 ReplyError: Deserialize<'r> + Debug,
103 {
104 #[derive(Debug, Deserialize)]
105 #[serde(untagged)]
106 enum ReplyMsg<'m, ReplyParams, ReplyError> {
107 #[serde(borrow)]
108 Varlink(varlink_service::Error<'m>),
109 Error(ReplyError),
110 Reply(Reply<ReplyParams>),
111 }
112
113 let recv_result = self
114 .read_message::<ReplyMsg<'_, ReplyParams, ReplyError>>()
115 .await?;
116
117 #[cfg(feature = "std")]
118 let (msg, fds) = recv_result;
119 #[cfg(not(feature = "std"))]
120 let msg = recv_result;
121
122 let result = match msg {
123 // Varlink service interface error need to be returned as the top-level error.
124 ReplyMsg::Varlink(e) => Err(crate::Error::VarlinkService(e.into())),
125 ReplyMsg::Error(e) => Ok(Err(e)),
126 ReplyMsg::Reply(reply) => Ok(Ok(reply)),
127 };
128
129 #[cfg(feature = "std")]
130 return result.map(|r| (r, fds));
131 #[cfg(not(feature = "std"))]
132 return result;
133 }
134
135 /// Receive a method call over the socket.
136 ///
137 /// The generic `Method` is the type of the method name and its input parameters. This should be
138 /// a type that can deserialize itself from a complete method call message, i-e an object
139 /// containing `method` and `parameter` fields. This can be easily achieved using the
140 /// `serde::Deserialize` derive (See the code snippet in [`super::WriteConnection::send_call`]
141 /// documentation for an example).
142 ///
143 /// Returns the call and any file descriptors received (std only).
144 pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
145 where
146 Method: Deserialize<'m> + Debug,
147 {
148 self.read_message::<Call<Method>>().await
149 }
150
151 // Reads at least one full message from the socket and return a single message bytes.
152 async fn read_message<'m, M>(&'m mut self) -> Result<RecvResult<M>>
153 where
154 M: Deserialize<'m> + Debug,
155 {
156 self.read_from_socket().await?;
157
158 let mut stream = Deserializer::from_slice(&self.buffer[self.msg_pos..]).into_iter::<M>();
159 let msg = stream.next();
160 let null_index = self.msg_pos + stream.byte_offset();
161 let buffer = &self.buffer[self.msg_pos..null_index];
162 if self.buffer[null_index + 1] == b'\0' {
163 // This means we're reading the last message and can now reset the indices.
164 self.read_pos = 0;
165 self.msg_pos = 0;
166 } else {
167 self.msg_pos = null_index + 1;
168 }
169
170 match msg {
171 Some(Ok(msg)) => {
172 // SAFETY: Since the parsing from JSON already succeeded, we can be sure that the
173 // buffer contains a valid UTF-8 string.
174 trace!("connection {}: received a message: {}", self.id, unsafe {
175 from_utf8_unchecked(buffer)
176 });
177
178 #[cfg(feature = "std")]
179 {
180 let fds = self.pending_fds.pop_front().unwrap_or_default();
181 Ok((msg, fds))
182 }
183 #[cfg(not(feature = "std"))]
184 Ok(msg)
185 }
186 Some(Err(e)) => Err(e.into()),
187 None => Err(crate::Error::UnexpectedEof),
188 }
189 }
190
191 // Reads at least one full message from the socket.
192 async fn read_from_socket(&mut self) -> Result<()> {
193 if self.msg_pos > 0 {
194 // This means we already have at least one message in the buffer so no need to read.
195 return Ok(());
196 }
197
198 loop {
199 #[cfg(feature = "std")]
200 let (bytes_read, fds) = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
201 #[cfg(not(feature = "std"))]
202 let bytes_read = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
203
204 if bytes_read == 0 {
205 return Err(crate::Error::UnexpectedEof);
206 }
207 self.read_pos += bytes_read;
208 #[cfg(feature = "std")]
209 if !fds.is_empty() {
210 self.pending_fds.push_back(fds);
211 // Track receipt so `Connection` can drain `WriteConnection::held_fds`.
212 #[cfg(target_os = "macos")]
213 {
214 self.fd_recvs += 1;
215 }
216 }
217
218 if self.read_pos == self.buffer.len() {
219 if self.read_pos >= MAX_BUFFER_SIZE {
220 return Err(crate::Error::BufferOverflow);
221 }
222
223 self.buffer.extend(core::iter::repeat_n(0, BUFFER_SIZE));
224 }
225
226 // This marks end of all messages. After this loop is finished, we'll have 2 consecutive
227 // null bytes at the end. This is then used by the callers to determine that they've
228 // read all messages and can now reset the `read_pos`.
229 self.buffer[self.read_pos] = b'\0';
230
231 if self.buffer[self.read_pos - 1] == b'\0' {
232 // One or more full messages were read.
233 break;
234 }
235 }
236
237 Ok(())
238 }
239
240 /// The underlying read half of the socket.
241 pub fn read_half(&self) -> &Read {
242 &self.socket
243 }
244}