1#![deny(missing_docs)]
3
4const MAX_MSG_BYTES: i32 = 20_000;
6
7use std::collections::HashMap;
8use std::io::{Error, Result};
9use std::net::{IpAddr, Ipv6Addr};
10use std::sync::{Arc, Mutex};
11
12mod config;
13pub use config::*;
14
15mod maybe_tls;
16pub use maybe_tls::*;
17
18mod ip_deny;
19mod ip_rate;
20pub use ip_rate::*;
21
22mod cslot;
23pub use cslot::*;
24
25mod cmd;
26
27pub mod ws {
29 pub enum Payload {
31 Vec(Vec<u8>),
33
34 BytesMut(bytes::BytesMut),
36 }
37
38 impl std::ops::Deref for Payload {
39 type Target = [u8];
40
41 #[inline(always)]
42 fn deref(&self) -> &Self::Target {
43 match self {
44 Payload::Vec(v) => v.as_slice(),
45 Payload::BytesMut(b) => b.as_ref(),
46 }
47 }
48 }
49
50 impl Payload {
51 #[inline(always)]
53 pub fn to_mut(&mut self) -> &mut [u8] {
54 match self {
55 Payload::Vec(ref mut owned) => owned,
56 Payload::BytesMut(b) => b.as_mut(),
57 }
58 }
59 }
60
61 use futures::future::BoxFuture;
62
63 pub trait SbdWebsocket: Send + Sync + 'static {
65 fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
67
68 fn send(
70 &self,
71 payload: Payload,
72 ) -> BoxFuture<'static, std::io::Result<()>>;
73
74 fn close(&self) -> BoxFuture<'static, ()>;
76 }
77}
78
79pub use ws::{Payload, SbdWebsocket};
80
81#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
83pub struct PubKey(pub Arc<[u8; 32]>);
84
85impl PubKey {
86 pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
88 use ed25519_dalek::Verifier;
89 if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
90 k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
91 .is_ok()
92 } else {
93 false
94 }
95 }
96}
97
98pub struct SbdServer {
100 task_list: Vec<tokio::task::JoinHandle<()>>,
101 bind_addrs: Vec<std::net::SocketAddr>,
102 _cslot: cslot::CSlot,
103}
104
105impl Drop for SbdServer {
106 fn drop(&mut self) {
107 for task in self.task_list.iter() {
108 task.abort();
109 }
110 }
111}
112
113pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
115 Arc::new(match ip {
116 IpAddr::V4(ip) => ip.to_ipv6_mapped(),
117 IpAddr::V6(ip) => ip,
118 })
119}
120
121pub async fn preflight_ip_check(
124 config: &Config,
125 ip_rate: &IpRate,
126 addr: std::net::SocketAddr,
127) -> Option<Arc<Ipv6Addr>> {
128 let raw_ip = to_canonical_ip(addr.ip());
129
130 let use_trusted_ip = config.trusted_ip_header.is_some();
131
132 if !use_trusted_ip {
133 if ip_rate.is_blocked(&raw_ip).await {
139 return None;
140 }
141
142 if !ip_rate.is_ok(&raw_ip, 1).await {
144 return None;
145 }
146 }
147
148 Some(raw_ip)
149}
150
151pub async fn handle_upgraded(
153 config: Arc<Config>,
154 ip_rate: Arc<IpRate>,
155 weak_cslot: WeakCSlot,
156 ws: Arc<impl SbdWebsocket>,
157 pub_key: PubKey,
158 calc_ip: Arc<Ipv6Addr>,
159 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
160) {
161 let use_trusted_ip = config.trusted_ip_header.is_some();
162
163 if &pub_key.0[..28] == cmd::CMD_PREFIX {
165 return;
166 }
167
168 if use_trusted_ip {
169 if ip_rate.is_blocked(&calc_ip).await {
172 return;
173 }
174
175 if !ip_rate.is_ok(&calc_ip, 1).await {
177 return;
178 }
179 }
180
181 if let Some(cslot) = weak_cslot.upgrade() {
182 cslot
183 .insert(&config, calc_ip, pub_key, ws, maybe_auth)
184 .await;
185 }
186}
187
188async fn handle_auth(
189 axum::extract::State(app_state): axum::extract::State<AppState>,
190 body: bytes::Bytes,
191) -> axum::response::Response {
192 use AuthenticateTokenError::*;
193
194 match process_authenticate_token(
195 &app_state.config,
196 &app_state.token_tracker,
197 body,
198 )
199 .await
200 {
201 Ok(token) => axum::response::IntoResponse::into_response(axum::Json(
202 serde_json::json!({
203 "authToken": *token,
204 }),
205 )),
206 Err(Unauthorized) => {
207 tracing::debug!("/authenticate: UNAUTHORIZED");
208 axum::response::IntoResponse::into_response((
209 axum::http::StatusCode::UNAUTHORIZED,
210 "Unauthorized",
211 ))
212 }
213 Err(HookServerError(err)) => {
214 tracing::debug!(?err, "/authenticate: BAD_GATEWAY");
215 axum::response::IntoResponse::into_response((
216 axum::http::StatusCode::BAD_GATEWAY,
217 format!("BAD_GATEWAY: {err:?}"),
218 ))
219 }
220 Err(OtherError(err)) => {
221 tracing::warn!(?err, "/authenticate: INTERNAL_SERVER_ERROR");
222 axum::response::IntoResponse::into_response((
223 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
224 format!("INTERNAL_SERVER_ERROR: {err:?}"),
225 ))
226 }
227 }
228}
229
230pub enum AuthenticateTokenError {
232 Unauthorized,
234 HookServerError(std::io::Error),
236 OtherError(std::io::Error),
238}
239
240pub async fn process_authenticate_token(
242 config: &Config,
243 token_tracker: &AuthTokenTracker,
244 auth_material: bytes::Bytes,
245) -> std::result::Result<Arc<str>, AuthenticateTokenError> {
246 use AuthenticateTokenError::*;
247
248 let token: Arc<str> = if let Some(url) = &config.authentication_hook_server
249 {
250 let url = url.clone();
251 tokio::task::spawn_blocking(move || {
252 ureq::put(&url)
253 .set("Content-Type", "application/octet-stream")
254 .send(&auth_material[..])
255 .map_err(|err| HookServerError(std::io::Error::other(err)))?
256 .into_string()
257 .map_err(HookServerError)
261 })
262 .await
263 .map_err(|_| OtherError(std::io::Error::other("tokio task died")))??
264 } else {
265 use base64::prelude::*;
267 use rand::Rng;
268
269 let mut bytes = [0; 32];
270 rand::thread_rng().fill(&mut bytes);
271 BASE64_URL_SAFE_NO_PAD.encode(&bytes[..])
272 }
273 .into();
274
275 token_tracker.register_token(token.clone());
276
277 Ok(token)
278}
279
280#[derive(Clone)]
281struct WebsocketImpl {
282 write: Arc<
283 tokio::sync::Mutex<
284 futures::stream::SplitSink<
285 axum::extract::ws::WebSocket,
286 axum::extract::ws::Message,
287 >,
288 >,
289 >,
290 read: Arc<
291 tokio::sync::Mutex<
292 futures::stream::SplitStream<axum::extract::ws::WebSocket>,
293 >,
294 >,
295}
296
297impl SbdWebsocket for WebsocketImpl {
298 fn recv(&self) -> futures::future::BoxFuture<'static, Result<Payload>> {
299 let this = self.clone();
300 Box::pin(async move {
301 let mut read = this.read.lock().await;
302 use futures::stream::StreamExt;
303 loop {
304 match read.next().await {
305 None => return Err(Error::other("closed")),
306 Some(r) => {
307 let msg = r.map_err(Error::other)?;
308 match msg {
309 axum::extract::ws::Message::Text(s) => {
310 return Ok(Payload::Vec(s.as_bytes().to_vec()))
311 }
312 axum::extract::ws::Message::Binary(v) => {
313 return Ok(Payload::Vec(v[..].to_vec()))
314 }
315 axum::extract::ws::Message::Ping(_)
316 | axum::extract::ws::Message::Pong(_) => (),
317 axum::extract::ws::Message::Close(_) => {
318 return Err(Error::other("closed"))
319 }
320 }
321 }
322 }
323 }
324 })
325 }
326
327 fn send(
328 &self,
329 payload: Payload,
330 ) -> futures::future::BoxFuture<'static, Result<()>> {
331 use futures::SinkExt;
332 let this = self.clone();
333 Box::pin(async move {
334 let mut write = this.write.lock().await;
335 let v = match payload {
336 Payload::Vec(v) => v,
337 Payload::BytesMut(b) => b.to_vec(),
338 };
339 write
340 .send(axum::extract::ws::Message::Binary(
341 bytes::Bytes::copy_from_slice(&v),
342 ))
343 .await
344 .map_err(Error::other)?;
345 write.flush().await.map_err(Error::other)?;
346 Ok(())
347 })
348 }
349
350 fn close(&self) -> futures::future::BoxFuture<'static, ()> {
351 use futures::SinkExt;
352 let this = self.clone();
353 Box::pin(async move {
354 let _ = this.write.lock().await.close().await;
355 })
356 }
357}
358
359impl WebsocketImpl {
360 fn new(ws: axum::extract::ws::WebSocket) -> Self {
361 use futures::StreamExt;
362 let (tx, rx) = ws.split();
363 Self {
364 write: Arc::new(tokio::sync::Mutex::new(tx)),
365 read: Arc::new(tokio::sync::Mutex::new(rx)),
366 }
367 }
368}
369
370async fn handle_ws(
371 axum::extract::Path(pub_key): axum::extract::Path<String>,
372 headers: axum::http::HeaderMap,
373 ws: axum::extract::WebSocketUpgrade,
374 axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<
375 std::net::SocketAddr,
376 >,
377 axum::extract::State(app_state): axum::extract::State<AppState>,
378) -> impl axum::response::IntoResponse {
379 use axum::response::IntoResponse;
380 use base64::Engine;
381
382 let token: Option<Arc<str>> = headers
383 .get("Authorization")
384 .and_then(|t| t.to_str().ok().map(<Arc<str>>::from));
385
386 let maybe_auth = Some((token.clone(), app_state.token_tracker.clone()));
387
388 if !app_state
389 .token_tracker
390 .check_is_token_valid(&app_state.config, token)
391 {
392 return axum::response::IntoResponse::into_response((
393 axum::http::StatusCode::UNAUTHORIZED,
394 "Unauthorized",
395 ));
396 }
397
398 let pk = match base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(pub_key) {
399 Ok(pk) if pk.len() == 32 => {
400 let mut sized_pk = [0; 32];
401 sized_pk.copy_from_slice(&pk);
402 PubKey(Arc::new(sized_pk))
403 }
404 _ => return axum::http::StatusCode::BAD_REQUEST.into_response(),
405 };
406
407 let mut calc_ip = to_canonical_ip(addr.ip());
408
409 if let Some(trusted_ip_header) = &app_state.config.trusted_ip_header {
410 if let Some(header) =
411 headers.get(trusted_ip_header).and_then(|h| h.to_str().ok())
412 {
413 if let Ok(ip) = header.parse::<IpAddr>() {
414 calc_ip = to_canonical_ip(ip);
415 }
416 }
417 }
418
419 ws.max_message_size(MAX_MSG_BYTES as usize).on_upgrade(
420 move |socket| async move {
421 handle_upgraded(
422 app_state.config.clone(),
423 app_state.ip_rate.clone(),
424 app_state.cslot.clone(),
425 Arc::new(WebsocketImpl::new(socket)),
426 pk,
427 calc_ip,
428 maybe_auth,
429 )
430 .await;
431 },
432 )
433}
434
435#[derive(Clone, Default)]
437pub struct AuthTokenTracker {
438 token_map: Arc<Mutex<HashMap<Arc<str>, std::time::Instant>>>,
439}
440
441impl AuthTokenTracker {
442 pub fn register_token(&self, token: Arc<str>) {
444 self.token_map
445 .lock()
446 .unwrap()
447 .insert(token, std::time::Instant::now());
448 }
449
450 pub fn check_is_token_valid(
457 &self,
458 config: &Config,
459 token: Option<Arc<str>>,
460 ) -> bool {
461 let token: Arc<str> = if let Some(token) = token {
462 if !token.starts_with("Bearer ") {
465 return false;
466 }
467 token.trim_start_matches("Bearer ").into()
468 } else if config.authentication_hook_server.is_none() {
469 return true;
472 } else {
473 return false;
475 };
476
477 let mut lock = self.token_map.lock().unwrap();
478
479 let idle_dur = config.idle_dur();
480
481 lock.retain(|_t, e| e.elapsed() < idle_dur);
482
483 if let std::collections::hash_map::Entry::Occupied(mut e) =
484 lock.entry(token)
485 {
486 e.insert(std::time::Instant::now());
487 true
488 } else {
489 false
490 }
491 }
492}
493
494#[derive(Clone)]
495struct AppState {
496 config: Arc<Config>,
497 token_tracker: AuthTokenTracker,
498 ip_rate: Arc<IpRate>,
499 cslot: WeakCSlot,
500}
501
502impl AppState {
503 pub fn new(
504 config: Arc<Config>,
505 ip_rate: Arc<IpRate>,
506 cslot: WeakCSlot,
507 ) -> Self {
508 Self {
509 config,
510 token_tracker: AuthTokenTracker::default(),
511 ip_rate,
512 cslot,
513 }
514 }
515}
516
517impl SbdServer {
518 pub async fn new(config: Arc<Config>) -> Result<Self> {
520 let tls_config = if let (Some(cert), Some(pk)) =
521 (&config.cert_pem_file, &config.priv_key_pem_file)
522 {
523 Some(Arc::new(maybe_tls::TlsConfig::new(cert, pk).await?))
524 } else {
525 None
526 };
527
528 let mut task_list = Vec::new();
529 let mut bind_addrs = Vec::new();
530
531 let ip_rate = Arc::new(IpRate::new(config.clone()));
532 task_list.push(spawn_prune_task(ip_rate.clone()));
533
534 let cslot = CSlot::new(config.clone(), ip_rate.clone());
535 let weak_cslot = cslot.weak();
536
537 let app: axum::Router<()> = axum::Router::new()
538 .route("/authenticate", axum::routing::put(handle_auth))
539 .route("/{pub_key}", axum::routing::any(handle_ws))
540 .layer(axum::extract::DefaultBodyLimit::max(1024))
541 .with_state(AppState::new(
542 config.clone(),
543 ip_rate.clone(),
544 weak_cslot.clone(),
545 ));
546
547 let app =
548 app.into_make_service_with_connect_info::<std::net::SocketAddr>();
549
550 let mut found_port_zero: Option<u16> = None;
551
552 for bind in config.bind.iter() {
553 let mut a: std::net::SocketAddr =
554 bind.parse().map_err(Error::other)?;
555 if let Some(found_port_zero) = &found_port_zero {
556 if a.port() == 0 {
557 a.set_port(*found_port_zero);
558 }
559 }
560
561 let h = axum_server::Handle::new();
562
563 if let Some(tls_config) = &tls_config {
564 let tls_config =
565 axum_server::tls_rustls::RustlsConfig::from_config(
566 tls_config.config(),
567 );
568 let server = axum_server::bind_rustls(a, tls_config)
569 .handle(h.clone())
570 .serve(app.clone());
571 task_list.push(tokio::task::spawn(async move {
572 if let Err(err) = server.await {
573 tracing::error!(?err);
574 }
575 }));
576 } else {
577 let server =
578 axum_server::bind(a).handle(h.clone()).serve(app.clone());
579 task_list.push(tokio::task::spawn(async move {
580 if let Err(err) = server.await {
581 tracing::error!(?err);
582 }
583 }));
584 }
585
586 if let Some(addr) = h.listening().await {
587 if found_port_zero.is_none() && a.port() == 0 {
588 found_port_zero = Some(addr.port());
589 }
590 bind_addrs.push(addr);
591 }
592 }
593
594 Ok(Self {
595 task_list,
596 bind_addrs,
597 _cslot: cslot,
598 })
599 }
600
601 pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
603 self.bind_addrs.as_slice()
604 }
605}
606
607#[cfg(test)]
608mod test;