tcp_channel_server/
peer.rs1use crate::error::Result;
2use async_channel::Sender;
3use std::io::ErrorKind;
4use std::marker::PhantomData;
5use std::net::SocketAddr;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use tokio::io::WriteHalf;
9use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
10
11pub enum State {
12 Disconnect,
13 Send(Vec<u8>),
14 SendFlush(Vec<u8>),
15 Flush,
16}
17
18pub struct TCPPeer<T> {
19 pub addr: SocketAddr,
20 pub sender: Sender<State>,
21 disconnect: AtomicBool,
22 _ph: PhantomData<T>,
23}
24
25impl<T> TCPPeer<T>
26where
27 T: AsyncRead + AsyncWrite + Send + 'static,
28{
29 #[inline]
31 pub fn new(addr: SocketAddr, mut sender: WriteHalf<T>) -> Arc<TCPPeer<T>> {
32 let (tx, rx) = async_channel::bounded(4096);
33
34 tokio::spawn(async move {
35 while let Ok(state) = rx.recv().await {
36 match state {
37 State::Disconnect => {
38 let _ = sender.shutdown().await;
39 return;
40 }
41 State::Send(data) => {
42 if sender.write(&data).await.is_err() {
43 return;
44 }
45 }
46 State::SendFlush(data) => {
47 if sender.write(&data).await.is_err() {
48 return;
49 }
50 if sender.flush().await.is_err() {
51 return;
52 }
53 }
54 State::Flush => {
55 if sender.flush().await.is_err() {
56 return;
57 }
58 }
59 }
60 }
61 });
62
63 Arc::new(TCPPeer {
64 addr,
65 sender: tx,
66 disconnect: AtomicBool::new(false),
67 _ph: Default::default(),
68 })
69 }
70
71 #[inline]
73 pub fn addr(&self) -> SocketAddr {
74 self.addr
75 }
76
77 #[inline]
79 pub fn is_disconnect(&self) -> bool {
80 self.disconnect.load(Ordering::Acquire)
81 }
82
83 #[inline]
85 pub async fn send(&self, buff: Vec<u8>) -> Result<()> {
86 if !self.disconnect.load(Ordering::Acquire) {
87 Ok(self.sender.clone().send(State::Send(buff)).await?)
88 } else {
89 Err(std::io::Error::from(ErrorKind::ConnectionReset).into())
90 }
91 }
92
93 #[inline]
95 pub async fn send_all(&self, buff: Vec<u8>) -> Result<()> {
96 if !self.disconnect.load(Ordering::Acquire) {
97 Ok(self.sender.clone().send(State::SendFlush(buff)).await?)
98 } else {
99 Err(std::io::Error::from(ErrorKind::ConnectionReset).into())
100 }
101 }
102
103 #[inline]
105 pub async fn flush(&self) -> Result<()> {
106 if !self.disconnect.load(Ordering::Acquire) {
107 Ok(self.sender.send(State::Flush).await?)
108 } else {
109 Err(std::io::Error::from(ErrorKind::ConnectionReset).into())
110 }
111 }
112
113 #[inline]
115 pub async fn disconnect(&self) -> Result<()> {
116 if !self.disconnect.load(Ordering::Acquire) {
117 self.sender.send(State::Disconnect).await?;
118 self.disconnect.store(true, Ordering::Release);
119 }
120 Ok(())
121 }
122}