1#![deny(missing_docs)]
3
4const MAX_MSG_BYTES: i32 = 20_000;
6
7use std::io::{Error, Result};
8use std::net::{IpAddr, Ipv6Addr};
9use std::sync::Arc;
10
11mod config;
12pub use config::*;
13
14mod maybe_tls;
15pub use maybe_tls::*;
16
17mod ip_deny;
18mod ip_rate;
19pub use ip_rate::*;
20
21mod cslot;
22pub use cslot::*;
23
24mod cmd;
25
26pub mod ws {
28 pub enum Payload {
30 Vec(Vec<u8>),
32
33 BytesMut(bytes::BytesMut),
35 }
36
37 impl std::ops::Deref for Payload {
38 type Target = [u8];
39
40 #[inline(always)]
41 fn deref(&self) -> &Self::Target {
42 match self {
43 Payload::Vec(v) => v.as_slice(),
44 Payload::BytesMut(b) => b.as_ref(),
45 }
46 }
47 }
48
49 impl Payload {
50 #[inline(always)]
52 pub fn to_mut(&mut self) -> &mut [u8] {
53 match self {
54 Payload::Vec(ref mut owned) => owned,
55 Payload::BytesMut(b) => b.as_mut(),
56 }
57 }
58 }
59
60 #[cfg(feature = "tungstenite")]
61 mod ws_tungstenite;
62
63 use futures::future::BoxFuture;
64 #[cfg(feature = "tungstenite")]
65 pub use ws_tungstenite::*;
66
67 #[cfg(all(not(feature = "tungstenite"), feature = "fastwebsockets"))]
68 mod ws_fastwebsockets;
69 #[cfg(all(not(feature = "tungstenite"), feature = "fastwebsockets"))]
70 pub use ws_fastwebsockets::*;
71
72 pub trait SbdWebsocket: Send + Sync + 'static {
74 fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
76
77 fn send(
79 &self,
80 payload: Payload,
81 ) -> BoxFuture<'static, std::io::Result<()>>;
82
83 fn close(&self) -> BoxFuture<'static, ()>;
85 }
86}
87
88pub use ws::{Payload, SbdWebsocket};
89
90#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
92pub struct PubKey(pub Arc<[u8; 32]>);
93
94impl PubKey {
95 pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
97 use ed25519_dalek::Verifier;
98 if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
99 k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
100 .is_ok()
101 } else {
102 false
103 }
104 }
105}
106
107pub struct SbdServer {
109 task_list: Vec<tokio::task::JoinHandle<()>>,
110 bind_addrs: Vec<std::net::SocketAddr>,
111 _cslot: cslot::CSlot,
112}
113
114impl Drop for SbdServer {
115 fn drop(&mut self) {
116 for task in self.task_list.iter() {
117 task.abort();
118 }
119 }
120}
121
122pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
124 Arc::new(match ip {
125 IpAddr::V4(ip) => ip.to_ipv6_mapped(),
126 IpAddr::V6(ip) => ip,
127 })
128}
129
130pub async fn preflight_ip_check(
133 config: &Config,
134 ip_rate: &IpRate,
135 addr: std::net::SocketAddr,
136) -> Option<Arc<Ipv6Addr>> {
137 let raw_ip = to_canonical_ip(addr.ip());
138
139 let use_trusted_ip = config.trusted_ip_header.is_some();
140
141 if !use_trusted_ip {
142 if ip_rate.is_blocked(&raw_ip).await {
148 return None;
149 }
150
151 if !ip_rate.is_ok(&raw_ip, 1).await {
153 return None;
154 }
155 }
156
157 Some(raw_ip)
158}
159
160pub async fn handle_upgraded(
162 config: Arc<Config>,
163 ip_rate: Arc<IpRate>,
164 weak_cslot: WeakCSlot,
165 ws: Arc<impl SbdWebsocket>,
166 pub_key: PubKey,
167 calc_ip: Arc<Ipv6Addr>,
168) {
169 let use_trusted_ip = config.trusted_ip_header.is_some();
170
171 if &pub_key.0[..28] == cmd::CMD_PREFIX {
173 return;
174 }
175
176 let ws = Arc::new(ws);
177
178 if use_trusted_ip {
179 if ip_rate.is_blocked(&calc_ip).await {
182 return;
183 }
184
185 if !ip_rate.is_ok(&calc_ip, 1).await {
187 return;
188 }
189 }
190
191 if let Some(cslot) = weak_cslot.upgrade() {
192 cslot.insert(&config, calc_ip, pub_key, ws).await;
193 }
194}
195
196async fn check_accept_connection(
197 _connect_permit: tokio::sync::OwnedSemaphorePermit,
198 config: Arc<Config>,
199 tls_config: Option<Arc<TlsConfig>>,
200 ip_rate: Arc<IpRate>,
201 tcp: tokio::net::TcpStream,
202 addr: std::net::SocketAddr,
203 weak_cslot: WeakCSlot,
204) {
205 let _ = tokio::time::timeout(config.idle_dur(), async {
206 let Some(mut calc_ip) =
207 preflight_ip_check(&config, &ip_rate, addr).await
208 else {
209 return;
210 };
211
212 let socket = if let Some(tls_config) = &tls_config {
213 match MaybeTlsStream::tls(tls_config, tcp).await {
214 Err(_) => return,
215 Ok(tls) => tls,
216 }
217 } else {
218 MaybeTlsStream::Tcp(tcp)
219 };
220
221 let (ws, pub_key, ip) =
222 match ws::WebSocket::upgrade(config.clone(), socket).await {
223 Ok(r) => r,
224 Err(_) => return,
225 };
226
227 if let Some(ip) = ip {
228 calc_ip = Arc::new(ip);
229 }
230
231 handle_upgraded(
232 config,
233 ip_rate,
234 weak_cslot,
235 Arc::new(ws),
236 pub_key,
237 calc_ip,
238 )
239 .await;
240 })
241 .await;
242}
243
244async fn bind_all<I: IntoIterator<Item = std::net::SocketAddr>>(
245 i: I,
246) -> Vec<tokio::net::TcpListener> {
247 let mut listeners = Vec::new();
248 for a in i.into_iter() {
249 if let Ok(tcp) = tokio::net::TcpListener::bind(a).await {
250 listeners.push(tcp);
251 }
252 }
253 listeners
254}
255
256impl SbdServer {
257 pub async fn new(config: Arc<Config>) -> Result<Self> {
259 let tls_config = if let (Some(cert), Some(pk)) =
260 (&config.cert_pem_file, &config.priv_key_pem_file)
261 {
262 Some(Arc::new(maybe_tls::TlsConfig::new(cert, pk).await?))
263 } else {
264 None
265 };
266
267 let mut task_list = Vec::new();
268 let mut bind_addrs = Vec::new();
269
270 let ip_rate = Arc::new(IpRate::new(config.clone()));
271 task_list.push(spawn_prune_task(ip_rate.clone()));
272
273 let cslot = CSlot::new(config.clone(), ip_rate.clone());
274
275 let connect_limit = Arc::new(tokio::sync::Semaphore::new(1024));
279
280 let mut bind_port_zero = Vec::new();
281 let mut bind_explicit_port = Vec::new();
282
283 for bind in config.bind.iter() {
284 let a: std::net::SocketAddr = bind.parse().map_err(Error::other)?;
285
286 if a.port() == 0 {
287 bind_port_zero.push(a);
288 } else {
289 bind_explicit_port.push(a);
290 }
291 }
292
293 let (mut listeners, mut l2) = tokio::join!(
294 async {
295 if bind_port_zero.is_empty() {
297 return Vec::new();
298 }
299
300 'top: for _ in 0..2 {
302 let mut listeners = Vec::new();
303
304 let mut a_iter = bind_port_zero.iter();
305
306 let a = a_iter.next().unwrap();
307 if let Ok(tcp) = tokio::net::TcpListener::bind(a).await {
308 let port = tcp.local_addr().unwrap().port();
309 listeners.push(tcp);
310
311 for a in a_iter {
312 let mut a = *a;
313 a.set_port(port);
314 match tokio::net::TcpListener::bind(a).await {
315 Err(_) => continue 'top,
316 Ok(tcp) => listeners.push(tcp),
317 }
318 }
319
320 return listeners;
321 }
322 }
323
324 bind_all(bind_port_zero).await
326 },
327 async { bind_all(bind_explicit_port).await },
328 );
329
330 listeners.append(&mut l2);
331
332 if listeners.is_empty() {
333 return Err(Error::other("failed to bind any listeners"));
334 }
335
336 let weak_cslot = cslot.weak();
337 for tcp in listeners {
338 bind_addrs.push(tcp.local_addr()?);
339
340 let tls_config = tls_config.clone();
341 let connect_limit = connect_limit.clone();
342 let config = config.clone();
343 let weak_cslot = weak_cslot.clone();
344 let ip_rate = ip_rate.clone();
345 task_list.push(tokio::task::spawn(async move {
346 loop {
347 if let Ok((tcp, addr)) = tcp.accept().await {
348 let connect_permit =
351 match connect_limit.clone().try_acquire_owned() {
352 Ok(permit) => permit,
353 _ => continue,
354 };
355
356 tokio::task::spawn(check_accept_connection(
359 connect_permit,
360 config.clone(),
361 tls_config.clone(),
362 ip_rate.clone(),
363 tcp,
364 addr,
365 weak_cslot.clone(),
366 ));
367 }
368 }
369 }));
370 }
371
372 Ok(Self {
373 task_list,
374 bind_addrs,
375 _cslot: cslot,
376 })
377 }
378
379 pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
381 self.bind_addrs.as_slice()
382 }
383}
384
385#[cfg(test)]
386mod test;