zlink_core/connection/
write_connection.rs

1//! Contains connection related API.
2
3use core::fmt::Debug;
4
5use mayheap::Vec;
6use serde::Serialize;
7
8use super::{socket::WriteHalf, Call, Reply, BUFFER_SIZE};
9
10/// A connection.
11///
12/// The low-level API to send messages.
13///
14/// # Cancel safety
15///
16/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
17/// documentation.
18#[derive(Debug)]
19pub struct WriteConnection<Write: WriteHalf> {
20    socket: Write,
21    buffer: Vec<u8, BUFFER_SIZE>,
22    pos: usize,
23    id: usize,
24}
25
26impl<Write: WriteHalf> WriteConnection<Write> {
27    /// Create a new connection.
28    pub(super) fn new(socket: Write, id: usize) -> Self {
29        Self {
30            socket,
31            id,
32            buffer: Vec::from_slice(&[0; BUFFER_SIZE]).unwrap(),
33            pos: 0,
34        }
35    }
36
37    /// The unique identifier of the connection.
38    #[inline]
39    pub fn id(&self) -> usize {
40        self.id
41    }
42
43    /// Sends a method call.
44    ///
45    /// The generic `Method` is the type of the method name and its input parameters. This should be
46    /// a type that can serialize itself to a complete method call message, i-e an object containing
47    /// `method` and `parameter` fields. This can be easily achieved using the `serde::Serialize`
48    /// derive:
49    ///
50    /// ```rust
51    /// use serde::{Serialize, Deserialize};
52    /// use serde_prefix_all::prefix_all;
53    ///
54    /// #[prefix_all("org.example.ftl.")]
55    /// #[derive(Debug, Serialize, Deserialize)]
56    /// #[serde(tag = "method", content = "parameters")]
57    /// enum MyMethods<'m> {
58    ///    // The name needs to be the fully-qualified name of the error.
59    ///    Alpha { param1: u32, param2: &'m str},
60    ///    Bravo,
61    ///    Charlie { param1: &'m str },
62    /// }
63    /// ```
64    pub async fn send_call<Method>(&mut self, call: &Call<Method>) -> crate::Result<()>
65    where
66        Method: Serialize + Debug,
67    {
68        trace!("connection {}: sending call: {:?}", self.id, call);
69        self.write(call).await
70    }
71
72    /// Send a reply over the socket.
73    ///
74    /// The generic parameter `Params` is the type of the successful reply. This should be a type
75    /// that can serialize itself as the `parameters` field of the reply.
76    pub async fn send_reply<Params>(&mut self, reply: &Reply<Params>) -> crate::Result<()>
77    where
78        Params: Serialize + Debug,
79    {
80        trace!("connection {}: sending reply: {:?}", self.id, reply);
81        self.write(reply).await
82    }
83
84    /// Send an error reply over the socket.
85    ///
86    /// The generic parameter `ReplyError` is the type of the error reply. This should be a type
87    /// that can serialize itself to the whole reply object, containing `error` and `parameter`
88    /// fields. This can be easily achieved using the `serde::Serialize` derive (See the code
89    /// snippet in [`super::ReadConnection::receive_reply`] documentation for an example).
90    pub async fn send_error<ReplyError>(&mut self, error: &ReplyError) -> crate::Result<()>
91    where
92        ReplyError: Serialize + Debug,
93    {
94        trace!("connection {}: sending error: {:?}", self.id, error);
95        self.write(error).await
96    }
97
98    /// Enqueue a call to be sent over the socket.
99    ///
100    /// Similar to [`WriteConnection::send_call`], except that the call is not sent immediately but
101    /// enqueued for later sending. This is useful when you want to send multiple calls in a
102    /// batch.
103    pub fn enqueue_call<Method>(&mut self, call: &Call<Method>) -> crate::Result<()>
104    where
105        Method: Serialize + Debug,
106    {
107        trace!("connection {}: enqueuing call: {:?}", self.id, call);
108        self.enqueue(call)
109    }
110
111    /// Send out the enqueued calls.
112    pub async fn flush(&mut self) -> crate::Result<()> {
113        if self.pos == 0 {
114            return Ok(());
115        }
116
117        trace!("connection {}: flushing {} bytes", self.id, self.pos);
118        self.socket.write(&self.buffer[..self.pos]).await?;
119        self.pos = 0;
120        Ok(())
121    }
122
123    /// The underlying write half of the socket.
124    pub fn write_half(&self) -> &Write {
125        &self.socket
126    }
127
128    async fn write<T>(&mut self, value: &T) -> crate::Result<()>
129    where
130        T: Serialize + ?Sized + Debug,
131    {
132        self.enqueue(value)?;
133        self.flush().await
134    }
135
136    fn enqueue<T>(&mut self, value: &T) -> crate::Result<()>
137    where
138        T: Serialize + ?Sized + Debug,
139    {
140        let len = loop {
141            match to_slice_at_pos(value, &mut self.buffer, self.pos) {
142                Ok(len) => break len,
143                #[cfg(feature = "std")]
144                Err(crate::Error::Json(e)) if e.is_io() => {
145                    // This can only happens if `serde-json` failed to write all bytes and that
146                    // means we're running out of space or already are out of space.
147                    self.grow_buffer()?;
148                }
149                Err(e) => return Err(e),
150            }
151        };
152
153        // Add null terminator after this message.
154        if self.pos + len == self.buffer.len() {
155            #[cfg(feature = "std")]
156            {
157                self.grow_buffer()?;
158            }
159            #[cfg(not(feature = "std"))]
160            {
161                return Err(crate::Error::BufferOverflow);
162            }
163        }
164        self.buffer[self.pos + len] = b'\0';
165        self.pos += len + 1;
166        Ok(())
167    }
168
169    #[cfg(feature = "std")]
170    fn grow_buffer(&mut self) -> crate::Result<()> {
171        if self.buffer.len() >= super::MAX_BUFFER_SIZE {
172            return Err(crate::Error::BufferOverflow);
173        }
174
175        self.buffer.extend_from_slice(&[0; BUFFER_SIZE])?;
176
177        Ok(())
178    }
179}
180
181fn to_slice_at_pos<T>(value: &T, buf: &mut [u8], pos: usize) -> crate::Result<usize>
182where
183    T: Serialize + ?Sized,
184{
185    #[cfg(feature = "std")]
186    {
187        let mut cursor = std::io::Cursor::new(&mut buf[pos..]);
188        serde_json::to_writer(&mut cursor, value)?;
189
190        Ok(cursor.position() as usize)
191    }
192
193    #[cfg(not(feature = "std"))]
194    {
195        serde_json_core::to_slice(value, &mut buf[pos..]).map_err(Into::into)
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    use crate::test_utils::mock_socket::TestWriteHalf;
204
205    #[tokio::test]
206    async fn write() {
207        const WRITE_LEN: usize =
208            // Every `0u8` is one byte.
209            BUFFER_SIZE +
210            // `,` separators.
211            (BUFFER_SIZE - 1) +
212            // `[` and `]`.
213            2 +
214            // null byte from enqueue.
215            1;
216        let mut write_conn = WriteConnection::new(TestWriteHalf::new(WRITE_LEN), 1);
217        // An item that serializes into `> BUFFER_SIZE * 2` bytes.
218        let item: Vec<u8, BUFFER_SIZE> = Vec::from_slice(&[0u8; BUFFER_SIZE]).unwrap();
219        let res = write_conn.write(&item).await;
220        #[cfg(feature = "std")]
221        {
222            res.unwrap();
223            assert_eq!(write_conn.buffer.len(), BUFFER_SIZE * 3);
224            assert_eq!(write_conn.pos, 0); // Reset after flush.
225        }
226        #[cfg(not(feature = "std"))]
227        {
228            assert!(matches!(
229                res,
230                Err(crate::Error::JsonSerialize(
231                    serde_json_core::ser::Error::BufferFull
232                ))
233            ));
234            assert_eq!(write_conn.buffer.len(), BUFFER_SIZE);
235        }
236    }
237
238    #[tokio::test]
239    async fn enqueue_and_flush() {
240        // Test enqueuing multiple small items.
241        let mut write_conn = WriteConnection::new(TestWriteHalf::new(5), 1); // "42\03\0"
242
243        write_conn.enqueue(&42u32).unwrap();
244        write_conn.enqueue(&3u32).unwrap();
245        assert_eq!(write_conn.pos, 5); // "42\03\0"
246
247        write_conn.flush().await.unwrap();
248        assert_eq!(write_conn.pos, 0); // Reset after flush.
249    }
250
251    #[tokio::test]
252    async fn enqueue_null_terminators() {
253        // Test that null terminators are properly placed.
254        let mut write_conn = WriteConnection::new(TestWriteHalf::new(4), 1); // "1\02\0"
255
256        write_conn.enqueue(&1u32).unwrap();
257        assert_eq!(write_conn.buffer[write_conn.pos - 1], b'\0');
258
259        write_conn.enqueue(&2u32).unwrap();
260        assert_eq!(write_conn.buffer[write_conn.pos - 1], b'\0');
261
262        write_conn.flush().await.unwrap();
263    }
264
265    #[cfg(feature = "std")]
266    #[tokio::test]
267    async fn enqueue_buffer_extension() {
268        // Test buffer extension when enqueuing large items.
269        let mut write_conn = WriteConnection::new(TestWriteHalf::new(0), 1);
270        let initial_len = write_conn.buffer.len();
271
272        // Fill up the buffer.
273        let large_item: Vec<u8, BUFFER_SIZE> = Vec::from_slice(&[0u8; BUFFER_SIZE]).unwrap();
274        write_conn.enqueue(&large_item).unwrap();
275
276        assert!(write_conn.buffer.len() > initial_len);
277    }
278
279    #[cfg(not(feature = "std"))]
280    #[tokio::test]
281    async fn enqueue_buffer_overflow() {
282        // Test buffer overflow error without std feature.
283        let mut write_conn = WriteConnection::new(TestWriteHalf::new(0), 1);
284
285        // Try to enqueue an item that doesn't fit.
286        let large_item: Vec<u8, BUFFER_SIZE> = Vec::from_slice(&[0u8; BUFFER_SIZE]).unwrap();
287        let res = write_conn.enqueue(&large_item);
288
289        assert!(matches!(
290            res,
291            Err(crate::Error::JsonSerialize(
292                serde_json_core::ser::Error::BufferFull
293            ))
294        ));
295    }
296
297    #[tokio::test]
298    async fn flush_empty_buffer() {
299        // Test that flushing an empty buffer is a no-op.
300        let mut write_conn = WriteConnection::new(TestWriteHalf::new(0), 1);
301
302        // Should not call write since buffer is empty.
303        write_conn.flush().await.unwrap();
304        assert_eq!(write_conn.pos, 0);
305    }
306
307    #[tokio::test]
308    async fn multiple_flushes() {
309        // Test multiple flushes in a row.
310        let mut write_conn = WriteConnection::new(TestWriteHalf::new(2), 1); // "1\0"
311
312        write_conn.enqueue(&1u32).unwrap();
313        write_conn.flush().await.unwrap();
314        assert_eq!(write_conn.pos, 0);
315
316        // Second flush should be a no-op.
317        write_conn.flush().await.unwrap();
318        assert_eq!(write_conn.pos, 0);
319    }
320
321    #[tokio::test]
322    async fn enqueue_after_flush() {
323        // Test that enqueuing works properly after a flush.
324        let mut write_conn = WriteConnection::new(TestWriteHalf::new(2), 1); // "2\0"
325
326        write_conn.enqueue(&1u32).unwrap();
327        write_conn.flush().await.unwrap();
328
329        // Should be able to enqueue again after flush.
330        write_conn.enqueue(&2u32).unwrap();
331        assert_eq!(write_conn.pos, 2); // "2\0"
332
333        write_conn.flush().await.unwrap();
334        assert_eq!(write_conn.pos, 0);
335    }
336
337    #[tokio::test]
338    async fn call_pipelining() {
339        use super::super::Call;
340        use serde::{Deserialize, Serialize};
341
342        #[derive(Debug, Serialize, Deserialize)]
343        struct TestMethod {
344            name: &'static str,
345            value: u32,
346        }
347
348        let mut write_conn = WriteConnection::new(TestWriteHalf::new(0), 1);
349
350        // Test pipelining multiple method calls.
351        let call1 = Call::new(TestMethod {
352            name: "method1",
353            value: 1,
354        });
355        write_conn.enqueue_call(&call1).unwrap();
356
357        let call2 = Call::new(TestMethod {
358            name: "method2",
359            value: 2,
360        });
361        write_conn.enqueue_call(&call2).unwrap();
362
363        let call3 = Call::new(TestMethod {
364            name: "method3",
365            value: 3,
366        });
367        write_conn.enqueue_call(&call3).unwrap();
368
369        assert!(write_conn.pos > 0);
370
371        // Verify that all calls are properly queued with null terminators.
372        let buffer = &write_conn.buffer[..write_conn.pos];
373        let mut null_positions = [0usize; 3];
374        let mut null_count = 0;
375
376        for (i, &byte) in buffer.iter().enumerate() {
377            if byte == b'\0' {
378                assert!(null_count < 3, "Found more than 3 null terminators");
379                null_positions[null_count] = i;
380                null_count += 1;
381            }
382        }
383
384        // Should have exactly 3 null terminators for 3 calls.
385        assert_eq!(null_count, 3);
386
387        // Verify each null terminator is at the end of a complete JSON object.
388        for i in 0..null_count {
389            let pos = null_positions[i];
390            assert!(
391                pos > 0,
392                "Null terminator at position {pos} should not be at start"
393            );
394            let preceding_byte = buffer[pos - 1];
395            assert!(
396                preceding_byte == b'}' || preceding_byte == b'"' || preceding_byte.is_ascii_digit(),
397                "Null terminator at position {pos} should be after valid JSON ending, found byte: {preceding_byte}"
398            );
399        }
400
401        // Verify the last null terminator is at the very end.
402        assert_eq!(null_positions[2], write_conn.pos - 1);
403    }
404
405    #[tokio::test]
406    async fn pipelining_vs_individual_sends() {
407        use super::super::Call;
408        use serde::{Deserialize, Serialize};
409
410        #[derive(Debug, Serialize, Deserialize)]
411        struct TestMethod {
412            operation: &'static str,
413            id: u32,
414        }
415
416        // Use consolidated counting write half from test_utils.
417        use crate::test_utils::mock_socket::CountingWriteHalf;
418
419        // Test individual sends (3 write calls expected).
420        let counting_write = CountingWriteHalf::new();
421        let mut write_conn_individual = WriteConnection::new(counting_write, 1);
422
423        for i in 1..=3 {
424            let call = Call::new(TestMethod {
425                operation: "fetch",
426                id: i,
427            });
428            write_conn_individual.send_call(&call).await.unwrap();
429        }
430        assert_eq!(write_conn_individual.socket.count(), 3);
431
432        // Test pipelined sends (1 write call expected).
433        let counting_write = CountingWriteHalf::new();
434        let mut write_conn_pipelined = WriteConnection::new(counting_write, 2);
435
436        for i in 1..=3 {
437            let call = Call::new(TestMethod {
438                operation: "fetch",
439                id: i,
440            });
441            write_conn_pipelined.enqueue_call(&call).unwrap();
442        }
443        write_conn_pipelined.flush().await.unwrap();
444        assert_eq!(write_conn_pipelined.socket.count(), 1);
445    }
446}