tcp_handler/streams/
mod.rs

1//! Helpful `TcpHandler`s.
2//!
3//! These structs wrap the functions in [crate::protocols].
4
5pub mod raw;
6#[cfg(feature = "compression")]
7#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
8pub mod compress;
9#[cfg(feature = "encryption")]
10#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
11pub mod encrypt;
12#[cfg(feature = "compress_encryption")]
13#[cfg_attr(docsrs, doc(cfg(feature = "compress_encryption")))]
14pub mod compress_encrypt;
15
16use async_trait::async_trait;
17use bytes::{Buf, BytesMut};
18use crate::protocols::common::PacketError;
19
20/// The handler trait, providing send and receive methods.
21///
22/// This trait uses [`async_trait`] so the parameters require [`Send`] mark.
23#[async_trait]
24pub trait TcpHandler {
25    /// Send a message to the remote.
26    async fn handler_send<B: Buf + Send>(&mut self, message: &mut B) -> Result<(), PacketError>;
27
28    /// Receive a message from the remote.
29    async fn handler_recv(&mut self) -> Result<BytesMut, PacketError>;
30
31    /// Send and receive a message.
32    #[inline]
33    async fn handler_send_recv<B: Buf + Send>(&mut self, message: &mut B) -> Result<BytesMut, PacketError> {
34        self.handler_send(message).await?;
35        self.handler_recv().await
36    }
37}
38
39macro_rules! impl_tcp_handler {
40    (@ $struct: ident) => {
41        #[::async_trait::async_trait]
42        impl<R: ::tokio::io::AsyncRead + Unpin + Send, W: ::tokio::io::AsyncWrite + Unpin + Send> $crate::streams::TcpHandler for $struct<R, W> {
43            #[inline]
44            async fn handler_send<B: ::bytes::Buf + Send>(&mut self, message: &mut B) -> Result<(), $crate::protocols::common::PacketError> {
45                self.send(message).await
46            }
47
48            #[inline]
49            async fn handler_recv(&mut self) -> Result<::bytes::BytesMut, $crate::protocols::common::PacketError> {
50                self.recv().await
51            }
52        }
53    };
54    (server $server: ident) => {
55        impl_tcp_handler!(@ $server);
56
57        impl<R: ::tokio::io::AsyncRead + Unpin, W: ::tokio::io::AsyncWrite + Unpin> $server<R, W> {
58            /// Get the client's application version.
59            #[inline]
60            pub fn get_client_version(&self) -> &str {
61                &self.version
62            }
63        }
64        #[cfg(feature = "stream_net")]
65        impl $server<::tokio::io::BufReader<::tokio::net::tcp::OwnedReadHalf>, ::tokio::io::BufWriter<::tokio::net::tcp::OwnedWriteHalf>> {
66            #[cfg_attr(docsrs, doc(cfg(feature = "stream_net")))]
67            #[doc = concat!("Construct the `", stringify!($server), "` from [`tokio::net::TcpStream`].")]
68            pub async fn from_stream<P: FnOnce(&str) -> bool>(stream: ::tokio::net::TcpStream, identifier: &str, version_prediction: P, version: &str) -> Result<Self, $crate::protocols::common::StarterError> {
69                let (reader, writer) = stream.into_split();
70                let reader = ::tokio::io::BufReader::new(reader);
71                let writer = ::tokio::io::BufWriter::new(writer);
72                Self::new(reader, writer, identifier, version_prediction, version).await
73            }
74        }
75    };
76    (client $client: ident) => {
77        impl_tcp_handler!(@ $client);
78
79        #[cfg(feature = "stream_net")]
80        impl $client<::tokio::io::BufReader<::tokio::net::tcp::OwnedReadHalf>, ::tokio::io::BufWriter<::tokio::net::tcp::OwnedWriteHalf>> {
81            #[cfg_attr(docsrs, doc(cfg(feature = "stream_net")))]
82            #[doc = concat!("Connection to `addr`, and construct the `", stringify!($client), "` using [`", stringify!($client), "::new`].")]
83            pub async fn connect<A: ::tokio::net::ToSocketAddrs>(addr: A, identifier: &str, version: &str) -> Result<Self, $crate::protocols::common::StarterError> {
84                let stream = ::tokio::net::TcpStream::connect(addr).await?;
85                Self::from_stream(stream, identifier, version).await
86            }
87        }
88        #[cfg(feature = "stream_net")]
89        impl $client<::tokio::io::BufReader<::tokio::net::tcp::OwnedReadHalf>, ::tokio::io::BufWriter<::tokio::net::tcp::OwnedWriteHalf>> {
90            #[cfg_attr(docsrs, doc(cfg(feature = "stream_net")))]
91            #[doc = concat!("Construct the `", stringify!($client), "` from [`tokio::net::TcpStream`].")]
92            pub async fn from_stream(stream: ::tokio::net::TcpStream, identifier: &str, version: &str) -> Result<Self, $crate::protocols::common::StarterError> {
93                let (reader, writer) = stream.into_split();
94                let reader = ::tokio::io::BufReader::new(reader);
95                let writer = ::tokio::io::BufWriter::new(writer);
96                Self::new(reader, writer, identifier, version).await
97            }
98        }
99    }
100}
101use impl_tcp_handler;
102
103#[cfg(test)]
104mod tests {
105    use anyhow::Result;
106    use tokio::io::{AsyncRead, AsyncWrite, duplex, split};
107
108    pub async fn create() -> Result<(impl AsyncRead + Unpin, impl AsyncWrite + Unpin, impl AsyncRead + Unpin, impl AsyncWrite + Unpin)> {
109        let (client, server) = duplex(1024);
110        let (cr, cw) = split(client);
111        let (sr, sw) = split(server);
112        Ok((cr, cw, sr, sw))
113    }
114
115    macro_rules! check_send_recv {
116        ($sender: expr, $receiver: expr, $msg: literal) => { {
117            let mut writer = ::bytes::BytesMut::new().writer();
118            writer.write_string($msg)?;
119            $sender.send(&mut writer.into_inner()).await?;
120
121            let mut reader = $receiver.recv().await?.reader();
122            let msg = reader.read_string()?;
123            assert_eq!($msg, msg);
124        } };
125    }
126    pub(crate) use check_send_recv;
127}