1use core::fmt::Debug;
4
5use mayheap::Vec;
6use serde::Serialize;
7
8use super::{socket::WriteHalf, Call, Reply, BUFFER_SIZE};
9
10#[derive(Debug)]
19pub struct WriteConnection<Write: WriteHalf> {
20 socket: Write,
21 buffer: Vec<u8, BUFFER_SIZE>,
22 pos: usize,
23 id: usize,
24}
25
26impl<Write: WriteHalf> WriteConnection<Write> {
27 pub(super) fn new(socket: Write, id: usize) -> Self {
29 Self {
30 socket,
31 id,
32 buffer: Vec::from_slice(&[0; BUFFER_SIZE]).unwrap(),
33 pos: 0,
34 }
35 }
36
37 #[inline]
39 pub fn id(&self) -> usize {
40 self.id
41 }
42
43 pub async fn send_call<Method>(&mut self, call: &Call<Method>) -> crate::Result<()>
65 where
66 Method: Serialize + Debug,
67 {
68 trace!("connection {}: sending call: {:?}", self.id, call);
69 self.write(call).await
70 }
71
72 pub async fn send_reply<Params>(&mut self, reply: &Reply<Params>) -> crate::Result<()>
77 where
78 Params: Serialize + Debug,
79 {
80 trace!("connection {}: sending reply: {:?}", self.id, reply);
81 self.write(reply).await
82 }
83
84 pub async fn send_error<ReplyError>(&mut self, error: &ReplyError) -> crate::Result<()>
91 where
92 ReplyError: Serialize + Debug,
93 {
94 trace!("connection {}: sending error: {:?}", self.id, error);
95 self.write(error).await
96 }
97
98 pub fn enqueue_call<Method>(&mut self, call: &Call<Method>) -> crate::Result<()>
104 where
105 Method: Serialize + Debug,
106 {
107 trace!("connection {}: enqueuing call: {:?}", self.id, call);
108 self.enqueue(call)
109 }
110
111 pub async fn flush(&mut self) -> crate::Result<()> {
113 if self.pos == 0 {
114 return Ok(());
115 }
116
117 trace!("connection {}: flushing {} bytes", self.id, self.pos);
118 self.socket.write(&self.buffer[..self.pos]).await?;
119 self.pos = 0;
120 Ok(())
121 }
122
123 pub fn write_half(&self) -> &Write {
125 &self.socket
126 }
127
128 async fn write<T>(&mut self, value: &T) -> crate::Result<()>
129 where
130 T: Serialize + ?Sized + Debug,
131 {
132 self.enqueue(value)?;
133 self.flush().await
134 }
135
136 fn enqueue<T>(&mut self, value: &T) -> crate::Result<()>
137 where
138 T: Serialize + ?Sized + Debug,
139 {
140 let len = loop {
141 match to_slice_at_pos(value, &mut self.buffer, self.pos) {
142 Ok(len) => break len,
143 #[cfg(feature = "std")]
144 Err(crate::Error::Json(e)) if e.is_io() => {
145 self.grow_buffer()?;
148 }
149 Err(e) => return Err(e),
150 }
151 };
152
153 if self.pos + len == self.buffer.len() {
155 #[cfg(feature = "std")]
156 {
157 self.grow_buffer()?;
158 }
159 #[cfg(not(feature = "std"))]
160 {
161 return Err(crate::Error::BufferOverflow);
162 }
163 }
164 self.buffer[self.pos + len] = b'\0';
165 self.pos += len + 1;
166 Ok(())
167 }
168
169 #[cfg(feature = "std")]
170 fn grow_buffer(&mut self) -> crate::Result<()> {
171 if self.buffer.len() >= super::MAX_BUFFER_SIZE {
172 return Err(crate::Error::BufferOverflow);
173 }
174
175 self.buffer.extend_from_slice(&[0; BUFFER_SIZE])?;
176
177 Ok(())
178 }
179}
180
181fn to_slice_at_pos<T>(value: &T, buf: &mut [u8], pos: usize) -> crate::Result<usize>
182where
183 T: Serialize + ?Sized,
184{
185 #[cfg(feature = "std")]
186 {
187 let mut cursor = std::io::Cursor::new(&mut buf[pos..]);
188 serde_json::to_writer(&mut cursor, value)?;
189
190 Ok(cursor.position() as usize)
191 }
192
193 #[cfg(not(feature = "std"))]
194 {
195 serde_json_core::to_slice(value, &mut buf[pos..]).map_err(Into::into)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 use crate::test_utils::mock_socket::TestWriteHalf;
204
205 #[tokio::test]
206 async fn write() {
207 const WRITE_LEN: usize =
208 BUFFER_SIZE +
210 (BUFFER_SIZE - 1) +
212 2 +
214 1;
216 let mut write_conn = WriteConnection::new(TestWriteHalf::new(WRITE_LEN), 1);
217 let item: Vec<u8, BUFFER_SIZE> = Vec::from_slice(&[0u8; BUFFER_SIZE]).unwrap();
219 let res = write_conn.write(&item).await;
220 #[cfg(feature = "std")]
221 {
222 res.unwrap();
223 assert_eq!(write_conn.buffer.len(), BUFFER_SIZE * 3);
224 assert_eq!(write_conn.pos, 0); }
226 #[cfg(not(feature = "std"))]
227 {
228 assert!(matches!(
229 res,
230 Err(crate::Error::JsonSerialize(
231 serde_json_core::ser::Error::BufferFull
232 ))
233 ));
234 assert_eq!(write_conn.buffer.len(), BUFFER_SIZE);
235 }
236 }
237
238 #[tokio::test]
239 async fn enqueue_and_flush() {
240 let mut write_conn = WriteConnection::new(TestWriteHalf::new(5), 1); write_conn.enqueue(&42u32).unwrap();
244 write_conn.enqueue(&3u32).unwrap();
245 assert_eq!(write_conn.pos, 5); write_conn.flush().await.unwrap();
248 assert_eq!(write_conn.pos, 0); }
250
251 #[tokio::test]
252 async fn enqueue_null_terminators() {
253 let mut write_conn = WriteConnection::new(TestWriteHalf::new(4), 1); write_conn.enqueue(&1u32).unwrap();
257 assert_eq!(write_conn.buffer[write_conn.pos - 1], b'\0');
258
259 write_conn.enqueue(&2u32).unwrap();
260 assert_eq!(write_conn.buffer[write_conn.pos - 1], b'\0');
261
262 write_conn.flush().await.unwrap();
263 }
264
265 #[cfg(feature = "std")]
266 #[tokio::test]
267 async fn enqueue_buffer_extension() {
268 let mut write_conn = WriteConnection::new(TestWriteHalf::new(0), 1);
270 let initial_len = write_conn.buffer.len();
271
272 let large_item: Vec<u8, BUFFER_SIZE> = Vec::from_slice(&[0u8; BUFFER_SIZE]).unwrap();
274 write_conn.enqueue(&large_item).unwrap();
275
276 assert!(write_conn.buffer.len() > initial_len);
277 }
278
279 #[cfg(not(feature = "std"))]
280 #[tokio::test]
281 async fn enqueue_buffer_overflow() {
282 let mut write_conn = WriteConnection::new(TestWriteHalf::new(0), 1);
284
285 let large_item: Vec<u8, BUFFER_SIZE> = Vec::from_slice(&[0u8; BUFFER_SIZE]).unwrap();
287 let res = write_conn.enqueue(&large_item);
288
289 assert!(matches!(
290 res,
291 Err(crate::Error::JsonSerialize(
292 serde_json_core::ser::Error::BufferFull
293 ))
294 ));
295 }
296
297 #[tokio::test]
298 async fn flush_empty_buffer() {
299 let mut write_conn = WriteConnection::new(TestWriteHalf::new(0), 1);
301
302 write_conn.flush().await.unwrap();
304 assert_eq!(write_conn.pos, 0);
305 }
306
307 #[tokio::test]
308 async fn multiple_flushes() {
309 let mut write_conn = WriteConnection::new(TestWriteHalf::new(2), 1); write_conn.enqueue(&1u32).unwrap();
313 write_conn.flush().await.unwrap();
314 assert_eq!(write_conn.pos, 0);
315
316 write_conn.flush().await.unwrap();
318 assert_eq!(write_conn.pos, 0);
319 }
320
321 #[tokio::test]
322 async fn enqueue_after_flush() {
323 let mut write_conn = WriteConnection::new(TestWriteHalf::new(2), 1); write_conn.enqueue(&1u32).unwrap();
327 write_conn.flush().await.unwrap();
328
329 write_conn.enqueue(&2u32).unwrap();
331 assert_eq!(write_conn.pos, 2); write_conn.flush().await.unwrap();
334 assert_eq!(write_conn.pos, 0);
335 }
336
337 #[tokio::test]
338 async fn call_pipelining() {
339 use super::super::Call;
340 use serde::{Deserialize, Serialize};
341
342 #[derive(Debug, Serialize, Deserialize)]
343 struct TestMethod {
344 name: &'static str,
345 value: u32,
346 }
347
348 let mut write_conn = WriteConnection::new(TestWriteHalf::new(0), 1);
349
350 let call1 = Call::new(TestMethod {
352 name: "method1",
353 value: 1,
354 });
355 write_conn.enqueue_call(&call1).unwrap();
356
357 let call2 = Call::new(TestMethod {
358 name: "method2",
359 value: 2,
360 });
361 write_conn.enqueue_call(&call2).unwrap();
362
363 let call3 = Call::new(TestMethod {
364 name: "method3",
365 value: 3,
366 });
367 write_conn.enqueue_call(&call3).unwrap();
368
369 assert!(write_conn.pos > 0);
370
371 let buffer = &write_conn.buffer[..write_conn.pos];
373 let mut null_positions = [0usize; 3];
374 let mut null_count = 0;
375
376 for (i, &byte) in buffer.iter().enumerate() {
377 if byte == b'\0' {
378 assert!(null_count < 3, "Found more than 3 null terminators");
379 null_positions[null_count] = i;
380 null_count += 1;
381 }
382 }
383
384 assert_eq!(null_count, 3);
386
387 for i in 0..null_count {
389 let pos = null_positions[i];
390 assert!(
391 pos > 0,
392 "Null terminator at position {pos} should not be at start"
393 );
394 let preceding_byte = buffer[pos - 1];
395 assert!(
396 preceding_byte == b'}' || preceding_byte == b'"' || preceding_byte.is_ascii_digit(),
397 "Null terminator at position {pos} should be after valid JSON ending, found byte: {preceding_byte}"
398 );
399 }
400
401 assert_eq!(null_positions[2], write_conn.pos - 1);
403 }
404
405 #[tokio::test]
406 async fn pipelining_vs_individual_sends() {
407 use super::super::Call;
408 use serde::{Deserialize, Serialize};
409
410 #[derive(Debug, Serialize, Deserialize)]
411 struct TestMethod {
412 operation: &'static str,
413 id: u32,
414 }
415
416 use crate::test_utils::mock_socket::CountingWriteHalf;
418
419 let counting_write = CountingWriteHalf::new();
421 let mut write_conn_individual = WriteConnection::new(counting_write, 1);
422
423 for i in 1..=3 {
424 let call = Call::new(TestMethod {
425 operation: "fetch",
426 id: i,
427 });
428 write_conn_individual.send_call(&call).await.unwrap();
429 }
430 assert_eq!(write_conn_individual.socket.count(), 3);
431
432 let counting_write = CountingWriteHalf::new();
434 let mut write_conn_pipelined = WriteConnection::new(counting_write, 2);
435
436 for i in 1..=3 {
437 let call = Call::new(TestMethod {
438 operation: "fetch",
439 id: i,
440 });
441 write_conn_pipelined.enqueue_call(&call).unwrap();
442 }
443 write_conn_pipelined.flush().await.unwrap();
444 assert_eq!(write_conn_pipelined.socket.count(), 1);
445 }
446}