tcp_channel_client/
lib.rs

1pub mod error;
2use async_channel::Sender;
3use error::{Error, Result};
4use log::*;
5use std::future::Future;
6use std::marker::PhantomData;
7use std::net::SocketAddr;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf};
11use tokio::net::{TcpStream, ToSocketAddrs};
12
13pub struct TcpClient<T> {
14    disconnect: AtomicBool,
15    sender: Sender<State>,
16    _ph: PhantomData<T>,
17}
18
19impl TcpClient<TcpStream> {
20    #[inline]
21    pub async fn connect<
22        T: ToSocketAddrs,
23        F: Future<Output = anyhow::Result<bool>> + Send + 'static,
24        A: Send + 'static,
25    >(
26        addr: T,
27        input: impl FnOnce(A, Arc<TcpClient<TcpStream>>, ReadHalf<TcpStream>) -> F + Send + 'static,
28        token: A,
29    ) -> Result<Arc<TcpClient<TcpStream>>> {
30        let stream = TcpStream::connect(addr).await?;
31        let target = stream.peer_addr()?;
32        Self::init(input, token, stream, target)
33    }
34}
35
36pub enum State {
37    Disconnect,
38    Send(Vec<u8>),
39    SendFlush(Vec<u8>),
40    Flush,
41}
42
43impl<T> TcpClient<T>
44where
45    T: AsyncRead + AsyncWrite + Send + Sync + 'static,
46{
47    #[inline]
48    pub async fn connect_stream_type<
49        H: ToSocketAddrs,
50        F: Future<Output = anyhow::Result<bool>> + Send + 'static,
51        S: Future<Output = anyhow::Result<T>> + Send + 'static,
52        A: Send + 'static,
53    >(
54        addr: H,
55        stream_init: impl FnOnce(TcpStream) -> S + Send + 'static,
56        input: impl FnOnce(A, Arc<TcpClient<T>>, ReadHalf<T>) -> F + Send + 'static,
57        token: A,
58    ) -> Result<Arc<TcpClient<T>>> {
59        let stream = TcpStream::connect(addr).await?;
60        let target = stream.peer_addr()?;
61        let stream = stream_init(stream).await?;
62        Self::init(input, token, stream, target)
63    }
64
65    #[inline]
66    fn init<F: Future<Output = anyhow::Result<bool>> + Send + 'static, A: Send + 'static>(
67        f: impl FnOnce(A, Arc<TcpClient<T>>, ReadHalf<T>) -> F + Send + 'static,
68        token: A,
69        stream: T,
70        target: SocketAddr,
71    ) -> Result<Arc<TcpClient<T>>> {
72        let (reader, mut sender) = tokio::io::split(stream);
73
74        let (tx, rx) = async_channel::bounded(4096);
75
76        let client = Arc::new(TcpClient {
77            disconnect: AtomicBool::new(false),
78            sender: tx,
79            _ph: Default::default(),
80        });
81        let read_client = client.clone();
82        tokio::spawn(async move {
83            let disconnect_client = read_client.clone();
84            let need_disconnect = f(token, read_client, reader).await.unwrap_or_else(|err| {
85                error!("reader error:{}", err);
86                true
87            });
88
89            if need_disconnect {
90                if let Err(er) = disconnect_client.disconnect().await {
91                    error!("disconnect to{} err:{}", target, er);
92                } else {
93                    debug!("disconnect to {}", target)
94                }
95            } else {
96                debug!("{} reader is close", target);
97            }
98        });
99
100        tokio::spawn(async move {
101            loop {
102                if let Ok(state) = rx.recv().await {
103                    match state {
104                        State::Disconnect => {
105                            let _ = sender.shutdown().await;
106                            return;
107                        }
108                        State::Send(data) => {
109                            if sender.write(&data).await.is_err() {
110                                return;
111                            }
112                        }
113                        State::SendFlush(data) => {
114                            if sender.write(&data).await.is_err() {
115                                return;
116                            }
117
118                            if sender.flush().await.is_err() {
119                                return;
120                            }
121                        }
122                        State::Flush => {
123                            if sender.flush().await.is_err() {
124                                return;
125                            }
126                        }
127                    }
128                } else {
129                    return;
130                }
131            }
132        });
133
134        Ok(client)
135    }
136
137    #[inline]
138    pub async fn disconnect(&self) -> Result<()> {
139        if !self.disconnect.load(Ordering::Acquire) {
140            self.sender.send(State::Disconnect).await?;
141            self.disconnect.store(true, Ordering::Release);
142        }
143        Ok(())
144    }
145    #[inline]
146    pub async fn send(&self, buff: Vec<u8>) -> Result<()> {
147        if !self.disconnect.load(Ordering::Acquire) {
148            Ok(self.sender.send(State::Send(buff)).await?)
149        } else {
150            Err(Error::SendError("Disconnect".to_string()))
151        }
152    }
153
154    #[inline]
155    pub async fn send_all(&self, buff: Vec<u8>) -> Result<()> {
156        if !self.disconnect.load(Ordering::Acquire) {
157            Ok(self.sender.send(State::SendFlush(buff)).await?)
158        } else {
159            Err(Error::SendError("Disconnect".to_string()))
160        }
161    }
162
163    #[inline]
164    pub async fn flush(&self) -> Result<()> {
165        if !self.disconnect.load(Ordering::Acquire) {
166            Ok(self.sender.send(State::Flush).await?)
167        } else {
168            Err(Error::SendError("Disconnect".to_string()))
169        }
170    }
171}