1#![deny(missing_docs)]
3
4const MAX_MSG_BYTES: i32 = 20_000;
6
7use std::collections::HashMap;
8use std::io::{Error, Result};
9use std::net::{IpAddr, Ipv6Addr};
10use std::sync::{Arc, Mutex};
11
12mod config;
13pub use config::*;
14
15mod maybe_tls;
16pub use maybe_tls::*;
17
18mod ip_deny;
19mod ip_rate;
20pub use ip_rate::*;
21
22mod cslot;
23pub use cslot::*;
24
25mod cmd;
26
27pub mod ws {
29 pub enum Payload {
31 Vec(Vec<u8>),
33
34 BytesMut(bytes::BytesMut),
36 }
37
38 impl std::ops::Deref for Payload {
39 type Target = [u8];
40
41 #[inline(always)]
42 fn deref(&self) -> &Self::Target {
43 match self {
44 Payload::Vec(v) => v.as_slice(),
45 Payload::BytesMut(b) => b.as_ref(),
46 }
47 }
48 }
49
50 impl Payload {
51 #[inline(always)]
53 pub fn to_mut(&mut self) -> &mut [u8] {
54 match self {
55 Payload::Vec(ref mut owned) => owned,
56 Payload::BytesMut(b) => b.as_mut(),
57 }
58 }
59 }
60
61 use futures::future::BoxFuture;
62
63 pub trait SbdWebsocket: Send + Sync + 'static {
65 fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
67
68 fn send(
70 &self,
71 payload: Payload,
72 ) -> BoxFuture<'static, std::io::Result<()>>;
73
74 fn close(&self) -> BoxFuture<'static, ()>;
76 }
77}
78
79pub use ws::{Payload, SbdWebsocket};
80
81#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
83pub struct PubKey(pub Arc<[u8; 32]>);
84
85impl PubKey {
86 pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
88 use ed25519_dalek::Verifier;
89 if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
90 k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
91 .is_ok()
92 } else {
93 false
94 }
95 }
96}
97
98pub struct SbdServer {
100 task_list: Vec<tokio::task::JoinHandle<()>>,
101 bind_addrs: Vec<std::net::SocketAddr>,
102
103 _cslot: CSlot,
106}
107
108impl Drop for SbdServer {
109 fn drop(&mut self) {
110 for task in self.task_list.iter() {
111 task.abort();
112 }
113 }
114}
115
116pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
118 Arc::new(match ip {
119 IpAddr::V4(ip) => ip.to_ipv6_mapped(),
120 IpAddr::V6(ip) => ip,
121 })
122}
123
124pub async fn preflight_ip_check(
127 config: &Config,
128 ip_rate: &IpRate,
129 addr: std::net::SocketAddr,
130) -> Option<Arc<Ipv6Addr>> {
131 let raw_ip = to_canonical_ip(addr.ip());
132
133 let use_trusted_ip = config.trusted_ip_header.is_some();
134
135 if !use_trusted_ip {
136 if ip_rate.is_blocked(&raw_ip).await {
142 return None;
143 }
144
145 if !ip_rate.is_ok(&raw_ip, 1).await {
147 return None;
148 }
149 }
150
151 Some(raw_ip)
152}
153
154pub async fn handle_upgraded(
156 config: Arc<Config>,
157 ip_rate: Arc<IpRate>,
158 weak_cslot: WeakCSlot,
159 ws: Arc<impl SbdWebsocket>,
160 pub_key: PubKey,
161 calc_ip: Arc<Ipv6Addr>,
162 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
163) {
164 let use_trusted_ip = config.trusted_ip_header.is_some();
165
166 if &pub_key.0[..28] == cmd::CMD_PREFIX {
168 return;
169 }
170
171 if use_trusted_ip {
172 if ip_rate.is_blocked(&calc_ip).await {
175 return;
176 }
177
178 if !ip_rate.is_ok(&calc_ip, 1).await {
180 return;
181 }
182 }
183
184 if let Some(cslot) = weak_cslot.upgrade() {
185 cslot
186 .insert(&config, calc_ip, pub_key, ws, maybe_auth)
187 .await;
188 }
189}
190
191async fn handle_auth(
193 axum::extract::State(app_state): axum::extract::State<AppState>,
194 body: bytes::Bytes,
195) -> axum::response::Response {
196 use AuthenticateTokenError::*;
197
198 match process_authenticate_token(
200 &app_state.config,
201 &app_state.token_tracker,
202 body,
203 )
204 .await
205 {
206 Ok(token) => axum::response::IntoResponse::into_response(axum::Json(
207 serde_json::json!({
208 "authToken": *token,
209 }),
210 )),
211 Err(Unauthorized) => {
212 tracing::debug!("/authenticate: UNAUTHORIZED");
213 axum::response::IntoResponse::into_response((
214 axum::http::StatusCode::UNAUTHORIZED,
215 "Unauthorized",
216 ))
217 }
218 Err(HookServerError(err)) => {
219 tracing::debug!(?err, "/authenticate: BAD_GATEWAY");
220 axum::response::IntoResponse::into_response((
221 axum::http::StatusCode::BAD_GATEWAY,
222 format!("BAD_GATEWAY: {err:?}"),
223 ))
224 }
225 Err(OtherError(err)) => {
226 tracing::warn!(?err, "/authenticate: INTERNAL_SERVER_ERROR");
227 axum::response::IntoResponse::into_response((
228 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
229 format!("INTERNAL_SERVER_ERROR: {err:?}"),
230 ))
231 }
232 }
233}
234
235pub enum AuthenticateTokenError {
237 Unauthorized,
239 HookServerError(Error),
241 OtherError(Error),
243}
244
245pub async fn process_authenticate_token(
247 config: &Config,
248 token_tracker: &AuthTokenTracker,
249 auth_material: bytes::Bytes,
250) -> std::result::Result<Arc<str>, AuthenticateTokenError> {
251 use AuthenticateTokenError::*;
252
253 let token: Arc<str> = if let Some(url) = &config.authentication_hook_server
254 {
255 let url = url.clone();
258 let token = tokio::task::spawn_blocking(move || {
259 ureq::put(&url)
260 .header("Content-Type", "application/octet-stream")
261 .send(&auth_material[..])
262 .map_err(|err| match err {
263 ureq::Error::StatusCode(401) => Unauthorized,
264 oth => HookServerError(Error::other(oth)),
265 })?
266 .into_body()
267 .read_to_string()
268 .map_err(Error::other)
269 .map_err(HookServerError)
273 })
274 .await
275 .map_err(|_| OtherError(Error::other("tokio task died")))??;
276
277 #[derive(serde::Deserialize)]
278 #[serde(rename_all = "camelCase")]
279 struct Token {
280 auth_token: String,
281 }
282
283 let token: Token = serde_json::from_str(&token)
284 .map_err(|err| OtherError(Error::other(err)))?;
285
286 token.auth_token
287 } else {
288 use base64::prelude::*;
291 use rand::Rng;
292
293 let mut bytes = [0; 32];
294 rand::thread_rng().fill(&mut bytes);
295 BASE64_URL_SAFE_NO_PAD.encode(&bytes[..])
296 }
297 .into();
298
299 token_tracker.register_token(token.clone());
301
302 Ok(token)
303}
304
305#[derive(Clone)]
307struct WebsocketImpl {
308 write: Arc<
309 tokio::sync::Mutex<
310 futures::stream::SplitSink<
311 axum::extract::ws::WebSocket,
312 axum::extract::ws::Message,
313 >,
314 >,
315 >,
316 read: Arc<
317 tokio::sync::Mutex<
318 futures::stream::SplitStream<axum::extract::ws::WebSocket>,
319 >,
320 >,
321}
322
323impl SbdWebsocket for WebsocketImpl {
324 fn recv(&self) -> futures::future::BoxFuture<'static, Result<Payload>> {
325 let this = self.clone();
326 Box::pin(async move {
327 let mut read = this.read.lock().await;
328 use futures::stream::StreamExt;
329 loop {
330 match read.next().await {
331 None => return Err(Error::other("closed")),
332 Some(r) => {
333 let msg = r.map_err(Error::other)?;
334 match msg {
335 axum::extract::ws::Message::Text(s) => {
336 return Ok(Payload::Vec(s.as_bytes().to_vec()))
337 }
338 axum::extract::ws::Message::Binary(v) => {
339 return Ok(Payload::Vec(v[..].to_vec()))
340 }
341 axum::extract::ws::Message::Ping(_)
342 | axum::extract::ws::Message::Pong(_) => (),
343 axum::extract::ws::Message::Close(_) => {
344 return Err(Error::other("closed"))
345 }
346 }
347 }
348 }
349 }
350 })
351 }
352
353 fn send(
354 &self,
355 payload: Payload,
356 ) -> futures::future::BoxFuture<'static, Result<()>> {
357 use futures::SinkExt;
358 let this = self.clone();
359 Box::pin(async move {
360 let mut write = this.write.lock().await;
361 let v = match payload {
362 Payload::Vec(v) => v,
363 Payload::BytesMut(b) => b.to_vec(),
364 };
365 write
366 .send(axum::extract::ws::Message::Binary(
367 bytes::Bytes::copy_from_slice(&v),
368 ))
369 .await
370 .map_err(Error::other)?;
371 write.flush().await.map_err(Error::other)?;
372 Ok(())
373 })
374 }
375
376 fn close(&self) -> futures::future::BoxFuture<'static, ()> {
377 use futures::SinkExt;
378 let this = self.clone();
379 Box::pin(async move {
380 let _ = this.write.lock().await.close().await;
381 })
382 }
383}
384
385impl WebsocketImpl {
386 fn new(ws: axum::extract::ws::WebSocket) -> Self {
387 use futures::StreamExt;
388 let (tx, rx) = ws.split();
389 Self {
390 write: Arc::new(tokio::sync::Mutex::new(tx)),
391 read: Arc::new(tokio::sync::Mutex::new(rx)),
392 }
393 }
394}
395
396async fn handle_ws(
398 axum::extract::Path(pub_key): axum::extract::Path<String>,
399 headers: axum::http::HeaderMap,
400 ws: axum::extract::WebSocketUpgrade,
401 axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<
402 std::net::SocketAddr,
403 >,
404 axum::extract::State(app_state): axum::extract::State<AppState>,
405) -> impl axum::response::IntoResponse {
406 use axum::response::IntoResponse;
407 use base64::Engine;
408
409 let token: Option<Arc<str>> = headers
411 .get("Authorization")
412 .and_then(|t| t.to_str().ok().map(<Arc<str>>::from));
413
414 let maybe_auth = Some((token.clone(), app_state.token_tracker.clone()));
415
416 if !app_state
418 .token_tracker
419 .check_is_token_valid(&app_state.config, token)
420 {
421 return axum::response::IntoResponse::into_response((
422 axum::http::StatusCode::UNAUTHORIZED,
423 "Unauthorized",
424 ));
425 }
426
427 let pk = match base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(pub_key) {
429 Ok(pk) if pk.len() == 32 => {
430 let mut sized_pk = [0; 32];
431 sized_pk.copy_from_slice(&pk);
432 PubKey(Arc::new(sized_pk))
433 }
434 _ => return axum::http::StatusCode::BAD_REQUEST.into_response(),
435 };
436
437 let mut calc_ip = to_canonical_ip(addr.ip());
438
439 if let Some(trusted_ip_header) = &app_state.config.trusted_ip_header {
441 if let Some(header) =
442 headers.get(trusted_ip_header).and_then(|h| h.to_str().ok())
443 {
444 if let Ok(ip) = header.parse::<IpAddr>() {
445 calc_ip = to_canonical_ip(ip);
446 }
447 }
448 }
449
450 ws.max_message_size(MAX_MSG_BYTES as usize).on_upgrade(
452 move |socket| async move {
453 handle_upgraded(
454 app_state.config.clone(),
455 app_state.ip_rate.clone(),
456 app_state.cslot.clone(),
457 Arc::new(WebsocketImpl::new(socket)),
458 pk,
459 calc_ip,
460 maybe_auth,
461 )
462 .await;
463 },
464 )
465}
466
467#[derive(Clone, Default)]
469pub struct AuthTokenTracker {
470 token_map: Arc<Mutex<HashMap<Arc<str>, std::time::Instant>>>,
471}
472
473impl AuthTokenTracker {
474 pub fn register_token(&self, token: Arc<str>) {
476 self.token_map
477 .lock()
478 .unwrap()
479 .insert(token, std::time::Instant::now());
480 }
481
482 pub fn check_is_token_valid(
489 &self,
490 config: &Config,
491 token: Option<Arc<str>>,
492 ) -> bool {
493 let token: Arc<str> = if let Some(token) = token {
494 if !token.starts_with("Bearer ") {
497 return false;
498 }
499 token.trim_start_matches("Bearer ").into()
500 } else if config.authentication_hook_server.is_none() {
501 return true;
504 } else {
505 return false;
507 };
508
509 let mut lock = self.token_map.lock().unwrap();
510
511 let idle_dur = config.idle_dur();
512
513 lock.retain(|_t, e| e.elapsed() < idle_dur);
514
515 if let std::collections::hash_map::Entry::Occupied(mut e) =
516 lock.entry(token)
517 {
518 e.insert(std::time::Instant::now());
519 true
520 } else {
521 false
522 }
523 }
524}
525
526#[derive(Clone)]
527struct AppState {
528 config: Arc<Config>,
529 token_tracker: AuthTokenTracker,
530 ip_rate: Arc<IpRate>,
531 cslot: WeakCSlot,
532}
533
534impl AppState {
535 pub fn new(
536 config: Arc<Config>,
537 ip_rate: Arc<IpRate>,
538 cslot: WeakCSlot,
539 ) -> Self {
540 Self {
541 config,
542 token_tracker: AuthTokenTracker::default(),
543 ip_rate,
544 cslot,
545 }
546 }
547}
548
549impl SbdServer {
550 pub async fn new(config: Arc<Config>) -> Result<Self> {
552 let tls_config = if let (Some(cert), Some(pk)) =
553 (&config.cert_pem_file, &config.priv_key_pem_file)
554 {
555 Some(Arc::new(TlsConfig::new(cert, pk).await?))
556 } else {
557 None
558 };
559
560 let mut task_list = Vec::new();
561 let mut bind_addrs = Vec::new();
562
563 let ip_rate = Arc::new(IpRate::new(config.clone()));
564 task_list.push(spawn_prune_task(ip_rate.clone()));
565
566 let cslot = CSlot::new(config.clone(), ip_rate.clone());
567 let weak_cslot = cslot.weak();
568
569 let app: axum::Router<()> = axum::Router::new()
571 .route("/authenticate", axum::routing::put(handle_auth))
572 .route("/{pub_key}", axum::routing::any(handle_ws))
573 .layer(axum::extract::DefaultBodyLimit::max(1024))
574 .with_state(AppState::new(
575 config.clone(),
576 ip_rate.clone(),
577 weak_cslot.clone(),
578 ));
579
580 let app =
581 app.into_make_service_with_connect_info::<std::net::SocketAddr>();
582
583 let mut found_port_zero: Option<u16> = None;
584
585 for bind in config.bind.iter() {
587 let mut a: std::net::SocketAddr =
588 bind.parse().map_err(Error::other)?;
589 if let Some(found_port_zero) = &found_port_zero {
590 if a.port() == 0 {
591 a.set_port(*found_port_zero);
592 }
593 }
594
595 let h = axum_server::Handle::new();
596
597 if let Some(tls_config) = &tls_config {
598 let tls_config =
599 axum_server::tls_rustls::RustlsConfig::from_config(
600 tls_config.config(),
601 );
602 let server = axum_server::bind_rustls(a, tls_config)
603 .handle(h.clone())
604 .serve(app.clone());
605 task_list.push(tokio::task::spawn(async move {
606 if let Err(err) = server.await {
607 tracing::error!(?err);
608 }
609 }));
610 } else {
611 let server =
612 axum_server::bind(a).handle(h.clone()).serve(app.clone());
613 task_list.push(tokio::task::spawn(async move {
614 if let Err(err) = server.await {
615 tracing::error!(?err);
616 }
617 }));
618 }
619
620 if let Some(addr) = h.listening().await {
621 if found_port_zero.is_none() && a.port() == 0 {
622 found_port_zero = Some(addr.port());
623 }
624 bind_addrs.push(addr);
625 }
626 }
627
628 Ok(Self {
629 task_list,
630 bind_addrs,
631 _cslot: cslot,
632 })
633 }
634
635 pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
637 self.bind_addrs.as_slice()
638 }
639}
640
641#[cfg(test)]
642mod test;