1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
//! Utilities for processing Mercurial command-server protocol.

use bytes::{Bytes, BytesMut, Buf, BufMut};
use std::io::{self, Cursor};
use std::mem;
use tokio_codec::{Decoder, Encoder};

/// Message sent from Mercurial command server.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ChannelMessage {
    /// Data sent from server to the specified channel.
    Data(u8, Bytes),
    /// Server requesting data up to the specified size in bytes.
    InputRequest(usize),
    /// Server requesting a single-line input up to the specified size in bytes.
    LineRequest(usize),
    /// Server requesting a shell command execution at client side. (cHg extension)
    SystemRequest(Bytes),
}

/// Decoder to parse and split messages sent from Mercurial command server.
#[derive(Debug)]
pub struct ChannelCodec {
}

impl ChannelCodec {
    pub fn new() -> ChannelCodec {
        ChannelCodec {}
    }
}

impl Decoder for ChannelCodec {
    type Item = ChannelMessage;
    type Error = io::Error;  // TODO: maybe introduce a dedicated error type

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        parse_channel_header(src.as_ref())
            .map_or(Ok(None), |(ch, len)| decode_channel_payload(src, ch, len))
    }
}

const CHANNEL_HEADER_LEN: usize = mem::size_of::<u8>() + mem::size_of::<u32>();

fn parse_channel_header(src: &[u8]) -> Option<(u8, usize)> {
    if src.len() < CHANNEL_HEADER_LEN {
        return None;
    }
    let mut buf = Cursor::new(src);
    let ch = buf.get_u8();
    let len = buf.get_u32_be();
    Some((ch, len as usize))
}

fn decode_channel_payload(src: &mut BytesMut, ch: u8, payload_len: usize)
                          -> Result<Option<ChannelMessage>, io::Error> {
    let has_payload = ch.is_ascii_lowercase() || ch == b'S';
    if has_payload && src.len() < CHANNEL_HEADER_LEN + payload_len {
        return Ok(None);
    }
    src.advance(CHANNEL_HEADER_LEN);
    match ch {
        b'I' => Ok(Some(ChannelMessage::InputRequest(payload_len))),
        b'L' => Ok(Some(ChannelMessage::LineRequest(payload_len))),
        b'S' => {
            let payload = src.split_to(payload_len).freeze();
            Ok(Some(ChannelMessage::SystemRequest(payload)))
        }
        c if has_payload => {
            let payload = src.split_to(payload_len).freeze();
            Ok(Some(ChannelMessage::Data(c, payload)))
        }
        c => {
            let msg = format!("unknown required channel: {}", c);
            Err(io::Error::new(io::ErrorKind::InvalidData, msg))
        }
    }
}

/// Request and response sent to Mercurial command server.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum BlockMessage {
    /// Command to initiate new request, without argument.
    Command(Bytes),
    /// Command argument or data sent to server.
    Data(Bytes),
}

/// Encoder to build messages sent to Mercurial command server.
#[derive(Debug)]
pub struct BlockCodec {
}

impl BlockCodec {
    pub fn new() -> BlockCodec {
        BlockCodec {}
    }
}

impl Encoder for BlockCodec {
    type Item = BlockMessage;
    type Error = io::Error;  // TODO: maybe introduce a dedicated error type

    fn encode(&mut self, msg: BlockMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
        match msg {
            BlockMessage::Command(cmd) => encode_command_line(cmd, dst),
            BlockMessage::Data(data) => encode_data_block(data, dst),
        }
    }
}

fn encode_command_line(cmd: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
    // TODO: error out if cmd contains '\n'?
    dst.reserve(cmd.len() + 1);
    dst.put(cmd);
    dst.put_u8(b'\n');
    Ok(())
}

fn encode_data_block(data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
    if data.len() > u32::max_value() as usize {
        let msg = format!("data length exceeds protocol limit: {}", data.len());
        return Err(io::Error::new(io::ErrorKind::InvalidInput, msg));
    }
    dst.reserve(mem::size_of::<u32>() + data.len());
    dst.put_u32_be(data.len() as u32);
    dst.put(data);
    Ok(())
}