tcp_channel_client/
lib.rs1pub mod error;
2use async_channel::Sender;
3use error::{Error, Result};
4use log::*;
5use std::future::Future;
6use std::marker::PhantomData;
7use std::net::SocketAddr;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf};
11use tokio::net::{TcpStream, ToSocketAddrs};
12
13pub struct TcpClient<T> {
14 disconnect: AtomicBool,
15 sender: Sender<State>,
16 _ph: PhantomData<T>,
17}
18
19impl TcpClient<TcpStream> {
20 #[inline]
21 pub async fn connect<
22 T: ToSocketAddrs,
23 F: Future<Output = anyhow::Result<bool>> + Send + 'static,
24 A: Send + 'static,
25 >(
26 addr: T,
27 input: impl FnOnce(A, Arc<TcpClient<TcpStream>>, ReadHalf<TcpStream>) -> F + Send + 'static,
28 token: A,
29 ) -> Result<Arc<TcpClient<TcpStream>>> {
30 let stream = TcpStream::connect(addr).await?;
31 let target = stream.peer_addr()?;
32 Self::init(input, token, stream, target)
33 }
34}
35
36pub enum State {
37 Disconnect,
38 Send(Vec<u8>),
39 SendFlush(Vec<u8>),
40 Flush,
41}
42
43impl<T> TcpClient<T>
44where
45 T: AsyncRead + AsyncWrite + Send + Sync + 'static,
46{
47 #[inline]
48 pub async fn connect_stream_type<
49 H: ToSocketAddrs,
50 F: Future<Output = anyhow::Result<bool>> + Send + 'static,
51 S: Future<Output = anyhow::Result<T>> + Send + 'static,
52 A: Send + 'static,
53 >(
54 addr: H,
55 stream_init: impl FnOnce(TcpStream) -> S + Send + 'static,
56 input: impl FnOnce(A, Arc<TcpClient<T>>, ReadHalf<T>) -> F + Send + 'static,
57 token: A,
58 ) -> Result<Arc<TcpClient<T>>> {
59 let stream = TcpStream::connect(addr).await?;
60 let target = stream.peer_addr()?;
61 let stream = stream_init(stream).await?;
62 Self::init(input, token, stream, target)
63 }
64
65 #[inline]
66 fn init<F: Future<Output = anyhow::Result<bool>> + Send + 'static, A: Send + 'static>(
67 f: impl FnOnce(A, Arc<TcpClient<T>>, ReadHalf<T>) -> F + Send + 'static,
68 token: A,
69 stream: T,
70 target: SocketAddr,
71 ) -> Result<Arc<TcpClient<T>>> {
72 let (reader, mut sender) = tokio::io::split(stream);
73
74 let (tx, rx) = async_channel::bounded(4096);
75
76 let client = Arc::new(TcpClient {
77 disconnect: AtomicBool::new(false),
78 sender: tx,
79 _ph: Default::default(),
80 });
81 let read_client = client.clone();
82 tokio::spawn(async move {
83 let disconnect_client = read_client.clone();
84 let need_disconnect = f(token, read_client, reader).await.unwrap_or_else(|err| {
85 error!("reader error:{}", err);
86 true
87 });
88
89 if need_disconnect {
90 if let Err(er) = disconnect_client.disconnect().await {
91 error!("disconnect to{} err:{}", target, er);
92 } else {
93 debug!("disconnect to {}", target)
94 }
95 } else {
96 debug!("{} reader is close", target);
97 }
98 });
99
100 tokio::spawn(async move {
101 loop {
102 if let Ok(state) = rx.recv().await {
103 match state {
104 State::Disconnect => {
105 let _ = sender.shutdown().await;
106 return;
107 }
108 State::Send(data) => {
109 if sender.write(&data).await.is_err() {
110 return;
111 }
112 }
113 State::SendFlush(data) => {
114 if sender.write(&data).await.is_err() {
115 return;
116 }
117
118 if sender.flush().await.is_err() {
119 return;
120 }
121 }
122 State::Flush => {
123 if sender.flush().await.is_err() {
124 return;
125 }
126 }
127 }
128 } else {
129 return;
130 }
131 }
132 });
133
134 Ok(client)
135 }
136
137 #[inline]
138 pub async fn disconnect(&self) -> Result<()> {
139 if !self.disconnect.load(Ordering::Acquire) {
140 self.sender.send(State::Disconnect).await?;
141 self.disconnect.store(true, Ordering::Release);
142 }
143 Ok(())
144 }
145 #[inline]
146 pub async fn send(&self, buff: Vec<u8>) -> Result<()> {
147 if !self.disconnect.load(Ordering::Acquire) {
148 Ok(self.sender.send(State::Send(buff)).await?)
149 } else {
150 Err(Error::SendError("Disconnect".to_string()))
151 }
152 }
153
154 #[inline]
155 pub async fn send_all(&self, buff: Vec<u8>) -> Result<()> {
156 if !self.disconnect.load(Ordering::Acquire) {
157 Ok(self.sender.send(State::SendFlush(buff)).await?)
158 } else {
159 Err(Error::SendError("Disconnect".to_string()))
160 }
161 }
162
163 #[inline]
164 pub async fn flush(&self) -> Result<()> {
165 if !self.disconnect.load(Ordering::Acquire) {
166 Ok(self.sender.send(State::Flush).await?)
167 } else {
168 Err(Error::SendError("Disconnect".to_string()))
169 }
170 }
171}