1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use bytes::{Bytes, BytesMut, Buf, BufMut};
use std::io::{self, Cursor};
use std::mem;
use tokio_codec::{Decoder, Encoder};
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ChannelMessage {
Data(u8, Bytes),
InputRequest(usize),
LineRequest(usize),
SystemRequest(Bytes),
}
#[derive(Debug)]
pub struct ChannelCodec {
}
impl ChannelCodec {
pub fn new() -> ChannelCodec {
ChannelCodec {}
}
}
impl Decoder for ChannelCodec {
type Item = ChannelMessage;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
parse_channel_header(src.as_ref())
.map_or(Ok(None), |(ch, len)| decode_channel_payload(src, ch, len))
}
}
const CHANNEL_HEADER_LEN: usize = mem::size_of::<u8>() + mem::size_of::<u32>();
fn parse_channel_header(src: &[u8]) -> Option<(u8, usize)> {
if src.len() < CHANNEL_HEADER_LEN {
return None;
}
let mut buf = Cursor::new(src);
let ch = buf.get_u8();
let len = buf.get_u32_be();
Some((ch, len as usize))
}
fn decode_channel_payload(src: &mut BytesMut, ch: u8, payload_len: usize)
-> Result<Option<ChannelMessage>, io::Error> {
let has_payload = ch.is_ascii_lowercase() || ch == b'S';
if has_payload && src.len() < CHANNEL_HEADER_LEN + payload_len {
return Ok(None);
}
src.advance(CHANNEL_HEADER_LEN);
match ch {
b'I' => Ok(Some(ChannelMessage::InputRequest(payload_len))),
b'L' => Ok(Some(ChannelMessage::LineRequest(payload_len))),
b'S' => {
let payload = src.split_to(payload_len).freeze();
Ok(Some(ChannelMessage::SystemRequest(payload)))
}
c if has_payload => {
let payload = src.split_to(payload_len).freeze();
Ok(Some(ChannelMessage::Data(c, payload)))
}
c => {
let msg = format!("unknown required channel: {}", c);
Err(io::Error::new(io::ErrorKind::InvalidData, msg))
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum BlockMessage {
Command(Bytes),
Data(Bytes),
}
#[derive(Debug)]
pub struct BlockCodec {
}
impl BlockCodec {
pub fn new() -> BlockCodec {
BlockCodec {}
}
}
impl Encoder for BlockCodec {
type Item = BlockMessage;
type Error = io::Error;
fn encode(&mut self, msg: BlockMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
match msg {
BlockMessage::Command(cmd) => encode_command_line(cmd, dst),
BlockMessage::Data(data) => encode_data_block(data, dst),
}
}
}
fn encode_command_line(cmd: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
dst.reserve(cmd.len() + 1);
dst.put(cmd);
dst.put_u8(b'\n');
Ok(())
}
fn encode_data_block(data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
if data.len() > u32::max_value() as usize {
let msg = format!("data length exceeds protocol limit: {}", data.len());
return Err(io::Error::new(io::ErrorKind::InvalidInput, msg));
}
dst.reserve(mem::size_of::<u32>() + data.len());
dst.put_u32_be(data.len() as u32);
dst.put(data);
Ok(())
}