tcp_channel_server/
peer.rs

1use crate::error::Result;
2use async_channel::Sender;
3use std::io::ErrorKind;
4use std::marker::PhantomData;
5use std::net::SocketAddr;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use tokio::io::WriteHalf;
9use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
10
11pub enum State {
12    Disconnect,
13    Send(Vec<u8>),
14    SendFlush(Vec<u8>),
15    Flush,
16}
17
18pub struct TCPPeer<T> {
19    pub addr: SocketAddr,
20    pub sender: Sender<State>,
21    disconnect: AtomicBool,
22    _ph: PhantomData<T>,
23}
24
25impl<T> TCPPeer<T>
26where
27    T: AsyncRead + AsyncWrite + Send + 'static,
28{
29    /// 创建一个TCP PEER
30    #[inline]
31    pub fn new(addr: SocketAddr, mut sender: WriteHalf<T>) -> Arc<TCPPeer<T>> {
32        let (tx, rx) = async_channel::bounded(4096);
33
34        tokio::spawn(async move {
35            while let Ok(state) = rx.recv().await {
36                match state {
37                    State::Disconnect => {
38                        let _ = sender.shutdown().await;
39                        return;
40                    }
41                    State::Send(data) => {
42                        if sender.write(&data).await.is_err() {
43                            return;
44                        }
45                    }
46                    State::SendFlush(data) => {
47                        if sender.write(&data).await.is_err() {
48                            return;
49                        }
50                        if sender.flush().await.is_err() {
51                            return;
52                        }
53                    }
54                    State::Flush => {
55                        if sender.flush().await.is_err() {
56                            return;
57                        }
58                    }
59                }
60            }
61        });
62
63        Arc::new(TCPPeer {
64            addr,
65            sender: tx,
66            disconnect: AtomicBool::new(false),
67            _ph: Default::default(),
68        })
69    }
70
71    /// ipaddress
72    #[inline]
73    pub fn addr(&self) -> SocketAddr {
74        self.addr
75    }
76
77    /// 是否断线
78    #[inline]
79    pub fn is_disconnect(&self) -> bool {
80        self.disconnect.load(Ordering::Acquire)
81    }
82
83    /// 发送
84    #[inline]
85    pub async fn send(&self, buff: Vec<u8>) -> Result<()> {
86        if !self.disconnect.load(Ordering::Acquire) {
87            Ok(self.sender.clone().send(State::Send(buff)).await?)
88        } else {
89            Err(std::io::Error::from(ErrorKind::ConnectionReset).into())
90        }
91    }
92
93    /// 发送全部
94    #[inline]
95    pub async fn send_all(&self, buff: Vec<u8>) -> Result<()> {
96        if !self.disconnect.load(Ordering::Acquire) {
97            Ok(self.sender.clone().send(State::SendFlush(buff)).await?)
98        } else {
99            Err(std::io::Error::from(ErrorKind::ConnectionReset).into())
100        }
101    }
102
103    /// flush
104    #[inline]
105    pub async fn flush(&self) -> Result<()> {
106        if !self.disconnect.load(Ordering::Acquire) {
107            Ok(self.sender.send(State::Flush).await?)
108        } else {
109            Err(std::io::Error::from(ErrorKind::ConnectionReset).into())
110        }
111    }
112
113    /// 掐线
114    #[inline]
115    pub async fn disconnect(&self) -> Result<()> {
116        if !self.disconnect.load(Ordering::Acquire) {
117            self.sender.send(State::Disconnect).await?;
118            self.disconnect.store(true, Ordering::Release);
119        }
120        Ok(())
121    }
122}