tokio_hglib/
codec.rs

1//! Utilities for processing Mercurial command-server protocol.
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use std::io::{self, Cursor};
5use std::mem;
6use tokio_util::codec::{Decoder, Encoder};
7
8/// Message sent from Mercurial command server.
9#[derive(Clone, Debug, Eq, PartialEq)]
10pub enum ChannelMessage {
11    /// Data sent from server to the specified channel.
12    Data(u8, Bytes),
13    /// Server requesting data up to the specified size in bytes.
14    InputRequest(usize),
15    /// Server requesting a single-line input up to the specified size in bytes.
16    LineRequest(usize),
17    /// Server requesting a shell command execution at client side. (cHg extension)
18    SystemRequest(Bytes),
19}
20
21/// Decoder to parse and split messages sent from Mercurial command server.
22#[derive(Debug)]
23pub struct ChannelDecoder {}
24
25impl ChannelDecoder {
26    pub fn new() -> ChannelDecoder {
27        ChannelDecoder {}
28    }
29}
30
31impl Decoder for ChannelDecoder {
32    type Item = ChannelMessage;
33    type Error = io::Error; // TODO: maybe introduce a dedicated error type
34
35    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
36        parse_channel_header(src.as_ref())
37            .map_or(Ok(None), |(ch, len)| decode_channel_payload(src, ch, len))
38    }
39}
40
41const CHANNEL_HEADER_LEN: usize = mem::size_of::<u8>() + mem::size_of::<u32>();
42
43fn parse_channel_header(src: &[u8]) -> Option<(u8, usize)> {
44    if src.len() < CHANNEL_HEADER_LEN {
45        return None;
46    }
47    let mut buf = Cursor::new(src);
48    let ch = buf.get_u8();
49    let len = buf.get_u32();
50    Some((ch, len as usize))
51}
52
53fn decode_channel_payload(
54    src: &mut BytesMut,
55    ch: u8,
56    payload_len: usize,
57) -> Result<Option<ChannelMessage>, io::Error> {
58    let has_payload = ch.is_ascii_lowercase() || ch == b'S';
59    if has_payload && src.len() < CHANNEL_HEADER_LEN + payload_len {
60        return Ok(None);
61    }
62    src.advance(CHANNEL_HEADER_LEN);
63    match ch {
64        b'I' => Ok(Some(ChannelMessage::InputRequest(payload_len))),
65        b'L' => Ok(Some(ChannelMessage::LineRequest(payload_len))),
66        b'S' => {
67            let payload = src.split_to(payload_len).freeze();
68            Ok(Some(ChannelMessage::SystemRequest(payload)))
69        }
70        c if has_payload => {
71            let payload = src.split_to(payload_len).freeze();
72            Ok(Some(ChannelMessage::Data(c, payload)))
73        }
74        c => {
75            let msg = format!("unknown required channel: {}", c);
76            Err(io::Error::new(io::ErrorKind::InvalidData, msg))
77        }
78    }
79}
80
81/// Request and response sent to Mercurial command server.
82#[derive(Clone, Debug, Eq, PartialEq)]
83pub enum BlockMessage {
84    /// Command to initiate new request, without argument.
85    Command(Bytes),
86    /// Command argument or data sent to server.
87    Data(Bytes),
88}
89
90/// Encoder to build messages sent to Mercurial command server.
91#[derive(Debug)]
92pub struct BlockEncoder {}
93
94impl BlockEncoder {
95    pub fn new() -> BlockEncoder {
96        BlockEncoder {}
97    }
98}
99
100impl Encoder<BlockMessage> for BlockEncoder {
101    type Error = io::Error; // TODO: maybe introduce a dedicated error type
102
103    fn encode(&mut self, msg: BlockMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
104        match msg {
105            BlockMessage::Command(cmd) => encode_command_line(cmd, dst),
106            BlockMessage::Data(data) => encode_data_block(data, dst),
107        }
108    }
109}
110
111fn encode_command_line(cmd: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
112    // TODO: error out if cmd contains '\n'?
113    dst.reserve(cmd.len() + 1);
114    dst.put(cmd);
115    dst.put_u8(b'\n');
116    Ok(())
117}
118
119fn encode_data_block(data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
120    if data.len() > u32::max_value() as usize {
121        let msg = format!("data length exceeds protocol limit: {}", data.len());
122        return Err(io::Error::new(io::ErrorKind::InvalidInput, msg));
123    }
124    dst.reserve(mem::size_of::<u32>() + data.len());
125    dst.put_u32(data.len() as u32);
126    dst.put(data);
127    Ok(())
128}
129
130/// Unified codec for client-side operations.
131///
132/// With this codec, a framed stream can be constructed without splitting
133/// it into read/write halves.
134#[derive(Debug)]
135pub struct ClientCodec {
136    dec: ChannelDecoder,
137    enc: BlockEncoder,
138}
139
140impl ClientCodec {
141    pub fn new() -> ClientCodec {
142        ClientCodec {
143            dec: ChannelDecoder::new(),
144            enc: BlockEncoder::new(),
145        }
146    }
147}
148
149impl Decoder for ClientCodec {
150    type Item = ChannelMessage;
151    type Error = io::Error;
152
153    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
154        self.dec.decode(src)
155    }
156}
157
158impl Encoder<BlockMessage> for ClientCodec {
159    type Error = io::Error;
160
161    fn encode(&mut self, msg: BlockMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
162        self.enc.encode(msg, dst)
163    }
164}