1use bytes::{Buf, BufMut, Bytes, BytesMut};
11use prost::Message;
12use thiserror::Error;
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14
15pub const MAX_FRAME_SIZE: usize = 64 * 1024 * 1024;
18
19pub const HEADER_SIZE: usize = 6;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24#[repr(u16)]
25pub enum MessageType {
26 Request = 1,
28 Response = 2,
30 StreamStart = 3,
32 StreamData = 4,
34 StreamEnd = 5,
36 Error = 6,
38}
39
40impl TryFrom<u16> for MessageType {
41 type Error = FrameError;
42
43 fn try_from(value: u16) -> Result<Self, <Self as TryFrom<u16>>::Error> {
44 match value {
45 1 => Ok(MessageType::Request),
46 2 => Ok(MessageType::Response),
47 3 => Ok(MessageType::StreamStart),
48 4 => Ok(MessageType::StreamData),
49 5 => Ok(MessageType::StreamEnd),
50 6 => Ok(MessageType::Error),
51 _ => Err(FrameError::InvalidMessageType(value)),
52 }
53 }
54}
55
56#[derive(Debug, Error)]
58pub enum FrameError {
59 #[error("frame too large: {0} bytes (max: {MAX_FRAME_SIZE})")]
60 FrameTooLarge(usize),
61
62 #[error("invalid message type: {0}")]
63 InvalidMessageType(u16),
64
65 #[error("IO error: {0}")]
66 Io(#[from] std::io::Error),
67
68 #[error("protobuf decode error: {0}")]
69 Decode(#[from] prost::DecodeError),
70
71 #[error("connection closed")]
72 ConnectionClosed,
73}
74
75#[derive(Debug, Clone)]
77pub struct Frame {
78 pub message_type: MessageType,
79 pub payload: Bytes,
80}
81
82impl Frame {
83 pub fn request<M: Message>(msg: &M) -> Result<Self, FrameError> {
85 Self::new(MessageType::Request, msg)
86 }
87
88 pub fn response<M: Message>(msg: &M) -> Result<Self, FrameError> {
90 Self::new(MessageType::Response, msg)
91 }
92
93 pub fn error<M: Message>(msg: &M) -> Result<Self, FrameError> {
95 Self::new(MessageType::Error, msg)
96 }
97
98 pub fn stream_data<M: Message>(msg: &M) -> Result<Self, FrameError> {
100 Self::new(MessageType::StreamData, msg)
101 }
102
103 pub fn new<M: Message>(message_type: MessageType, msg: &M) -> Result<Self, FrameError> {
105 let payload = msg.encode_to_vec();
106 if payload.len() > MAX_FRAME_SIZE {
107 return Err(FrameError::FrameTooLarge(payload.len()));
108 }
109 Ok(Self {
110 message_type,
111 payload: Bytes::from(payload),
112 })
113 }
114
115 pub fn decode<M: Message + Default>(&self) -> Result<M, FrameError> {
117 Ok(M::decode(self.payload.clone())?)
118 }
119
120 pub fn encode(&self) -> Bytes {
122 let mut buf = BytesMut::with_capacity(HEADER_SIZE + self.payload.len());
123 buf.put_u32(self.payload.len() as u32);
124 buf.put_u16(self.message_type as u16);
125 buf.put(self.payload.clone());
126 buf.freeze()
127 }
128
129 pub fn decode_from_bytes(mut bytes: Bytes) -> Result<Self, FrameError> {
131 if bytes.len() < HEADER_SIZE {
132 return Err(FrameError::Io(std::io::Error::new(
133 std::io::ErrorKind::UnexpectedEof,
134 "incomplete frame header",
135 )));
136 }
137
138 let length = bytes.get_u32() as usize;
139 let message_type = MessageType::try_from(bytes.get_u16())?;
140
141 if length > MAX_FRAME_SIZE {
142 return Err(FrameError::FrameTooLarge(length));
143 }
144
145 if bytes.len() < length {
146 return Err(FrameError::Io(std::io::Error::new(
147 std::io::ErrorKind::UnexpectedEof,
148 "incomplete frame payload",
149 )));
150 }
151
152 let payload = bytes.split_to(length);
153 Ok(Self {
154 message_type,
155 payload,
156 })
157 }
158}
159
160pub async fn write_frame<W: AsyncWrite + Unpin>(
162 writer: &mut W,
163 frame: &Frame,
164) -> Result<(), FrameError> {
165 let encoded = frame.encode();
166 writer.write_all(&encoded).await?;
167 Ok(())
168}
169
170pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Frame, FrameError> {
172 let mut header = [0u8; HEADER_SIZE];
174 match reader.read_exact(&mut header).await {
175 Ok(_) => {}
176 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
177 return Err(FrameError::ConnectionClosed);
178 }
179 Err(e) => return Err(e.into()),
180 }
181
182 let length = u32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize;
183 let message_type = MessageType::try_from(u16::from_be_bytes([header[4], header[5]]))?;
184
185 if length > MAX_FRAME_SIZE {
186 return Err(FrameError::FrameTooLarge(length));
187 }
188
189 let mut payload = vec![0u8; length];
191 reader.read_exact(&mut payload).await?;
192
193 Ok(Frame {
194 message_type,
195 payload: Bytes::from(payload),
196 })
197}
198
199pub struct FramedStream<S> {
201 stream: S,
202}
203
204impl<S> FramedStream<S> {
205 pub fn new(stream: S) -> Self {
206 Self { stream }
207 }
208
209 pub fn into_inner(self) -> S {
210 self.stream
211 }
212}
213
214impl<S: AsyncRead + Unpin> FramedStream<S> {
215 pub async fn read_frame(&mut self) -> Result<Frame, FrameError> {
217 read_frame(&mut self.stream).await
218 }
219}
220
221impl<S: AsyncWrite + Unpin> FramedStream<S> {
222 pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), FrameError> {
224 write_frame(&mut self.stream, frame).await
225 }
226}
227
228impl<S: AsyncRead + AsyncWrite + Unpin> FramedStream<S> {
229 pub async fn request<Req: Message, Resp: Message + Default>(
231 &mut self,
232 request: &Req,
233 ) -> Result<Resp, FrameError> {
234 let frame = Frame::request(request)?;
235 self.write_frame(&frame).await?;
236
237 let response_frame = self.read_frame().await?;
238 match response_frame.message_type {
239 MessageType::Response => response_frame.decode(),
240 MessageType::Error => {
241 Err(FrameError::Io(std::io::Error::other(
243 "received error response",
244 )))
245 }
246 _ => Err(FrameError::Io(std::io::Error::new(
247 std::io::ErrorKind::InvalidData,
248 "unexpected message type",
249 ))),
250 }
251 }
252
253 pub async fn respond<Resp: Message>(&mut self, response: &Resp) -> Result<(), FrameError> {
255 let frame = Frame::response(response)?;
256 self.write_frame(&frame).await
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_message_type_round_trip() {
266 for &mt in &[
267 MessageType::Request,
268 MessageType::Response,
269 MessageType::StreamStart,
270 MessageType::StreamData,
271 MessageType::StreamEnd,
272 MessageType::Error,
273 ] {
274 let value = mt as u16;
275 let decoded = MessageType::try_from(value).unwrap();
276 assert_eq!(mt, decoded);
277 }
278 }
279
280 #[test]
281 fn test_frame_encode_decode() {
282 use crate::management_proto::HealthCheckRequest;
283
284 let msg = HealthCheckRequest {};
285 let frame = Frame::request(&msg).unwrap();
286 let encoded = frame.encode();
287 let decoded = Frame::decode_from_bytes(encoded).unwrap();
288
289 assert_eq!(frame.message_type, decoded.message_type);
290 assert_eq!(frame.payload, decoded.payload);
291 }
292}