websocket_server_async/
websocket_server.rs

1use crate::peer::{IPeer, WSPeer};
2use crate::stream::MaybeRustlsStream;
3use anyhow::{bail, Result};
4use aqueue::Actor;
5use futures_util::stream::SplitStream;
6use futures_util::StreamExt;
7use log::*;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::io::{AsyncRead, AsyncWrite};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio::task::JoinHandle;
16use tokio_rustls::TlsAcceptor;
17use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
18use tokio_tungstenite::{accept_async_with_config, WebSocketStream};
19
20pub type ConnectEventType = fn(SocketAddr) -> bool;
21
22/// websocket server
23pub struct WebSocketServer<I, R, T> {
24    listener: Option<TcpListener>,
25    tls_acceptor: Option<TlsAcceptor>,
26    connect_event: Option<ConnectEventType>,
27    input_event: Arc<I>,
28    config: Option<WebSocketConfig>,
29    load_timeout_secs: u64,
30    _phantom1: PhantomData<R>,
31    _phantom2: PhantomData<T>,
32}
33
34unsafe impl<I, R, T> Send for WebSocketServer<I, R, T> {}
35unsafe impl<I, R, T> Sync for WebSocketServer<I, R, T> {}
36
37impl<I, R, T> WebSocketServer<I, R, T>
38where
39    I: Fn(SplitStream<WebSocketStream<MaybeRustlsStream<TcpStream>>>, Arc<Actor<WSPeer>>, T) -> R
40        + Send
41        + Sync
42        + 'static,
43    R: Future<Output = Result<()>> + Send + 'static,
44    T: Clone + Send + 'static,
45{
46    /// 创建一个websocket server
47    pub(crate) async fn new<A: ToSocketAddrs>(
48        addr: A,
49        input: I,
50        connect_event: Option<ConnectEventType>,
51        config: Option<WebSocketConfig>,
52        tls_acceptor: Option<TlsAcceptor>,
53        load_timeout_secs: u64,
54    ) -> Result<Arc<Actor<WebSocketServer<I, R, T>>>> {
55        let listener = TcpListener::bind(addr).await?;
56        Ok(Arc::new(Actor::new(WebSocketServer {
57            listener: Some(listener),
58            tls_acceptor,
59            connect_event,
60            input_event: Arc::new(input),
61            config,
62            load_timeout_secs,
63            _phantom1: Default::default(),
64            _phantom2: Default::default(),
65        })))
66    }
67
68    #[inline]
69    async fn accept<S>(stream: S, tls_acceptor: Option<TlsAcceptor>) -> Result<MaybeRustlsStream<S>>
70    where
71        S: AsyncRead + AsyncWrite + Unpin,
72    {
73        if let Some(acceptor) = tls_acceptor {
74            Ok(MaybeRustlsStream::ServerTls(acceptor.accept(stream).await?))
75        } else {
76            Ok(MaybeRustlsStream::Plain(stream))
77        }
78    }
79
80    /// 启动websocket server
81    pub async fn start(&mut self, token: T) -> Result<JoinHandle<Result<()>>> {
82        if let Some(listener) = self.listener.take() {
83            let connect_event = self.connect_event.take();
84            let input_event = self.input_event.clone();
85            let config = self.config;
86            let load_timeout_secs = self.load_timeout_secs;
87            let tls = self.tls_acceptor.clone();
88            let join: JoinHandle<Result<()>> = tokio::spawn(async move {
89                loop {
90                    let (socket, addr) = listener.accept().await?;
91                    if let Some(ref connect_event) = connect_event {
92                        if !connect_event(addr) {
93                            warn!("addr:{} not connect", addr);
94                            continue;
95                        }
96                    }
97                    trace!("start read:{}", addr);
98                    let input = input_event.clone();
99                    let peer_token = token.clone();
100                    let tls_acceptor = tls.clone();
101                    tokio::spawn(async move {
102                        let socket = match Self::accept(socket, tls_acceptor).await {
103                            Ok(socket) => socket,
104                            Err(err) => {
105                                error!("rustls error:{}", err);
106                                return;
107                            }
108                        };
109
110                        match tokio::time::timeout(
111                            Duration::from_secs(load_timeout_secs),
112                            accept_async_with_config(socket, config),
113                        )
114                        .await
115                        {
116                            Ok(Ok(ws_stream)) => {
117                                let (sender, reader) = ws_stream.split();
118                                let peer = WSPeer::new(addr, sender);
119                                if let Err(err) = (*input)(reader, peer.clone(), peer_token).await {
120                                    error!("input data error:{}", err);
121                                }
122                                if let Err(er) = peer.disconnect().await {
123                                    debug!("disconnect client:{:?} err:{}", peer.addr(), er);
124                                } else {
125                                    debug!("{} disconnect", peer.addr())
126                                }
127                            }
128                            Ok(Err(err)) => {
129                                error!(
130                                    "ipaddress:{} init websocket error:{:?} disconnect!",
131                                    addr, err
132                                );
133                            }
134                            Err(_) => {
135                                error!(
136                                    "ipaddress:{} accept websocket init  timeout disconnect!",
137                                    addr
138                                );
139                            }
140                        }
141                    });
142                }
143            });
144
145            return Ok(join);
146        }
147        bail!("not listener or repeat start")
148    }
149}
150
151#[async_trait::async_trait]
152pub trait IWebSocketServer<T> {
153    async fn start(&self, token: T) -> Result<JoinHandle<Result<()>>>;
154    async fn start_block(&self, token: T) -> Result<()>;
155}
156
157#[async_trait::async_trait]
158impl<I, R, T> IWebSocketServer<T> for Actor<WebSocketServer<I, R, T>>
159where
160    I: Fn(SplitStream<WebSocketStream<MaybeRustlsStream<TcpStream>>>, Arc<Actor<WSPeer>>, T) -> R
161        + Send
162        + Sync
163        + 'static,
164    R: Future<Output = Result<()>> + Send + 'static,
165    T: Clone + Send + Sync + 'static,
166{
167    async fn start(&self, token: T) -> Result<JoinHandle<Result<()>>> {
168        self.inner_call(|inner| async move { inner.get_mut().start(token).await })
169            .await
170    }
171
172    async fn start_block(&self, token: T) -> Result<()> {
173        Self::start(self, token).await?.await??;
174        Ok(())
175    }
176}