Skip to main content

stomp_agnostic/transport/
server.rs

1use crate::frame::parse_frame;
2use crate::transport::{ReadData, ReadError, WriteError};
3use crate::{FromServer, Message, ToServer};
4use async_trait::async_trait;
5use bytes::{Buf, Bytes, BytesMut};
6use std::fmt::Debug;
7use winnow::Partial;
8use winnow::error::ErrMode;
9use winnow::stream::Offset;
10
11#[async_trait]
12pub trait ServerTransport {
13    /// A side channel to shuffle arbitrary data that is not part of the STOMP communication,
14    /// e.g. WebSocket Ping/Pong.
15    type ProtocolSideChannel;
16
17    async fn write(&mut self, message: Message<FromServer>) -> Result<(), WriteError>;
18    async fn read(&mut self) -> Result<ReadData<Self::ProtocolSideChannel>, ReadError>;
19}
20
21/// A parsed response, either a [Message] coming from the server, or a custom protocol signal
22/// in the `Custom` variant.
23#[derive(Debug)]
24pub enum ClientData<T>
25where
26    T: Debug,
27{
28    Message(Message<ToServer>),
29    Custom(T),
30}
31
32pub(crate) struct BufferedTransport<T>
33where
34    T: ServerTransport,
35    T::ProtocolSideChannel: Debug,
36{
37    transport: T,
38    buffer: BytesMut,
39}
40
41impl<T> BufferedTransport<T>
42where
43    T: ServerTransport,
44    T::ProtocolSideChannel: Debug,
45{
46    pub(crate) fn new(transport: T) -> Self {
47        Self {
48            transport,
49            buffer: BytesMut::with_capacity(4096),
50        }
51    }
52
53    fn append(&mut self, data: Bytes) {
54        self.buffer.extend_from_slice(&data);
55    }
56
57    fn decode(&mut self) -> Result<Option<Message<ToServer>>, ReadError> {
58        // Create a partial view of the buffer for parsing
59        let buf = &mut Partial::new(self.buffer.chunk());
60
61        // Attempt to parse a frame from the buffer
62        let item = match parse_frame(buf) {
63            Ok(frame) => Message::<ToServer>::from_frame(frame),
64            // Need more data
65            Err(ErrMode::Incomplete(_)) => return Ok(None),
66            Err(e) => return Err(ReadError::Parser(e)),
67        };
68
69        // Calculate how many bytes were consumed
70        let len = buf.offset_from(&Partial::new(self.buffer.chunk()));
71
72        // Advance the buffer past the consumed bytes
73        self.buffer.advance(len);
74
75        // Return the parsed message (or error)
76        item.map_err(|e| e.into()).map(Some)
77    }
78
79    pub(crate) async fn send(&mut self, message: Message<FromServer>) -> Result<(), WriteError> {
80        self.transport.write(message).await
81    }
82
83    pub(crate) async fn next(&mut self) -> Result<ClientData<T::ProtocolSideChannel>, ReadError> {
84        loop {
85            let response = self.transport.read().await?;
86            match response {
87                ReadData::Binary(buffer) => {
88                    self.append(buffer);
89                }
90                ReadData::Custom(custom) => {
91                    return Ok(ClientData::Custom(custom));
92                }
93            }
94
95            if let Some(message) = self.decode()? {
96                return Ok(ClientData::Message(message));
97            }
98        }
99    }
100
101    pub(crate) fn into_transport(self) -> T {
102        self.transport
103    }
104
105    pub(crate) fn as_mut_inner(&mut self) -> &mut T {
106        &mut self.transport
107    }
108}