zlink_core/connection/
read_connection.rs1use core::{fmt::Debug, str::from_utf8_unchecked};
4
5use crate::{varlink_service, Result};
6
7use super::{
8 reply::{self, Reply},
9 socket::ReadHalf,
10 Call, BUFFER_SIZE, MAX_BUFFER_SIZE,
11};
12#[cfg(feature = "std")]
13use alloc::collections::VecDeque;
14use alloc::vec::Vec;
15use serde::Deserialize;
16use serde_json::Deserializer;
17
18#[cfg(feature = "std")]
19use std::os::fd::OwnedFd;
20
21#[cfg(feature = "std")]
23type RecvResult<T> = (T, Vec<OwnedFd>);
24#[cfg(not(feature = "std"))]
25type RecvResult<T> = T;
26
27#[derive(Debug)]
34pub struct ReadConnection<Read: ReadHalf> {
35 socket: Read,
36 read_pos: usize,
37 msg_pos: usize,
38 buffer: Vec<u8>,
39 id: usize,
40 #[cfg(feature = "std")]
41 pending_fds: VecDeque<Vec<OwnedFd>>,
42}
43
44impl<Read: ReadHalf> ReadConnection<Read> {
45 pub(super) fn new(socket: Read, id: usize) -> Self {
47 Self {
48 socket,
49 read_pos: 0,
50 msg_pos: 0,
51 id,
52 buffer: alloc::vec![0; BUFFER_SIZE],
53 #[cfg(feature = "std")]
54 pending_fds: VecDeque::new(),
55 }
56 }
57
58 #[inline]
60 pub fn id(&self) -> usize {
61 self.id
62 }
63
64 pub async fn receive_reply<'r, ReplyParams, ReplyError>(
92 &'r mut self,
93 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
94 where
95 ReplyParams: Deserialize<'r> + Debug,
96 ReplyError: Deserialize<'r> + Debug,
97 {
98 #[derive(Debug, Deserialize)]
99 #[serde(untagged)]
100 enum ReplyMsg<'m, ReplyParams, ReplyError> {
101 #[serde(borrow)]
102 Varlink(varlink_service::Error<'m>),
103 Error(ReplyError),
104 Reply(Reply<ReplyParams>),
105 }
106
107 let recv_result = self
108 .read_message::<ReplyMsg<'_, ReplyParams, ReplyError>>()
109 .await?;
110
111 #[cfg(feature = "std")]
112 let (msg, fds) = recv_result;
113 #[cfg(not(feature = "std"))]
114 let msg = recv_result;
115
116 let result = match msg {
117 ReplyMsg::Varlink(e) => Err(crate::Error::VarlinkService(e.into_owned())),
119 ReplyMsg::Error(e) => Ok(Err(e)),
120 ReplyMsg::Reply(reply) => Ok(Ok(reply)),
121 };
122
123 #[cfg(feature = "std")]
124 return result.map(|r| (r, fds));
125 #[cfg(not(feature = "std"))]
126 return result;
127 }
128
129 pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
139 where
140 Method: Deserialize<'m> + Debug,
141 {
142 self.read_message::<Call<Method>>().await
143 }
144
145 async fn read_message<'m, M>(&'m mut self) -> Result<RecvResult<M>>
147 where
148 M: Deserialize<'m> + Debug,
149 {
150 self.read_from_socket().await?;
151
152 let mut stream = Deserializer::from_slice(&self.buffer[self.msg_pos..]).into_iter::<M>();
153 let msg = stream.next();
154 let null_index = self.msg_pos + stream.byte_offset();
155 let buffer = &self.buffer[self.msg_pos..null_index];
156 if self.buffer[null_index + 1] == b'\0' {
157 self.read_pos = 0;
159 self.msg_pos = 0;
160 } else {
161 self.msg_pos = null_index + 1;
162 }
163
164 match msg {
165 Some(Ok(msg)) => {
166 trace!("connection {}: received a message: {}", self.id, unsafe {
169 from_utf8_unchecked(buffer)
170 });
171
172 #[cfg(feature = "std")]
173 {
174 let fds = self.pending_fds.pop_front().unwrap_or_default();
175 Ok((msg, fds))
176 }
177 #[cfg(not(feature = "std"))]
178 Ok(msg)
179 }
180 Some(Err(e)) => Err(e.into()),
181 None => Err(crate::Error::UnexpectedEof),
182 }
183 }
184
185 async fn read_from_socket(&mut self) -> Result<()> {
187 if self.msg_pos > 0 {
188 return Ok(());
190 }
191
192 loop {
193 #[cfg(feature = "std")]
194 let (bytes_read, fds) = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
195 #[cfg(not(feature = "std"))]
196 let bytes_read = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
197
198 if bytes_read == 0 {
199 return Err(crate::Error::UnexpectedEof);
200 }
201 self.read_pos += bytes_read;
202 #[cfg(feature = "std")]
203 if !fds.is_empty() {
204 self.pending_fds.push_back(fds);
205 }
206
207 if self.read_pos == self.buffer.len() {
208 if self.read_pos >= MAX_BUFFER_SIZE {
209 return Err(crate::Error::BufferOverflow);
210 }
211
212 self.buffer.extend(core::iter::repeat_n(0, BUFFER_SIZE));
213 }
214
215 self.buffer[self.read_pos] = b'\0';
219
220 if self.buffer[self.read_pos - 1] == b'\0' {
221 break;
223 }
224 }
225
226 Ok(())
227 }
228
229 pub fn read_half(&self) -> &Read {
231 &self.socket
232 }
233}