1use crate::{
2 ProtocolError,
3 protocol::{Message, MessageReader, Protocol},
4};
5use futures::{AsyncRead, AsyncWrite, Stream, ready};
6use pin_project::pin_project;
7use std::{
8 io, mem,
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13#[pin_project]
14#[derive(Debug)]
15pub struct Negotiated<R> {
16 #[pin]
17 state: State<R>,
18}
19
20impl<R> Negotiated<R> {
21 pub(crate) fn completed(io: R) -> Self {
22 Negotiated {
23 state: State::Completed { io },
24 }
25 }
26
27 pub(crate) fn expecting(io: MessageReader<R>, protocol: Protocol) -> Self {
28 Negotiated {
29 state: State::Expecting { io, protocol },
30 }
31 }
32
33 pub fn complete(self) -> NegotiatedComplete<R> {
34 NegotiatedComplete { inner: Some(self) }
35 }
36}
37
38#[pin_project(project = StateProj)]
39#[derive(Debug)]
40enum State<R> {
41 Expecting {
42 #[pin]
43 io: MessageReader<R>,
44 protocol: Protocol,
45 },
46 Completed {
47 #[pin]
48 io: R,
49 },
50
51 Invalid,
52}
53
54impl<R> Negotiated<R> {
55 fn poll_negotiated(
56 mut self: Pin<&mut Self>,
57 cx: &mut Context<'_>,
58 ) -> Poll<Result<(), NegotiationError>>
59 where
60 R: AsyncRead + AsyncWrite + Unpin,
61 {
62 match self.as_mut().poll_flush(cx) {
63 Poll::Ready(Ok(())) => {}
64 Poll::Pending => return Poll::Pending,
65 Poll::Ready(Err(e)) => {
66 if e.kind() != io::ErrorKind::WriteZero {
67 return Poll::Ready(Err(e.into()));
68 }
69 }
70 }
71 let mut this = self.project();
72 if let StateProj::Completed { .. } = this.state.as_mut().project() {
73 return Poll::Ready(Ok(()));
74 }
75 loop {
76 match mem::replace(&mut *this.state, State::Invalid) {
77 State::Expecting { mut io, protocol } => {
78 let msg = match Pin::new(&mut io).poll_next(cx)? {
79 Poll::Ready(Some(msg)) => msg,
80 Poll::Ready(None) => {
81 return Poll::Ready(Err(io::Error::new(
82 io::ErrorKind::UnexpectedEof,
83 "unexpected end of stream",
84 )
85 .into()));
86 }
87 Poll::Pending => {
88 *this.state = State::Expecting { io, protocol };
89 return Poll::Pending;
90 }
91 };
92 tracing::trace!("Received message: {:?}", msg);
93 if let Message::Protocol(p) = &msg {
94 if p.as_ref() == protocol.as_ref() {
95 tracing::trace!("Negotiated protocol completed: {}", p.as_ref());
96 *this.state = State::Completed {
97 io: io.into_inner(),
98 };
99 return Poll::Ready(Ok(()));
100 }
101 }
102 return Poll::Ready(Err(NegotiationError::Failed));
103 }
104 _ => panic!("Negotiated state should not be in Invalid state"),
105 }
106 }
107 }
108}
109
110#[derive(Debug, thiserror::Error)]
111pub enum NegotiationError {
112 #[error("Invalid Protocol, {0}")]
113 ProtocolError(#[from] ProtocolError),
114 #[error("Protocol negotiation failed.")]
115 Failed,
116}
117
118impl From<io::Error> for NegotiationError {
119 fn from(err: io::Error) -> NegotiationError {
120 ProtocolError::from(err).into()
121 }
122}
123
124impl From<NegotiationError> for io::Error {
125 fn from(err: NegotiationError) -> io::Error {
126 if let NegotiationError::ProtocolError(e) = err {
127 return e.into();
128 }
129 io::Error::other(err)
130 }
131}
132
133impl<R> AsyncRead for Negotiated<R>
134where
135 R: AsyncRead + AsyncWrite + Unpin,
136{
137 fn poll_read(
138 mut self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 buf: &mut [u8],
141 ) -> Poll<io::Result<usize>> {
142 loop {
143 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
144 return io.poll_read(cx, buf);
145 }
146 match self.as_mut().poll_negotiated(cx) {
147 Poll::Ready(Ok(())) => {}
148 Poll::Pending => return Poll::Pending,
149 Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
150 }
151 }
152 }
153
154 fn poll_read_vectored(
155 mut self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 bufs: &mut [io::IoSliceMut<'_>],
158 ) -> Poll<io::Result<usize>> {
159 loop {
160 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
161 return io.poll_read_vectored(cx, bufs);
162 }
163 match self.as_mut().poll_negotiated(cx) {
165 Poll::Ready(Ok(())) => {}
166 Poll::Pending => return Poll::Pending,
167 Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
168 }
169 }
170 }
171}
172
173impl<R> AsyncWrite for Negotiated<R>
174where
175 R: AsyncWrite + Unpin,
176{
177 fn poll_write(
178 self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &[u8],
181 ) -> Poll<io::Result<usize>> {
182 match self.project().state.project() {
183 StateProj::Completed { io } => io.poll_write(cx, buf),
184 StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
185 StateProj::Invalid => panic!("Negotiated state should not be in Invalid state"),
186 }
187 }
188
189 fn poll_write_vectored(
190 self: Pin<&mut Self>,
191 cx: &mut Context<'_>,
192 bufs: &[io::IoSlice<'_>],
193 ) -> Poll<io::Result<usize>> {
194 match self.project().state.project() {
195 StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
196 StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
197 StateProj::Invalid => panic!("Negotiated state should not be in Invalid state"),
198 }
199 }
200
201 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
202 match self.project().state.project() {
203 StateProj::Completed { io } => io.poll_flush(cx),
204 StateProj::Expecting { io, .. } => io.poll_flush(cx),
205 StateProj::Invalid => panic!("Negotiated state should not be in Invalid state"),
206 }
207 }
208
209 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
210 ready!(self.as_mut().poll_flush(cx))?;
211 match self.project().state.project() {
212 StateProj::Completed { io } => io.poll_close(cx),
213 StateProj::Expecting { io, .. } => io.poll_close(cx),
214 StateProj::Invalid => panic!("Negotiated state should not be in Invalid state"),
215 }
216 }
217}
218
219#[derive(Debug)]
220pub struct NegotiatedComplete<R> {
221 inner: Option<Negotiated<R>>,
222}
223
224impl<R> Future for NegotiatedComplete<R>
225where
226 R: AsyncRead + AsyncWrite + Unpin,
227{
228 type Output = Result<Negotiated<R>, NegotiationError>;
229
230 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
231 let mut io = self
232 .inner
233 .take()
234 .expect("NegotiatedFuture called after completion.");
235 match Pin::new(&mut io).poll_negotiated(cx) {
236 Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
237 Poll::Pending => {
238 self.get_mut().inner = Some(io);
239 Poll::Pending
240 }
241 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
242 }
243 }
244}