1mod handler;
2
3use crate::AddPeerOpt;
4use futures::StreamExt;
5use futures_timer::Delay;
6
7use connexa::prelude::{
8 swarm::{
9 self, behaviour::ConnectionEstablished, dial_opts::DialOpts, AddressChange,
10 ConnectionClosed, ConnectionDenied, ConnectionId, DialError, DialFailure, FromSwarm,
11 NetworkBehaviour, NewExternalAddrOfPeer, THandler, THandlerInEvent, ToSwarm,
12 },
13 transport::{transport::PortUse, ConnectedPoint, Endpoint},
14 Multiaddr, PeerId, Protocol,
15};
16
17use pollable_map::futures::FutureMap;
18use std::convert::Infallible;
19use std::fmt::Debug;
20use std::task::Waker;
21use std::time::Duration;
22use std::{
23 collections::{hash_map::Entry, HashMap, HashSet, VecDeque},
24 task::{Context, Poll},
25};
26
27#[derive(Default, Debug, Copy, Clone)]
28pub struct Config {
29 pub store_on_connection: bool,
31 pub keep_connection_alive: bool,
33}
34
35#[derive(Default)]
36pub struct Behaviour {
37 events: VecDeque<ToSwarm<<Self as NetworkBehaviour>::ToSwarm, THandlerInEvent<Self>>>,
38 connections: HashMap<PeerId, HashSet<ConnectionId>>,
39 peer_addresses: HashMap<PeerId, HashSet<Multiaddr>>,
40 peer_keepalive: HashSet<PeerId>,
41 can_reconnect: HashMap<PeerId, (Duration, u8, bool)>,
42 peer_reconnect_attempts: HashMap<PeerId, u8>,
43 reconnect_peers: FutureMap<PeerId, Delay>,
44 config: Config,
45 waker: Option<Waker>,
46}
47
48impl Debug for Behaviour {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.debug_struct("Behaviour").finish()
51 }
52}
53
54impl Behaviour {
55 pub fn with_config(config: Config) -> Self {
56 Self {
57 config,
58 ..Default::default()
59 }
60 }
61 pub fn add_address(&mut self, opt: impl Into<AddPeerOpt>) -> bool {
62 let opt = opt.into();
63
64 let peer_id = opt.peer_id();
65 let addresses = opt.addresses().to_vec();
66
67 if !addresses.is_empty() {
68 self.peer_addresses
69 .entry(*peer_id)
70 .or_default()
71 .extend(addresses);
72
73 if let Some(opts) = opt.to_dial_opts() {
74 self.events.push_back(ToSwarm::Dial { opts });
75 }
76 }
77
78 if (opt.can_keep_alive() || self.config.keep_connection_alive)
79 && self.peer_addresses.contains_key(peer_id)
80 {
81 self.keep_peer_alive(peer_id);
82 }
83
84 if let Some((duration, attempts)) = opt.reconnect_opt() {
85 self.can_reconnect
86 .insert(*peer_id, (duration, attempts, false));
87 }
88
89 if let Some(waker) = self.waker.take() {
90 waker.wake();
91 }
92
93 true
94 }
95
96 pub fn remove_address(&mut self, peer_id: &PeerId, addr: &Multiaddr) -> bool {
97 if let Entry::Occupied(mut e) = self.peer_addresses.entry(*peer_id) {
98 let entry = e.get_mut();
99
100 if !entry.remove(addr) {
101 return false;
102 }
103
104 if entry.is_empty() {
105 e.remove();
106 self.dont_keep_peer_alive(peer_id);
107 }
108 }
109 if let Some(waker) = self.waker.take() {
110 waker.wake();
111 }
112 true
113 }
114
115 pub fn remove_peer(&mut self, peer_id: &PeerId) -> bool {
116 let removed = self.peer_addresses.remove(peer_id).is_some();
117 if removed {
118 self.dont_keep_peer_alive(peer_id);
119 if let Some(waker) = self.waker.take() {
120 waker.wake();
121 }
122 }
123 removed
124 }
125
126 pub fn contains(&self, peer_id: &PeerId, addr: &Multiaddr) -> bool {
127 self.peer_addresses
128 .get(peer_id)
129 .map(|list| list.contains(addr))
130 .unwrap_or_default()
131 }
132
133 pub fn get_peer_addresses(&self, peer_id: &PeerId) -> Option<Vec<Multiaddr>> {
134 self.peer_addresses
135 .get(peer_id)
136 .cloned()
137 .map(Vec::from_iter)
138 }
139
140 pub fn iter(&self) -> impl Iterator<Item = (&PeerId, &HashSet<Multiaddr>)> {
141 self.peer_addresses.iter()
142 }
143
144 fn keep_peer_alive(&mut self, peer_id: &PeerId) {
145 self.peer_keepalive.insert(*peer_id);
146 self.set_peer_protected_status(peer_id, handler::In::Protect);
147 }
148
149 fn dont_keep_peer_alive(&mut self, peer_id: &PeerId) {
150 self.peer_keepalive.remove(peer_id);
151 self.set_peer_protected_status(peer_id, handler::In::Unprotect);
152 }
153
154 fn set_peer_protected_status(&mut self, peer_id: &PeerId, event: handler::In) {
155 if let Some(conns) = self.connections.get(peer_id) {
156 self.events.extend(
157 conns
158 .iter()
159 .copied()
160 .map(|connection_id| ToSwarm::NotifyHandler {
161 peer_id: *peer_id,
162 handler: swarm::NotifyHandler::One(connection_id),
163 event,
164 }),
165 )
166 }
167 }
168
169 fn on_connection_established(
170 &mut self,
171 ConnectionEstablished {
172 peer_id,
173 connection_id,
174 endpoint,
175 ..
176 }: ConnectionEstablished,
177 ) {
178 self.connections
179 .entry(peer_id)
180 .or_default()
181 .insert(connection_id);
182
183 self.reconnect_peers.remove(&peer_id);
184 self.peer_reconnect_attempts.remove(&peer_id);
185
186 if let Entry::Occupied(mut e) = self.can_reconnect.entry(peer_id) {
187 let (_, _, backoff) = e.get_mut();
188 *backoff = false;
189 }
190
191 if self.config.keep_connection_alive && !self.peer_keepalive.contains(&peer_id) {
192 self.keep_peer_alive(&peer_id);
193 }
194
195 if !self.config.store_on_connection {
196 return;
197 }
198
199 let mut addr = address_from_connection_point(endpoint);
200
201 if matches!(addr.iter().last(), Some(Protocol::P2p(_))) {
202 addr.pop();
203 }
204
205 self.peer_addresses.entry(peer_id).or_default().insert(addr);
206 }
207
208 fn on_connection_closed(
209 &mut self,
210 ConnectionClosed {
211 peer_id,
212 connection_id,
213 remaining_established,
214 ..
215 }: ConnectionClosed,
216 ) {
217 if let Entry::Occupied(mut entry) = self.connections.entry(peer_id) {
218 let list = entry.get_mut();
219 list.remove(&connection_id);
220 if list.is_empty() && remaining_established == 0 {
221 entry.remove();
222 if let Some((duration, attempts, backoff)) = self.can_reconnect.get(&peer_id) {
223 if *attempts == 0 || *backoff {
224 return;
225 }
226 self.reconnect_peers.insert(peer_id, Delay::new(*duration));
227 self.peer_reconnect_attempts.insert(peer_id, 0);
228 if let Some(waker) = self.waker.take() {
229 waker.wake();
230 }
231 }
232 }
233 }
234 }
235
236 fn on_address_change(
237 &mut self,
238 AddressChange {
239 peer_id, old, new, ..
240 }: AddressChange,
241 ) {
242 let mut old = address_from_connection_point(old);
243
244 if matches!(old.iter().last(), Some(Protocol::P2p(_))) {
245 old.pop();
246 }
247
248 let mut new = address_from_connection_point(new);
249
250 if matches!(new.iter().last(), Some(Protocol::P2p(_))) {
251 new.pop();
252 }
253
254 if let Entry::Occupied(mut e) = self.peer_addresses.entry(peer_id) {
255 let entry = e.get_mut();
256 entry.insert(new);
257 entry.remove(&old);
258 }
259 }
260
261 fn on_dial_failure(
262 &mut self,
263 DialFailure {
264 peer_id,
265 error,
266 connection_id,
267 }: DialFailure,
268 ) {
269 let Some(peer_id) = peer_id else {
270 return;
271 };
272
273 match error {
274 DialError::LocalPeerId { .. } => {
275 tracing::error!(%peer_id, %connection_id, "local peer id is not allowed to dial");
276 self.reconnect_peers.remove(&peer_id);
277 self.peer_reconnect_attempts.remove(&peer_id);
278 self.peer_keepalive.remove(&peer_id);
279 self.peer_addresses.remove(&peer_id);
280 return;
281 }
282 DialError::NoAddresses => {
283 tracing::error!(%peer_id, %connection_id, "no addresses to dial");
284 self.reconnect_peers.remove(&peer_id);
285 self.peer_reconnect_attempts.remove(&peer_id);
286 return;
287 }
288 DialError::DialPeerConditionFalse(_) => {}
289 DialError::Aborted => {}
290 DialError::WrongPeerId { .. } => {
291 tracing::error!(%peer_id, %connection_id, "wrong peer id");
292 self.reconnect_peers.remove(&peer_id);
293 self.peer_reconnect_attempts.remove(&peer_id);
294 self.peer_keepalive.remove(&peer_id);
295 self.peer_addresses.remove(&peer_id);
296 return;
297 }
298 DialError::Denied { .. } => {}
299 DialError::Transport(_) => {}
300 }
301
302 if let Some((duration, attempts, backoff)) = self.can_reconnect.get_mut(&peer_id) {
303 let current_attempts = self.peer_reconnect_attempts.entry(peer_id).or_insert(0);
304 if *current_attempts >= *attempts {
305 let current_attempts = *current_attempts;
306 self.peer_reconnect_attempts.remove(&peer_id);
307 self.reconnect_peers.remove(&peer_id);
308 *backoff = true;
309 tracing::debug!(%peer_id, current_attempts, max_attempts = attempts, "unable to reconnect. backing off on attempts at reconnection");
310 return;
311 }
312
313 if *backoff {
314 return;
315 }
316
317 *current_attempts += 1;
318
319 tracing::info!(%peer_id, next_attempt = *current_attempts, max_attempts = attempts, "attempting reconnection to peer");
320
321 if !self.reconnect_peers.contains_key(&peer_id) {
325 self.reconnect_peers.insert(peer_id, Delay::new(*duration));
326 } else {
327 let timer = self
328 .reconnect_peers
329 .get_mut(&peer_id)
330 .expect("timer available");
331 timer.reset(*duration);
332 }
333
334 if let Some(waker) = self.waker.take() {
335 waker.wake();
336 }
337 }
338 }
339
340 fn on_external_addr_of_peer(
341 &mut self,
342 NewExternalAddrOfPeer { peer_id, addr }: NewExternalAddrOfPeer,
343 ) {
344 self.peer_addresses
345 .entry(peer_id)
346 .or_default()
347 .insert(addr.clone());
348 }
349}
350
351impl NetworkBehaviour for Behaviour {
352 type ConnectionHandler = handler::Handler;
353 type ToSwarm = Infallible;
354
355 fn handle_established_inbound_connection(
356 &mut self,
357 _: ConnectionId,
358 peer_id: PeerId,
359 _: &Multiaddr,
360 _: &Multiaddr,
361 ) -> Result<THandler<Self>, ConnectionDenied> {
362 let keepalive = self.peer_keepalive.contains(&peer_id);
363 Ok(handler::Handler::new(keepalive))
364 }
365
366 fn handle_pending_outbound_connection(
367 &mut self,
368 _: ConnectionId,
369 peer_id: Option<PeerId>,
370 _: &[Multiaddr],
371 _: Endpoint,
372 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
373 let Some(peer_id) = peer_id else {
374 return Ok(vec![]);
375 };
376
377 let list = self
378 .peer_addresses
379 .get(&peer_id)
380 .cloned()
381 .map(Vec::from_iter)
382 .unwrap_or_default();
383
384 Ok(list)
385 }
386
387 fn handle_established_outbound_connection(
388 &mut self,
389 _: ConnectionId,
390 peer_id: PeerId,
391 _: &Multiaddr,
392 _: Endpoint,
393 _: PortUse,
394 ) -> Result<THandler<Self>, ConnectionDenied> {
395 let keepalive = self.peer_keepalive.contains(&peer_id);
396 Ok(handler::Handler::new(keepalive))
397 }
398
399 fn on_swarm_event(&mut self, event: FromSwarm) {
400 match event {
401 FromSwarm::AddressChange(ev) => self.on_address_change(ev),
402 FromSwarm::ConnectionEstablished(ev) => self.on_connection_established(ev),
403 FromSwarm::ConnectionClosed(ev) => self.on_connection_closed(ev),
404 FromSwarm::DialFailure(ev) => self.on_dial_failure(ev),
405 FromSwarm::NewExternalAddrOfPeer(ev) => self.on_external_addr_of_peer(ev),
406 FromSwarm::ListenFailure(_) => {}
407 FromSwarm::NewListener(_) => {}
408 FromSwarm::NewListenAddr(_) => {}
409 FromSwarm::ExpiredListenAddr(_) => {}
410 FromSwarm::ListenerError(_) => {}
411 FromSwarm::ListenerClosed(_) => {}
412 FromSwarm::NewExternalAddrCandidate(_) => {}
413 FromSwarm::ExternalAddrConfirmed(_) => {}
414 FromSwarm::ExternalAddrExpired(_) => {}
415 _ => {}
416 }
417 }
418
419 fn on_connection_handler_event(
420 &mut self,
421 _: PeerId,
422 _: ConnectionId,
423 _: swarm::THandlerOutEvent<Self>,
424 ) {
425 }
426
427 fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
428 if let Some(event) = self.events.pop_front() {
429 return Poll::Ready(event);
430 }
431
432 while let Poll::Ready(Some((peer_id, _))) = self.reconnect_peers.poll_next_unpin(cx) {
433 let opts = DialOpts::peer_id(peer_id).build();
434 self.events.push_back(ToSwarm::Dial { opts });
435 }
436
437 Poll::Pending
438 }
439}
440
441fn address_from_connection_point(connection_point: &ConnectedPoint) -> Multiaddr {
442 match connection_point {
443 ConnectedPoint::Dialer { address, .. } => address.clone(),
444 ConnectedPoint::Listener { local_addr, .. } if connection_point.is_relayed() => {
445 local_addr.clone()
446 }
447 ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr.clone(),
448 }
449}
450
451#[cfg(test)]
452mod test {
453 use std::time::Duration;
454
455 use connexa::prelude::{
456 swarm::{dial_opts::DialOpts, Swarm, SwarmBuilder, SwarmEvent},
457 Multiaddr, PeerId,
458 };
459 use futures::{FutureExt, StreamExt};
460
461 use crate::{AddPeerOpt, NetworkBehaviour};
462
463 async fn wait_on_connection<B: NetworkBehaviour>(
464 swarm1: &mut Swarm<B>,
465 swarm2: &mut Swarm<B>,
466 peer_id: PeerId,
467 ) {
468 loop {
469 futures::select! {
470 event = swarm1.select_next_some() => {
471 if let SwarmEvent::ConnectionEstablished { peer_id: peer, .. } = event {
472 assert_eq!(peer, peer_id);
473 break;
474 }
475 }
476 _ = swarm2.next() => {}
477 }
478 }
479 }
480
481 #[tokio::test]
482 async fn dial_with_peer_id() -> anyhow::Result<()> {
483 let (_, _, mut swarm1) = build_swarm(false).await;
484 let (peer2, addr2, mut swarm2) = build_swarm(false).await;
485
486 let opts = AddPeerOpt::with_peer_id(peer2).add_address(addr2);
487
488 swarm1.behaviour_mut().add_address(opts);
489
490 swarm1.dial(peer2)?;
491
492 wait_on_connection(&mut swarm1, &mut swarm2, peer2).await;
493 Ok(())
494 }
495
496 #[tokio::test]
497 async fn remove_peer_address() -> anyhow::Result<()> {
498 let (_, _, mut swarm1) = build_swarm(false).await;
499 let (peer2, addr2, mut swarm2) = build_swarm(false).await;
500 let opts = AddPeerOpt::with_peer_id(peer2).add_address(addr2);
501 swarm1.behaviour_mut().add_address(opts);
502
503 swarm1.dial(peer2)?;
504
505 wait_on_connection(&mut swarm1, &mut swarm2, peer2).await;
506
507 swarm1.disconnect_peer_id(peer2).expect("Shouldnt fail");
508
509 loop {
510 futures::select! {
511 event = swarm1.select_next_some() => {
512 if let SwarmEvent::ConnectionClosed { peer_id, .. } = event {
513 assert_eq!(peer_id, peer2);
514 break;
515 }
516 }
517 _ = swarm2.next() => {}
518 }
519 }
520
521 swarm1.behaviour_mut().remove_peer(&peer2);
522
523 assert!(swarm1.dial(peer2).is_err());
524
525 Ok(())
526 }
527
528 #[tokio::test]
529 async fn dial_and_keepalive() -> anyhow::Result<()> {
530 let (peer1, addr1, mut swarm1) = build_swarm(false).await;
531 let (peer2, addr2, mut swarm2) = build_swarm(false).await;
532 let opts_1 = AddPeerOpt::with_peer_id(peer2)
533 .add_address(addr2)
534 .keepalive();
535 swarm1.behaviour_mut().add_address(opts_1);
536
537 let opts_2 = AddPeerOpt::with_peer_id(peer1)
538 .add_address(addr1)
539 .keepalive();
540 swarm2.behaviour_mut().add_address(opts_2);
541
542 swarm1.dial(peer2)?;
543
544 let mut peer_a_connected = false;
545 let mut peer_b_connected = false;
546
547 loop {
548 futures::select! {
549 event = swarm1.select_next_some() => {
550 if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
551 assert_eq!(peer_id, peer2);
552 peer_b_connected = true;
553 }
554 }
555 event = swarm2.select_next_some() => {
556 if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
557 assert_eq!(peer_id, peer1);
558 peer_a_connected = true;
559 }
560 }
561 }
562
563 if peer_a_connected && peer_b_connected {
564 break;
565 }
566 }
567
568 let mut timer = futures_timer::Delay::new(Duration::from_secs(4)).fuse();
569
570 loop {
571 futures::select! {
572 _ = &mut timer => {
573 break;
574 }
575 event = swarm1.select_next_some() => {
576 if let SwarmEvent::ConnectionClosed { peer_id, .. } = event {
577 assert_eq!(peer_id, peer2);
578 unreachable!("connection shouldnt have closed")
579 }
580 }
581 event = swarm2.select_next_some() => {
582 if let SwarmEvent::ConnectionClosed { peer_id, .. } = event {
583 assert_eq!(peer_id, peer1);
584 unreachable!("connection shouldnt have closed")
585 }
586 }
587 }
588 }
589
590 Ok(())
591 }
592
593 #[tokio::test]
594 async fn store_address() -> anyhow::Result<()> {
595 let (_, _, mut swarm1) = build_swarm(true).await;
596 let (peer2, addr2, mut swarm2) = build_swarm(true).await;
597
598 let opt = DialOpts::peer_id(peer2)
599 .addresses(vec![addr2.clone()])
600 .build();
601
602 swarm1.dial(opt)?;
603
604 wait_on_connection(&mut swarm1, &mut swarm2, peer2).await;
605
606 let addrs = swarm1
607 .behaviour()
608 .get_peer_addresses(&peer2)
609 .expect("Exist");
610
611 for addr in addrs {
612 assert_eq!(addr, addr2);
613 }
614 Ok(())
615 }
616
617 async fn build_swarm(
618 store_on_connection: bool,
619 ) -> (PeerId, Multiaddr, Swarm<super::Behaviour>) {
620 use connexa::prelude::transport::{noise, tcp, yamux};
621
622 let mut swarm = SwarmBuilder::with_new_identity()
623 .with_tokio()
624 .with_tcp(
625 tcp::Config::default(),
626 noise::Config::new,
627 yamux::Config::default,
628 )
629 .expect("")
630 .with_behaviour(|_| {
631 super::Behaviour::with_config(super::Config {
632 store_on_connection,
633 ..Default::default()
634 })
635 })
636 .expect("")
637 .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(3)))
638 .build();
639
640 Swarm::listen_on(&mut swarm, "/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap();
641
642 if let Some(SwarmEvent::NewListenAddr { address, .. }) = swarm.next().await {
643 let peer_id = swarm.local_peer_id();
644 return (*peer_id, address, swarm);
645 }
646
647 panic!("no new addrs")
648 }
649}