use crate::{
builder::{BeforeReceiveFn, CodecFn, NameFn, SelectVersionFn, SessionHandleFn},
traits::{Codec, ProtocolSpawn, ServiceProtocol, SessionProtocol},
yamux::config::Config as YamuxConfig,
ProtocolId, SessionId,
};
use std::{net::SocketAddr, sync::Arc, time::Duration};
const MAX_BUF_SIZE: usize = 24 * 1024 * 1024;
pub(crate) struct ServiceConfig {
pub timeout: Duration,
pub session_config: SessionConfig,
pub max_frame_length: usize,
pub keep_buffer: bool,
#[cfg(all(not(target_arch = "wasm32"), feature = "upnp"))]
pub upnp: bool,
pub max_connection_number: usize,
pub tcp_bind_addr: Option<SocketAddr>,
#[cfg(feature = "ws")]
pub ws_bind_addr: Option<SocketAddr>,
}
impl Default for ServiceConfig {
fn default() -> Self {
ServiceConfig {
timeout: Duration::from_secs(10),
session_config: SessionConfig::default(),
max_frame_length: 1024 * 1024 * 8,
keep_buffer: false,
#[cfg(all(not(target_arch = "wasm32"), feature = "upnp"))]
upnp: false,
max_connection_number: 65535,
tcp_bind_addr: None,
#[cfg(feature = "ws")]
ws_bind_addr: None,
}
}
}
#[derive(Clone, Copy)]
pub(crate) struct SessionConfig {
pub yamux_config: YamuxConfig,
pub send_buffer_size: usize,
pub recv_buffer_size: usize,
}
impl SessionConfig {
pub const fn recv_event_size(&self) -> usize {
(self.recv_buffer_size / self.yamux_config.max_stream_window_size as usize) + 1
}
pub const fn send_event_size(&self) -> usize {
(self.send_buffer_size / self.yamux_config.max_stream_window_size as usize) + 1
}
}
impl Default for SessionConfig {
fn default() -> Self {
SessionConfig {
recv_buffer_size: MAX_BUF_SIZE,
send_buffer_size: MAX_BUF_SIZE,
yamux_config: YamuxConfig::default(),
}
}
}
pub enum TargetProtocol {
All,
Single(ProtocolId),
Filter(Box<dyn Fn(&ProtocolId) -> bool + Send>),
}
impl From<ProtocolId> for TargetProtocol {
fn from(id: ProtocolId) -> Self {
TargetProtocol::Single(id)
}
}
impl From<usize> for TargetProtocol {
fn from(id: usize) -> Self {
TargetProtocol::Single(id.into())
}
}
pub enum TargetSession {
All,
Single(SessionId),
Filter(Box<dyn Fn(&SessionId) -> bool + Send>),
}
impl From<SessionId> for TargetSession {
fn from(id: SessionId) -> Self {
TargetSession::Single(id)
}
}
impl From<usize> for TargetSession {
fn from(id: usize) -> Self {
TargetSession::Single(id.into())
}
}
pub struct ProtocolMeta {
pub(crate) inner: Arc<Meta>,
pub(crate) service_handle: ProtocolHandle<Box<dyn ServiceProtocol + Send + 'static + Unpin>>,
pub(crate) session_handle: SessionHandleFn,
pub(crate) before_send: Option<Box<dyn Fn(bytes::Bytes) -> bytes::Bytes + Send + 'static>>,
pub(crate) flag: BlockingFlag,
}
impl ProtocolMeta {
#[inline]
pub fn id(&self) -> ProtocolId {
self.inner.id
}
#[inline]
pub fn name(&self) -> String {
(self.inner.name)(self.inner.id)
}
#[inline]
pub fn support_versions(&self) -> Vec<String> {
self.inner.support_versions.clone()
}
#[inline]
pub fn codec(&self) -> Box<dyn Codec + Send + 'static> {
(self.inner.codec)()
}
#[inline]
pub fn service_handle(
&mut self,
) -> ProtocolHandle<Box<dyn ServiceProtocol + Send + 'static + Unpin>> {
::std::mem::replace(&mut self.service_handle, ProtocolHandle::None)
}
#[inline]
pub fn session_handle(
&mut self,
) -> ProtocolHandle<Box<dyn SessionProtocol + Send + 'static + Unpin>> {
(self.session_handle)()
}
pub fn blocking_flag(&self) -> BlockingFlag {
self.flag
}
}
pub(crate) struct Meta {
pub(crate) id: ProtocolId,
pub(crate) name: NameFn,
pub(crate) support_versions: Vec<String>,
pub(crate) codec: CodecFn,
pub(crate) select_version: SelectVersionFn,
pub(crate) before_receive: BeforeReceiveFn,
pub(crate) spawn: Option<Box<dyn ProtocolSpawn + Send + Sync + 'static>>,
}
pub enum ProtocolHandle<T: Sized> {
None,
Callback(T),
}
impl<T> ProtocolHandle<T> {
#[inline]
pub fn is_callback(&self) -> bool {
matches!(self, ProtocolHandle::Callback(_))
}
#[inline]
pub fn is_none(&self) -> bool {
matches!(self, ProtocolHandle::None)
}
}
#[derive(Copy, Clone, Debug)]
pub struct BlockingFlag(u8);
impl BlockingFlag {
#[inline]
pub fn disable_connected(&mut self) {
self.0 &= 0b0111
}
#[inline]
pub fn disable_disconnected(&mut self) {
self.0 &= 0b1011
}
#[inline]
pub fn disable_received(&mut self) {
self.0 &= 0b1101
}
pub fn disable_notify(&mut self) {
self.0 &= 0b1110
}
#[inline]
pub fn enable_all(&mut self) {
self.0 |= 0b1111
}
#[inline]
pub fn disable_all(&mut self) {
self.0 &= 0b0000
}
#[inline]
pub const fn connected(self) -> bool {
self.0 & 0b1000 > 0
}
#[inline]
pub const fn disconnected(self) -> bool {
self.0 & 0b0100 > 0
}
#[inline]
pub const fn received(self) -> bool {
self.0 & 0b0010 > 0
}
#[inline]
pub const fn notify(self) -> bool {
self.0 & 0b0001 > 0
}
}
impl Default for BlockingFlag {
fn default() -> Self {
BlockingFlag(0b1111)
}
}
impl From<u8> for BlockingFlag {
fn from(inner: u8) -> BlockingFlag {
BlockingFlag(inner)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum State {
Running(usize),
Forever,
PreShutdown,
}
impl State {
pub fn new(forever: bool) -> Self {
if forever {
State::Forever
} else {
State::Running(0)
}
}
#[inline]
pub fn is_shutdown(&self) -> bool {
match self {
State::Running(num) if num == &0 => true,
State::PreShutdown => true,
State::Running(_) | State::Forever => false,
}
}
#[inline]
pub fn pre_shutdown(&mut self) {
*self = State::PreShutdown
}
#[inline]
pub fn increase(&mut self) {
match self {
State::Running(num) => *num += 1,
State::PreShutdown | State::Forever => (),
}
}
#[inline]
pub fn decrease(&mut self) {
match self {
State::Running(num) => *num -= 1,
State::PreShutdown | State::Forever => (),
}
}
#[inline]
pub fn into_inner(self) -> Option<usize> {
match self {
State::Running(num) => Some(num),
State::PreShutdown | State::Forever => None,
}
}
}
#[cfg(test)]
mod test {
use super::{BlockingFlag, State};
#[test]
fn test_state_no_forever() {
let mut state = State::new(false);
state.increase();
state.increase();
assert_eq!(state, State::Running(2));
state.decrease();
state.decrease();
assert_eq!(state, State::Running(0));
state.increase();
state.increase();
state.increase();
state.increase();
state.pre_shutdown();
assert_eq!(state, State::PreShutdown);
}
#[test]
fn test_state_forever() {
let mut state = State::new(true);
state.increase();
state.increase();
assert_eq!(state, State::Forever);
state.decrease();
state.decrease();
assert_eq!(state, State::Forever);
state.increase();
state.increase();
state.increase();
state.increase();
state.pre_shutdown();
assert_eq!(state, State::PreShutdown);
}
#[test]
fn test_proto_flag() {
let mut p = BlockingFlag::default();
assert_eq!(p.connected(), true);
assert_eq!(p.disconnected(), true);
assert_eq!(p.received(), true);
assert_eq!(p.notify(), true);
p.disable_connected();
assert_eq!(p.connected(), false);
p.disable_disconnected();
assert_eq!(p.disconnected(), false);
p.disable_received();
assert_eq!(p.received(), false);
p.disable_notify();
assert_eq!(p.notify(), false);
p.enable_all();
p.disable_all();
assert_eq!(p.connected(), false);
assert_eq!(p.disconnected(), false);
assert_eq!(p.received(), false);
assert_eq!(p.notify(), false);
}
}