ssh_agent_lib/
codec.rs

1//! SSH agent protocol framing codec.
2
3use std::marker::PhantomData;
4use std::mem::size_of;
5
6use byteorder::{BigEndian, ReadBytesExt};
7use ssh_encoding::{Decode, Encode};
8use tokio_util::bytes::{Buf, BufMut, BytesMut};
9use tokio_util::codec::{Decoder, Encoder};
10
11use super::error::AgentError;
12use super::proto::ProtoError;
13
14/// SSH framing codec.
15///
16/// This codec first reads an `u32` which indicates the length of the incoming
17/// message. Then decodes the message using specified `Input` type.
18///
19/// The reverse transformation which appends the length of the encoded data
20/// is also implemented for the given `Output` type.
21#[derive(Debug)]
22pub struct Codec<Input, Output>(PhantomData<Input>, PhantomData<Output>)
23where
24    Input: Decode,
25    Output: Encode,
26    AgentError: From<Input::Error>;
27
28impl<Input, Output> Default for Codec<Input, Output>
29where
30    Input: Decode,
31    Output: Encode,
32    AgentError: From<Input::Error>,
33{
34    fn default() -> Self {
35        Self(PhantomData, PhantomData)
36    }
37}
38
39impl<Input, Output> Decoder for Codec<Input, Output>
40where
41    Input: Decode,
42    Output: Encode,
43    AgentError: From<Input::Error>,
44{
45    type Item = Input;
46    type Error = AgentError;
47
48    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
49        let mut bytes = &src[..];
50
51        if bytes.len() < size_of::<u32>() {
52            return Ok(None);
53        }
54
55        let length = bytes.read_u32::<BigEndian>()? as usize;
56
57        if bytes.len() < length {
58            return Ok(None);
59        }
60
61        let message = Self::Item::decode(&mut bytes)?;
62        src.advance(size_of::<u32>() + length);
63        Ok(Some(message))
64    }
65}
66
67impl<Input, Output> Encoder<Output> for Codec<Input, Output>
68where
69    Input: Decode,
70    Output: Encode,
71    AgentError: From<Input::Error>,
72{
73    type Error = AgentError;
74
75    fn encode(&mut self, item: Output, dst: &mut BytesMut) -> Result<(), Self::Error> {
76        let mut bytes = Vec::new();
77
78        let len = item.encoded_len().map_err(ProtoError::SshEncoding)? as u32;
79        len.encode(&mut bytes).map_err(ProtoError::SshEncoding)?;
80
81        item.encode(&mut bytes).map_err(ProtoError::SshEncoding)?;
82        dst.put(&*bytes);
83
84        Ok(())
85    }
86}