Skip to main content

volans_stream_select/
dialer_select.rs

1use futures::{AsyncRead, AsyncWrite, Sink, Stream};
2
3use crate::{
4    Negotiated, NegotiationError,
5    protocol::{Message, MessageIO, Protocol},
6};
7use std::{
8    iter, mem,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13#[pin_project::pin_project]
14pub struct DialerSelectFuture<R, I: Iterator> {
15    protocols: iter::Peekable<I>,
16    state: State<R, I::Item>,
17    lazy: bool,
18}
19
20impl<R, I> DialerSelectFuture<R, I>
21where
22    R: AsyncRead + AsyncWrite,
23    I: Iterator,
24    I::Item: AsRef<str>,
25{
26    pub fn new(io: R, protocols: I) -> Self {
27        DialerSelectFuture {
28            protocols: protocols.peekable(),
29            state: State::Initial {
30                io: MessageIO::new(io),
31            },
32            lazy: false,
33        }
34    }
35}
36
37enum State<R, P> {
38    Initial { io: MessageIO<R> },
39    SendProtocol { io: MessageIO<R>, protocol: P },
40    FlushProtocol { io: MessageIO<R>, protocol: P },
41    AwaitProtocol { io: MessageIO<R>, protocol: P },
42    Done,
43}
44
45impl<R, I> Future for DialerSelectFuture<R, I>
46where
47    R: AsyncRead + AsyncWrite + Unpin,
48    I: Iterator,
49    I::Item: AsRef<str>,
50{
51    type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
52
53    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54        let this = self.project();
55        loop {
56            match mem::replace(this.state, State::Done) {
57                State::Initial { mut io } => {
58                    match Pin::new(&mut io).poll_ready(cx)? {
59                        Poll::Ready(()) => {}
60                        Poll::Pending => {
61                            *this.state = State::Initial { io };
62                            return Poll::Pending;
63                        }
64                    };
65                    let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
66                    *this.state = State::SendProtocol { io, protocol };
67                }
68                State::SendProtocol { mut io, protocol } => {
69                    tracing::trace!("Sending protocol: {}", protocol.as_ref());
70                    match Pin::new(&mut io).poll_ready(cx)? {
71                        Poll::Ready(()) => {}
72                        Poll::Pending => {
73                            *this.state = State::SendProtocol { io, protocol };
74                            return Poll::Pending;
75                        }
76                    };
77                    let p = Protocol::try_from(protocol.as_ref())?;
78                    // 发送协议到 IO
79                    if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
80                        return Poll::Ready(Err(From::from(err)));
81                    }
82                    if this.protocols.peek().is_some() {
83                        // 如果还有更多协议,进入发送协议状态
84                        *this.state = State::FlushProtocol { io, protocol };
85                    } else if *this.lazy {
86                        // 如果没有更多协议,直接进入等待状态
87                        tracing::trace!("Expecting protocol: {}", p.as_ref());
88                        let io = Negotiated::expecting(io.into_reader(), p);
89                        return Poll::Ready(Ok((protocol, io)));
90                    } else {
91                        // 如果没有更多协议,进入等待状态
92                        *this.state = State::FlushProtocol { io, protocol };
93                    }
94                }
95                State::FlushProtocol { mut io, protocol } => {
96                    match Pin::new(&mut io).poll_flush(cx)? {
97                        Poll::Ready(()) => {}
98                        Poll::Pending => {
99                            *this.state = State::FlushProtocol { io, protocol };
100                            return Poll::Pending;
101                        }
102                    };
103                    // 进入等待状态
104                    *this.state = State::AwaitProtocol { io, protocol };
105                }
106                State::AwaitProtocol { mut io, protocol } => {
107                    let msg = match Pin::new(&mut io).poll_next(cx)? {
108                        Poll::Ready(Some(msg)) => msg,
109                        Poll::Ready(None) => {
110                            tracing::debug!("No message received, connection closed");
111                            return Poll::Ready(Err(NegotiationError::Failed));
112                        }
113                        Poll::Pending => {
114                            *this.state = State::AwaitProtocol { io, protocol };
115                            return Poll::Pending;
116                        }
117                    };
118                    match msg {
119                        Message::Protocol(p) if p.as_ref() == protocol.as_ref() => {
120                            // 协议匹配成功,返回 Negotiated
121                            let io = Negotiated::completed(io.into_inner());
122                            return Poll::Ready(Ok((protocol, io)));
123                        }
124                        Message::NotAvailable => {
125                            // 不支持的协议,继续协商下一个协议
126                            tracing::debug!("Protocol not available, trying next protocol");
127                            let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
128                            *this.state = State::SendProtocol { io, protocol }
129                        }
130                        _ => {
131                            // 协议不匹配,继续等待下一个协议
132                            *this.state = State::Initial { io };
133                        }
134                    }
135                }
136                _ => panic!("Unexpected state in DialerSelectFuture"),
137            }
138        }
139    }
140}