1use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6#[derive(Error, Debug)]
8pub enum ProtocolError {
9 #[error("Invalid message format: {0}")]
11 InvalidFormat(String),
12
13 #[error("Invalid opcode: {0}")]
15 InvalidOpCode(u8),
16
17 #[error("Invalid status code: {0}")]
19 InvalidStatusCode(u8),
20
21 #[error("Message too large: {0} bytes")]
23 MessageTooLarge(usize),
24
25 #[error("Serialization error: {0}")]
27 Serialization(String),
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum OpCode {
33 Ping = 0x01,
35 Set = 0x02,
37 Get = 0x03,
39 Delete = 0x04,
41 Cas = 0x05,
43 Subscribe = 0x06,
45 Unsubscribe = 0x07,
47 Publish = 0x08,
49 Fetch = 0x09,
51 Info = 0x0A,
53}
54
55impl TryFrom<u8> for OpCode {
56 type Error = ProtocolError;
57
58 fn try_from(value: u8) -> Result<Self, Self::Error> {
59 match value {
60 0x01 => Ok(OpCode::Ping),
61 0x02 => Ok(OpCode::Set),
62 0x03 => Ok(OpCode::Get),
63 0x04 => Ok(OpCode::Delete),
64 0x05 => Ok(OpCode::Cas),
65 0x06 => Ok(OpCode::Subscribe),
66 0x07 => Ok(OpCode::Unsubscribe),
67 0x08 => Ok(OpCode::Publish),
68 0x09 => Ok(OpCode::Fetch),
69 0x0A => Ok(OpCode::Info),
70 _ => Err(ProtocolError::InvalidOpCode(value)),
71 }
72 }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum StatusCode {
78 Ok = 0x00,
80 NotFound = 0x01,
82 VersionMismatch = 0x02,
84 InvalidArgs = 0x03,
86 InternalError = 0x04,
88 Unauthorized = 0x05,
90 NotSupported = 0x06,
92}
93
94impl TryFrom<u8> for StatusCode {
95 type Error = ProtocolError;
96
97 fn try_from(value: u8) -> Result<Self, Self::Error> {
98 match value {
99 0x00 => Ok(StatusCode::Ok),
100 0x01 => Ok(StatusCode::NotFound),
101 0x02 => Ok(StatusCode::VersionMismatch),
102 0x03 => Ok(StatusCode::InvalidArgs),
103 0x04 => Ok(StatusCode::InternalError),
104 0x05 => Ok(StatusCode::Unauthorized),
105 0x06 => Ok(StatusCode::NotSupported),
106 _ => Err(ProtocolError::InvalidStatusCode(value)),
107 }
108 }
109}
110
111#[derive(Debug, Clone, Copy)]
113#[repr(C)]
114pub struct CommandHeader {
115 pub opcode: u8,
117 pub flags: u8,
119 pub reserved: u16,
121 pub seq: u32,
123 pub key_len: u32,
125 pub value_len: u32,
127 pub extra: u64,
129}
130
131impl CommandHeader {
132 pub fn new(opcode: OpCode, seq: u32) -> Self {
134 Self {
135 opcode: opcode as u8,
136 flags: 0,
137 reserved: 0,
138 seq,
139 key_len: 0,
140 value_len: 0,
141 extra: 0,
142 }
143 }
144
145 pub fn with_lengths(mut self, key_len: u32, value_len: u32) -> Self {
147 self.key_len = key_len;
148 self.value_len = value_len;
149 self
150 }
151
152 pub fn with_extra(mut self, extra: u64) -> Self {
154 self.extra = extra;
155 self
156 }
157
158 pub fn with_flag(mut self, flag: u8) -> Self {
160 self.flags |= flag;
161 self
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct Command {
168 pub header: CommandHeader,
170 pub key: Bytes,
172 pub value: Bytes,
174}
175
176impl Command {
177 pub fn new(header: CommandHeader, key: impl Into<Bytes>, value: impl Into<Bytes>) -> Self {
179 let key = key.into();
180 let value = value.into();
181 Self {
182 header: header.with_lengths(key.len() as u32, value.len() as u32),
183 key,
184 value,
185 }
186 }
187
188 pub fn ping(seq: u32) -> Self {
190 Self::new(
191 CommandHeader::new(OpCode::Ping, seq),
192 Bytes::new(),
193 Bytes::new(),
194 )
195 }
196
197 pub fn set<K, V>(seq: u32, key: K, value: V) -> Self
199 where
200 K: Into<Bytes>,
201 V: Into<Bytes>,
202 {
203 Self::new(CommandHeader::new(OpCode::Set, seq), key, value)
204 }
205
206 pub fn get<K>(seq: u32, key: K) -> Self
208 where
209 K: Into<Bytes>,
210 {
211 Self::new(CommandHeader::new(OpCode::Get, seq), key, Bytes::new())
212 }
213
214 pub fn delete<K>(seq: u32, key: K) -> Self
216 where
217 K: Into<Bytes>,
218 {
219 Self::new(CommandHeader::new(OpCode::Delete, seq), key, Bytes::new())
220 }
221
222 pub fn cas<K, V>(seq: u32, key: K, expected_version: u64, value: V) -> Self
224 where
225 K: Into<Bytes>,
226 V: Into<Bytes>,
227 {
228 Self::new(
229 CommandHeader::new(OpCode::Cas, seq).with_extra(expected_version),
230 key,
231 value,
232 )
233 }
234
235 pub fn fetch(seq: u32, key: impl Into<Bytes>) -> Self {
237 Self::new(CommandHeader::new(OpCode::Fetch, seq), key, Bytes::new())
238 }
239
240 pub fn to_bytes(&self) -> Bytes {
242 let mut buf = BytesMut::with_capacity(24 + self.key.len() + self.value.len());
243
244 buf.put_u8(self.header.opcode);
246 buf.put_u8(self.header.flags);
247 buf.put_u16_le(self.header.reserved);
248 buf.put_u32_le(self.header.seq);
249 buf.put_u32_le(self.header.key_len);
250 buf.put_u32_le(self.header.value_len);
251 buf.put_u64_le(self.header.extra);
252
253 buf.extend_from_slice(&self.key);
255 buf.extend_from_slice(&self.value);
256
257 buf.freeze()
258 }
259}
260
261#[derive(Debug, Clone, Copy)]
263#[repr(C)]
264pub struct ResponseHeader {
265 pub status: u8,
267 pub flags: u8,
269 pub reserved: u16,
271 pub seq: u32,
273 pub payload_len: u32,
275 pub extra: u64,
277}
278
279impl ResponseHeader {
280 pub fn new(status: StatusCode, seq: u32) -> Self {
282 Self {
283 status: status as u8,
284 flags: 0,
285 reserved: 0,
286 seq,
287 payload_len: 0,
288 extra: 0,
289 }
290 }
291
292 pub fn with_payload_len(mut self, len: u32) -> Self {
294 self.payload_len = len;
295 self
296 }
297}
298
299#[derive(Debug, Clone)]
301pub struct Response {
302 pub header: ResponseHeader,
304 pub payload: Bytes,
306}
307
308impl Response {
309 pub fn new(header: ResponseHeader, payload: impl Into<Bytes>) -> Self {
311 let payload = payload.into();
312 Self {
313 header: header.with_payload_len(payload.len() as u32),
314 payload,
315 }
316 }
317
318 pub fn ok(seq: u32, payload: impl Into<Bytes>) -> Self {
320 Self::new(ResponseHeader::new(StatusCode::Ok, seq), payload)
321 }
322
323 pub fn not_found(seq: u32) -> Self {
325 Self::new(ResponseHeader::new(StatusCode::NotFound, seq), Bytes::new())
326 }
327
328 pub fn error(seq: u32) -> Self {
330 Self::new(
331 ResponseHeader::new(StatusCode::InternalError, seq),
332 Bytes::new(),
333 )
334 }
335
336 pub fn from_bytes(mut bytes: &[u8]) -> Result<Self, ProtocolError> {
338 if bytes.len() < 20 {
339 return Err(ProtocolError::InvalidFormat("response too short".into()));
340 }
341
342 let status = StatusCode::try_from(bytes.get_u8())?;
344 let flags = bytes.get_u8();
345 let reserved = bytes.get_u16_le();
346 let seq = bytes.get_u32_le();
347 let payload_len = bytes.get_u32_le() as usize;
348 let extra = bytes.get_u64_le();
349
350 if bytes.remaining() < payload_len {
352 return Err(ProtocolError::InvalidFormat(
353 "invalid payload length".into(),
354 ));
355 }
356
357 let payload = bytes.copy_to_bytes(payload_len);
359
360 Ok(Self {
361 header: ResponseHeader {
362 status: status as u8,
363 flags,
364 reserved,
365 seq,
366 payload_len: payload_len as u32,
367 extra,
368 },
369 payload,
370 })
371 }
372
373 pub fn is_ok(&self) -> bool {
375 matches!(StatusCode::try_from(self.header.status), Ok(StatusCode::Ok))
376 }
377
378 pub fn status(&self) -> StatusCode {
380 StatusCode::try_from(self.header.status).unwrap_or(StatusCode::InternalError)
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use bytes::Bytes;
388
389 #[test]
390 fn test_command_serialization() {
391 let cmd = Command::set(1, "key", "value");
392 let bytes = cmd.to_bytes();
393
394 assert_eq!(bytes[0], OpCode::Set as u8);
395 assert_eq!(&bytes[24..27], b"key");
396 assert_eq!(&bytes[27..32], b"value");
397 }
398
399 #[test]
400 fn test_response_deserialization() {
401 let mut buf = BytesMut::new();
402 buf.put_u8(StatusCode::Ok as u8); buf.put_u8(0); buf.put_u16_le(0); buf.put_u32_le(42); buf.put_u32_le(5); buf.put_u64_le(0); buf.extend_from_slice(b"hello"); let resp = Response::from_bytes(&buf).unwrap();
411 assert!(resp.is_ok());
412 assert_eq!(resp.header.seq, 42);
413 assert_eq!(&resp.payload[..], b"hello");
414 }
415}