1pub mod messages;
7
8pub use messages::*;
9
10pub const PROTOCOL_VERSION: u32 = 9;
19
20use bytes::{Buf, BufMut, BytesMut};
21
22pub fn encode_message<T: serde::Serialize>(msg: &T) -> Result<BytesMut, ProtocolError> {
32 let payload =
33 bincode::serialize(msg).map_err(|e| ProtocolError::Serialization(e.to_string()))?;
34 let frame_len: u32 = (4 + payload.len())
35 .try_into()
36 .map_err(|_| ProtocolError::MessageTooLarge(payload.len()))?;
37
38 let mut buf = BytesMut::with_capacity(4 + 4 + payload.len());
39 buf.put_u32_le(frame_len);
40 buf.put_u32_le(PROTOCOL_VERSION);
41 buf.extend_from_slice(&payload);
42 Ok(buf)
43}
44
45pub fn decode_message<T: serde::de::DeserializeOwned>(
55 buf: &mut BytesMut,
56) -> Result<Option<T>, ProtocolError> {
57 if buf.len() < 4 {
58 return Ok(None);
59 }
60
61 let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
62
63 if len > MAX_MESSAGE_SIZE {
64 return Err(ProtocolError::MessageTooLarge(len));
65 }
66
67 if buf.len() < 4 + len {
68 return Ok(None);
69 }
70
71 if len < 4 {
72 return Err(ProtocolError::Deserialization(
73 "frame too small for protocol version".into(),
74 ));
75 }
76
77 buf.advance(4);
78 let frame = buf.split_to(len);
79
80 let remote_ver = u32::from_le_bytes([frame[0], frame[1], frame[2], frame[3]]);
81 if remote_ver != PROTOCOL_VERSION {
82 return Err(ProtocolError::VersionMismatch {
83 expected: PROTOCOL_VERSION,
84 received: remote_ver,
85 });
86 }
87
88 let msg = bincode::deserialize(&frame[4..])
89 .map_err(|e| ProtocolError::Deserialization(e.to_string()))?;
90 Ok(Some(msg))
91}
92
93const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
95
96#[derive(Debug, thiserror::Error)]
98pub enum ProtocolError {
99 #[error("serialization error: {0}")]
100 Serialization(String),
101
102 #[error("deserialization error: {0}")]
103 Deserialization(String),
104
105 #[error("message too large: {0} bytes")]
106 MessageTooLarge(usize),
107
108 #[error(
109 "protocol version mismatch: expected v{expected}, received v{received}. \
110 Run `zccache stop` first."
111 )]
112 VersionMismatch { expected: u32, received: u32 },
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn encode_decode_roundtrip() {
121 let msg = messages::Request::Ping;
122 let encoded = encode_message(&msg).unwrap();
123 let mut buf = BytesMut::from(&encoded[..]);
124 let decoded: Option<messages::Request> = decode_message(&mut buf).unwrap();
125 assert_eq!(decoded, Some(messages::Request::Ping));
126 assert!(buf.is_empty());
127 }
128
129 #[test]
130 fn frame_includes_protocol_version() {
131 let encoded = encode_message(&messages::Request::Ping).unwrap();
132 let ver = u32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]);
134 assert_eq!(ver, PROTOCOL_VERSION);
135 }
136
137 #[test]
138 fn version_mismatch_returns_error() {
139 let mut encoded = encode_message(&messages::Request::Ping).unwrap();
140 let bad_ver: u32 = PROTOCOL_VERSION + 1;
142 encoded[4..8].copy_from_slice(&bad_ver.to_le_bytes());
143
144 let mut buf = BytesMut::from(&encoded[..]);
145 let result: Result<Option<messages::Request>, _> = decode_message(&mut buf);
146 assert!(matches!(result, Err(ProtocolError::VersionMismatch { .. })));
147 }
148
149 #[test]
150 fn old_frame_without_protocol_version_fails() {
151 let payload = bincode::serialize(&messages::Request::Ping).unwrap();
154 let len = payload.len() as u32;
155 let mut buf = BytesMut::with_capacity(4 + payload.len());
156 buf.put_u32_le(len);
157 buf.extend_from_slice(&payload);
158
159 let result: Result<Option<messages::Request>, _> = decode_message(&mut buf);
160 assert!(
163 result.is_err(),
164 "old-format frame must not decode successfully"
165 );
166 }
167
168 #[test]
169 fn incomplete_frame_returns_none() {
170 let encoded = encode_message(&messages::Request::Ping).unwrap();
171 let mut buf = BytesMut::from(&encoded[..encoded.len() - 1]);
173 let result: Option<messages::Request> = decode_message(&mut buf).unwrap();
174 assert!(result.is_none());
175 }
176}