1use bytes::BytesMut;
4use log::{trace, warn};
5use std::{
6 cell::RefCell,
7 io::Error,
8 net::SocketAddr,
9 rc::Rc,
10 sync::Arc,
11 time::{Duration, Instant},
12};
13use tokio::{
14 io::{AsyncReadExt, AsyncWriteExt},
15 net::ToSocketAddrs,
16 sync::{Notify, broadcast},
17};
18use wg::AsyncWaitGroup;
19
20use sansio::{InboundPipeline, OutboundPipeline, Pipeline};
21use sansio_executor::spawn_local;
22use sansio_transport::{TaggedBytesMut, TransportContext};
23
24mod bootstrap_tcp;
25mod bootstrap_udp;
26
27pub use bootstrap_tcp::{
28 bootstrap_tcp_client::BootstrapTcpClient, bootstrap_tcp_server::BootstrapTcpServer,
29};
30pub use bootstrap_udp::{
31 bootstrap_udp_client::BootstrapUdpClient, bootstrap_udp_server::BootstrapUdpServer,
32};
33
34pub(crate) struct PipelineWithNotify<R, W> {
36 pub(crate) pipeline: Rc<Pipeline<R, W>>,
37 pub(crate) write_notify: Arc<Notify>,
38}
39
40impl<R, W> PipelineWithNotify<R, W> {
41 pub(crate) fn new(pipeline: Rc<Pipeline<R, W>>) -> Self
43 where
44 R: 'static,
45 W: 'static,
46 {
47 let write_notify = Arc::new(Notify::new());
48 let notify_clone = write_notify.clone();
49
50 pipeline.set_write_notify(Arc::new(move || {
51 notify_clone.notify_one();
52 }));
53
54 Self {
55 pipeline,
56 write_notify,
57 }
58 }
59
60 pub(crate) fn pipeline(&self) -> &Rc<Pipeline<R, W>> {
62 &self.pipeline
63 }
64}
65
66pub type PipelineFactoryFn<R, W> = Box<dyn Fn() -> Rc<Pipeline<R, W>>>;
68
69const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(86400); struct Bootstrap<W> {
72 max_payload_size: usize,
73 pipeline_factory_fn: Option<Rc<PipelineFactoryFn<TaggedBytesMut, W>>>,
74 close_tx: Rc<RefCell<Option<broadcast::Sender<()>>>>,
75 wg: Rc<RefCell<Option<AsyncWaitGroup>>>,
76}
77
78impl<W: 'static> Default for Bootstrap<W> {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl<W: 'static> Bootstrap<W> {
85 fn new() -> Self {
86 Self {
87 max_payload_size: 2048, pipeline_factory_fn: None,
89 close_tx: Rc::new(RefCell::new(None)),
90 wg: Rc::new(RefCell::new(None)),
91 }
92 }
93
94 fn max_payload_size(&mut self, max_payload_size: usize) -> &mut Self {
95 self.max_payload_size = max_payload_size;
96 self
97 }
98
99 fn pipeline(&mut self, pipeline_factory_fn: PipelineFactoryFn<TaggedBytesMut, W>) -> &mut Self {
100 self.pipeline_factory_fn = Some(Rc::new(Box::new(pipeline_factory_fn)));
101 self
102 }
103
104 async fn stop(&self) {
105 let mut close_tx = self.close_tx.borrow_mut();
106 if let Some(close_tx) = close_tx.take() {
107 let _ = close_tx.send(());
108 }
109 }
110
111 async fn wait_for_stop(&self) {
112 let wg = {
113 let mut wg = self.wg.borrow_mut();
114 wg.take()
115 };
116 if let Some(wg) = wg {
117 wg.wait().await;
118 }
119 }
120
121 async fn graceful_stop(&self) {
122 self.stop().await;
123 self.wait_for_stop().await;
124 }
125}