1use std::marker::PhantomData;
4use std::mem::size_of;
5
6use byteorder::{BigEndian, ReadBytesExt};
7use ssh_encoding::{Decode, Encode};
8use tokio_util::bytes::{Buf, BufMut, BytesMut};
9use tokio_util::codec::{Decoder, Encoder};
10
11use super::error::AgentError;
12use super::proto::ProtoError;
13
14#[derive(Debug)]
22pub struct Codec<Input, Output>(PhantomData<Input>, PhantomData<Output>)
23where
24 Input: Decode,
25 Output: Encode,
26 AgentError: From<Input::Error>;
27
28impl<Input, Output> Default for Codec<Input, Output>
29where
30 Input: Decode,
31 Output: Encode,
32 AgentError: From<Input::Error>,
33{
34 fn default() -> Self {
35 Self(PhantomData, PhantomData)
36 }
37}
38
39impl<Input, Output> Decoder for Codec<Input, Output>
40where
41 Input: Decode,
42 Output: Encode,
43 AgentError: From<Input::Error>,
44{
45 type Item = Input;
46 type Error = AgentError;
47
48 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
49 let mut bytes = &src[..];
50
51 if bytes.len() < size_of::<u32>() {
52 return Ok(None);
53 }
54
55 let length = bytes.read_u32::<BigEndian>()? as usize;
56
57 if bytes.len() < length {
58 return Ok(None);
59 }
60
61 let message = Self::Item::decode(&mut bytes)?;
62 src.advance(size_of::<u32>() + length);
63 Ok(Some(message))
64 }
65}
66
67impl<Input, Output> Encoder<Output> for Codec<Input, Output>
68where
69 Input: Decode,
70 Output: Encode,
71 AgentError: From<Input::Error>,
72{
73 type Error = AgentError;
74
75 fn encode(&mut self, item: Output, dst: &mut BytesMut) -> Result<(), Self::Error> {
76 let mut bytes = Vec::new();
77
78 let len = item.encoded_len().map_err(ProtoError::SshEncoding)? as u32;
79 len.encode(&mut bytes).map_err(ProtoError::SshEncoding)?;
80
81 item.encode(&mut bytes).map_err(ProtoError::SshEncoding)?;
82 dst.put(&*bytes);
83
84 Ok(())
85 }
86}