Skip to main content

volans_stream_select/
negotiated.rs

1use crate::{
2    ProtocolError,
3    protocol::{Message, MessageReader, Protocol},
4};
5use futures::{AsyncRead, AsyncWrite, Stream, ready};
6use pin_project::pin_project;
7use std::{
8    io, mem,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13#[pin_project]
14#[derive(Debug)]
15pub struct Negotiated<R> {
16    #[pin]
17    state: State<R>,
18}
19
20impl<R> Negotiated<R> {
21    pub(crate) fn completed(io: R) -> Self {
22        Negotiated {
23            state: State::Completed { io },
24        }
25    }
26
27    pub(crate) fn expecting(io: MessageReader<R>, protocol: Protocol) -> Self {
28        Negotiated {
29            state: State::Expecting { io, protocol },
30        }
31    }
32
33    pub fn complete(self) -> NegotiatedComplete<R> {
34        NegotiatedComplete { inner: Some(self) }
35    }
36}
37
38#[pin_project(project = StateProj)]
39#[derive(Debug)]
40enum State<R> {
41    Expecting {
42        #[pin]
43        io: MessageReader<R>,
44        protocol: Protocol,
45    },
46    Completed {
47        #[pin]
48        io: R,
49    },
50
51    Invalid,
52}
53
54impl<R> Negotiated<R> {
55    fn poll_negotiated(
56        mut self: Pin<&mut Self>,
57        cx: &mut Context<'_>,
58    ) -> Poll<Result<(), NegotiationError>>
59    where
60        R: AsyncRead + AsyncWrite + Unpin,
61    {
62        match self.as_mut().poll_flush(cx) {
63            Poll::Ready(Ok(())) => {}
64            Poll::Pending => return Poll::Pending,
65            Poll::Ready(Err(e)) => {
66                if e.kind() != io::ErrorKind::WriteZero {
67                    return Poll::Ready(Err(e.into()));
68                }
69            }
70        }
71        let mut this = self.project();
72        if let StateProj::Completed { .. } = this.state.as_mut().project() {
73            return Poll::Ready(Ok(()));
74        }
75        loop {
76            match mem::replace(&mut *this.state, State::Invalid) {
77                State::Expecting { mut io, protocol } => {
78                    let msg = match Pin::new(&mut io).poll_next(cx)? {
79                        Poll::Ready(Some(msg)) => msg,
80                        Poll::Ready(None) => {
81                            return Poll::Ready(Err(io::Error::new(
82                                io::ErrorKind::UnexpectedEof,
83                                "unexpected end of stream",
84                            )
85                            .into()));
86                        }
87                        Poll::Pending => {
88                            *this.state = State::Expecting { io, protocol };
89                            return Poll::Pending;
90                        }
91                    };
92                    tracing::trace!("Received message: {:?}", msg);
93                    if let Message::Protocol(p) = &msg {
94                        if p.as_ref() == protocol.as_ref() {
95                            tracing::trace!("Negotiated protocol completed: {}", p.as_ref());
96                            *this.state = State::Completed {
97                                io: io.into_inner(),
98                            };
99                            return Poll::Ready(Ok(()));
100                        }
101                    }
102                    return Poll::Ready(Err(NegotiationError::Failed));
103                }
104                _ => panic!("Negotiated state should not be in Invalid state"),
105            }
106        }
107    }
108}
109
110#[derive(Debug, thiserror::Error)]
111pub enum NegotiationError {
112    #[error("Invalid Protocol, {0}")]
113    ProtocolError(#[from] ProtocolError),
114    #[error("Protocol negotiation failed.")]
115    Failed,
116}
117
118impl From<io::Error> for NegotiationError {
119    fn from(err: io::Error) -> NegotiationError {
120        ProtocolError::from(err).into()
121    }
122}
123
124impl From<NegotiationError> for io::Error {
125    fn from(err: NegotiationError) -> io::Error {
126        if let NegotiationError::ProtocolError(e) = err {
127            return e.into();
128        }
129        io::Error::other(err)
130    }
131}
132
133impl<R> AsyncRead for Negotiated<R>
134where
135    R: AsyncRead + AsyncWrite + Unpin,
136{
137    fn poll_read(
138        mut self: Pin<&mut Self>,
139        cx: &mut Context<'_>,
140        buf: &mut [u8],
141    ) -> Poll<io::Result<usize>> {
142        loop {
143            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
144                return io.poll_read(cx, buf);
145            }
146            match self.as_mut().poll_negotiated(cx) {
147                Poll::Ready(Ok(())) => {}
148                Poll::Pending => return Poll::Pending,
149                Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
150            }
151        }
152    }
153
154    fn poll_read_vectored(
155        mut self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        bufs: &mut [io::IoSliceMut<'_>],
158    ) -> Poll<io::Result<usize>> {
159        loop {
160            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
161                return io.poll_read_vectored(cx, bufs);
162            }
163            //
164            match self.as_mut().poll_negotiated(cx) {
165                Poll::Ready(Ok(())) => {}
166                Poll::Pending => return Poll::Pending,
167                Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
168            }
169        }
170    }
171}
172
173impl<R> AsyncWrite for Negotiated<R>
174where
175    R: AsyncWrite + Unpin,
176{
177    fn poll_write(
178        self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &[u8],
181    ) -> Poll<io::Result<usize>> {
182        match self.project().state.project() {
183            StateProj::Completed { io } => io.poll_write(cx, buf),
184            StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
185            StateProj::Invalid => panic!("Negotiated state should not be in Invalid state"),
186        }
187    }
188
189    fn poll_write_vectored(
190        self: Pin<&mut Self>,
191        cx: &mut Context<'_>,
192        bufs: &[io::IoSlice<'_>],
193    ) -> Poll<io::Result<usize>> {
194        match self.project().state.project() {
195            StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
196            StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
197            StateProj::Invalid => panic!("Negotiated state should not be in Invalid state"),
198        }
199    }
200
201    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
202        match self.project().state.project() {
203            StateProj::Completed { io } => io.poll_flush(cx),
204            StateProj::Expecting { io, .. } => io.poll_flush(cx),
205            StateProj::Invalid => panic!("Negotiated state should not be in Invalid state"),
206        }
207    }
208
209    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
210        ready!(self.as_mut().poll_flush(cx))?;
211        match self.project().state.project() {
212            StateProj::Completed { io } => io.poll_close(cx),
213            StateProj::Expecting { io, .. } => io.poll_close(cx),
214            StateProj::Invalid => panic!("Negotiated state should not be in Invalid state"),
215        }
216    }
217}
218
219#[derive(Debug)]
220pub struct NegotiatedComplete<R> {
221    inner: Option<Negotiated<R>>,
222}
223
224impl<R> Future for NegotiatedComplete<R>
225where
226    R: AsyncRead + AsyncWrite + Unpin,
227{
228    type Output = Result<Negotiated<R>, NegotiationError>;
229
230    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
231        let mut io = self
232            .inner
233            .take()
234            .expect("NegotiatedFuture called after completion.");
235        match Pin::new(&mut io).poll_negotiated(cx) {
236            Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
237            Poll::Pending => {
238                self.get_mut().inner = Some(io);
239                Poll::Pending
240            }
241            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
242        }
243    }
244}