1use bytes::{Buf, BufMut, Bytes, BytesMut};
4use std::io::{self, Cursor};
5use std::mem;
6use tokio_util::codec::{Decoder, Encoder};
7
8#[derive(Clone, Debug, Eq, PartialEq)]
10pub enum ChannelMessage {
11 Data(u8, Bytes),
13 InputRequest(usize),
15 LineRequest(usize),
17 SystemRequest(Bytes),
19}
20
21#[derive(Debug)]
23pub struct ChannelDecoder {}
24
25impl ChannelDecoder {
26 pub fn new() -> ChannelDecoder {
27 ChannelDecoder {}
28 }
29}
30
31impl Decoder for ChannelDecoder {
32 type Item = ChannelMessage;
33 type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
36 parse_channel_header(src.as_ref())
37 .map_or(Ok(None), |(ch, len)| decode_channel_payload(src, ch, len))
38 }
39}
40
41const CHANNEL_HEADER_LEN: usize = mem::size_of::<u8>() + mem::size_of::<u32>();
42
43fn parse_channel_header(src: &[u8]) -> Option<(u8, usize)> {
44 if src.len() < CHANNEL_HEADER_LEN {
45 return None;
46 }
47 let mut buf = Cursor::new(src);
48 let ch = buf.get_u8();
49 let len = buf.get_u32();
50 Some((ch, len as usize))
51}
52
53fn decode_channel_payload(
54 src: &mut BytesMut,
55 ch: u8,
56 payload_len: usize,
57) -> Result<Option<ChannelMessage>, io::Error> {
58 let has_payload = ch.is_ascii_lowercase() || ch == b'S';
59 if has_payload && src.len() < CHANNEL_HEADER_LEN + payload_len {
60 return Ok(None);
61 }
62 src.advance(CHANNEL_HEADER_LEN);
63 match ch {
64 b'I' => Ok(Some(ChannelMessage::InputRequest(payload_len))),
65 b'L' => Ok(Some(ChannelMessage::LineRequest(payload_len))),
66 b'S' => {
67 let payload = src.split_to(payload_len).freeze();
68 Ok(Some(ChannelMessage::SystemRequest(payload)))
69 }
70 c if has_payload => {
71 let payload = src.split_to(payload_len).freeze();
72 Ok(Some(ChannelMessage::Data(c, payload)))
73 }
74 c => {
75 let msg = format!("unknown required channel: {}", c);
76 Err(io::Error::new(io::ErrorKind::InvalidData, msg))
77 }
78 }
79}
80
81#[derive(Clone, Debug, Eq, PartialEq)]
83pub enum BlockMessage {
84 Command(Bytes),
86 Data(Bytes),
88}
89
90#[derive(Debug)]
92pub struct BlockEncoder {}
93
94impl BlockEncoder {
95 pub fn new() -> BlockEncoder {
96 BlockEncoder {}
97 }
98}
99
100impl Encoder<BlockMessage> for BlockEncoder {
101 type Error = io::Error; fn encode(&mut self, msg: BlockMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
104 match msg {
105 BlockMessage::Command(cmd) => encode_command_line(cmd, dst),
106 BlockMessage::Data(data) => encode_data_block(data, dst),
107 }
108 }
109}
110
111fn encode_command_line(cmd: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
112 dst.reserve(cmd.len() + 1);
114 dst.put(cmd);
115 dst.put_u8(b'\n');
116 Ok(())
117}
118
119fn encode_data_block(data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
120 if data.len() > u32::max_value() as usize {
121 let msg = format!("data length exceeds protocol limit: {}", data.len());
122 return Err(io::Error::new(io::ErrorKind::InvalidInput, msg));
123 }
124 dst.reserve(mem::size_of::<u32>() + data.len());
125 dst.put_u32(data.len() as u32);
126 dst.put(data);
127 Ok(())
128}
129
130#[derive(Debug)]
135pub struct ClientCodec {
136 dec: ChannelDecoder,
137 enc: BlockEncoder,
138}
139
140impl ClientCodec {
141 pub fn new() -> ClientCodec {
142 ClientCodec {
143 dec: ChannelDecoder::new(),
144 enc: BlockEncoder::new(),
145 }
146 }
147}
148
149impl Decoder for ClientCodec {
150 type Item = ChannelMessage;
151 type Error = io::Error;
152
153 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
154 self.dec.decode(src)
155 }
156}
157
158impl Encoder<BlockMessage> for ClientCodec {
159 type Error = io::Error;
160
161 fn encode(&mut self, msg: BlockMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
162 self.enc.encode(msg, dst)
163 }
164}