1use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5pub mod config;
6
7pub use config::*;
8
9pub struct Connection<T> {
11 stream: T,
12 pub config: Config,
13}
14impl<T> AsRef<T> for Connection<T> {
15 fn as_ref(&self) -> &T {
16 &self.stream
17 }
18}
19impl<T> AsMut<T> for Connection<T> {
20 fn as_mut(&mut self) -> &mut T {
21 &mut self.stream
22 }
23}
24impl<T> From<T> for Connection<T> {
25 fn from(value: T) -> Self {
26 Self::with_config(value, Config::default())
27 }
28}
29impl<T> Connection<T> {
30 pub fn with_config(stream: T, config: Config) -> Self {
32 Self { stream, config }
33 }
34
35 pub fn new(stream: T) -> Self {
37 stream.into()
38 }
39}
40impl<T> Connection<T>
41where
42 T: AsyncRead + AsyncWrite + Unpin,
43{
44 pub async fn recv(&mut self) -> std::io::Result<Vec<u8>> {
46 let channel = self.stream.read_u8().await?;
47 let len = self.stream.read_u64_le().await? as usize;
48 if len > self.config.max_message_size {
49 for _ in 0..len {
50 self.stream.read_u8().await?;
51 }
52 self.send_err("MsgTooLarge").await?;
53 return Err(std::io::ErrorKind::OutOfMemory.into());
54 }
55 let mut buf = vec![0u8; len];
56 self.stream.read_exact(&mut buf).await?;
57 match channel {
58 0 => Ok(buf),
59 1 => Err(std::io::Error::new(
60 std::io::ErrorKind::Other,
61 String::from_utf8_lossy(&buf).to_string(),
62 )),
63 _ => Err(std::io::Error::new(
64 std::io::ErrorKind::Unsupported,
65 format!("unsupported channel: {}", channel),
66 )),
67 }
68 }
69
70 pub async fn send(&mut self, message: &[u8]) -> std::io::Result<()> {
72 self.send_raw(0, message).await
73 }
74
75 pub async fn send_err(&mut self, message: &str) -> std::io::Result<()> {
77 self.send_raw(1, message.as_bytes()).await
78 }
79
80 async fn send_raw(&mut self, channel: u8, message: &[u8]) -> std::io::Result<()> {
82 if message.len() > self.config.max_message_size {
83 return Err(std::io::ErrorKind::OutOfMemory.into());
84 }
85 self.stream.write_u8(channel).await?;
86 self.stream.write_u64_le(message.len() as _).await?;
87 self.stream.write_all(message).await?;
88 Ok(())
89 }
90}