1use crate::NegotiatedSubstream;
25use crate::protocols_handler::{
26 KeepAlive,
27 IntoProtocolsHandler,
28 ProtocolsHandler,
29 ProtocolsHandlerEvent,
30 ProtocolsHandlerUpgrErr,
31 SubstreamProtocol
32};
33use crate::upgrade::{
34 InboundUpgradeSend,
35 OutboundUpgradeSend,
36 UpgradeInfoSend
37};
38use futures::{future::BoxFuture, prelude::*};
39use tetsy_libp2p_core::{ConnectedPoint, Multiaddr, PeerId};
40use tetsy_libp2p_core::upgrade::{ProtocolName, UpgradeError, NegotiationError, ProtocolError};
41use rand::Rng;
42use std::{
43 cmp,
44 collections::{HashMap, HashSet},
45 error,
46 fmt,
47 hash::Hash,
48 iter::{self, FromIterator},
49 task::{Context, Poll},
50 time::Duration
51};
52
53#[derive(Clone)]
55pub struct MultiHandler<K, H> {
56 handlers: HashMap<K, H>
57}
58
59impl<K, H> fmt::Debug for MultiHandler<K, H>
60where
61 K: fmt::Debug + Eq + Hash,
62 H: fmt::Debug
63{
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("MultiHandler")
66 .field("handlers", &self.handlers)
67 .finish()
68 }
69}
70
71impl<K, H> MultiHandler<K, H>
72where
73 K: Hash + Eq,
74 H: ProtocolsHandler
75{
76 pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
80 where
81 I: IntoIterator<Item = (K, H)>
82 {
83 let m = MultiHandler { handlers: HashMap::from_iter(iter) };
84 uniq_proto_names(m.handlers.values().map(|h| h.listen_protocol().into_upgrade().0))?;
85 Ok(m)
86 }
87}
88
89impl<K, H> ProtocolsHandler for MultiHandler<K, H>
90where
91 K: Clone + Hash + Eq + Send + 'static,
92 H: ProtocolsHandler,
93 H::InboundProtocol: InboundUpgradeSend,
94 H::OutboundProtocol: OutboundUpgradeSend
95{
96 type InEvent = (K, <H as ProtocolsHandler>::InEvent);
97 type OutEvent = (K, <H as ProtocolsHandler>::OutEvent);
98 type Error = <H as ProtocolsHandler>::Error;
99 type InboundProtocol = Upgrade<K, <H as ProtocolsHandler>::InboundProtocol>;
100 type OutboundProtocol = <H as ProtocolsHandler>::OutboundProtocol;
101 type InboundOpenInfo = Info<K, <H as ProtocolsHandler>::InboundOpenInfo>;
102 type OutboundOpenInfo = (K, <H as ProtocolsHandler>::OutboundOpenInfo);
103
104 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
105 let (upgrade, info, timeout) = self.handlers.iter()
106 .map(|(key, handler)| {
107 let proto = handler.listen_protocol();
108 let timeout = *proto.timeout();
109 let (upgrade, info) = proto.into_upgrade();
110 (key.clone(), (upgrade, info, timeout))
111 })
112 .fold((Upgrade::new(), Info::new(), Duration::from_secs(0)),
113 |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
114 upg.upgrades.push((k.clone(), u));
115 inf.infos.push((k, i));
116 timeout = cmp::max(timeout, t);
117 (upg, inf, timeout)
118 }
119 );
120 SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
121 }
122
123 fn inject_fully_negotiated_outbound (
124 &mut self,
125 protocol: <Self::OutboundProtocol as OutboundUpgradeSend>::Output,
126 (key, arg): Self::OutboundOpenInfo
127 ) {
128 if let Some(h) = self.handlers.get_mut(&key) {
129 h.inject_fully_negotiated_outbound(protocol, arg)
130 } else {
131 log::error!("inject_fully_negotiated_outbound: no handler for key")
132 }
133 }
134
135 fn inject_fully_negotiated_inbound (
136 &mut self,
137 (key, arg): <Self::InboundProtocol as InboundUpgradeSend>::Output,
138 mut info: Self::InboundOpenInfo
139 ) {
140 if let Some(h) = self.handlers.get_mut(&key) {
141 if let Some(i) = info.take(&key) {
142 h.inject_fully_negotiated_inbound(arg, i)
143 }
144 } else {
145 log::error!("inject_fully_negotiated_inbound: no handler for key")
146 }
147 }
148
149 fn inject_event(&mut self, (key, event): Self::InEvent) {
150 if let Some(h) = self.handlers.get_mut(&key) {
151 h.inject_event(event)
152 } else {
153 log::error!("inject_event: no handler for key")
154 }
155 }
156
157 fn inject_address_change(&mut self, addr: &Multiaddr) {
158 for h in self.handlers.values_mut() {
159 h.inject_address_change(addr)
160 }
161 }
162
163 fn inject_dial_upgrade_error (
164 &mut self,
165 (key, arg): Self::OutboundOpenInfo,
166 error: ProtocolsHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>
167 ) {
168 if let Some(h) = self.handlers.get_mut(&key) {
169 h.inject_dial_upgrade_error(arg, error)
170 } else {
171 log::error!("inject_dial_upgrade_error: no handler for protocol")
172 }
173 }
174
175 fn inject_listen_upgrade_error(
176 &mut self,
177 mut info: Self::InboundOpenInfo,
178 error: ProtocolsHandlerUpgrErr<<Self::InboundProtocol as InboundUpgradeSend>::Error>
179 ) {
180 match error {
181 ProtocolsHandlerUpgrErr::Timer =>
182 for (k, h) in &mut self.handlers {
183 if let Some(i) = info.take(k) {
184 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timer)
185 }
186 }
187 ProtocolsHandlerUpgrErr::Timeout =>
188 for (k, h) in &mut self.handlers {
189 if let Some(i) = info.take(k) {
190 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timeout)
191 }
192 }
193 ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) =>
194 for (k, h) in &mut self.handlers {
195 if let Some(i) = info.take(k) {
196 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)))
197 }
198 }
199 ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::ProtocolError(e))) =>
200 match e {
201 ProtocolError::IoError(e) =>
202 for (k, h) in &mut self.handlers {
203 if let Some(i) = info.take(k) {
204 let e = NegotiationError::ProtocolError(ProtocolError::IoError(e.kind().into()));
205 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
206 }
207 }
208 ProtocolError::InvalidMessage =>
209 for (k, h) in &mut self.handlers {
210 if let Some(i) = info.take(k) {
211 let e = NegotiationError::ProtocolError(ProtocolError::InvalidMessage);
212 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
213 }
214 }
215 ProtocolError::InvalidProtocol =>
216 for (k, h) in &mut self.handlers {
217 if let Some(i) = info.take(k) {
218 let e = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol);
219 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
220 }
221 }
222 ProtocolError::TooManyProtocols =>
223 for (k, h) in &mut self.handlers {
224 if let Some(i) = info.take(k) {
225 let e = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols);
226 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
227 }
228 }
229 }
230 ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply((k, e))) =>
231 if let Some(h) = self.handlers.get_mut(&k) {
232 if let Some(i) = info.take(&k) {
233 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)))
234 }
235 }
236 }
237 }
238
239 fn connection_keep_alive(&self) -> KeepAlive {
240 self.handlers.values()
241 .map(|h| h.connection_keep_alive())
242 .max()
243 .unwrap_or(KeepAlive::No)
244 }
245
246 fn poll(&mut self, cx: &mut Context<'_>)
247 -> Poll<ProtocolsHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::OutEvent, Self::Error>>
248 {
249 if self.handlers.is_empty() {
252 return Poll::Pending;
253 }
254
255 let pos = rand::thread_rng().gen_range(0, self.handlers.len());
257
258 for (k, h) in self.handlers.iter_mut().skip(pos) {
259 if let Poll::Ready(e) = h.poll(cx) {
260 let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p));
261 return Poll::Ready(e)
262 }
263 }
264
265 for (k, h) in self.handlers.iter_mut().take(pos) {
266 if let Poll::Ready(e) = h.poll(cx) {
267 let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p));
268 return Poll::Ready(e)
269 }
270 }
271
272 Poll::Pending
273 }
274}
275
276#[derive(Clone)]
278pub struct IntoMultiHandler<K, H> {
279 handlers: HashMap<K, H>
280}
281
282impl<K, H> fmt::Debug for IntoMultiHandler<K, H>
283where
284 K: fmt::Debug + Eq + Hash,
285 H: fmt::Debug
286{
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 f.debug_struct("IntoMultiHandler")
289 .field("handlers", &self.handlers)
290 .finish()
291 }
292}
293
294
295impl<K, H> IntoMultiHandler<K, H>
296where
297 K: Hash + Eq,
298 H: IntoProtocolsHandler
299{
300 pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
304 where
305 I: IntoIterator<Item = (K, H)>
306 {
307 let m = IntoMultiHandler { handlers: HashMap::from_iter(iter) };
308 uniq_proto_names(m.handlers.values().map(|h| h.inbound_protocol()))?;
309 Ok(m)
310 }
311}
312
313impl<K, H> IntoProtocolsHandler for IntoMultiHandler<K, H>
314where
315 K: Clone + Eq + Hash + Send + 'static,
316 H: IntoProtocolsHandler
317{
318 type Handler = MultiHandler<K, H::Handler>;
319
320 fn into_handler(self, p: &PeerId, c: &ConnectedPoint) -> Self::Handler {
321 MultiHandler {
322 handlers: self.handlers.into_iter()
323 .map(|(k, h)| (k, h.into_handler(p, c)))
324 .collect()
325 }
326 }
327
328 fn inbound_protocol(&self) -> <Self::Handler as ProtocolsHandler>::InboundProtocol {
329 Upgrade {
330 upgrades: self.handlers.iter()
331 .map(|(k, h)| (k.clone(), h.inbound_protocol()))
332 .collect()
333 }
334 }
335}
336
337#[derive(Debug, Clone)]
339pub struct IndexedProtoName<H>(usize, H);
340
341impl<H: ProtocolName> ProtocolName for IndexedProtoName<H> {
342 fn protocol_name(&self) -> &[u8] {
343 self.1.protocol_name()
344 }
345}
346
347#[derive(Clone)]
349pub struct Info<K, I> {
350 infos: Vec<(K, I)>
351}
352
353impl<K: Eq, I> Info<K, I> {
354 fn new() -> Self {
355 Info { infos: Vec::new() }
356 }
357
358 pub fn take(&mut self, k: &K) -> Option<I> {
359 if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
360 return Some(self.infos.remove(p).1)
361 }
362 None
363 }
364}
365
366#[derive(Clone)]
368pub struct Upgrade<K, H> {
369 upgrades: Vec<(K, H)>
370}
371
372impl<K, H> Upgrade<K, H> {
373 fn new() -> Self {
374 Upgrade { upgrades: Vec::new() }
375 }
376}
377
378impl<K, H> fmt::Debug for Upgrade<K, H>
379where
380 K: fmt::Debug + Eq + Hash,
381 H: fmt::Debug
382{
383 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384 f.debug_struct("Upgrade")
385 .field("upgrades", &self.upgrades)
386 .finish()
387 }
388}
389
390impl<K, H> UpgradeInfoSend for Upgrade<K, H>
391where
392 H: UpgradeInfoSend,
393 K: Send + 'static
394{
395 type Info = IndexedProtoName<H::Info>;
396 type InfoIter = std::vec::IntoIter<Self::Info>;
397
398 fn protocol_info(&self) -> Self::InfoIter {
399 self.upgrades.iter().enumerate()
400 .map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
401 .flatten()
402 .map(|(i, h)| IndexedProtoName(i, h))
403 .collect::<Vec<_>>()
404 .into_iter()
405 }
406}
407
408impl<K, H> InboundUpgradeSend for Upgrade<K, H>
409where
410 H: InboundUpgradeSend,
411 K: Send + 'static
412{
413 type Output = (K, <H as InboundUpgradeSend>::Output);
414 type Error = (K, <H as InboundUpgradeSend>::Error);
415 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
416
417 fn upgrade_inbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
418 let IndexedProtoName(index, info) = info;
419 let (key, upgrade) = self.upgrades.remove(index);
420 upgrade.upgrade_inbound(resource, info)
421 .map(move |out| {
422 match out {
423 Ok(o) => Ok((key, o)),
424 Err(e) => Err((key, e))
425 }
426 })
427 .boxed()
428 }
429}
430
431impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
432where
433 H: OutboundUpgradeSend,
434 K: Send + 'static
435{
436 type Output = (K, <H as OutboundUpgradeSend>::Output);
437 type Error = (K, <H as OutboundUpgradeSend>::Error);
438 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
439
440 fn upgrade_outbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
441 let IndexedProtoName(index, info) = info;
442 let (key, upgrade) = self.upgrades.remove(index);
443 upgrade.upgrade_outbound(resource, info)
444 .map(move |out| {
445 match out {
446 Ok(o) => Ok((key, o)),
447 Err(e) => Err((key, e))
448 }
449 })
450 .boxed()
451 }
452}
453
454fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
456where
457 I: Iterator<Item = T>,
458 T: UpgradeInfoSend
459{
460 let mut set = HashSet::new();
461 for infos in iter {
462 for i in infos.protocol_info() {
463 let v = Vec::from(i.protocol_name());
464 if set.contains(&v) {
465 return Err(DuplicateProtonameError(v))
466 } else {
467 set.insert(v);
468 }
469 }
470 }
471 Ok(())
472}
473
474#[derive(Debug, Clone)]
476pub struct DuplicateProtonameError(Vec<u8>);
477
478impl DuplicateProtonameError {
479 pub fn protocol_name(&self) -> &[u8] {
481 &self.0
482 }
483}
484
485impl fmt::Display for DuplicateProtonameError {
486 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
487 if let Ok(s) = std::str::from_utf8(&self.0) {
488 write!(f, "duplicate protocol name: {}", s)
489 } else {
490 write!(f, "duplicate protocol name: {:?}", self.0)
491 }
492 }
493}
494
495impl error::Error for DuplicateProtonameError {}