1use std::{
2 marker::PhantomData,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use futures::{AsyncRead, AsyncWrite, TryFuture, future, ready};
8
9use crate::{
10 Listener, ListenerEvent, Multiaddr, Negotiated, Transport, TransportError,
11 upgrade::{
12 InboundConnectionUpgrade, InboundUpgradeApply, OutboundConnectionUpgrade,
13 OutboundUpgradeApply, UpgradeError,
14 },
15};
16
17#[derive(Debug, Copy, Clone)]
18pub struct UpgradeApply<T, U> {
19 transport: T,
20 upgrade: U,
21}
22
23impl<T, U> UpgradeApply<T, U> {
24 pub(crate) fn new(transport: T, upgrade: U) -> Self {
25 UpgradeApply { transport, upgrade }
26 }
27}
28
29impl<T, C, D, E, U> Transport for UpgradeApply<T, U>
30where
31 T: Transport<Output = C>,
32 C: AsyncRead + AsyncWrite + Unpin,
33 U: Clone,
34 U: InboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E>,
35 U: OutboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E>,
36 E: std::error::Error,
37{
38 type Output = D;
39 type Error = UpgradeApplyError<T::Error, E>;
40 type Dial = DialUpgradeFuture<T::Dial, U, C>;
41 type Incoming = ListenerUpgradeFuture<T::Incoming, U, C>;
42 type Listener = UpgradeApplyListener<T, U>;
43
44 fn dial(&self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
45 let fut = self
46 .transport
47 .dial(addr)
48 .map_err(|e| e.map(UpgradeApplyError::Transport))?;
49 Ok(DialUpgradeFuture {
50 future: Box::pin(fut),
51 upgrade: future::Either::Left(Some(self.upgrade.clone())),
52 })
53 }
54
55 fn listen(&self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
56 let inner = self
57 .transport
58 .listen(addr)
59 .map_err(|e| e.map(UpgradeApplyError::Transport))?;
60
61 Ok(UpgradeApplyListener {
62 inner,
63 upgrade: self.upgrade.clone(),
64 _phantom: PhantomData,
65 })
66 }
67}
68
69#[pin_project::pin_project]
70#[derive(Clone, Debug)]
71pub struct UpgradeApplyListener<T, U>
72where
73 T: Transport,
74{
75 #[pin]
76 inner: T::Listener,
77 upgrade: U,
78 _phantom: PhantomData<T>,
79}
80
81impl<T, U, C, D> Listener for UpgradeApplyListener<T, U>
82where
83 T: Transport<Output = C>,
84 U: Clone,
85 C: AsyncRead + AsyncWrite + Unpin,
86 U: InboundConnectionUpgrade<Negotiated<C>, Output = D>,
87 U::Error: std::error::Error,
88{
89 type Output = D;
90 type Error = UpgradeApplyError<T::Error, U::Error>;
91 type Upgrade = ListenerUpgradeFuture<T::Incoming, U, C>;
92
93 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94 let this = self.project();
95 this.inner
96 .poll_close(cx)
97 .map_err(UpgradeApplyError::Transport)
98 }
99
100 fn poll_event(
101 self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 ) -> Poll<ListenerEvent<Self::Upgrade, Self::Error>> {
104 let this = self.project();
105 this.inner.poll_event(cx).map(|event| {
106 event
107 .map_upgrade(move |u| ListenerUpgradeFuture {
108 future: Box::pin(u),
109 upgrade: future::Either::Left(Some(this.upgrade.clone())),
110 })
111 .map_err(UpgradeApplyError::Transport)
112 })
113 }
114}
115
116#[derive(Debug, thiserror::Error)]
117pub enum UpgradeApplyError<TErr, TUpgrErr> {
118 #[error("Transport error: {0}")]
119 Transport(TErr),
120 #[error("Upgrade error: {0}")]
121 Upgrade(UpgradeError<TUpgrErr>),
122}
123
124pub struct ListenerUpgradeFuture<F, U, C>
125where
126 U: InboundConnectionUpgrade<Negotiated<C>>,
127 C: AsyncRead + AsyncWrite + Unpin,
128{
129 future: Pin<Box<F>>,
130 upgrade: future::Either<Option<U>, InboundUpgradeApply<C, U>>,
131}
132
133impl<F, U, C> Unpin for ListenerUpgradeFuture<F, U, C>
134where
135 U: InboundConnectionUpgrade<Negotiated<C>>,
136 C: AsyncRead + AsyncWrite + Unpin,
137{
138}
139
140impl<F, U, C, D> Future for ListenerUpgradeFuture<F, U, C>
141where
142 F: TryFuture<Ok = C>,
143 C: AsyncRead + AsyncWrite + Unpin,
144 U: InboundConnectionUpgrade<Negotiated<C>, Output = D>,
145 U::Error: std::error::Error,
146{
147 type Output = Result<D, UpgradeApplyError<F::Error, U::Error>>;
148
149 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150 let this = &mut *self;
151 loop {
152 this.upgrade = match this.upgrade {
153 future::Either::Left(ref mut upgrade) => {
154 let c = ready!(
155 TryFuture::try_poll(this.future.as_mut(), cx)
156 .map_err(UpgradeApplyError::Transport)?
157 );
158 let u = upgrade
159 .take()
160 .expect("ListenerUpgradeFuture should have upgrade set");
161 future::Either::Right(InboundUpgradeApply::new(c, u))
163 }
164 future::Either::Right(ref mut upgrade) => {
165 let res = ready!(
166 Future::poll(Pin::new(upgrade), cx).map_err(UpgradeApplyError::Upgrade)?
167 );
168 return Poll::Ready(Ok(res));
169 }
170 }
171 }
172 }
173}
174
175pub struct DialUpgradeFuture<F, U, C>
176where
177 U: OutboundConnectionUpgrade<Negotiated<C>>,
178 C: AsyncRead + AsyncWrite + Unpin,
179{
180 future: Pin<Box<F>>,
181 upgrade: future::Either<Option<U>, OutboundUpgradeApply<C, U>>,
182}
183
184impl<F, U, C> Unpin for DialUpgradeFuture<F, U, C>
185where
186 U: OutboundConnectionUpgrade<Negotiated<C>>,
187 C: AsyncRead + AsyncWrite + Unpin,
188{
189}
190
191impl<F, U, C, D> Future for DialUpgradeFuture<F, U, C>
192where
193 F: TryFuture<Ok = C>,
194 C: AsyncRead + AsyncWrite + Unpin,
195 U: OutboundConnectionUpgrade<Negotiated<C>, Output = D>,
196 U::Error: std::error::Error,
197{
198 type Output = Result<D, UpgradeApplyError<F::Error, U::Error>>;
199
200 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
201 let this = &mut *self;
202 loop {
203 this.upgrade = match this.upgrade {
204 future::Either::Left(ref mut upgrade) => {
206 let c = ready!(
207 TryFuture::try_poll(this.future.as_mut(), cx)
208 .map_err(UpgradeApplyError::Transport)?
209 );
210 let u = upgrade
211 .take()
212 .expect("DialUpgradeFuture should have upgrade set");
213 future::Either::Right(OutboundUpgradeApply::new(c, u))
215 }
216 future::Either::Right(ref mut upgrade) => {
217 let res = ready!(
218 TryFuture::try_poll(Pin::new(upgrade), cx)
219 .map_err(UpgradeApplyError::Upgrade)?
220 );
221 return Poll::Ready(Ok(res));
222 }
223 }
224 }
225 }
226}