volans_swarm/handler/
select.rs

1use std::{
2    cmp,
3    task::{Context, Poll},
4};
5
6use either::Either;
7use futures::{future, ready};
8use volans_core::upgrade::SelectUpgrade;
9
10use crate::{
11    ConnectionHandler, ConnectionHandlerEvent, InboundStreamHandler, InboundUpgradeSend,
12    OutboundStreamHandler, OutboundUpgradeSend, StreamUpgradeError, SubstreamProtocol,
13    upgrade::SendWrapper,
14};
15
16#[derive(Debug, Clone)]
17pub struct ConnectionHandlerSelect<THandler1, THandler2> {
18    first: THandler1,
19    second: THandler2,
20}
21
22impl<THandler1, THandler2> ConnectionHandlerSelect<THandler1, THandler2> {
23    pub fn select(first: THandler1, second: THandler2) -> Self {
24        Self { first, second }
25    }
26}
27
28impl<THandler1, THandler2> ConnectionHandler for ConnectionHandlerSelect<THandler1, THandler2>
29where
30    THandler1: ConnectionHandler,
31    THandler2: ConnectionHandler,
32{
33    type Action = Either<THandler1::Action, THandler2::Action>;
34    type Event = Either<THandler1::Event, THandler2::Event>;
35
36    fn handle_action(&mut self, action: Self::Action) {
37        match action {
38            Either::Left(action) => self.first.handle_action(action),
39            Either::Right(action) => self.second.handle_action(action),
40        }
41    }
42
43    fn connection_keep_alive(&self) -> bool {
44        self.first.connection_keep_alive() || self.second.connection_keep_alive()
45    }
46
47    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::Event>> {
48        if let Some(e) = ready!(self.first.poll_close(cx)) {
49            return Poll::Ready(Some(Either::Left(e)));
50        }
51
52        if let Some(e) = ready!(self.second.poll_close(cx)) {
53            return Poll::Ready(Some(Either::Right(e)));
54        }
55
56        Poll::Ready(None)
57    }
58
59    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::Event>> {
60        match self.first.poll(cx) {
61            Poll::Ready(event) => return Poll::Ready(event.map_event(Either::Left)),
62            Poll::Pending => {}
63        };
64        match self.second.poll(cx) {
65            Poll::Ready(event) => return Poll::Ready(event.map_event(Either::Right)),
66            Poll::Pending => {}
67        };
68        Poll::Pending
69    }
70}
71
72impl<THandler1, THandler2> InboundStreamHandler for ConnectionHandlerSelect<THandler1, THandler2>
73where
74    THandler1: InboundStreamHandler,
75    THandler2: InboundStreamHandler,
76{
77    type InboundUpgrade = SelectUpgrade<
78        SendWrapper<THandler1::InboundUpgrade>,
79        SendWrapper<THandler2::InboundUpgrade>,
80    >;
81    type InboundUserData = (THandler1::InboundUserData, THandler2::InboundUserData);
82
83    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundUpgrade, Self::InboundUserData> {
84        let first = self.first.listen_protocol();
85        let second = self.second.listen_protocol();
86        let (upgrade1, info1, timeout1) = first.into_inner();
87        let (upgrade2, info2, timeout2) = second.into_inner();
88        let timeout = cmp::max(timeout1, timeout2);
89        let choice = SelectUpgrade::new(SendWrapper(upgrade1), SendWrapper(upgrade2));
90        SubstreamProtocol::new(choice, (info1, info2)).with_timeout(timeout)
91    }
92
93    fn on_fully_negotiated(
94        &mut self,
95        user_data: Self::InboundUserData,
96        protocol: <Self::InboundUpgrade as InboundUpgradeSend>::Output,
97    ) {
98        match protocol {
99            future::Either::Left(output) => {
100                self.first.on_fully_negotiated(user_data.0, output);
101            }
102            future::Either::Right(output) => {
103                self.second.on_fully_negotiated(user_data.1, output);
104            }
105        }
106    }
107
108    fn on_upgrade_error(
109        &mut self,
110        user_data: Self::InboundUserData,
111        error: <Self::InboundUpgrade as InboundUpgradeSend>::Error,
112    ) {
113        match error {
114            Either::Left(err) => {
115                self.first.on_upgrade_error(user_data.0, err);
116            }
117            Either::Right(err) => {
118                self.second.on_upgrade_error(user_data.1, err);
119            }
120        }
121    }
122}
123
124impl<THandler1, THandler2> OutboundStreamHandler for ConnectionHandlerSelect<THandler1, THandler2>
125where
126    THandler1: OutboundStreamHandler,
127    THandler2: OutboundStreamHandler,
128{
129    type OutboundUpgrade =
130        Either<SendWrapper<THandler1::OutboundUpgrade>, SendWrapper<THandler2::OutboundUpgrade>>;
131    type OutboundUserData = Either<THandler1::OutboundUserData, THandler2::OutboundUserData>;
132
133    fn on_fully_negotiated(
134        &mut self,
135        user_data: Self::OutboundUserData,
136        protocol: <Self::OutboundUpgrade as OutboundUpgradeSend>::Output,
137    ) {
138        match protocol {
139            future::Either::Left(output) => {
140                self.first.on_fully_negotiated(
141                    user_data.left().expect("Dial left info must be present"),
142                    output,
143                );
144            }
145            future::Either::Right(output) => {
146                self.second.on_fully_negotiated(
147                    user_data.right().expect("Dial right info must be present"),
148                    output,
149                );
150            }
151        }
152    }
153
154    fn on_upgrade_error(
155        &mut self,
156        user_data: Self::OutboundUserData,
157        error: StreamUpgradeError<<Self::OutboundUpgrade as OutboundUpgradeSend>::Error>,
158    ) {
159        match user_data {
160            Either::Left(data) => {
161                let error =
162                    error.map_upgrade_err(|e| e.left().expect("Left error must be present"));
163                self.first.on_upgrade_error(data, error);
164            }
165            Either::Right(data) => {
166                let error =
167                    error.map_upgrade_err(|e| e.right().expect("Right error must be present"));
168                self.second.on_upgrade_error(data, error);
169            }
170        }
171    }
172
173    fn poll_outbound_request(
174        &mut self,
175        cx: &mut Context<'_>,
176    ) -> Poll<SubstreamProtocol<Self::OutboundUpgrade, Self::OutboundUserData>> {
177        match self.first.poll_outbound_request(cx) {
178            Poll::Ready(protocol) => {
179                return Poll::Ready(
180                    protocol
181                        .map_upgrade(|u| Either::Left(SendWrapper(u)))
182                        .map_user_data(Either::Left),
183                );
184            }
185            Poll::Pending => {}
186        }
187
188        match self.second.poll_outbound_request(cx) {
189            Poll::Ready(protocol) => {
190                return Poll::Ready(
191                    protocol
192                        .map_upgrade(|u| Either::Right(SendWrapper(u)))
193                        .map_user_data(Either::Right),
194                );
195            }
196            Poll::Pending => {}
197        }
198
199        Poll::Pending
200    }
201}