#![deny(missing_docs)]
const MAX_MSG_BYTES: i32 = 20_000;
use std::io::{Error, Result};
use std::sync::Arc;
mod config;
pub use config::*;
mod maybe_tls;
use maybe_tls::*;
mod ip_deny;
mod ip_rate;
mod cslot;
mod cmd;
pub mod ws {
pub enum Payload<'a> {
Slice(&'a [u8]),
SliceMut(&'a mut [u8]),
Vec(Vec<u8>),
BytesMut(bytes::BytesMut),
}
impl std::ops::Deref for Payload<'_> {
type Target = [u8];
#[inline(always)]
fn deref(&self) -> &Self::Target {
match self {
Payload::Slice(b) => b,
Payload::SliceMut(b) => b,
Payload::Vec(v) => v.as_slice(),
Payload::BytesMut(b) => b.as_ref(),
}
}
}
impl Payload<'_> {
#[inline(always)]
pub fn to_mut(&mut self) -> &mut [u8] {
match self {
Payload::Slice(borrowed) => {
*self = Payload::Vec(borrowed.to_owned());
match self {
Payload::Vec(owned) => owned,
_ => unreachable!(),
}
}
Payload::SliceMut(borrowed) => borrowed,
Payload::Vec(ref mut owned) => owned,
Payload::BytesMut(b) => b.as_mut(),
}
}
}
#[cfg(feature = "tungstenite")]
mod ws_tungstenite;
#[cfg(feature = "tungstenite")]
pub use ws_tungstenite::*;
#[cfg(all(not(feature = "tungstenite"), feature = "fastwebsockets"))]
mod ws_fastwebsockets;
#[cfg(all(not(feature = "tungstenite"), feature = "fastwebsockets"))]
pub use ws_fastwebsockets::*;
}
use ws::*;
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PubKey(pub Arc<[u8; 32]>);
impl PubKey {
pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
use ed25519_dalek::Verifier;
if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
.is_ok()
} else {
false
}
}
}
pub struct SbdServer {
task_list: Vec<tokio::task::JoinHandle<()>>,
bind_addrs: Vec<std::net::SocketAddr>,
_cslot: cslot::CSlot,
}
impl Drop for SbdServer {
fn drop(&mut self) {
for task in self.task_list.iter() {
task.abort();
}
}
}
async fn check_accept_connection(
_connect_permit: tokio::sync::OwnedSemaphorePermit,
config: Arc<Config>,
tls_config: Option<Arc<maybe_tls::TlsConfig>>,
ip_rate: Arc<ip_rate::IpRate>,
tcp: tokio::net::TcpStream,
addr: std::net::SocketAddr,
weak_cslot: cslot::WeakCSlot,
) {
let raw_ip = Arc::new(match addr.ip() {
std::net::IpAddr::V4(ip) => ip.to_ipv6_mapped(),
std::net::IpAddr::V6(ip) => ip,
});
let mut calc_ip = raw_ip.clone();
let use_trusted_ip = config.trusted_ip_header.is_some();
let _ = tokio::time::timeout(config.idle_dur(), async {
if !use_trusted_ip {
if ip_rate.is_blocked(&raw_ip).await {
return;
}
if !ip_rate.is_ok(&raw_ip, 1).await {
return;
}
}
let socket = if let Some(tls_config) = &tls_config {
match MaybeTlsStream::tls(tls_config, tcp).await {
Err(_) => return,
Ok(tls) => tls,
}
} else {
MaybeTlsStream::Tcp(tcp)
};
let (ws, pub_key, ip) =
match ws::WebSocket::upgrade(config.clone(), socket).await {
Ok(r) => r,
Err(_) => return,
};
if &pub_key.0[..28] == cmd::CMD_PREFIX {
return;
}
let ws = Arc::new(ws);
if let Some(ip) = ip {
calc_ip = Arc::new(ip);
}
if use_trusted_ip {
if ip_rate.is_blocked(&calc_ip).await {
return;
}
if !ip_rate.is_ok(&calc_ip, 1).await {
return;
}
}
if let Some(cslot) = weak_cslot.upgrade() {
cslot.insert(&config, calc_ip, pub_key, ws).await;
}
})
.await;
}
async fn bind_all<I: IntoIterator<Item = std::net::SocketAddr>>(
i: I,
) -> Vec<tokio::net::TcpListener> {
let mut listeners = Vec::new();
for a in i.into_iter() {
if let Ok(tcp) = tokio::net::TcpListener::bind(a).await {
listeners.push(tcp);
}
}
listeners
}
impl SbdServer {
pub async fn new(config: Arc<Config>) -> Result<Self> {
let tls_config = if let (Some(cert), Some(pk)) =
(&config.cert_pem_file, &config.priv_key_pem_file)
{
Some(Arc::new(maybe_tls::TlsConfig::new(cert, pk).await?))
} else {
None
};
let mut task_list = Vec::new();
let mut bind_addrs = Vec::new();
let ip_rate = Arc::new(ip_rate::IpRate::new(config.clone()));
{
let ip_rate = Arc::downgrade(&ip_rate);
task_list.push(tokio::task::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
if let Some(ip_rate) = ip_rate.upgrade() {
ip_rate.prune();
} else {
break;
}
}
}));
}
let cslot = cslot::CSlot::new(config.clone(), ip_rate.clone());
let connect_limit = Arc::new(tokio::sync::Semaphore::new(1024));
let mut bind_port_zero = Vec::new();
let mut bind_explicit_port = Vec::new();
for bind in config.bind.iter() {
let a: std::net::SocketAddr = bind.parse().map_err(Error::other)?;
if a.port() == 0 {
bind_port_zero.push(a);
} else {
bind_explicit_port.push(a);
}
}
let (mut listeners, mut l2) = tokio::join!(
async {
if bind_port_zero.is_empty() {
return Vec::new();
}
'top: for _ in 0..2 {
let mut listeners = Vec::new();
let mut a_iter = bind_port_zero.iter();
let a = a_iter.next().unwrap();
if let Ok(tcp) = tokio::net::TcpListener::bind(a).await {
let port = tcp.local_addr().unwrap().port();
listeners.push(tcp);
for a in a_iter {
let mut a = *a;
a.set_port(port);
match tokio::net::TcpListener::bind(a).await {
Err(_) => continue 'top,
Ok(tcp) => listeners.push(tcp),
}
}
return listeners;
}
}
bind_all(bind_port_zero).await
},
async { bind_all(bind_explicit_port).await },
);
listeners.append(&mut l2);
if listeners.is_empty() {
return Err(Error::other("failed to bind any listeners"));
}
let weak_cslot = cslot.weak();
for tcp in listeners {
bind_addrs.push(tcp.local_addr()?);
let tls_config = tls_config.clone();
let connect_limit = connect_limit.clone();
let config = config.clone();
let weak_cslot = weak_cslot.clone();
let ip_rate = ip_rate.clone();
task_list.push(tokio::task::spawn(async move {
loop {
if let Ok((tcp, addr)) = tcp.accept().await {
let connect_permit =
match connect_limit.clone().try_acquire_owned() {
Ok(permit) => permit,
_ => continue,
};
tokio::task::spawn(check_accept_connection(
connect_permit,
config.clone(),
tls_config.clone(),
ip_rate.clone(),
tcp,
addr,
weak_cslot.clone(),
));
}
}
}));
}
Ok(Self {
task_list,
bind_addrs,
_cslot: cslot,
})
}
pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
self.bind_addrs.as_slice()
}
}
#[cfg(test)]
mod test;