Skip to main content

volans_core/transport/
apply.rs

1use std::{
2    marker::PhantomData,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use futures::{AsyncRead, AsyncWrite, TryFuture, future, ready};
8
9use crate::{
10    Listener, ListenerEvent, Multiaddr, Negotiated, Transport, TransportError,
11    upgrade::{
12        InboundConnectionUpgrade, InboundUpgradeApply, OutboundConnectionUpgrade,
13        OutboundUpgradeApply, UpgradeError,
14    },
15};
16
17#[derive(Debug, Copy, Clone)]
18pub struct UpgradeApply<T, U> {
19    transport: T,
20    upgrade: U,
21}
22
23impl<T, U> UpgradeApply<T, U> {
24    pub(crate) fn new(transport: T, upgrade: U) -> Self {
25        UpgradeApply { transport, upgrade }
26    }
27}
28
29impl<T, C, D, E, U> Transport for UpgradeApply<T, U>
30where
31    T: Transport<Output = C>,
32    C: AsyncRead + AsyncWrite + Unpin,
33    U: Clone,
34    U: InboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E>,
35    U: OutboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E>,
36    E: std::error::Error,
37{
38    type Output = D;
39    type Error = UpgradeApplyError<T::Error, E>;
40    type Dial = DialUpgradeFuture<T::Dial, U, C>;
41    type Incoming = ListenerUpgradeFuture<T::Incoming, U, C>;
42    type Listener = UpgradeApplyListener<T, U>;
43
44    fn dial(&self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
45        let fut = self
46            .transport
47            .dial(addr)
48            .map_err(|e| e.map(UpgradeApplyError::Transport))?;
49        Ok(DialUpgradeFuture {
50            future: Box::pin(fut),
51            upgrade: future::Either::Left(Some(self.upgrade.clone())),
52        })
53    }
54
55    fn listen(&self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
56        let inner = self
57            .transport
58            .listen(addr)
59            .map_err(|e| e.map(UpgradeApplyError::Transport))?;
60
61        Ok(UpgradeApplyListener {
62            inner,
63            upgrade: self.upgrade.clone(),
64            _phantom: PhantomData,
65        })
66    }
67}
68
69#[pin_project::pin_project]
70#[derive(Clone, Debug)]
71pub struct UpgradeApplyListener<T, U>
72where
73    T: Transport,
74{
75    #[pin]
76    inner: T::Listener,
77    upgrade: U,
78    _phantom: PhantomData<T>,
79}
80
81impl<T, U, C, D> Listener for UpgradeApplyListener<T, U>
82where
83    T: Transport<Output = C>,
84    U: Clone,
85    C: AsyncRead + AsyncWrite + Unpin,
86    U: InboundConnectionUpgrade<Negotiated<C>, Output = D>,
87    U::Error: std::error::Error,
88{
89    type Output = D;
90    type Error = UpgradeApplyError<T::Error, U::Error>;
91    type Upgrade = ListenerUpgradeFuture<T::Incoming, U, C>;
92
93    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94        let this = self.project();
95        this.inner
96            .poll_close(cx)
97            .map_err(UpgradeApplyError::Transport)
98    }
99
100    fn poll_event(
101        self: Pin<&mut Self>,
102        cx: &mut Context<'_>,
103    ) -> Poll<ListenerEvent<Self::Upgrade, Self::Error>> {
104        let this = self.project();
105        this.inner.poll_event(cx).map(|event| {
106            event
107                .map_upgrade(move |u| ListenerUpgradeFuture {
108                    future: Box::pin(u),
109                    upgrade: future::Either::Left(Some(this.upgrade.clone())),
110                })
111                .map_err(UpgradeApplyError::Transport)
112        })
113    }
114}
115
116#[derive(Debug, thiserror::Error)]
117pub enum UpgradeApplyError<TErr, TUpgrErr> {
118    #[error("Transport error: {0}")]
119    Transport(TErr),
120    #[error("Upgrade error: {0}")]
121    Upgrade(UpgradeError<TUpgrErr>),
122}
123
124pub struct ListenerUpgradeFuture<F, U, C>
125where
126    U: InboundConnectionUpgrade<Negotiated<C>>,
127    C: AsyncRead + AsyncWrite + Unpin,
128{
129    future: Pin<Box<F>>,
130    upgrade: future::Either<Option<U>, InboundUpgradeApply<C, U>>,
131}
132
133impl<F, U, C> Unpin for ListenerUpgradeFuture<F, U, C>
134where
135    U: InboundConnectionUpgrade<Negotiated<C>>,
136    C: AsyncRead + AsyncWrite + Unpin,
137{
138}
139
140impl<F, U, C, D> Future for ListenerUpgradeFuture<F, U, C>
141where
142    F: TryFuture<Ok = C>,
143    C: AsyncRead + AsyncWrite + Unpin,
144    U: InboundConnectionUpgrade<Negotiated<C>, Output = D>,
145    U::Error: std::error::Error,
146{
147    type Output = Result<D, UpgradeApplyError<F::Error, U::Error>>;
148
149    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150        let this = &mut *self;
151        loop {
152            this.upgrade = match this.upgrade {
153                future::Either::Left(ref mut upgrade) => {
154                    let c = ready!(
155                        TryFuture::try_poll(this.future.as_mut(), cx)
156                            .map_err(UpgradeApplyError::Transport)?
157                    );
158                    let u = upgrade
159                        .take()
160                        .expect("ListenerUpgradeFuture should have upgrade set");
161                    // 使用 InboundUpgradeApply to apply the upgrade
162                    future::Either::Right(InboundUpgradeApply::new(c, u))
163                }
164                future::Either::Right(ref mut upgrade) => {
165                    let res = ready!(
166                        Future::poll(Pin::new(upgrade), cx).map_err(UpgradeApplyError::Upgrade)?
167                    );
168                    return Poll::Ready(Ok(res));
169                }
170            }
171        }
172    }
173}
174
175pub struct DialUpgradeFuture<F, U, C>
176where
177    U: OutboundConnectionUpgrade<Negotiated<C>>,
178    C: AsyncRead + AsyncWrite + Unpin,
179{
180    future: Pin<Box<F>>,
181    upgrade: future::Either<Option<U>, OutboundUpgradeApply<C, U>>,
182}
183
184impl<F, U, C> Unpin for DialUpgradeFuture<F, U, C>
185where
186    U: OutboundConnectionUpgrade<Negotiated<C>>,
187    C: AsyncRead + AsyncWrite + Unpin,
188{
189}
190
191impl<F, U, C, D> Future for DialUpgradeFuture<F, U, C>
192where
193    F: TryFuture<Ok = C>,
194    C: AsyncRead + AsyncWrite + Unpin,
195    U: OutboundConnectionUpgrade<Negotiated<C>, Output = D>,
196    U::Error: std::error::Error,
197{
198    type Output = Result<D, UpgradeApplyError<F::Error, U::Error>>;
199
200    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
201        let this = &mut *self;
202        loop {
203            this.upgrade = match this.upgrade {
204                // 上层 Upgrade
205                future::Either::Left(ref mut upgrade) => {
206                    let c = ready!(
207                        TryFuture::try_poll(this.future.as_mut(), cx)
208                            .map_err(UpgradeApplyError::Transport)?
209                    );
210                    let u = upgrade
211                        .take()
212                        .expect("DialUpgradeFuture should have upgrade set");
213                    // 使用 OutboundUpgradeApply to apply the upgrade
214                    future::Either::Right(OutboundUpgradeApply::new(c, u))
215                }
216                future::Either::Right(ref mut upgrade) => {
217                    let res = ready!(
218                        TryFuture::try_poll(Pin::new(upgrade), cx)
219                            .map_err(UpgradeApplyError::Upgrade)?
220                    );
221                    return Poll::Ready(Ok(res));
222                }
223            }
224        }
225    }
226}