1#![deny(missing_docs)]
18
19const MAX_MSG_BYTES: i32 = 20_000;
21
22use base64::Engine;
23use opentelemetry::global;
24use std::collections::HashMap;
25use std::io::{Error, Result};
26use std::net::{IpAddr, Ipv6Addr};
27use std::sync::{Arc, Mutex};
28
29mod config;
30pub use config::*;
31
32mod maybe_tls;
33pub use maybe_tls::*;
34
35mod ip_deny;
36mod ip_rate;
37pub use ip_rate::*;
38
39mod cslot;
40pub use cslot::*;
41
42mod cmd;
43
44mod metrics;
45pub use metrics::*;
46
47pub mod ws {
49 pub enum Payload {
51 Vec(Vec<u8>),
53
54 BytesMut(bytes::BytesMut),
56 }
57
58 impl std::ops::Deref for Payload {
59 type Target = [u8];
60
61 #[inline(always)]
62 fn deref(&self) -> &Self::Target {
63 match self {
64 Payload::Vec(v) => v.as_slice(),
65 Payload::BytesMut(b) => b.as_ref(),
66 }
67 }
68 }
69
70 impl Payload {
71 #[inline(always)]
73 pub fn to_mut(&mut self) -> &mut [u8] {
74 match self {
75 Payload::Vec(ref mut owned) => owned,
76 Payload::BytesMut(b) => b.as_mut(),
77 }
78 }
79 }
80
81 use futures::future::BoxFuture;
82
83 pub trait SbdWebsocket: Send + Sync + 'static {
85 fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
87
88 fn send(
90 &self,
91 payload: Payload,
92 ) -> BoxFuture<'static, std::io::Result<()>>;
93
94 fn close(&self) -> BoxFuture<'static, ()>;
96 }
97}
98
99pub use ws::{Payload, SbdWebsocket};
100
101#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
103pub struct PubKey(pub Arc<[u8; 32]>);
104
105impl PubKey {
106 pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
108 use ed25519_dalek::Verifier;
109 if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
110 k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
111 .is_ok()
112 } else {
113 false
114 }
115 }
116}
117
118pub struct SbdServer {
120 task_list: Vec<tokio::task::JoinHandle<()>>,
121 bind_addrs: Vec<std::net::SocketAddr>,
122
123 _cslot: CSlot,
126}
127
128impl Drop for SbdServer {
129 fn drop(&mut self) {
130 for task in self.task_list.iter() {
131 task.abort();
132 }
133 }
134}
135
136pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
138 Arc::new(match ip {
139 IpAddr::V4(ip) => ip.to_ipv6_mapped(),
140 IpAddr::V6(ip) => ip,
141 })
142}
143
144pub async fn preflight_ip_check(
147 config: &Config,
148 ip_rate: &IpRate,
149 addr: std::net::SocketAddr,
150) -> Option<Arc<Ipv6Addr>> {
151 let raw_ip = to_canonical_ip(addr.ip());
152
153 let use_trusted_ip = config.trusted_ip_header.is_some();
154
155 if !use_trusted_ip {
156 if ip_rate.is_blocked(&raw_ip).await {
162 return None;
163 }
164
165 if !ip_rate.is_ok(&raw_ip, 1).await {
167 return None;
168 }
169 }
170
171 Some(raw_ip)
172}
173
174pub async fn handle_upgraded(
176 config: Arc<Config>,
177 ip_rate: Arc<IpRate>,
178 weak_cslot: WeakCSlot,
179 ws: Arc<impl SbdWebsocket>,
180 pub_key: PubKey,
181 calc_ip: Arc<Ipv6Addr>,
182 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
183) {
184 let use_trusted_ip = config.trusted_ip_header.is_some();
185
186 if &pub_key.0[..28] == cmd::CMD_PREFIX {
188 return;
189 }
190
191 if use_trusted_ip {
192 if ip_rate.is_blocked(&calc_ip).await {
195 return;
196 }
197
198 if !ip_rate.is_ok(&calc_ip, 1).await {
200 return;
201 }
202 }
203
204 if let Some(cslot) = weak_cslot.upgrade() {
205 cslot
206 .insert(&config, calc_ip, pub_key, ws, maybe_auth)
207 .await;
208 }
209}
210
211async fn handle_auth(
213 axum::extract::State(app_state): axum::extract::State<AppState>,
214 body: bytes::Bytes,
215) -> axum::response::Response {
216 use AuthenticateTokenError::*;
217
218 match process_authenticate_token(
220 &app_state.config,
221 &app_state.token_tracker,
222 app_state.auth_failures,
223 body,
224 )
225 .await
226 {
227 Ok(token) => axum::response::IntoResponse::into_response(axum::Json(
228 serde_json::json!({
229 "authToken": *token,
230 }),
231 )),
232 Err(Unauthorized) => {
233 tracing::debug!("/authenticate: UNAUTHORIZED");
234 axum::response::IntoResponse::into_response((
235 axum::http::StatusCode::UNAUTHORIZED,
236 "Unauthorized",
237 ))
238 }
239 Err(HookServerError(err)) => {
240 tracing::debug!(?err, "/authenticate: BAD_GATEWAY");
241 axum::response::IntoResponse::into_response((
242 axum::http::StatusCode::BAD_GATEWAY,
243 format!("BAD_GATEWAY: {err:?}"),
244 ))
245 }
246 Err(OtherError(err)) => {
247 tracing::warn!(?err, "/authenticate: INTERNAL_SERVER_ERROR");
248 axum::response::IntoResponse::into_response((
249 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
250 format!("INTERNAL_SERVER_ERROR: {err:?}"),
251 ))
252 }
253 }
254}
255
256pub enum AuthenticateTokenError {
258 Unauthorized,
260 HookServerError(Error),
262 OtherError(Error),
264}
265
266pub async fn process_authenticate_token(
268 config: &Config,
269 token_tracker: &AuthTokenTracker,
270 auth_failures: opentelemetry::metrics::Counter<u64>,
271 auth_material: bytes::Bytes,
272) -> std::result::Result<Arc<str>, AuthenticateTokenError> {
273 use AuthenticateTokenError::*;
274
275 let token: Arc<str> = if let Some(url) = &config.authentication_hook_server
276 {
277 let url = url.clone();
280 let token = tokio::task::spawn_blocking(move || {
281 ureq::put(&url)
282 .header("Content-Type", "application/octet-stream")
283 .send(&auth_material[..])
284 .map_err(|err| {
285 auth_failures.add(1, &[]);
286
287 match err {
288 ureq::Error::StatusCode(401) => Unauthorized,
289 oth => HookServerError(Error::other(oth)),
290 }
291 })?
292 .into_body()
293 .read_to_string()
294 .map_err(Error::other)
295 .map_err(HookServerError)
299 })
300 .await
301 .map_err(|_| OtherError(Error::other("tokio task died")))??;
302
303 #[derive(serde::Deserialize)]
304 #[serde(rename_all = "camelCase")]
305 struct Token {
306 auth_token: String,
307 }
308
309 let token: Token = serde_json::from_str(&token)
310 .map_err(|err| OtherError(Error::other(err)))?;
311
312 token.auth_token
313 } else {
314 use base64::prelude::*;
317 use rand::Rng;
318
319 let mut bytes = [0; 32];
320 rand::thread_rng().fill(&mut bytes);
321 BASE64_URL_SAFE_NO_PAD.encode(&bytes[..])
322 }
323 .into();
324
325 token_tracker.register_token(token.clone());
327
328 Ok(token)
329}
330
331#[derive(Clone)]
333struct WebsocketImpl {
334 write: Arc<
335 tokio::sync::Mutex<
336 futures::stream::SplitSink<
337 axum::extract::ws::WebSocket,
338 axum::extract::ws::Message,
339 >,
340 >,
341 >,
342 read: Arc<
343 tokio::sync::Mutex<
344 futures::stream::SplitStream<axum::extract::ws::WebSocket>,
345 >,
346 >,
347 attr: Vec<opentelemetry::KeyValue>,
348 bytes_send: opentelemetry::metrics::Counter<u64>,
349 bytes_recv: opentelemetry::metrics::Counter<u64>,
350}
351
352impl SbdWebsocket for WebsocketImpl {
353 fn recv(&self) -> futures::future::BoxFuture<'static, Result<Payload>> {
354 let this = self.clone();
355 Box::pin(async move {
356 let mut read = this.read.lock().await;
357 use futures::stream::StreamExt;
358 loop {
359 match read.next().await {
360 None => return Err(Error::other("closed")),
361 Some(r) => {
362 let msg = r.map_err(Error::other)?;
363 match msg {
364 axum::extract::ws::Message::Text(s) => {
365 this.bytes_recv.add(s.len() as u64, &this.attr);
366 return Ok(Payload::Vec(s.as_bytes().to_vec()));
367 }
368 axum::extract::ws::Message::Binary(v) => {
369 this.bytes_recv.add(v.len() as u64, &this.attr);
370 return Ok(Payload::Vec(v[..].to_vec()));
371 }
372 axum::extract::ws::Message::Ping(_)
373 | axum::extract::ws::Message::Pong(_) => (),
374 axum::extract::ws::Message::Close(_) => {
375 return Err(Error::other("closed"))
376 }
377 }
378 }
379 }
380 }
381 })
382 }
383
384 fn send(
385 &self,
386 payload: Payload,
387 ) -> futures::future::BoxFuture<'static, Result<()>> {
388 use futures::SinkExt;
389 let this = self.clone();
390 Box::pin(async move {
391 let mut write = this.write.lock().await;
392 let v = match payload {
393 Payload::Vec(v) => v,
394 Payload::BytesMut(b) => b.to_vec(),
395 };
396 this.bytes_send.add(v.len() as u64, &this.attr);
397 write
398 .send(axum::extract::ws::Message::Binary(
399 bytes::Bytes::copy_from_slice(&v),
400 ))
401 .await
402 .map_err(Error::other)?;
403 write.flush().await.map_err(Error::other)?;
404 Ok(())
405 })
406 }
407
408 fn close(&self) -> futures::future::BoxFuture<'static, ()> {
409 use futures::SinkExt;
410 let this = self.clone();
411 Box::pin(async move {
412 let _ = this.write.lock().await.close().await;
413 })
414 }
415}
416
417impl WebsocketImpl {
418 fn new(
419 ws: axum::extract::ws::WebSocket,
420 pk: PubKey,
421 meter: &opentelemetry::metrics::Meter,
422 ) -> Self {
423 use futures::StreamExt;
424
425 let bytes_send = meter
426 .u64_counter("sbd.server.bytes_send")
427 .with_description("Number of bytes sent to client")
428 .with_unit("bytes")
429 .build();
430 let bytes_recv = meter
431 .u64_counter("sbd.server.bytes_recv")
432 .with_description("Number of bytes received from client")
433 .with_unit("bytes")
434 .build();
435
436 let (tx, rx) = ws.split();
437 Self {
438 write: Arc::new(tokio::sync::Mutex::new(tx)),
439 read: Arc::new(tokio::sync::Mutex::new(rx)),
440 attr: vec![opentelemetry::KeyValue::new(
441 "pub_key",
442 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(*pk.0),
443 )],
444 bytes_send,
445 bytes_recv,
446 }
447 }
448}
449
450async fn handle_ws(
452 axum::extract::Path(pub_key): axum::extract::Path<String>,
453 headers: axum::http::HeaderMap,
454 ws: axum::extract::WebSocketUpgrade,
455 axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<
456 std::net::SocketAddr,
457 >,
458 axum::extract::State(app_state): axum::extract::State<AppState>,
459) -> impl axum::response::IntoResponse {
460 use axum::response::IntoResponse;
461 use base64::Engine;
462
463 let token: Option<Arc<str>> = headers
465 .get("Authorization")
466 .and_then(|t| t.to_str().ok().map(<Arc<str>>::from));
467
468 let maybe_auth = Some((token.clone(), app_state.token_tracker.clone()));
469
470 if !app_state
472 .token_tracker
473 .check_is_token_valid(&app_state.config, token)
474 {
475 app_state
479 .auth_failures
480 .add(1, &[opentelemetry::KeyValue::new("pub_key", pub_key)]);
481
482 return axum::response::IntoResponse::into_response((
483 axum::http::StatusCode::UNAUTHORIZED,
484 "Unauthorized",
485 ));
486 }
487
488 let pk = match base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(pub_key) {
490 Ok(pk) if pk.len() == 32 => {
491 let mut sized_pk = [0; 32];
492 sized_pk.copy_from_slice(&pk);
493 PubKey(Arc::new(sized_pk))
494 }
495 _ => return axum::http::StatusCode::BAD_REQUEST.into_response(),
496 };
497
498 let mut calc_ip = to_canonical_ip(addr.ip());
499
500 if let Some(trusted_ip_header) = &app_state.config.trusted_ip_header {
502 if let Some(header) =
503 headers.get(trusted_ip_header).and_then(|h| h.to_str().ok())
504 {
505 if let Ok(ip) = header.parse::<IpAddr>() {
506 calc_ip = to_canonical_ip(ip);
507 }
508 }
509 }
510
511 ws.max_message_size(MAX_MSG_BYTES as usize).on_upgrade(
513 move |socket| async move {
514 handle_upgraded(
515 app_state.config.clone(),
516 app_state.ip_rate.clone(),
517 app_state.cslot.clone(),
518 Arc::new(WebsocketImpl::new(
519 socket,
520 pk.clone(),
521 &app_state.meter,
522 )),
523 pk,
524 calc_ip,
525 maybe_auth,
526 )
527 .await;
528 },
529 )
530}
531
532#[derive(Clone, Default)]
534pub struct AuthTokenTracker {
535 token_map: Arc<Mutex<HashMap<Arc<str>, std::time::Instant>>>,
536}
537
538impl AuthTokenTracker {
539 pub fn register_token(&self, token: Arc<str>) {
541 self.token_map
542 .lock()
543 .unwrap()
544 .insert(token, std::time::Instant::now());
545 }
546
547 pub fn check_is_token_valid(
554 &self,
555 config: &Config,
556 token: Option<Arc<str>>,
557 ) -> bool {
558 let token: Arc<str> = if let Some(token) = token {
559 if !token.starts_with("Bearer ") {
562 return false;
563 }
564 token.trim_start_matches("Bearer ").into()
565 } else if config.authentication_hook_server.is_none() {
566 return true;
569 } else {
570 return false;
572 };
573
574 let mut lock = self.token_map.lock().unwrap();
575
576 let idle_dur = config.idle_dur();
577
578 lock.retain(|_t, e| e.elapsed() < idle_dur);
579
580 if let std::collections::hash_map::Entry::Occupied(mut e) =
581 lock.entry(token)
582 {
583 e.insert(std::time::Instant::now());
584 true
585 } else {
586 false
587 }
588 }
589}
590
591#[derive(Clone)]
592struct AppState {
593 config: Arc<Config>,
594 token_tracker: AuthTokenTracker,
595 ip_rate: Arc<IpRate>,
596 cslot: WeakCSlot,
597 auth_failures: opentelemetry::metrics::Counter<u64>,
598 meter: opentelemetry::metrics::Meter,
599}
600
601impl AppState {
602 pub fn new(
603 config: Arc<Config>,
604 ip_rate: Arc<IpRate>,
605 cslot: WeakCSlot,
606 meter: opentelemetry::metrics::Meter,
607 ) -> Self {
608 Self {
609 config,
610 token_tracker: AuthTokenTracker::default(),
611 ip_rate,
612 cslot,
613 auth_failures: meter
614 .u64_counter("sbd.server.auth_failures")
615 .with_description("Number of failed authentication attempts")
616 .with_unit("count")
617 .build(),
618 meter,
619 }
620 }
621}
622
623impl SbdServer {
624 pub async fn new(config: Arc<Config>) -> Result<Self> {
626 let tls_config = if let (Some(cert), Some(pk)) =
627 (&config.cert_pem_file, &config.priv_key_pem_file)
628 {
629 Some(Arc::new(TlsConfig::new(cert, pk).await?))
630 } else {
631 None
632 };
633
634 let sbd_server_meter = global::meter("sbd-server");
635
636 let mut task_list = Vec::new();
637 let mut bind_addrs = Vec::new();
638
639 let ip_rate = Arc::new(IpRate::new(config.clone()));
640 task_list.push(spawn_prune_task(ip_rate.clone()));
641
642 let cslot = CSlot::new(
643 config.clone(),
644 ip_rate.clone(),
645 sbd_server_meter.clone(),
646 );
647 let weak_cslot = cslot.weak();
648
649 let app: axum::Router<()> = axum::Router::new()
651 .route("/authenticate", axum::routing::put(handle_auth))
652 .route("/{pub_key}", axum::routing::any(handle_ws))
653 .layer(axum::extract::DefaultBodyLimit::max(1024))
654 .with_state(AppState::new(
655 config.clone(),
656 ip_rate.clone(),
657 weak_cslot.clone(),
658 sbd_server_meter,
659 ));
660
661 let app =
662 app.into_make_service_with_connect_info::<std::net::SocketAddr>();
663
664 let mut found_port_zero: Option<u16> = None;
665
666 for bind in config.bind.iter() {
668 let mut a: std::net::SocketAddr =
669 bind.parse().map_err(Error::other)?;
670 if let Some(found_port_zero) = &found_port_zero {
671 if a.port() == 0 {
672 a.set_port(*found_port_zero);
673 }
674 }
675
676 let h = axum_server::Handle::new();
677
678 if let Some(tls_config) = &tls_config {
679 let tls_config =
680 axum_server::tls_rustls::RustlsConfig::from_config(
681 tls_config.config(),
682 );
683 let server = axum_server::bind_rustls(a, tls_config)
684 .handle(h.clone())
685 .serve(app.clone());
686 task_list.push(tokio::task::spawn(async move {
687 if let Err(err) = server.await {
688 tracing::error!(?err);
689 }
690 }));
691 } else {
692 let server =
693 axum_server::bind(a).handle(h.clone()).serve(app.clone());
694 task_list.push(tokio::task::spawn(async move {
695 if let Err(err) = server.await {
696 tracing::error!(?err);
697 }
698 }));
699 }
700
701 if let Some(addr) = h.listening().await {
702 if found_port_zero.is_none() && a.port() == 0 {
703 found_port_zero = Some(addr.port());
704 }
705 bind_addrs.push(addr);
706 }
707 }
708
709 Ok(Self {
710 task_list,
711 bind_addrs,
712 _cslot: cslot,
713 })
714 }
715
716 pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
718 self.bind_addrs.as_slice()
719 }
720}
721
722#[cfg(test)]
723mod test;