Skip to main content

volans_core/upgrade/
apply.rs

1use futures::{AsyncRead, AsyncWrite, future};
2use std::{
3    mem,
4    pin::Pin,
5    task::{Context, Poll},
6};
7use volans_stream_select::{DialerSelectFuture, ListenerSelectFuture};
8
9use crate::{
10    ConnectedPoint, Negotiated,
11    upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeError},
12};
13
14pub fn apply<C, U>(
15    socket: C,
16    upgrade: U,
17    connected_point: ConnectedPoint,
18) -> future::Either<InboundUpgradeApply<C, U>, OutboundUpgradeApply<C, U>>
19where
20    C: AsyncRead + AsyncWrite + Unpin,
21    U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
22{
23    match connected_point {
24        ConnectedPoint::Dialer { .. } => {
25            future::Either::Right(OutboundUpgradeApply::new(socket, upgrade))
26        }
27        _ => future::Either::Left(InboundUpgradeApply::new(socket, upgrade)),
28    }
29}
30
31pub struct InboundUpgradeApply<C, U>
32where
33    C: AsyncRead + AsyncWrite + Unpin,
34    U: InboundConnectionUpgrade<Negotiated<C>>,
35{
36    inner: InboundUpgradeApplyState<C, U>,
37}
38
39#[allow(clippy::large_enum_variant)]
40enum InboundUpgradeApplyState<C, U>
41where
42    C: AsyncRead + AsyncWrite + Unpin,
43    U: InboundConnectionUpgrade<Negotiated<C>>,
44{
45    Init {
46        future: ListenerSelectFuture<C, U::Info>,
47        upgrade: U,
48    },
49    Upgrade {
50        future: Pin<Box<U::Future>>,
51        name: String,
52    },
53    Undefined,
54}
55
56impl<C, U> InboundUpgradeApply<C, U>
57where
58    C: AsyncRead + AsyncWrite + Unpin,
59    U: InboundConnectionUpgrade<Negotiated<C>>,
60{
61    pub fn new(socket: C, upgrade: U) -> Self {
62        let future = ListenerSelectFuture::new(socket, upgrade.protocol_info());
63        Self {
64            inner: InboundUpgradeApplyState::Init { future, upgrade },
65        }
66    }
67}
68
69impl<C, U> Unpin for InboundUpgradeApply<C, U>
70where
71    C: AsyncRead + AsyncWrite + Unpin,
72    U: InboundConnectionUpgrade<Negotiated<C>>,
73{
74}
75
76impl<C, U> Future for InboundUpgradeApply<C, U>
77where
78    C: AsyncRead + AsyncWrite + Unpin,
79    U: InboundConnectionUpgrade<Negotiated<C>>,
80{
81    type Output = Result<U::Output, UpgradeError<U::Error>>;
82
83    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84        loop {
85            match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) {
86                InboundUpgradeApplyState::Init {
87                    mut future,
88                    upgrade,
89                } => {
90                    let (info, io) = match Future::poll(Pin::new(&mut future), cx)? {
91                        Poll::Ready(x) => x,
92                        Poll::Pending => {
93                            self.inner = InboundUpgradeApplyState::Init { future, upgrade };
94                            return Poll::Pending;
95                        }
96                    };
97                    self.inner = InboundUpgradeApplyState::Upgrade {
98                        future: Box::pin(upgrade.upgrade_inbound(io, info.clone())),
99                        name: info.as_ref().to_owned(),
100                    };
101                }
102                InboundUpgradeApplyState::Upgrade { mut future, name } => {
103                    match Future::poll(Pin::new(&mut future), cx) {
104                        Poll::Pending => {
105                            self.inner = InboundUpgradeApplyState::Upgrade { future, name };
106                            return Poll::Pending;
107                        }
108                        Poll::Ready(Ok(x)) => {
109                            tracing::trace!(upgrade=%name, "Upgraded inbound stream");
110                            return Poll::Ready(Ok(x));
111                        }
112                        Poll::Ready(Err(e)) => {
113                            tracing::debug!(upgrade=%name, "Failed to upgrade inbound stream");
114                            return Poll::Ready(Err(UpgradeError::Apply(e)));
115                        }
116                    }
117                }
118                InboundUpgradeApplyState::Undefined => {
119                    panic!("InboundUpgradeApplyState::poll called after completion")
120                }
121            }
122        }
123    }
124}
125
126/// Future returned by `apply_outbound`. Drives the upgrade process.
127pub struct OutboundUpgradeApply<C, U>
128where
129    C: AsyncRead + AsyncWrite + Unpin,
130    U: OutboundConnectionUpgrade<Negotiated<C>>,
131{
132    inner: OutboundUpgradeApplyState<C, U>,
133}
134
135impl<C, U> OutboundUpgradeApply<C, U>
136where
137    C: AsyncRead + AsyncWrite + Unpin,
138    U: OutboundConnectionUpgrade<Negotiated<C>>,
139{
140    pub fn new(socket: C, upgrade: U) -> Self {
141        let future = DialerSelectFuture::new(socket, upgrade.protocol_info());
142        Self {
143            inner: OutboundUpgradeApplyState::Init { future, upgrade },
144        }
145    }
146}
147
148enum OutboundUpgradeApplyState<C, U>
149where
150    C: AsyncRead + AsyncWrite + Unpin,
151    U: OutboundConnectionUpgrade<Negotiated<C>>,
152{
153    Init {
154        future: DialerSelectFuture<C, <U::InfoIter as IntoIterator>::IntoIter>,
155        upgrade: U,
156    },
157    Upgrade {
158        future: Pin<Box<U::Future>>,
159        name: String,
160    },
161    Undefined,
162}
163
164impl<C, U> Unpin for OutboundUpgradeApply<C, U>
165where
166    C: AsyncRead + AsyncWrite + Unpin,
167    U: OutboundConnectionUpgrade<Negotiated<C>>,
168{
169}
170
171impl<C, U> Future for OutboundUpgradeApply<C, U>
172where
173    C: AsyncRead + AsyncWrite + Unpin,
174    U: OutboundConnectionUpgrade<Negotiated<C>>,
175{
176    type Output = Result<U::Output, UpgradeError<U::Error>>;
177
178    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        loop {
180            match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) {
181                OutboundUpgradeApplyState::Init {
182                    mut future,
183                    upgrade,
184                } => {
185                    let (info, connection) = match Future::poll(Pin::new(&mut future), cx)? {
186                        Poll::Ready(x) => x,
187                        Poll::Pending => {
188                            self.inner = OutboundUpgradeApplyState::Init { future, upgrade };
189                            return Poll::Pending;
190                        }
191                    };
192                    self.inner = OutboundUpgradeApplyState::Upgrade {
193                        future: Box::pin(upgrade.upgrade_outbound(connection, info.clone())),
194                        name: info.as_ref().to_owned(),
195                    };
196                }
197                OutboundUpgradeApplyState::Upgrade { mut future, name } => {
198                    match Future::poll(Pin::new(&mut future), cx) {
199                        Poll::Pending => {
200                            self.inner = OutboundUpgradeApplyState::Upgrade { future, name };
201                            return Poll::Pending;
202                        }
203                        Poll::Ready(Ok(x)) => {
204                            tracing::trace!(upgrade=%name, "Upgraded outbound stream");
205                            return Poll::Ready(Ok(x));
206                        }
207                        Poll::Ready(Err(e)) => {
208                            tracing::debug!(upgrade=%name, "Failed to upgrade outbound stream",);
209                            return Poll::Ready(Err(UpgradeError::Apply(e)));
210                        }
211                    }
212                }
213                OutboundUpgradeApplyState::Undefined => {
214                    panic!("OutboundUpgradeApplyState::poll called after completion")
215                }
216            }
217        }
218    }
219}