1use std::future::Future;
10use std::net::SocketAddr;
11use std::str::from_utf8;
12use std::sync::{Arc, Weak};
13
14use async_trait::async_trait;
15use cfg_if::cfg_if;
16use log::{debug, warn};
17
18use crate::bytes::{ByteBuffer, DynamicByteBuffer, StaticByteBuffer};
19pub use crate::flow::decoy::{DecoyFactory, decoy_factory, random_decoy_factory};
20pub use crate::flow::probe::{ActiveProbeHandler, ProbeFactory, ProbeFlowSender, probe_factory};
21use crate::settings::Settings;
22use crate::settings::consts::DEFAULT_TYPHOON_ID_LENGTH;
23pub use crate::tailer::{ClientConnectionHandler, ServerConnectionHandler};
24use crate::tailer::{IdentityType, Tailer};
25use crate::utils::random::{SupportRng, get_rng};
26pub use crate::utils::sync::AsyncExecutor;
27
28cfg_if! {
29 if #[cfg(feature = "tokio")] {
30 use tokio::spawn;
31 use tokio::runtime::Handle;
32 use tokio::task::block_in_place;
33 } else if #[cfg(feature = "async-std")] {
34 use async_io::block_on as async_io_block_on;
35 }
36}
37
38fn parse_version(bytes: &[u8]) -> (u64, u64, u64) {
41 let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());
42 let s = from_utf8(&bytes[..end]).unwrap_or("").trim();
43 let base = s.split('-').next().unwrap_or(s);
44 let mut parts = base.split('.');
45 let major = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0);
46 let minor = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0);
47 let patch = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0);
48 (major, minor, patch)
49}
50
51impl IdentityType for StaticByteBuffer {
52 fn from_bytes(bytes: &[u8]) -> Self {
53 assert_eq!(bytes.len(), DEFAULT_TYPHOON_ID_LENGTH, "invalid bytes identity length: expected {}, got {}", DEFAULT_TYPHOON_ID_LENGTH, bytes.len());
54 Self::from_slice(bytes)
55 }
56
57 fn to_bytes(&self) -> &[u8] {
58 self.slice()
59 }
60
61 fn length() -> usize {
62 DEFAULT_TYPHOON_ID_LENGTH
63 }
64}
65
66#[cfg(feature = "tokio")]
68#[derive(Clone)]
69pub struct TokioExecutor;
70
71#[cfg(feature = "tokio")]
72impl AsyncExecutor for TokioExecutor {
73 fn new() -> Self {
74 Self
75 }
76
77 fn spawn<F: Future<Output = ()> + Send + 'static>(&self, future: F) {
78 spawn(future);
79 }
80
81 fn block_on<F: Future<Output = ()>>(&self, future: F) {
82 block_in_place(|| Handle::current().block_on(future));
83 }
84}
85
86#[cfg(feature = "async-std")]
88#[derive(Clone)]
89pub struct AsyncStdExecutor {
90 executor: Arc<async_executor::Executor<'static>>,
91}
92
93#[cfg(feature = "async-std")]
94impl AsyncExecutor for AsyncStdExecutor {
95 fn new() -> Self {
96 Self {
97 executor: Arc::new(async_executor::Executor::new()),
98 }
99 }
100
101 fn spawn<F: Future<Output = ()> + Send + 'static>(&self, future: F) {
102 self.executor.spawn(future).detach();
103 }
104
105 fn block_on<F: Future<Output = ()>>(&self, future: F) {
106 async_io_block_on(future);
107 }
108}
109
110#[cfg(feature = "async-std")]
111impl From<Arc<async_executor::Executor<'static>>> for AsyncStdExecutor {
112 fn from(executor: Arc<async_executor::Executor<'static>>) -> Self {
113 Self {
114 executor,
115 }
116 }
117}
118
119cfg_if! {
122 if #[cfg(feature = "tokio")] {
123 pub type DefaultExecutor = TokioExecutor;
125 } else if #[cfg(feature = "async-std")] {
126 pub type DefaultExecutor = AsyncStdExecutor;
128 }
129}
130
131pub type DefaultSettings = Settings<DefaultExecutor>;
132
133pub type DefaultTailer = Tailer<StaticByteBuffer>;
134
135pub struct DefaultServerConnectionHandler;
138
139impl ServerConnectionHandler<StaticByteBuffer> for DefaultServerConnectionHandler {
140 fn generate(&self, _initial_data: &[u8]) -> StaticByteBuffer {
141 StaticByteBuffer::from_slice(get_rng().random_byte_buffer::<DEFAULT_TYPHOON_ID_LENGTH>().slice())
142 }
143
144 fn initial_data(&self, _identity: &StaticByteBuffer) -> StaticByteBuffer {
145 StaticByteBuffer::from_slice(&[])
146 }
147
148 fn verify_version(&self, version_bytes: &[u8]) -> bool {
149 let (cli_major, cli_minor, cli_patch) = parse_version(version_bytes);
150 let (srv_major, srv_minor, srv_patch) = parse_version(env!("CARGO_PKG_VERSION").as_bytes());
151 if cli_major != srv_major {
152 warn!("client version major mismatch (client={cli_major}.{cli_minor}.{cli_patch}, server={srv_major}.{srv_minor}.{srv_patch}), rejecting handshake");
153 false
154 } else if cli_minor != srv_minor {
155 warn!("client version minor mismatch (client={cli_major}.{cli_minor}.{cli_patch}, server={srv_major}.{srv_minor}.{srv_patch})");
156 true
157 } else if cli_patch != srv_patch {
158 debug!("client version patch mismatch (client={cli_major}.{cli_minor}.{cli_patch}, server={srv_major}.{srv_minor}.{srv_patch})");
159 true
160 } else {
161 true
162 }
163 }
164}
165
166#[derive(Default)]
169pub struct NoopProbeHandler;
170
171#[async_trait]
172impl<AE: AsyncExecutor + 'static> ActiveProbeHandler<AE> for NoopProbeHandler {
173 async fn start(&mut self, _: Weak<dyn ProbeFlowSender>, _: Arc<Settings<AE>>) {}
174 async fn process(&mut self, _: DynamicByteBuffer, _: Option<SocketAddr>) {}
175}
176
177pub struct DefaultClientConnectionHandler;
180
181impl ClientConnectionHandler for DefaultClientConnectionHandler {
182 fn initial_data(&self) -> StaticByteBuffer {
183 StaticByteBuffer::from_slice(&[])
184 }
185
186 fn version(&self, length: usize) -> StaticByteBuffer {
187 let ver = env!("CARGO_PKG_VERSION").as_bytes();
188 let copy_len = ver.len().min(length);
189 let mut buf = vec![0u8; length];
190 buf[..copy_len].copy_from_slice(&ver[..copy_len]);
191 StaticByteBuffer::from(buf)
193 }
194}