volans_stream_select/
protocol.rs1use bytes::{BufMut, Bytes, BytesMut};
2use futures::{AsyncRead, AsyncWrite, Sink, Stream, ready};
3use std::{
4 io,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
10
11const MSG_PROTOCOL_NA: &[u8] = b"na";
12
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub(crate) struct Protocol(String);
15impl AsRef<str> for Protocol {
16 fn as_ref(&self) -> &str {
17 self.0.as_ref()
18 }
19}
20
21impl TryFrom<Bytes> for Protocol {
22 type Error = ProtocolError;
23
24 fn try_from(value: Bytes) -> Result<Self, Self::Error> {
25 if !value.as_ref().starts_with(b"/") {
26 return Err(ProtocolError::InvalidProtocol);
27 }
28 let protocol_as_string =
29 String::from_utf8(value.to_vec()).map_err(|_| ProtocolError::InvalidProtocol)?;
30
31 Ok(Protocol(protocol_as_string))
32 }
33}
34
35impl TryFrom<&[u8]> for Protocol {
36 type Error = ProtocolError;
37
38 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
39 Self::try_from(Bytes::copy_from_slice(value))
40 }
41}
42
43impl TryFrom<&str> for Protocol {
44 type Error = ProtocolError;
45
46 fn try_from(value: &str) -> Result<Self, Self::Error> {
47 if !value.starts_with('/') {
48 return Err(ProtocolError::InvalidProtocol);
49 }
50
51 Ok(Protocol(value.to_owned()))
52 }
53}
54
55#[derive(Debug, thiserror::Error)]
56pub enum ProtocolError {
57 #[error("I/O error: {0}")]
58 IoError(#[from] io::Error),
59 #[error("Received an invalid message.")]
60 InvalidMessage,
61 #[error("A protocol (name) is invalid.")]
62 InvalidProtocol,
63}
64
65impl From<ProtocolError> for io::Error {
66 fn from(err: ProtocolError) -> Self {
67 match err {
68 ProtocolError::IoError(e) => e,
69 ProtocolError::InvalidMessage => io::Error::new(io::ErrorKind::InvalidData, err),
70 ProtocolError::InvalidProtocol => io::Error::new(io::ErrorKind::InvalidInput, err),
71 }
72 }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub(crate) enum Message {
77 Protocol(Protocol),
78 NotAvailable,
79}
80
81impl Message {
82 fn encode(&self, dst: &mut BytesMut) {
83 match self {
84 Message::NotAvailable => {
85 dst.reserve(MSG_PROTOCOL_NA.len());
86 dst.put(MSG_PROTOCOL_NA);
87 }
88 Message::Protocol(protocol) => {
89 dst.reserve(protocol.as_ref().len());
90 dst.put(protocol.0.as_ref());
91 }
92 }
93 }
94
95 fn decode(mut src: Bytes) -> Result<Self, ProtocolError> {
96 if src == MSG_PROTOCOL_NA {
97 return Ok(Message::NotAvailable);
98 }
99 if src.first() == Some(&b'/') {
100 let protocol = Protocol::try_from(src.split_to(src.len()))?;
101 return Ok(Message::Protocol(protocol));
102 }
103 Err(ProtocolError::InvalidMessage)
104 }
105}
106
107#[pin_project::pin_project]
108pub(crate) struct MessageIO<R> {
109 #[pin]
110 inner: LengthDelimited<R>,
111}
112
113impl<R> MessageIO<R> {
114 pub(crate) fn new(inner: R) -> MessageIO<R>
115 where
116 R: AsyncRead + AsyncWrite,
117 {
118 Self {
119 inner: LengthDelimited::new(inner),
120 }
121 }
122
123 pub(crate) fn into_reader(self) -> MessageReader<R> {
124 MessageReader {
125 inner: self.inner.into_reader(),
126 }
127 }
128
129 pub(crate) fn into_inner(self) -> R {
130 self.inner.into_inner()
131 }
132}
133
134impl<R> Sink<Message> for MessageIO<R>
135where
136 R: AsyncWrite,
137{
138 type Error = ProtocolError;
139
140 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141 self.project().inner.poll_ready(cx).map_err(From::from)
142 }
143
144 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
145 let mut buf = BytesMut::new();
146 item.encode(&mut buf);
147 self.project()
148 .inner
149 .start_send(buf.freeze())
150 .map_err(From::from)
151 }
152
153 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154 self.project().inner.poll_flush(cx).map_err(From::from)
155 }
156
157 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158 self.project().inner.poll_close(cx).map_err(From::from)
159 }
160}
161
162impl<R> Stream for MessageIO<R>
163where
164 R: AsyncRead,
165{
166 type Item = Result<Message, ProtocolError>;
167
168 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
169 poll_stream(self.project().inner, cx)
170 }
171}
172
173#[pin_project::pin_project]
174#[derive(Debug)]
175pub(crate) struct MessageReader<R> {
176 #[pin]
177 inner: LengthDelimitedReader<R>,
178}
179
180impl<R> MessageReader<R> {
181 pub(crate) fn into_inner(self) -> R {
182 self.inner.into_inner()
183 }
184}
185
186impl<R> Stream for MessageReader<R>
187where
188 R: AsyncRead,
189{
190 type Item = Result<Message, ProtocolError>;
191
192 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
193 poll_stream(self.project().inner, cx)
194 }
195}
196
197impl<R> AsyncWrite for MessageReader<R>
198where
199 R: AsyncWrite,
200{
201 fn poll_write(
202 self: Pin<&mut Self>,
203 cx: &mut Context<'_>,
204 buf: &[u8],
205 ) -> Poll<Result<usize, io::Error>> {
206 self.project().inner.poll_write(cx, buf)
207 }
208
209 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
210 self.project().inner.poll_flush(cx)
211 }
212
213 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
214 self.project().inner.poll_close(cx)
215 }
216
217 fn poll_write_vectored(
218 self: Pin<&mut Self>,
219 cx: &mut Context<'_>,
220 bufs: &[io::IoSlice<'_>],
221 ) -> Poll<Result<usize, io::Error>> {
222 self.project().inner.poll_write_vectored(cx, bufs)
223 }
224}
225
226fn poll_stream<S>(
227 stream: Pin<&mut S>,
228 cx: &mut Context<'_>,
229) -> Poll<Option<Result<Message, ProtocolError>>>
230where
231 S: Stream<Item = Result<Bytes, io::Error>>,
232{
233 let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
234 match Message::decode(msg) {
235 Ok(m) => m,
236 Err(err) => return Poll::Ready(Some(Err(err))),
237 }
238 } else {
239 return Poll::Ready(None);
240 };
241
242 Poll::Ready(Some(Ok(msg)))
243}