1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
//! Helpful `TcpHandler`s.
//!
//! These structs wrap the functions in [crate::protocols].

pub mod raw;
#[cfg(feature = "compression")]
#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
pub mod compress;
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub mod encrypt;
#[cfg(feature = "compress_encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "compress_encryption")))]
pub mod compress_encrypt;

use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use crate::protocols::common::PacketError;

/// The handler trait, providing send and receive methods.
///
/// This trait uses [`async_trait`] so the parameters require [`Send`] mark.
#[async_trait]
pub trait TcpHandler {
    /// Send a message to the remote.
    async fn handler_send<B: Buf + Send>(&mut self, message: &mut B) -> Result<(), PacketError>;

    /// Receive a message from the remote.
    async fn handler_recv(&mut self) -> Result<BytesMut, PacketError>;

    /// Send and receive a message.
    #[inline]
    async fn handler_send_recv<B: Buf + Send>(&mut self, message: &mut B) -> Result<BytesMut, PacketError> {
        self.handler_send(message).await?;
        self.handler_recv().await
    }
}

macro_rules! impl_tcp_handler {
    (@ $struct: ident) => {
        #[::async_trait::async_trait]
        impl<R: ::tokio::io::AsyncRead + Unpin + Send, W: ::tokio::io::AsyncWrite + Unpin + Send> $crate::streams::TcpHandler for $struct<R, W> {
            #[inline]
            async fn handler_send<B: ::bytes::Buf + Send>(&mut self, message: &mut B) -> Result<(), $crate::protocols::common::PacketError> {
                self.send(message).await
            }

            #[inline]
            async fn handler_recv(&mut self) -> Result<::bytes::BytesMut, $crate::protocols::common::PacketError> {
                self.recv().await
            }
        }
    };
    (server $server: ident) => {
        impl_tcp_handler!(@ $server);

        impl<R: ::tokio::io::AsyncRead + Unpin, W: ::tokio::io::AsyncWrite + Unpin> $server<R, W> {
            /// Get the client's application version.
            #[inline]
            pub fn get_client_version(&self) -> &str {
                &self.version
            }
        }
        #[cfg(feature = "stream_net")]
        impl $server<::tokio::io::BufReader<::tokio::net::tcp::OwnedReadHalf>, ::tokio::io::BufWriter<::tokio::net::tcp::OwnedWriteHalf>> {
            #[cfg_attr(docsrs, doc(cfg(feature = "stream_net")))]
            #[doc = concat!("Construct the `", stringify!($server), "` from [`tokio::net::TcpStream`].")]
            pub async fn from_stream<P: FnOnce(&str) -> bool>(stream: ::tokio::net::TcpStream, identifier: &str, version_prediction: P, version: &str) -> Result<Self, $crate::protocols::common::StarterError> {
                let (reader, writer) = stream.into_split();
                let reader = ::tokio::io::BufReader::new(reader);
                let writer = ::tokio::io::BufWriter::new(writer);
                Self::new(reader, writer, identifier, version_prediction, version).await
            }
        }
    };
    (client $client: ident) => {
        impl_tcp_handler!(@ $client);

        #[cfg(feature = "stream_net")]
        impl $client<::tokio::io::BufReader<::tokio::net::tcp::OwnedReadHalf>, ::tokio::io::BufWriter<::tokio::net::tcp::OwnedWriteHalf>> {
            #[cfg_attr(docsrs, doc(cfg(feature = "stream_net")))]
            #[doc = concat!("Connection to `addr`, and construct the `", stringify!($client), "` using [`", stringify!($client), "::new`].")]
            pub async fn connect<A: ::tokio::net::ToSocketAddrs>(addr: A, identifier: &str, version: &str) -> Result<Self, $crate::protocols::common::StarterError> {
                let stream = ::tokio::net::TcpStream::connect(addr).await?;
                Self::from_stream(stream, identifier, version).await
            }
        }
        #[cfg(feature = "stream_net")]
        impl $client<::tokio::io::BufReader<::tokio::net::tcp::OwnedReadHalf>, ::tokio::io::BufWriter<::tokio::net::tcp::OwnedWriteHalf>> {
            #[cfg_attr(docsrs, doc(cfg(feature = "stream_net")))]
            #[doc = concat!("Construct the `", stringify!($client), "` from [`tokio::net::TcpStream`].")]
            pub async fn from_stream(stream: ::tokio::net::TcpStream, identifier: &str, version: &str) -> Result<Self, $crate::protocols::common::StarterError> {
                let (reader, writer) = stream.into_split();
                let reader = ::tokio::io::BufReader::new(reader);
                let writer = ::tokio::io::BufWriter::new(writer);
                Self::new(reader, writer, identifier, version).await
            }
        }
    }
}
use impl_tcp_handler;

#[cfg(test)]
mod tests {
    use anyhow::Result;
    use tokio::io::{AsyncRead, AsyncWrite, duplex, split};

    pub async fn create() -> Result<(impl AsyncRead + Unpin, impl AsyncWrite + Unpin, impl AsyncRead + Unpin, impl AsyncWrite + Unpin)> {
        let (client, server) = duplex(1024);
        let (cr, cw) = split(client);
        let (sr, sw) = split(server);
        Ok((cr, cw, sr, sw))
    }

    macro_rules! check_send_recv {
        ($sender: expr, $receiver: expr, $msg: literal) => { {
            let mut writer = ::bytes::BytesMut::new().writer();
            writer.write_string($msg)?;
            $sender.send(&mut writer.into_inner()).await?;

            let mut reader = $receiver.recv().await?.reader();
            let msg = reader.read_string()?;
            assert_eq!($msg, msg);
        } };
    }
    pub(crate) use check_send_recv;
}