1#![deny(missing_docs)]
3
4const MAX_MSG_BYTES: i32 = 20_000;
6
7use std::collections::HashMap;
8use std::io::{Error, Result};
9use std::net::{IpAddr, Ipv6Addr};
10use std::sync::{Arc, Mutex};
11
12mod config;
13pub use config::*;
14
15mod maybe_tls;
16pub use maybe_tls::*;
17
18mod ip_deny;
19mod ip_rate;
20pub use ip_rate::*;
21
22mod cslot;
23pub use cslot::*;
24
25mod cmd;
26
27pub mod ws {
29 pub enum Payload {
31 Vec(Vec<u8>),
33
34 BytesMut(bytes::BytesMut),
36 }
37
38 impl std::ops::Deref for Payload {
39 type Target = [u8];
40
41 #[inline(always)]
42 fn deref(&self) -> &Self::Target {
43 match self {
44 Payload::Vec(v) => v.as_slice(),
45 Payload::BytesMut(b) => b.as_ref(),
46 }
47 }
48 }
49
50 impl Payload {
51 #[inline(always)]
53 pub fn to_mut(&mut self) -> &mut [u8] {
54 match self {
55 Payload::Vec(ref mut owned) => owned,
56 Payload::BytesMut(b) => b.as_mut(),
57 }
58 }
59 }
60
61 use futures::future::BoxFuture;
62
63 pub trait SbdWebsocket: Send + Sync + 'static {
65 fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
67
68 fn send(
70 &self,
71 payload: Payload,
72 ) -> BoxFuture<'static, std::io::Result<()>>;
73
74 fn close(&self) -> BoxFuture<'static, ()>;
76 }
77}
78
79pub use ws::{Payload, SbdWebsocket};
80
81#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
83pub struct PubKey(pub Arc<[u8; 32]>);
84
85impl PubKey {
86 pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
88 use ed25519_dalek::Verifier;
89 if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
90 k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
91 .is_ok()
92 } else {
93 false
94 }
95 }
96}
97
98pub struct SbdServer {
100 task_list: Vec<tokio::task::JoinHandle<()>>,
101 bind_addrs: Vec<std::net::SocketAddr>,
102 _cslot: cslot::CSlot,
103}
104
105impl Drop for SbdServer {
106 fn drop(&mut self) {
107 for task in self.task_list.iter() {
108 task.abort();
109 }
110 }
111}
112
113pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
115 Arc::new(match ip {
116 IpAddr::V4(ip) => ip.to_ipv6_mapped(),
117 IpAddr::V6(ip) => ip,
118 })
119}
120
121pub async fn preflight_ip_check(
124 config: &Config,
125 ip_rate: &IpRate,
126 addr: std::net::SocketAddr,
127) -> Option<Arc<Ipv6Addr>> {
128 let raw_ip = to_canonical_ip(addr.ip());
129
130 let use_trusted_ip = config.trusted_ip_header.is_some();
131
132 if !use_trusted_ip {
133 if ip_rate.is_blocked(&raw_ip).await {
139 return None;
140 }
141
142 if !ip_rate.is_ok(&raw_ip, 1).await {
144 return None;
145 }
146 }
147
148 Some(raw_ip)
149}
150
151pub async fn handle_upgraded(
153 config: Arc<Config>,
154 ip_rate: Arc<IpRate>,
155 weak_cslot: WeakCSlot,
156 ws: Arc<impl SbdWebsocket>,
157 pub_key: PubKey,
158 calc_ip: Arc<Ipv6Addr>,
159 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
160) {
161 let use_trusted_ip = config.trusted_ip_header.is_some();
162
163 if &pub_key.0[..28] == cmd::CMD_PREFIX {
165 return;
166 }
167
168 if use_trusted_ip {
169 if ip_rate.is_blocked(&calc_ip).await {
172 return;
173 }
174
175 if !ip_rate.is_ok(&calc_ip, 1).await {
177 return;
178 }
179 }
180
181 if let Some(cslot) = weak_cslot.upgrade() {
182 cslot
183 .insert(&config, calc_ip, pub_key, ws, maybe_auth)
184 .await;
185 }
186}
187
188async fn handle_auth(
189 axum::extract::State(app_state): axum::extract::State<AppState>,
190 body: bytes::Bytes,
191) -> axum::response::Response {
192 use AuthenticateTokenError::*;
193
194 match process_authenticate_token(
195 &app_state.config,
196 &app_state.token_tracker,
197 body,
198 )
199 .await
200 {
201 Ok(token) => axum::response::IntoResponse::into_response(axum::Json(
202 serde_json::json!({
203 "authToken": *token,
204 }),
205 )),
206 Err(Unauthorized) => {
207 tracing::debug!("/authenticate: UNAUTHORIZED");
208 axum::response::IntoResponse::into_response((
209 axum::http::StatusCode::UNAUTHORIZED,
210 "Unauthorized",
211 ))
212 }
213 Err(HookServerError(err)) => {
214 tracing::debug!(?err, "/authenticate: BAD_GATEWAY");
215 axum::response::IntoResponse::into_response((
216 axum::http::StatusCode::BAD_GATEWAY,
217 format!("BAD_GATEWAY: {err:?}"),
218 ))
219 }
220 Err(OtherError(err)) => {
221 tracing::warn!(?err, "/authenticate: INTERNAL_SERVER_ERROR");
222 axum::response::IntoResponse::into_response((
223 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
224 format!("INTERNAL_SERVER_ERROR: {err:?}"),
225 ))
226 }
227 }
228}
229
230pub enum AuthenticateTokenError {
232 Unauthorized,
234 HookServerError(std::io::Error),
236 OtherError(std::io::Error),
238}
239
240pub async fn process_authenticate_token(
242 config: &Config,
243 token_tracker: &AuthTokenTracker,
244 auth_material: bytes::Bytes,
245) -> std::result::Result<Arc<str>, AuthenticateTokenError> {
246 use AuthenticateTokenError::*;
247
248 let token: Arc<str> = if let Some(url) = &config.authentication_hook_server
249 {
250 let url = url.clone();
251 let token = tokio::task::spawn_blocking(move || {
252 ureq::put(&url)
253 .set("Content-Type", "application/octet-stream")
254 .send(&auth_material[..])
255 .map_err(|err| match err {
256 ureq::Error::Status(401, _) => Unauthorized,
257 oth => HookServerError(std::io::Error::other(oth)),
258 })?
259 .into_string()
260 .map_err(HookServerError)
264 })
265 .await
266 .map_err(|_| OtherError(std::io::Error::other("tokio task died")))??;
267
268 #[derive(serde::Deserialize)]
269 #[serde(rename_all = "camelCase")]
270 struct Token {
271 auth_token: String,
272 }
273
274 let token: Token = serde_json::from_str(&token)
275 .map_err(|err| OtherError(std::io::Error::other(err)))?;
276
277 token.auth_token
278 } else {
279 use base64::prelude::*;
281 use rand::Rng;
282
283 let mut bytes = [0; 32];
284 rand::thread_rng().fill(&mut bytes);
285 BASE64_URL_SAFE_NO_PAD.encode(&bytes[..])
286 }
287 .into();
288
289 token_tracker.register_token(token.clone());
290
291 Ok(token)
292}
293
294#[derive(Clone)]
295struct WebsocketImpl {
296 write: Arc<
297 tokio::sync::Mutex<
298 futures::stream::SplitSink<
299 axum::extract::ws::WebSocket,
300 axum::extract::ws::Message,
301 >,
302 >,
303 >,
304 read: Arc<
305 tokio::sync::Mutex<
306 futures::stream::SplitStream<axum::extract::ws::WebSocket>,
307 >,
308 >,
309}
310
311impl SbdWebsocket for WebsocketImpl {
312 fn recv(&self) -> futures::future::BoxFuture<'static, Result<Payload>> {
313 let this = self.clone();
314 Box::pin(async move {
315 let mut read = this.read.lock().await;
316 use futures::stream::StreamExt;
317 loop {
318 match read.next().await {
319 None => return Err(Error::other("closed")),
320 Some(r) => {
321 let msg = r.map_err(Error::other)?;
322 match msg {
323 axum::extract::ws::Message::Text(s) => {
324 return Ok(Payload::Vec(s.as_bytes().to_vec()))
325 }
326 axum::extract::ws::Message::Binary(v) => {
327 return Ok(Payload::Vec(v[..].to_vec()))
328 }
329 axum::extract::ws::Message::Ping(_)
330 | axum::extract::ws::Message::Pong(_) => (),
331 axum::extract::ws::Message::Close(_) => {
332 return Err(Error::other("closed"))
333 }
334 }
335 }
336 }
337 }
338 })
339 }
340
341 fn send(
342 &self,
343 payload: Payload,
344 ) -> futures::future::BoxFuture<'static, Result<()>> {
345 use futures::SinkExt;
346 let this = self.clone();
347 Box::pin(async move {
348 let mut write = this.write.lock().await;
349 let v = match payload {
350 Payload::Vec(v) => v,
351 Payload::BytesMut(b) => b.to_vec(),
352 };
353 write
354 .send(axum::extract::ws::Message::Binary(
355 bytes::Bytes::copy_from_slice(&v),
356 ))
357 .await
358 .map_err(Error::other)?;
359 write.flush().await.map_err(Error::other)?;
360 Ok(())
361 })
362 }
363
364 fn close(&self) -> futures::future::BoxFuture<'static, ()> {
365 use futures::SinkExt;
366 let this = self.clone();
367 Box::pin(async move {
368 let _ = this.write.lock().await.close().await;
369 })
370 }
371}
372
373impl WebsocketImpl {
374 fn new(ws: axum::extract::ws::WebSocket) -> Self {
375 use futures::StreamExt;
376 let (tx, rx) = ws.split();
377 Self {
378 write: Arc::new(tokio::sync::Mutex::new(tx)),
379 read: Arc::new(tokio::sync::Mutex::new(rx)),
380 }
381 }
382}
383
384async fn handle_ws(
385 axum::extract::Path(pub_key): axum::extract::Path<String>,
386 headers: axum::http::HeaderMap,
387 ws: axum::extract::WebSocketUpgrade,
388 axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<
389 std::net::SocketAddr,
390 >,
391 axum::extract::State(app_state): axum::extract::State<AppState>,
392) -> impl axum::response::IntoResponse {
393 use axum::response::IntoResponse;
394 use base64::Engine;
395
396 let token: Option<Arc<str>> = headers
397 .get("Authorization")
398 .and_then(|t| t.to_str().ok().map(<Arc<str>>::from));
399
400 let maybe_auth = Some((token.clone(), app_state.token_tracker.clone()));
401
402 if !app_state
403 .token_tracker
404 .check_is_token_valid(&app_state.config, token)
405 {
406 return axum::response::IntoResponse::into_response((
407 axum::http::StatusCode::UNAUTHORIZED,
408 "Unauthorized",
409 ));
410 }
411
412 let pk = match base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(pub_key) {
413 Ok(pk) if pk.len() == 32 => {
414 let mut sized_pk = [0; 32];
415 sized_pk.copy_from_slice(&pk);
416 PubKey(Arc::new(sized_pk))
417 }
418 _ => return axum::http::StatusCode::BAD_REQUEST.into_response(),
419 };
420
421 let mut calc_ip = to_canonical_ip(addr.ip());
422
423 if let Some(trusted_ip_header) = &app_state.config.trusted_ip_header {
424 if let Some(header) =
425 headers.get(trusted_ip_header).and_then(|h| h.to_str().ok())
426 {
427 if let Ok(ip) = header.parse::<IpAddr>() {
428 calc_ip = to_canonical_ip(ip);
429 }
430 }
431 }
432
433 ws.max_message_size(MAX_MSG_BYTES as usize).on_upgrade(
434 move |socket| async move {
435 handle_upgraded(
436 app_state.config.clone(),
437 app_state.ip_rate.clone(),
438 app_state.cslot.clone(),
439 Arc::new(WebsocketImpl::new(socket)),
440 pk,
441 calc_ip,
442 maybe_auth,
443 )
444 .await;
445 },
446 )
447}
448
449#[derive(Clone, Default)]
451pub struct AuthTokenTracker {
452 token_map: Arc<Mutex<HashMap<Arc<str>, std::time::Instant>>>,
453}
454
455impl AuthTokenTracker {
456 pub fn register_token(&self, token: Arc<str>) {
458 self.token_map
459 .lock()
460 .unwrap()
461 .insert(token, std::time::Instant::now());
462 }
463
464 pub fn check_is_token_valid(
471 &self,
472 config: &Config,
473 token: Option<Arc<str>>,
474 ) -> bool {
475 let token: Arc<str> = if let Some(token) = token {
476 if !token.starts_with("Bearer ") {
479 return false;
480 }
481 token.trim_start_matches("Bearer ").into()
482 } else if config.authentication_hook_server.is_none() {
483 return true;
486 } else {
487 return false;
489 };
490
491 let mut lock = self.token_map.lock().unwrap();
492
493 let idle_dur = config.idle_dur();
494
495 lock.retain(|_t, e| e.elapsed() < idle_dur);
496
497 if let std::collections::hash_map::Entry::Occupied(mut e) =
498 lock.entry(token)
499 {
500 e.insert(std::time::Instant::now());
501 true
502 } else {
503 false
504 }
505 }
506}
507
508#[derive(Clone)]
509struct AppState {
510 config: Arc<Config>,
511 token_tracker: AuthTokenTracker,
512 ip_rate: Arc<IpRate>,
513 cslot: WeakCSlot,
514}
515
516impl AppState {
517 pub fn new(
518 config: Arc<Config>,
519 ip_rate: Arc<IpRate>,
520 cslot: WeakCSlot,
521 ) -> Self {
522 Self {
523 config,
524 token_tracker: AuthTokenTracker::default(),
525 ip_rate,
526 cslot,
527 }
528 }
529}
530
531impl SbdServer {
532 pub async fn new(config: Arc<Config>) -> Result<Self> {
534 let tls_config = if let (Some(cert), Some(pk)) =
535 (&config.cert_pem_file, &config.priv_key_pem_file)
536 {
537 Some(Arc::new(maybe_tls::TlsConfig::new(cert, pk).await?))
538 } else {
539 None
540 };
541
542 let mut task_list = Vec::new();
543 let mut bind_addrs = Vec::new();
544
545 let ip_rate = Arc::new(IpRate::new(config.clone()));
546 task_list.push(spawn_prune_task(ip_rate.clone()));
547
548 let cslot = CSlot::new(config.clone(), ip_rate.clone());
549 let weak_cslot = cslot.weak();
550
551 let app: axum::Router<()> = axum::Router::new()
552 .route("/authenticate", axum::routing::put(handle_auth))
553 .route("/{pub_key}", axum::routing::any(handle_ws))
554 .layer(axum::extract::DefaultBodyLimit::max(1024))
555 .with_state(AppState::new(
556 config.clone(),
557 ip_rate.clone(),
558 weak_cslot.clone(),
559 ));
560
561 let app =
562 app.into_make_service_with_connect_info::<std::net::SocketAddr>();
563
564 let mut found_port_zero: Option<u16> = None;
565
566 for bind in config.bind.iter() {
567 let mut a: std::net::SocketAddr =
568 bind.parse().map_err(Error::other)?;
569 if let Some(found_port_zero) = &found_port_zero {
570 if a.port() == 0 {
571 a.set_port(*found_port_zero);
572 }
573 }
574
575 let h = axum_server::Handle::new();
576
577 if let Some(tls_config) = &tls_config {
578 let tls_config =
579 axum_server::tls_rustls::RustlsConfig::from_config(
580 tls_config.config(),
581 );
582 let server = axum_server::bind_rustls(a, tls_config)
583 .handle(h.clone())
584 .serve(app.clone());
585 task_list.push(tokio::task::spawn(async move {
586 if let Err(err) = server.await {
587 tracing::error!(?err);
588 }
589 }));
590 } else {
591 let server =
592 axum_server::bind(a).handle(h.clone()).serve(app.clone());
593 task_list.push(tokio::task::spawn(async move {
594 if let Err(err) = server.await {
595 tracing::error!(?err);
596 }
597 }));
598 }
599
600 if let Some(addr) = h.listening().await {
601 if found_port_zero.is_none() && a.port() == 0 {
602 found_port_zero = Some(addr.port());
603 }
604 bind_addrs.push(addr);
605 }
606 }
607
608 Ok(Self {
609 task_list,
610 bind_addrs,
611 _cslot: cslot,
612 })
613 }
614
615 pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
617 self.bind_addrs.as_slice()
618 }
619}
620
621#[cfg(test)]
622mod test;