volans_stream_select/
dialer_select.rs1use futures::{AsyncRead, AsyncWrite, Sink, Stream};
2
3use crate::{
4 Negotiated, NegotiationError,
5 protocol::{Message, MessageIO, Protocol},
6};
7use std::{
8 iter, mem,
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13#[pin_project::pin_project]
14pub struct DialerSelectFuture<R, I: Iterator> {
15 protocols: iter::Peekable<I>,
16 state: State<R, I::Item>,
17 lazy: bool,
18}
19
20impl<R, I> DialerSelectFuture<R, I>
21where
22 R: AsyncRead + AsyncWrite,
23 I: Iterator,
24 I::Item: AsRef<str>,
25{
26 pub fn new(io: R, protocols: I) -> Self {
27 DialerSelectFuture {
28 protocols: protocols.peekable(),
29 state: State::Initial {
30 io: MessageIO::new(io),
31 },
32 lazy: false,
33 }
34 }
35}
36
37enum State<R, P> {
38 Initial { io: MessageIO<R> },
39 SendProtocol { io: MessageIO<R>, protocol: P },
40 FlushProtocol { io: MessageIO<R>, protocol: P },
41 AwaitProtocol { io: MessageIO<R>, protocol: P },
42 Done,
43}
44
45impl<R, I> Future for DialerSelectFuture<R, I>
46where
47 R: AsyncRead + AsyncWrite + Unpin,
48 I: Iterator,
49 I::Item: AsRef<str>,
50{
51 type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
52
53 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54 let this = self.project();
55 loop {
56 match mem::replace(this.state, State::Done) {
57 State::Initial { mut io } => {
58 match Pin::new(&mut io).poll_ready(cx)? {
59 Poll::Ready(()) => {}
60 Poll::Pending => {
61 *this.state = State::Initial { io };
62 return Poll::Pending;
63 }
64 };
65 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
66 *this.state = State::SendProtocol { io, protocol };
67 }
68 State::SendProtocol { mut io, protocol } => {
69 tracing::trace!("Sending protocol: {}", protocol.as_ref());
70 match Pin::new(&mut io).poll_ready(cx)? {
71 Poll::Ready(()) => {}
72 Poll::Pending => {
73 *this.state = State::SendProtocol { io, protocol };
74 return Poll::Pending;
75 }
76 };
77 let p = Protocol::try_from(protocol.as_ref())?;
78 if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
80 return Poll::Ready(Err(From::from(err)));
81 }
82 if this.protocols.peek().is_some() {
83 *this.state = State::FlushProtocol { io, protocol };
85 } else if *this.lazy {
86 tracing::trace!("Expecting protocol: {}", p.as_ref());
88 let io = Negotiated::expecting(io.into_reader(), p);
89 return Poll::Ready(Ok((protocol, io)));
90 } else {
91 *this.state = State::FlushProtocol { io, protocol };
93 }
94 }
95 State::FlushProtocol { mut io, protocol } => {
96 match Pin::new(&mut io).poll_flush(cx)? {
97 Poll::Ready(()) => {}
98 Poll::Pending => {
99 *this.state = State::FlushProtocol { io, protocol };
100 return Poll::Pending;
101 }
102 };
103 *this.state = State::AwaitProtocol { io, protocol };
105 }
106 State::AwaitProtocol { mut io, protocol } => {
107 let msg = match Pin::new(&mut io).poll_next(cx)? {
108 Poll::Ready(Some(msg)) => msg,
109 Poll::Ready(None) => {
110 tracing::debug!("No message received, connection closed");
111 return Poll::Ready(Err(NegotiationError::Failed));
112 }
113 Poll::Pending => {
114 *this.state = State::AwaitProtocol { io, protocol };
115 return Poll::Pending;
116 }
117 };
118 match msg {
119 Message::Protocol(p) if p.as_ref() == protocol.as_ref() => {
120 let io = Negotiated::completed(io.into_inner());
122 return Poll::Ready(Ok((protocol, io)));
123 }
124 Message::NotAvailable => {
125 tracing::debug!("Protocol not available, trying next protocol");
127 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
128 *this.state = State::SendProtocol { io, protocol }
129 }
130 _ => {
131 *this.state = State::Initial { io };
133 }
134 }
135 }
136 _ => panic!("Unexpected state in DialerSelectFuture"),
137 }
138 }
139 }
140}