1use prost::Message;
2use std::convert::TryInto;
3
4use crate::gsb_api::*;
5use thiserror::Error;
6
7use bytes::{Buf, BufMut};
8use tokio_util::codec::{Decoder, Encoder};
9
10#[derive(Debug, Error)]
11pub enum ProtocolError {
12 #[error("Unrecognized message type")]
13 UnrecognizedMessageType,
14 #[error("Cannot decode message header: not enough bytes")]
15 HeaderNotEnoughBytes,
16 #[error("{0}")]
17 Io(#[from] std::io::Error),
18 #[error("encode error: {0}")]
19 Encode(#[from] prost::EncodeError),
20 #[error("decode {0}")]
21 Decode(#[from] prost::DecodeError),
22 #[error("channel Receiver error")]
23 RecvError,
24 #[error("packet too big")]
25 MsgTooBig,
26}
27
28trait Encodable {
29 fn encode_(&self, buf: &mut bytes::BytesMut) -> Result<(), ProtocolError>;
32 fn encoded_len_(&self) -> usize;
33}
34
35impl<T: Message> Encodable for T {
36 fn encode_(&self, mut buf: &mut bytes::BytesMut) -> Result<(), ProtocolError> {
37 Ok(self.encode(&mut buf)?)
38 }
39
40 fn encoded_len_(&self) -> usize {
41 self.encoded_len()
42 }
43}
44
45pub type GsbMessage = packet::Packet;
46
47impl GsbMessage {
48 pub fn pong() -> GsbMessage {
49 packet::Packet::Pong(Pong {})
50 }
51}
52
53macro_rules! into_packet {
54 ($($t:ident),*) => {
55 $(
56 #[allow(clippy::from_over_into)]
57 impl Into<packet::Packet> for $t {
58 fn into(self) -> packet::Packet {
59 packet::Packet::$t(self)
60 }
61 }
62 )*
63 };
64}
65
66into_packet! {
67 RegisterRequest,
68 RegisterReply,
69 UnregisterRequest,
70 UnregisterReply,
71 CallRequest,
72 CallReply,
73 SubscribeRequest,
74 SubscribeReply,
75 UnsubscribeRequest,
76 UnsubscribeReply,
77 BroadcastRequest,
78 BroadcastReply,
79 Ping,
80 Pong
81}
82
83fn decode_header(src: &mut bytes::BytesMut) -> Result<Option<u32>, ProtocolError> {
84 if src.len() < 4 {
85 Ok(None)
86 } else {
87 let mut buf = src.split_to(4);
88 Ok(Some(buf.get_u32()))
89 }
90}
91
92fn decode_message(
93 src: &mut bytes::BytesMut,
94 msg_length: u32,
95) -> Result<Option<GsbMessage>, ProtocolError> {
96 let msg_length = msg_length
97 .try_into()
98 .map_err(|_| ProtocolError::MsgTooBig)?;
99 if src.len() < msg_length {
100 Ok(None)
101 } else {
102 let buf = src.split_to(msg_length);
103 let packet = Packet::decode(buf.as_ref())?;
104 match packet.packet {
105 Some(msg) => Ok(Some(msg)),
106 None => Err(ProtocolError::UnrecognizedMessageType),
107 }
108 }
109}
110
111fn encode_message(dst: &mut bytes::BytesMut, msg: GsbMessage) -> Result<(), ProtocolError> {
112 let packet = Packet { packet: Some(msg) };
113 let len = packet.encoded_len();
114 dst.put_u32(len as u32);
115 packet.encode(dst)?;
116 Ok(())
117}
118
119#[derive(Default)]
120pub struct GsbMessageDecoder {
121 msg_header: Option<u32>,
122}
123
124impl GsbMessageDecoder {
125 pub fn new() -> Self {
126 GsbMessageDecoder { msg_header: None }
127 }
128}
129
130impl Decoder for GsbMessageDecoder {
131 type Item = GsbMessage;
132 type Error = ProtocolError;
133
134 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
135 if self.msg_header.is_none() {
136 self.msg_header = decode_header(src)?;
137 }
138 match self.msg_header {
139 None => Ok(None),
140 Some(msg_length) => match decode_message(src, msg_length)? {
141 None => {
142 src.reserve(msg_length as usize);
143 Ok(None)
144 }
145 Some(msg) => {
146 self.msg_header = None;
147 Ok(Some(msg))
148 }
149 },
150 }
151 }
152}
153
154#[derive(Default)]
155pub struct GsbMessageEncoder;
156
157impl Encoder<GsbMessage> for GsbMessageEncoder {
158 type Error = ProtocolError;
159
160 fn encode(&mut self, item: GsbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
161 encode_message(dst, item)
162 }
163}
164
165#[derive(Default)]
166pub struct GsbMessageCodec {
167 encoder: GsbMessageEncoder,
168 decoder: GsbMessageDecoder,
169}
170
171impl Encoder<GsbMessage> for GsbMessageCodec {
172 type Error = ProtocolError;
173
174 fn encode(&mut self, item: GsbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
175 self.encoder.encode(item, dst)
176 }
177}
178
179impl Decoder for GsbMessageCodec {
180 type Item = GsbMessage;
181 type Error = ProtocolError;
182
183 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
184 self.decoder.decode(src)
185 }
186}