1use futures::{AsyncRead, AsyncWrite, future};
2use std::{
3 mem,
4 pin::Pin,
5 task::{Context, Poll},
6};
7use volans_stream_select::{DialerSelectFuture, ListenerSelectFuture};
8
9use crate::{
10 ConnectedPoint, Negotiated,
11 upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeError},
12};
13
14pub fn apply<C, U>(
15 socket: C,
16 upgrade: U,
17 connected_point: ConnectedPoint,
18) -> future::Either<InboundUpgradeApply<C, U>, OutboundUpgradeApply<C, U>>
19where
20 C: AsyncRead + AsyncWrite + Unpin,
21 U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
22{
23 match connected_point {
24 ConnectedPoint::Dialer { .. } => {
25 future::Either::Right(OutboundUpgradeApply::new(socket, upgrade))
26 }
27 _ => future::Either::Left(InboundUpgradeApply::new(socket, upgrade)),
28 }
29}
30
31pub struct InboundUpgradeApply<C, U>
32where
33 C: AsyncRead + AsyncWrite + Unpin,
34 U: InboundConnectionUpgrade<Negotiated<C>>,
35{
36 inner: InboundUpgradeApplyState<C, U>,
37}
38
39#[allow(clippy::large_enum_variant)]
40enum InboundUpgradeApplyState<C, U>
41where
42 C: AsyncRead + AsyncWrite + Unpin,
43 U: InboundConnectionUpgrade<Negotiated<C>>,
44{
45 Init {
46 future: ListenerSelectFuture<C, U::Info>,
47 upgrade: U,
48 },
49 Upgrade {
50 future: Pin<Box<U::Future>>,
51 name: String,
52 },
53 Undefined,
54}
55
56impl<C, U> InboundUpgradeApply<C, U>
57where
58 C: AsyncRead + AsyncWrite + Unpin,
59 U: InboundConnectionUpgrade<Negotiated<C>>,
60{
61 pub fn new(socket: C, upgrade: U) -> Self {
62 let future = ListenerSelectFuture::new(socket, upgrade.protocol_info());
63 Self {
64 inner: InboundUpgradeApplyState::Init { future, upgrade },
65 }
66 }
67}
68
69impl<C, U> Unpin for InboundUpgradeApply<C, U>
70where
71 C: AsyncRead + AsyncWrite + Unpin,
72 U: InboundConnectionUpgrade<Negotiated<C>>,
73{
74}
75
76impl<C, U> Future for InboundUpgradeApply<C, U>
77where
78 C: AsyncRead + AsyncWrite + Unpin,
79 U: InboundConnectionUpgrade<Negotiated<C>>,
80{
81 type Output = Result<U::Output, UpgradeError<U::Error>>;
82
83 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84 loop {
85 match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) {
86 InboundUpgradeApplyState::Init {
87 mut future,
88 upgrade,
89 } => {
90 let (info, io) = match Future::poll(Pin::new(&mut future), cx)? {
91 Poll::Ready(x) => x,
92 Poll::Pending => {
93 self.inner = InboundUpgradeApplyState::Init { future, upgrade };
94 return Poll::Pending;
95 }
96 };
97 self.inner = InboundUpgradeApplyState::Upgrade {
98 future: Box::pin(upgrade.upgrade_inbound(io, info.clone())),
99 name: info.as_ref().to_owned(),
100 };
101 }
102 InboundUpgradeApplyState::Upgrade { mut future, name } => {
103 match Future::poll(Pin::new(&mut future), cx) {
104 Poll::Pending => {
105 self.inner = InboundUpgradeApplyState::Upgrade { future, name };
106 return Poll::Pending;
107 }
108 Poll::Ready(Ok(x)) => {
109 tracing::trace!(upgrade=%name, "Upgraded inbound stream");
110 return Poll::Ready(Ok(x));
111 }
112 Poll::Ready(Err(e)) => {
113 tracing::debug!(upgrade=%name, "Failed to upgrade inbound stream");
114 return Poll::Ready(Err(UpgradeError::Apply(e)));
115 }
116 }
117 }
118 InboundUpgradeApplyState::Undefined => {
119 panic!("InboundUpgradeApplyState::poll called after completion")
120 }
121 }
122 }
123 }
124}
125
126pub struct OutboundUpgradeApply<C, U>
128where
129 C: AsyncRead + AsyncWrite + Unpin,
130 U: OutboundConnectionUpgrade<Negotiated<C>>,
131{
132 inner: OutboundUpgradeApplyState<C, U>,
133}
134
135impl<C, U> OutboundUpgradeApply<C, U>
136where
137 C: AsyncRead + AsyncWrite + Unpin,
138 U: OutboundConnectionUpgrade<Negotiated<C>>,
139{
140 pub fn new(socket: C, upgrade: U) -> Self {
141 let future = DialerSelectFuture::new(socket, upgrade.protocol_info());
142 Self {
143 inner: OutboundUpgradeApplyState::Init { future, upgrade },
144 }
145 }
146}
147
148enum OutboundUpgradeApplyState<C, U>
149where
150 C: AsyncRead + AsyncWrite + Unpin,
151 U: OutboundConnectionUpgrade<Negotiated<C>>,
152{
153 Init {
154 future: DialerSelectFuture<C, <U::InfoIter as IntoIterator>::IntoIter>,
155 upgrade: U,
156 },
157 Upgrade {
158 future: Pin<Box<U::Future>>,
159 name: String,
160 },
161 Undefined,
162}
163
164impl<C, U> Unpin for OutboundUpgradeApply<C, U>
165where
166 C: AsyncRead + AsyncWrite + Unpin,
167 U: OutboundConnectionUpgrade<Negotiated<C>>,
168{
169}
170
171impl<C, U> Future for OutboundUpgradeApply<C, U>
172where
173 C: AsyncRead + AsyncWrite + Unpin,
174 U: OutboundConnectionUpgrade<Negotiated<C>>,
175{
176 type Output = Result<U::Output, UpgradeError<U::Error>>;
177
178 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179 loop {
180 match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) {
181 OutboundUpgradeApplyState::Init {
182 mut future,
183 upgrade,
184 } => {
185 let (info, connection) = match Future::poll(Pin::new(&mut future), cx)? {
186 Poll::Ready(x) => x,
187 Poll::Pending => {
188 self.inner = OutboundUpgradeApplyState::Init { future, upgrade };
189 return Poll::Pending;
190 }
191 };
192 self.inner = OutboundUpgradeApplyState::Upgrade {
193 future: Box::pin(upgrade.upgrade_outbound(connection, info.clone())),
194 name: info.as_ref().to_owned(),
195 };
196 }
197 OutboundUpgradeApplyState::Upgrade { mut future, name } => {
198 match Future::poll(Pin::new(&mut future), cx) {
199 Poll::Pending => {
200 self.inner = OutboundUpgradeApplyState::Upgrade { future, name };
201 return Poll::Pending;
202 }
203 Poll::Ready(Ok(x)) => {
204 tracing::trace!(upgrade=%name, "Upgraded outbound stream");
205 return Poll::Ready(Ok(x));
206 }
207 Poll::Ready(Err(e)) => {
208 tracing::debug!(upgrade=%name, "Failed to upgrade outbound stream",);
209 return Poll::Ready(Err(UpgradeError::Apply(e)));
210 }
211 }
212 }
213 OutboundUpgradeApplyState::Undefined => {
214 panic!("OutboundUpgradeApplyState::poll called after completion")
215 }
216 }
217 }
218 }
219}