ya_sb_proto/
codec.rs

1use prost::Message;
2use std::convert::TryInto;
3
4use crate::gsb_api::*;
5use thiserror::Error;
6
7use bytes::{Buf, BufMut};
8use tokio_util::codec::{Decoder, Encoder};
9
10#[derive(Debug, Error)]
11pub enum ProtocolError {
12    #[error("Unrecognized message type")]
13    UnrecognizedMessageType,
14    #[error("Cannot decode message header: not enough bytes")]
15    HeaderNotEnoughBytes,
16    #[error("{0}")]
17    Io(#[from] std::io::Error),
18    #[error("encode error: {0}")]
19    Encode(#[from] prost::EncodeError),
20    #[error("decode {0}")]
21    Decode(#[from] prost::DecodeError),
22    #[error("channel Receiver error")]
23    RecvError,
24    #[error("packet too big")]
25    MsgTooBig,
26}
27
28trait Encodable {
29    // This trait exists because prost::Message has template methods
30
31    fn encode_(&self, buf: &mut bytes::BytesMut) -> Result<(), ProtocolError>;
32    fn encoded_len_(&self) -> usize;
33}
34
35impl<T: Message> Encodable for T {
36    fn encode_(&self, mut buf: &mut bytes::BytesMut) -> Result<(), ProtocolError> {
37        Ok(self.encode(&mut buf)?)
38    }
39
40    fn encoded_len_(&self) -> usize {
41        self.encoded_len()
42    }
43}
44
45pub type GsbMessage = packet::Packet;
46
47impl GsbMessage {
48    pub fn pong() -> GsbMessage {
49        packet::Packet::Pong(Pong {})
50    }
51}
52
53macro_rules! into_packet {
54    ($($t:ident),*) => {
55        $(
56        #[allow(clippy::from_over_into)]
57        impl Into<packet::Packet> for $t {
58            fn into(self) -> packet::Packet {
59                packet::Packet::$t(self)
60            }
61        }
62        )*
63    };
64}
65
66into_packet! {
67    RegisterRequest,
68    RegisterReply,
69    UnregisterRequest,
70    UnregisterReply,
71    CallRequest,
72    CallReply,
73    SubscribeRequest,
74    SubscribeReply,
75    UnsubscribeRequest,
76    UnsubscribeReply,
77    BroadcastRequest,
78    BroadcastReply,
79    Ping,
80    Pong
81}
82
83fn decode_header(src: &mut bytes::BytesMut) -> Result<Option<u32>, ProtocolError> {
84    if src.len() < 4 {
85        Ok(None)
86    } else {
87        let mut buf = src.split_to(4);
88        Ok(Some(buf.get_u32()))
89    }
90}
91
92fn decode_message(
93    src: &mut bytes::BytesMut,
94    msg_length: u32,
95) -> Result<Option<GsbMessage>, ProtocolError> {
96    let msg_length = msg_length
97        .try_into()
98        .map_err(|_| ProtocolError::MsgTooBig)?;
99    if src.len() < msg_length {
100        Ok(None)
101    } else {
102        let buf = src.split_to(msg_length);
103        let packet = Packet::decode(buf.as_ref())?;
104        match packet.packet {
105            Some(msg) => Ok(Some(msg)),
106            None => Err(ProtocolError::UnrecognizedMessageType),
107        }
108    }
109}
110
111fn encode_message(dst: &mut bytes::BytesMut, msg: GsbMessage) -> Result<(), ProtocolError> {
112    let packet = Packet { packet: Some(msg) };
113    let len = packet.encoded_len();
114    dst.put_u32(len as u32);
115    packet.encode(dst)?;
116    Ok(())
117}
118
119#[derive(Default)]
120pub struct GsbMessageDecoder {
121    msg_header: Option<u32>,
122}
123
124impl GsbMessageDecoder {
125    pub fn new() -> Self {
126        GsbMessageDecoder { msg_header: None }
127    }
128}
129
130impl Decoder for GsbMessageDecoder {
131    type Item = GsbMessage;
132    type Error = ProtocolError;
133
134    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
135        if self.msg_header.is_none() {
136            self.msg_header = decode_header(src)?;
137        }
138        match self.msg_header {
139            None => Ok(None),
140            Some(msg_length) => match decode_message(src, msg_length)? {
141                None => {
142                    src.reserve(msg_length as usize);
143                    Ok(None)
144                }
145                Some(msg) => {
146                    self.msg_header = None;
147                    Ok(Some(msg))
148                }
149            },
150        }
151    }
152}
153
154#[derive(Default)]
155pub struct GsbMessageEncoder;
156
157impl Encoder<GsbMessage> for GsbMessageEncoder {
158    type Error = ProtocolError;
159
160    fn encode(&mut self, item: GsbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
161        encode_message(dst, item)
162    }
163}
164
165#[derive(Default)]
166pub struct GsbMessageCodec {
167    encoder: GsbMessageEncoder,
168    decoder: GsbMessageDecoder,
169}
170
171impl Encoder<GsbMessage> for GsbMessageCodec {
172    type Error = ProtocolError;
173
174    fn encode(&mut self, item: GsbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
175        self.encoder.encode(item, dst)
176    }
177}
178
179impl Decoder for GsbMessageCodec {
180    type Item = GsbMessage;
181    type Error = ProtocolError;
182
183    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
184        self.decoder.decode(src)
185    }
186}