use crate::{
builder::{BeforeReceiveFn, CodecFn, NameFn, SelectVersionFn, SessionHandleFn},
traits::{Codec, ProtocolSpawn, ServiceProtocol, SessionProtocol},
yamux::config::Config as YamuxConfig,
ProtocolId, SessionId,
};
#[cfg(windows)]
use std::os::windows::io::{
AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, RawSocket,
};
#[cfg(unix)]
use std::os::{
fd::AsFd,
unix::io::{AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd},
};
use std::{sync::Arc, time::Duration};
#[cfg(feature = "tls")]
use tokio_rustls::rustls::{ClientConfig, ServerConfig};
const MAX_BUF_SIZE: usize = 24 * 1024 * 1024;
pub(crate) struct ServiceConfig {
pub timeout: Duration,
pub session_config: SessionConfig,
pub max_frame_length: usize,
pub keep_buffer: bool,
#[cfg(all(not(target_arch = "wasm32"), feature = "upnp"))]
pub upnp: bool,
pub max_connection_number: usize,
pub tcp_config: TcpConfig,
#[cfg(feature = "tls")]
pub tls_config: Option<TlsConfig>,
}
impl Default for ServiceConfig {
fn default() -> Self {
ServiceConfig {
timeout: Duration::from_secs(10),
session_config: SessionConfig::default(),
max_frame_length: 1024 * 1024 * 8,
keep_buffer: false,
#[cfg(all(not(target_arch = "wasm32"), feature = "upnp"))]
upnp: false,
max_connection_number: 65535,
tcp_config: Default::default(),
#[cfg(feature = "tls")]
tls_config: None,
}
}
}
#[derive(Clone, Copy)]
pub(crate) struct SessionConfig {
pub yamux_config: YamuxConfig,
pub send_buffer_size: usize,
pub recv_buffer_size: usize,
pub channel_size: usize,
}
impl SessionConfig {
pub const fn recv_event_size(&self) -> usize {
(self.recv_buffer_size / self.yamux_config.max_stream_window_size as usize) + 1
}
pub const fn send_event_size(&self) -> usize {
(self.send_buffer_size / self.yamux_config.max_stream_window_size as usize) + 1
}
}
impl Default for SessionConfig {
fn default() -> Self {
SessionConfig {
recv_buffer_size: MAX_BUF_SIZE,
send_buffer_size: MAX_BUF_SIZE,
channel_size: 128,
yamux_config: YamuxConfig::default(),
}
}
}
pub(crate) type TcpSocketConfig =
Arc<dyn Fn(TcpSocket) -> Result<TcpSocket, std::io::Error> + Send + Sync + 'static>;
#[derive(Clone)]
pub(crate) struct TcpConfig {
pub tcp: TcpSocketConfig,
#[cfg(feature = "ws")]
pub ws: TcpSocketConfig,
#[cfg(feature = "tls")]
pub tls: TcpSocketConfig,
}
impl Default for TcpConfig {
fn default() -> Self {
TcpConfig {
tcp: Arc::new(Ok),
#[cfg(feature = "ws")]
ws: Arc::new(Ok),
#[cfg(feature = "tls")]
tls: Arc::new(Ok),
}
}
}
pub struct TcpSocket {
#[cfg(not(target_arch = "wasm32"))]
pub(crate) inner: socket2::Socket,
}
#[cfg(unix)]
impl AsRawFd for TcpSocket {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
#[cfg(unix)]
impl FromRawFd for TcpSocket {
unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket {
let inner = socket2::Socket::from_raw_fd(fd);
TcpSocket { inner }
}
}
#[cfg(unix)]
impl IntoRawFd for TcpSocket {
fn into_raw_fd(self) -> RawFd {
self.inner.into_raw_fd()
}
}
#[cfg(unix)]
impl AsFd for TcpSocket {
fn as_fd(&self) -> BorrowedFd<'_> {
self.inner.as_fd()
}
}
#[cfg(windows)]
impl IntoRawSocket for TcpSocket {
fn into_raw_socket(self) -> RawSocket {
self.inner.into_raw_socket()
}
}
#[cfg(windows)]
impl AsRawSocket for TcpSocket {
fn as_raw_socket(&self) -> RawSocket {
self.inner.as_raw_socket()
}
}
#[cfg(windows)]
impl FromRawSocket for TcpSocket {
unsafe fn from_raw_socket(socket: RawSocket) -> TcpSocket {
let inner = socket2::Socket::from_raw_socket(socket);
TcpSocket { inner }
}
}
#[cfg(windows)]
impl AsSocket for TcpSocket {
fn as_socket(&self) -> BorrowedSocket<'_> {
unsafe { BorrowedSocket::borrow_raw(self.as_raw_socket()) }
}
}
#[derive(Clone, Default)]
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub struct TlsConfig {
pub(crate) tls_server_config: Option<Arc<ServerConfig>>,
pub(crate) tls_client_config: Option<Arc<ClientConfig>>,
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
impl TlsConfig {
pub fn new(server_config: Option<ServerConfig>, client_config: Option<ClientConfig>) -> Self {
let tls_server_config = server_config.map(Arc::new);
let tls_client_config = client_config.map(Arc::new);
TlsConfig {
tls_server_config,
tls_client_config,
}
}
}
pub enum TargetProtocol {
All,
Single(ProtocolId),
Filter(Box<dyn Fn(&ProtocolId) -> bool + Sync + Send + 'static>),
}
impl From<ProtocolId> for TargetProtocol {
fn from(id: ProtocolId) -> Self {
TargetProtocol::Single(id)
}
}
impl From<usize> for TargetProtocol {
fn from(id: usize) -> Self {
TargetProtocol::Single(id.into())
}
}
pub enum TargetSession {
All,
Single(SessionId),
Multi(Box<dyn Iterator<Item = SessionId> + Send + 'static>),
Filter(Box<dyn FnMut(&SessionId) -> bool + Send + 'static>),
}
impl From<SessionId> for TargetSession {
fn from(id: SessionId) -> Self {
TargetSession::Single(id)
}
}
impl From<usize> for TargetSession {
fn from(id: usize) -> Self {
TargetSession::Single(id.into())
}
}
pub struct ProtocolMeta {
pub(crate) inner: Arc<Meta>,
pub(crate) service_handle: ProtocolHandle<Box<dyn ServiceProtocol + Send + 'static + Unpin>>,
pub(crate) session_handle: SessionHandleFn,
pub(crate) before_send: Option<Box<dyn Fn(bytes::Bytes) -> bytes::Bytes + Send + 'static>>,
}
impl ProtocolMeta {
#[inline]
pub fn id(&self) -> ProtocolId {
self.inner.id
}
#[inline]
pub fn name(&self) -> String {
(self.inner.name)(self.inner.id)
}
#[inline]
pub fn support_versions(&self) -> Vec<String> {
self.inner.support_versions.clone()
}
#[inline]
pub fn codec(&self) -> Box<dyn Codec + Send + 'static> {
(self.inner.codec)()
}
#[inline]
pub fn service_handle(
&mut self,
) -> ProtocolHandle<Box<dyn ServiceProtocol + Send + 'static + Unpin>> {
::std::mem::replace(&mut self.service_handle, ProtocolHandle::None)
}
#[inline]
pub fn session_handle(
&mut self,
) -> ProtocolHandle<Box<dyn SessionProtocol + Send + 'static + Unpin>> {
(self.session_handle)()
}
}
pub(crate) struct Meta {
pub(crate) id: ProtocolId,
pub(crate) name: NameFn,
pub(crate) support_versions: Vec<String>,
pub(crate) codec: CodecFn,
pub(crate) select_version: SelectVersionFn,
pub(crate) before_receive: BeforeReceiveFn,
pub(crate) spawn: Option<Box<dyn ProtocolSpawn + Send + Sync + 'static>>,
}
pub enum ProtocolHandle<T: Sized> {
None,
Callback(T),
}
impl<T> ProtocolHandle<T> {
#[inline]
pub fn is_callback(&self) -> bool {
matches!(self, ProtocolHandle::Callback(_))
}
#[inline]
pub fn is_none(&self) -> bool {
matches!(self, ProtocolHandle::None)
}
}
#[non_exhaustive]
pub enum HandshakeType<T> {
Secio(T),
Noop,
}
impl<K> From<K> for HandshakeType<K>
where
K: secio::KeyProvider,
{
fn from(value: K) -> Self {
HandshakeType::Secio(value)
}
}
impl<T> Clone for HandshakeType<T>
where
T: Clone,
{
fn clone(&self) -> Self {
match self {
HandshakeType::Secio(s) => HandshakeType::Secio(s.clone()),
HandshakeType::Noop => HandshakeType::Noop,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum State {
Running(usize),
Forever,
PreShutdown,
}
impl State {
pub fn new(forever: bool) -> Self {
if forever {
State::Forever
} else {
State::Running(0)
}
}
#[inline]
pub fn is_shutdown(&self) -> bool {
match self {
State::Running(num) if num == &0 => true,
State::PreShutdown => true,
State::Running(_) | State::Forever => false,
}
}
#[inline]
pub fn pre_shutdown(&mut self) {
*self = State::PreShutdown
}
#[inline]
pub fn increase(&mut self) {
match self {
State::Running(num) => *num += 1,
State::PreShutdown | State::Forever => (),
}
}
#[inline]
pub fn decrease(&mut self) {
match self {
State::Running(num) => *num -= 1,
State::PreShutdown | State::Forever => (),
}
}
#[inline]
pub fn into_inner(self) -> Option<usize> {
match self {
State::Running(num) => Some(num),
State::PreShutdown | State::Forever => None,
}
}
}
#[cfg(test)]
mod test {
use super::State;
#[test]
fn test_state_no_forever() {
let mut state = State::new(false);
state.increase();
state.increase();
assert_eq!(state, State::Running(2));
state.decrease();
state.decrease();
assert_eq!(state, State::Running(0));
state.increase();
state.increase();
state.increase();
state.increase();
state.pre_shutdown();
assert_eq!(state, State::PreShutdown);
}
#[test]
fn test_state_forever() {
let mut state = State::new(true);
state.increase();
state.increase();
assert_eq!(state, State::Forever);
state.decrease();
state.decrease();
assert_eq!(state, State::Forever);
state.increase();
state.increase();
state.increase();
state.increase();
state.pre_shutdown();
assert_eq!(state, State::PreShutdown);
}
}