1pub mod messages;
7
8pub use messages::*;
9
10pub const PROTOCOL_VERSION: u32 = 5;
14
15use bytes::{Buf, BufMut, BytesMut};
16
17pub fn encode_message<T: serde::Serialize>(msg: &T) -> Result<BytesMut, ProtocolError> {
27 let payload =
28 bincode::serialize(msg).map_err(|e| ProtocolError::Serialization(e.to_string()))?;
29 let frame_len: u32 = (4 + payload.len())
30 .try_into()
31 .map_err(|_| ProtocolError::MessageTooLarge(payload.len()))?;
32
33 let mut buf = BytesMut::with_capacity(4 + 4 + payload.len());
34 buf.put_u32_le(frame_len);
35 buf.put_u32_le(PROTOCOL_VERSION);
36 buf.extend_from_slice(&payload);
37 Ok(buf)
38}
39
40pub fn decode_message<T: serde::de::DeserializeOwned>(
50 buf: &mut BytesMut,
51) -> Result<Option<T>, ProtocolError> {
52 if buf.len() < 4 {
53 return Ok(None);
54 }
55
56 let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
57
58 if len > MAX_MESSAGE_SIZE {
59 return Err(ProtocolError::MessageTooLarge(len));
60 }
61
62 if buf.len() < 4 + len {
63 return Ok(None);
64 }
65
66 if len < 4 {
67 return Err(ProtocolError::Deserialization(
68 "frame too small for protocol version".into(),
69 ));
70 }
71
72 buf.advance(4);
73 let frame = buf.split_to(len);
74
75 let remote_ver = u32::from_le_bytes([frame[0], frame[1], frame[2], frame[3]]);
76 if remote_ver != PROTOCOL_VERSION {
77 return Err(ProtocolError::VersionMismatch {
78 expected: PROTOCOL_VERSION,
79 received: remote_ver,
80 });
81 }
82
83 let msg = bincode::deserialize(&frame[4..])
84 .map_err(|e| ProtocolError::Deserialization(e.to_string()))?;
85 Ok(Some(msg))
86}
87
88const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
90
91#[derive(Debug, thiserror::Error)]
93pub enum ProtocolError {
94 #[error("serialization error: {0}")]
95 Serialization(String),
96
97 #[error("deserialization error: {0}")]
98 Deserialization(String),
99
100 #[error("message too large: {0} bytes")]
101 MessageTooLarge(usize),
102
103 #[error(
104 "protocol version mismatch: expected v{expected}, received v{received}. \
105 Run `zccache stop` first."
106 )]
107 VersionMismatch { expected: u32, received: u32 },
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn encode_decode_roundtrip() {
116 let msg = messages::Request::Ping;
117 let encoded = encode_message(&msg).unwrap();
118 let mut buf = BytesMut::from(&encoded[..]);
119 let decoded: Option<messages::Request> = decode_message(&mut buf).unwrap();
120 assert_eq!(decoded, Some(messages::Request::Ping));
121 assert!(buf.is_empty());
122 }
123
124 #[test]
125 fn frame_includes_protocol_version() {
126 let encoded = encode_message(&messages::Request::Ping).unwrap();
127 let ver = u32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]);
129 assert_eq!(ver, PROTOCOL_VERSION);
130 }
131
132 #[test]
133 fn version_mismatch_returns_error() {
134 let mut encoded = encode_message(&messages::Request::Ping).unwrap();
135 let bad_ver: u32 = PROTOCOL_VERSION + 1;
137 encoded[4..8].copy_from_slice(&bad_ver.to_le_bytes());
138
139 let mut buf = BytesMut::from(&encoded[..]);
140 let result: Result<Option<messages::Request>, _> = decode_message(&mut buf);
141 assert!(matches!(result, Err(ProtocolError::VersionMismatch { .. })));
142 }
143
144 #[test]
145 fn old_frame_without_protocol_version_fails() {
146 let payload = bincode::serialize(&messages::Request::Ping).unwrap();
149 let len = payload.len() as u32;
150 let mut buf = BytesMut::with_capacity(4 + payload.len());
151 buf.put_u32_le(len);
152 buf.extend_from_slice(&payload);
153
154 let result: Result<Option<messages::Request>, _> = decode_message(&mut buf);
155 assert!(
158 result.is_err(),
159 "old-format frame must not decode successfully"
160 );
161 }
162
163 #[test]
164 fn incomplete_frame_returns_none() {
165 let encoded = encode_message(&messages::Request::Ping).unwrap();
166 let mut buf = BytesMut::from(&encoded[..encoded.len() - 1]);
168 let result: Option<messages::Request> = decode_message(&mut buf).unwrap();
169 assert!(result.is_none());
170 }
171}