Skip to main content

volans_stream_select/
protocol.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use futures::{AsyncRead, AsyncWrite, Sink, Stream, ready};
3use std::{
4    io,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
10
11const MSG_PROTOCOL_NA: &[u8] = b"na";
12
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub(crate) struct Protocol(String);
15impl AsRef<str> for Protocol {
16    fn as_ref(&self) -> &str {
17        self.0.as_ref()
18    }
19}
20
21impl TryFrom<Bytes> for Protocol {
22    type Error = ProtocolError;
23
24    fn try_from(value: Bytes) -> Result<Self, Self::Error> {
25        if !value.as_ref().starts_with(b"/") {
26            return Err(ProtocolError::InvalidProtocol);
27        }
28        let protocol_as_string =
29            String::from_utf8(value.to_vec()).map_err(|_| ProtocolError::InvalidProtocol)?;
30
31        Ok(Protocol(protocol_as_string))
32    }
33}
34
35impl TryFrom<&[u8]> for Protocol {
36    type Error = ProtocolError;
37
38    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
39        Self::try_from(Bytes::copy_from_slice(value))
40    }
41}
42
43impl TryFrom<&str> for Protocol {
44    type Error = ProtocolError;
45
46    fn try_from(value: &str) -> Result<Self, Self::Error> {
47        if !value.starts_with('/') {
48            return Err(ProtocolError::InvalidProtocol);
49        }
50
51        Ok(Protocol(value.to_owned()))
52    }
53}
54
55#[derive(Debug, thiserror::Error)]
56pub enum ProtocolError {
57    #[error("I/O error: {0}")]
58    IoError(#[from] io::Error),
59    #[error("Received an invalid message.")]
60    InvalidMessage,
61    #[error("A protocol (name) is invalid.")]
62    InvalidProtocol,
63}
64
65impl From<ProtocolError> for io::Error {
66    fn from(err: ProtocolError) -> Self {
67        match err {
68            ProtocolError::IoError(e) => e,
69            ProtocolError::InvalidMessage => io::Error::new(io::ErrorKind::InvalidData, err),
70            ProtocolError::InvalidProtocol => io::Error::new(io::ErrorKind::InvalidInput, err),
71        }
72    }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub(crate) enum Message {
77    Protocol(Protocol),
78    NotAvailable,
79}
80
81impl Message {
82    fn encode(&self, dst: &mut BytesMut) {
83        match self {
84            Message::NotAvailable => {
85                dst.reserve(MSG_PROTOCOL_NA.len());
86                dst.put(MSG_PROTOCOL_NA);
87            }
88            Message::Protocol(protocol) => {
89                dst.reserve(protocol.as_ref().len());
90                dst.put(protocol.0.as_ref());
91            }
92        }
93    }
94
95    fn decode(mut src: Bytes) -> Result<Self, ProtocolError> {
96        if src == MSG_PROTOCOL_NA {
97            return Ok(Message::NotAvailable);
98        }
99        if src.first() == Some(&b'/') {
100            let protocol = Protocol::try_from(src.split_to(src.len()))?;
101            return Ok(Message::Protocol(protocol));
102        }
103        Err(ProtocolError::InvalidMessage)
104    }
105}
106
107#[pin_project::pin_project]
108pub(crate) struct MessageIO<R> {
109    #[pin]
110    inner: LengthDelimited<R>,
111}
112
113impl<R> MessageIO<R> {
114    pub(crate) fn new(inner: R) -> MessageIO<R>
115    where
116        R: AsyncRead + AsyncWrite,
117    {
118        Self {
119            inner: LengthDelimited::new(inner),
120        }
121    }
122
123    pub(crate) fn into_reader(self) -> MessageReader<R> {
124        MessageReader {
125            inner: self.inner.into_reader(),
126        }
127    }
128
129    pub(crate) fn into_inner(self) -> R {
130        self.inner.into_inner()
131    }
132}
133
134impl<R> Sink<Message> for MessageIO<R>
135where
136    R: AsyncWrite,
137{
138    type Error = ProtocolError;
139
140    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141        self.project().inner.poll_ready(cx).map_err(From::from)
142    }
143
144    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
145        let mut buf = BytesMut::new();
146        item.encode(&mut buf);
147        self.project()
148            .inner
149            .start_send(buf.freeze())
150            .map_err(From::from)
151    }
152
153    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154        self.project().inner.poll_flush(cx).map_err(From::from)
155    }
156
157    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158        self.project().inner.poll_close(cx).map_err(From::from)
159    }
160}
161
162impl<R> Stream for MessageIO<R>
163where
164    R: AsyncRead,
165{
166    type Item = Result<Message, ProtocolError>;
167
168    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
169        poll_stream(self.project().inner, cx)
170    }
171}
172
173#[pin_project::pin_project]
174#[derive(Debug)]
175pub(crate) struct MessageReader<R> {
176    #[pin]
177    inner: LengthDelimitedReader<R>,
178}
179
180impl<R> MessageReader<R> {
181    pub(crate) fn into_inner(self) -> R {
182        self.inner.into_inner()
183    }
184}
185
186impl<R> Stream for MessageReader<R>
187where
188    R: AsyncRead,
189{
190    type Item = Result<Message, ProtocolError>;
191
192    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
193        poll_stream(self.project().inner, cx)
194    }
195}
196
197impl<R> AsyncWrite for MessageReader<R>
198where
199    R: AsyncWrite,
200{
201    fn poll_write(
202        self: Pin<&mut Self>,
203        cx: &mut Context<'_>,
204        buf: &[u8],
205    ) -> Poll<Result<usize, io::Error>> {
206        self.project().inner.poll_write(cx, buf)
207    }
208
209    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
210        self.project().inner.poll_flush(cx)
211    }
212
213    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
214        self.project().inner.poll_close(cx)
215    }
216
217    fn poll_write_vectored(
218        self: Pin<&mut Self>,
219        cx: &mut Context<'_>,
220        bufs: &[io::IoSlice<'_>],
221    ) -> Poll<Result<usize, io::Error>> {
222        self.project().inner.poll_write_vectored(cx, bufs)
223    }
224}
225
226fn poll_stream<S>(
227    stream: Pin<&mut S>,
228    cx: &mut Context<'_>,
229) -> Poll<Option<Result<Message, ProtocolError>>>
230where
231    S: Stream<Item = Result<Bytes, io::Error>>,
232{
233    let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
234        match Message::decode(msg) {
235            Ok(m) => m,
236            Err(err) => return Poll::Ready(Some(Err(err))),
237        }
238    } else {
239        return Poll::Ready(None);
240    };
241
242    Poll::Ready(Some(Ok(msg)))
243}