zlink_core/connection/
write_connection.rs1use core::fmt::Debug;
4
5#[cfg(feature = "std")]
6use alloc::collections::VecDeque;
7use alloc::vec::Vec;
8use serde::Serialize;
9
10use super::{socket::WriteHalf, Call, Reply, BUFFER_SIZE};
11
12#[cfg(feature = "std")]
13use std::os::fd::OwnedFd;
14
15#[derive(Debug)]
24pub struct WriteConnection<Write: WriteHalf> {
25 pub(super) socket: Write,
26 pub(super) buffer: Vec<u8>,
27 pub(super) pos: usize,
28 id: usize,
29 #[cfg(feature = "std")]
30 pending_fds: VecDeque<MessageFds>,
31}
32
33impl<Write: WriteHalf> WriteConnection<Write> {
34 pub(super) fn new(socket: Write, id: usize) -> Self {
36 Self {
37 socket,
38 id,
39 buffer: alloc::vec![0; BUFFER_SIZE],
40 pos: 0,
41 #[cfg(feature = "std")]
42 pending_fds: VecDeque::new(),
43 }
44 }
45
46 #[inline]
48 pub fn id(&self) -> usize {
49 self.id
50 }
51
52 pub async fn send_call<Method>(
76 &mut self,
77 call: &Call<Method>,
78 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
79 ) -> crate::Result<()>
80 where
81 Method: Serialize + Debug,
82 {
83 trace!("connection {}: sending call: {:?}", self.id, call);
84 #[cfg(feature = "std")]
85 {
86 self.write(call, fds).await
87 }
88 #[cfg(not(feature = "std"))]
89 {
90 self.write(call).await
91 }
92 }
93
94 pub async fn send_reply<Params>(
101 &mut self,
102 reply: &Reply<Params>,
103 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
104 ) -> crate::Result<()>
105 where
106 Params: Serialize + Debug,
107 {
108 trace!("connection {}: sending reply: {:?}", self.id, reply);
109 #[cfg(feature = "std")]
110 {
111 self.write(reply, fds).await
112 }
113 #[cfg(not(feature = "std"))]
114 {
115 self.write(reply).await
116 }
117 }
118
119 pub async fn send_error<ReplyError>(
128 &mut self,
129 error: &ReplyError,
130 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
131 ) -> crate::Result<()>
132 where
133 ReplyError: Serialize + Debug,
134 {
135 trace!("connection {}: sending error: {:?}", self.id, error);
136 #[cfg(feature = "std")]
137 {
138 self.write(error, fds).await
139 }
140 #[cfg(not(feature = "std"))]
141 {
142 self.write(error).await
143 }
144 }
145
146 pub fn enqueue_call<Method>(
154 &mut self,
155 call: &Call<Method>,
156 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
157 ) -> crate::Result<()>
158 where
159 Method: Serialize + Debug,
160 {
161 trace!("connection {}: enqueuing call: {:?}", self.id, call);
162 #[cfg(feature = "std")]
163 {
164 self.enqueue(call, fds)
165 }
166 #[cfg(not(feature = "std"))]
167 {
168 self.enqueue(call)
169 }
170 }
171
172 pub async fn flush(&mut self) -> crate::Result<()> {
174 if self.pos == 0 {
175 return Ok(());
176 }
177
178 #[allow(unused_mut)]
179 let mut sent_pos = 0;
180
181 #[cfg(feature = "std")]
182 {
183 while !self.pending_fds.is_empty() {
185 let pending = self.pending_fds.front().unwrap();
187 let fd_offset = pending.offset;
188 let msg_len = pending.len;
189
190 if sent_pos < fd_offset {
192 trace!(
193 "connection {}: flushing {} bytes before FD message",
194 self.id,
195 fd_offset - sent_pos
196 );
197 self.socket
198 .write(&self.buffer[sent_pos..fd_offset], &[] as &[OwnedFd])
199 .await?;
200 }
201
202 let msg_end = fd_offset + msg_len;
204 let pending = self.pending_fds.pop_front().unwrap();
205 let fds = &pending.fds;
206 trace!(
207 "connection {}: flushing {} bytes with {} FDs",
208 self.id,
209 msg_len,
210 fds.len()
211 );
212 self.socket
213 .write(&self.buffer[fd_offset..msg_end], fds)
214 .await?;
215 sent_pos = msg_end;
216 }
217 }
218
219 if sent_pos < self.pos {
221 trace!(
222 "connection {}: flushing {} bytes",
223 self.id,
224 self.pos - sent_pos
225 );
226 #[cfg(feature = "std")]
227 {
228 self.socket
229 .write(&self.buffer[sent_pos..self.pos], &[] as &[OwnedFd])
230 .await?;
231 }
232 #[cfg(not(feature = "std"))]
233 {
234 self.socket.write(&self.buffer[sent_pos..self.pos]).await?;
235 }
236 }
237
238 self.pos = 0;
239 Ok(())
240 }
241
242 pub fn write_half(&self) -> &Write {
244 &self.socket
245 }
246
247 pub(super) async fn write<T>(
248 &mut self,
249 value: &T,
250 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
251 ) -> crate::Result<()>
252 where
253 T: Serialize + ?Sized + Debug,
254 {
255 #[cfg(feature = "std")]
256 {
257 self.enqueue(value, fds)?;
258 }
259 #[cfg(not(feature = "std"))]
260 {
261 self.enqueue(value)?;
262 }
263 self.flush().await
264 }
265
266 pub(super) fn enqueue<T>(
267 &mut self,
268 value: &T,
269 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
270 ) -> crate::Result<()>
271 where
272 T: Serialize + ?Sized + Debug,
273 {
274 #[cfg(feature = "std")]
275 let start_pos = self.pos;
276
277 let len = loop {
278 match crate::json_ser::to_slice(value, &mut self.buffer[self.pos..]) {
279 Ok(len) => break len,
280 Err(crate::json_ser::Error::BufferTooSmall) => {
281 self.grow_buffer()?;
283 }
284 Err(crate::json_ser::Error::KeyMustBeAString) => {
285 return Err(crate::Error::Json(serde::ser::Error::custom(
288 "key must be a string",
289 )));
290 }
291 }
292 };
293
294 if self.pos + len == self.buffer.len() {
296 self.grow_buffer()?;
297 }
298 self.buffer[self.pos + len] = b'\0';
299 self.pos += len + 1;
300
301 #[cfg(feature = "std")]
303 if !fds.is_empty() {
304 self.pending_fds.push_back(MessageFds {
305 offset: start_pos,
306 len: len + 1, fds,
308 });
309 }
310
311 Ok(())
312 }
313
314 fn grow_buffer(&mut self) -> crate::Result<()> {
315 if self.buffer.len() >= super::MAX_BUFFER_SIZE {
316 return Err(crate::Error::BufferOverflow);
317 }
318
319 self.buffer.extend_from_slice(&[0; BUFFER_SIZE]);
320
321 Ok(())
322 }
323}
324
325#[cfg(feature = "std")]
327#[derive(Debug)]
328struct MessageFds {
329 fds: Vec<OwnedFd>,
331 offset: usize,
333 len: usize,
335}