use std::convert::Infallible;
use binrw::{
binrw,
meta::{ReadEndian, WriteEndian},
BinRead, BinWrite,
};
use crate::Error;
#[binrw]
#[derive(Debug)]
#[brw(big)]
#[br(import(mac_len: usize))]
pub struct Packet {
#[bw(assert(payload.len() > u32::MAX as usize, "payload size is too large"), calc = payload.len() as u32)]
len: u32,
#[bw(assert(padding.len() > u8::MAX as usize, "padding size is too large"), calc = padding.len() as u8)]
padding_len: u8,
#[br(count = len - padding_len as u32 - 1)]
pub payload: Vec<u8>,
#[br(count = padding_len)]
pub padding: Vec<u8>,
#[br(count = mac_len)]
pub mac: Vec<u8>,
}
impl Packet {
pub fn decrypt<T, C>(self, cipher: &mut C) -> Result<T, Error<C::Err>>
where
for<'r> T: BinRead<Args<'r> = ()> + ReadEndian,
C: Cipher,
{
let payload = cipher.decrypt(self).map_err(Error::Cipher)?;
Ok(T::read(&mut std::io::Cursor::new(payload))?)
}
pub fn encrypt<T, C>(message: T, cipher: &mut C) -> Result<Self, Error<C::Err>>
where
for<'w> T: BinWrite<Args<'w> = ()> + WriteEndian,
C: Cipher,
{
let mut payload = std::io::Cursor::new(Vec::new());
message.write(&mut payload)?;
cipher.encrypt(payload.into_inner()).map_err(Error::Cipher)
}
pub fn from_reader<R, C>(reader: &mut R, cipher: &C) -> Result<Self, Error<Infallible>>
where
R: std::io::Read + std::io::Seek,
C: Cipher,
{
Ok(Self::read_args(reader, (cipher.size(),))?)
}
#[cfg(feature = "futures")]
#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
pub async fn from_async_reader<R, C>(
reader: &mut R,
cipher: &C,
) -> Result<Self, Error<Infallible>>
where
R: futures::io::AsyncRead + Unpin,
C: Cipher,
{
use futures::io::AsyncReadExt;
let mut buf = [0u8; 4];
reader.read_exact(&mut buf).await?;
let len = u32::from_be_bytes(buf);
let size = buf.len() + len as usize + cipher.size();
let mut buf = buf.to_vec();
buf.resize(size, 0);
reader.read_exact(&mut buf[..]).await?;
Ok(Self::read_args(
&mut std::io::Cursor::new(buf),
(cipher.size(),),
)?)
}
pub fn to_writer<W>(&self, writer: &mut W) -> Result<(), Error<Infallible>>
where
W: std::io::Write + std::io::Seek,
{
Ok(self.write(writer)?)
}
#[cfg(feature = "futures")]
#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
pub async fn to_async_writer<W>(&self, writer: &mut W) -> Result<(), Error<Infallible>>
where
W: futures::io::AsyncWrite + Unpin,
{
use futures::io::AsyncWriteExt;
let size = 4 + self.payload.len() + self.padding.len() + self.mac.len();
let mut buf = std::io::Cursor::new(vec![0u8; size]);
self.write(&mut buf)?;
Ok(writer.write_all(&buf.into_inner()).await?)
}
}
pub trait Cipher {
type Err;
fn size(&self) -> usize;
fn decrypt(&mut self, packet: Packet) -> Result<Vec<u8>, Self::Err>;
fn encrypt(&mut self, payload: Vec<u8>) -> Result<Packet, Self::Err>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn assert_cipher_is_object_safe() {
struct Dummy;
impl Cipher for Dummy {
type Err = ();
fn size(&self) -> usize {
16
}
fn decrypt(&mut self, _packet: Packet) -> Result<Vec<u8>, Self::Err> {
todo!()
}
fn encrypt(&mut self, _payload: Vec<u8>) -> Result<Packet, Self::Err> {
todo!()
}
}
let _: &dyn Cipher<Err = ()> = &Dummy;
}
}