1use {
2 bytes::{Buf, BufMut, Bytes, BytesMut},
3 tokio_util::codec::{Decoder, Encoder},
4};
5
6use crate::{
7 compression::Compressor,
8 encryption::DesEncryptor,
9 error::{Result, ZusError},
10 protocol::{RpcMessage, RpcProtocolHeader},
11};
12
13pub struct RpcCodec {
15 max_frame_length: usize,
16 pub compressor: Compressor,
17 encryptor: Option<DesEncryptor>,
18}
19
20impl RpcCodec {
21 pub fn new() -> Self {
22 Self {
23 max_frame_length: 10 * 1024 * 1024, compressor: Compressor::new(), encryptor: None,
26 }
27 }
28
29 pub fn with_max_frame_length(max_frame_length: usize) -> Self {
30 Self {
31 max_frame_length,
32 compressor: Compressor::new(),
33 encryptor: None,
34 }
35 }
36
37 pub fn with_compressor(compressor: Compressor) -> Self {
38 Self {
39 max_frame_length: 10 * 1024 * 1024,
40 compressor,
41 encryptor: None,
42 }
43 }
44
45 pub fn with_config(max_frame_length: usize, compressor: Compressor) -> Self {
46 Self {
47 max_frame_length,
48 compressor,
49 encryptor: None,
50 }
51 }
52
53 pub fn with_encryption(key: &[u8]) -> Result<Self> {
58 let encryptor = DesEncryptor::try_new(key)?;
59 Ok(Self {
60 max_frame_length: 10 * 1024 * 1024,
61 compressor: Compressor::new(),
62 encryptor: Some(encryptor),
63 })
64 }
65
66 pub fn with_full_config(
73 max_frame_length: usize,
74 compressor: Compressor,
75 encryption_key: Option<&[u8]>,
76 ) -> Result<Self> {
77 let encryptor = match encryption_key {
78 | Some(key) => Some(DesEncryptor::try_new(key)?),
79 | None => None,
80 };
81 Ok(Self {
82 max_frame_length,
83 compressor,
84 encryptor,
85 })
86 }
87
88 pub fn set_encryption_key(&mut self, key: Option<&[u8]>) -> Result<()> {
90 self.encryptor = match key {
91 | Some(k) => Some(DesEncryptor::try_new(k)?),
92 | None => None,
93 };
94 Ok(())
95 }
96
97 pub fn is_encryption_enabled(&self) -> bool {
99 self.encryptor.is_some()
100 }
101}
102
103impl Default for RpcCodec {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl Decoder for RpcCodec {
110 type Error = ZusError;
111 type Item = RpcMessage;
112
113 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
114 if src.len() < RpcProtocolHeader::HEADER_SIZE {
116 return Ok(None);
117 }
118
119 let mut peek_buf = src.clone();
121 let header = RpcProtocolHeader::decode(&mut peek_buf)?;
122
123 let total_length = RpcProtocolHeader::HEADER_SIZE + header.body_length as usize;
125 if src.len() < total_length {
126 src.reserve(total_length - src.len());
128 return Ok(None);
129 }
130
131 if total_length > self.max_frame_length {
133 return Err(ZusError::Protocol(format!(
134 "Frame too large: {} > {}",
135 total_length, self.max_frame_length
136 )));
137 }
138
139 let header = RpcProtocolHeader::decode(src)?;
141
142 let raw_body = src.split_to(header.body_length as usize).freeze();
144
145 if !header.verify_datacrc(&raw_body) {
148 return Err(ZusError::CrcMismatch);
149 }
150
151 let decrypted_body = if header.is_encrypted() {
154 match &self.encryptor {
155 | Some(enc) => {
156 let decrypted = enc.decrypt(&raw_body)?;
157 Bytes::from(decrypted)
158 }
159 | None => {
160 return Err(ZusError::Encryption(
161 "Received encrypted message but no encryption key configured".to_string(),
162 ));
163 }
164 }
165 } else {
166 raw_body
167 };
168
169 let (method, body) = if header.msg_type == zus_proto::constants::MSG_TYPE_REQ
173 || header.msg_type == zus_proto::constants::MSG_TYPE_NOTIFY
174 {
175 let mut body_buf = decrypted_body;
179
180 if body_buf.len() < 4 {
182 return Err(ZusError::Protocol("Invalid method name length".to_string()));
183 }
184 let method_len = body_buf.get_u32() as usize;
185 if body_buf.len() < method_len {
186 return Err(ZusError::Protocol("Invalid method name".to_string()));
187 }
188 let method = body_buf.split_to(method_len);
189
190 if body_buf.len() < 4 {
192 return Err(ZusError::Protocol("Invalid params length".to_string()));
193 }
194 let params_len = body_buf.get_u32() as usize;
195 if body_buf.len() < params_len {
196 return Err(ZusError::Protocol("Invalid params data".to_string()));
197 }
198
199 let raw_params = body_buf.split_to(params_len);
201
202 let params = if header.is_compressed() {
204 self.compressor.decompress(&raw_params)?
205 } else {
206 raw_params
207 };
208
209 (method, params)
210 } else {
211 let body_data = if header.is_encrypted() {
213 let mut body_buf = decrypted_body;
215 if body_buf.len() < 4 {
216 return Err(ZusError::Protocol(
217 "Invalid encrypted response: missing length prefix".to_string(),
218 ));
219 }
220 let original_len = body_buf.get_u32() as usize;
221 if body_buf.len() < original_len {
222 return Err(ZusError::Protocol(format!(
223 "Invalid encrypted response: expected {} bytes, got {}",
224 original_len,
225 body_buf.len()
226 )));
227 }
228 body_buf.split_to(original_len)
229 } else {
230 decrypted_body
231 };
232
233 let body = if header.is_compressed() {
235 self.compressor.decompress(&body_data)?
236 } else {
237 body_data
238 };
239 (Bytes::new(), body)
240 };
241
242 Ok(Some(RpcMessage { header, method, body }))
243 }
244}
245
246impl Encoder<RpcMessage> for RpcCodec {
247 type Error = ZusError;
248
249 fn encode(&mut self, mut item: RpcMessage, dst: &mut BytesMut) -> Result<()> {
250 let is_request = !item.method.is_empty();
260 let mut full_body = BytesMut::new();
261 let was_compressed: bool;
262
263 if is_request {
264 let (compressed_params, params_compressed) = self.compressor.compress(&item.body)?;
267 was_compressed = params_compressed;
268
269 full_body.put_u32(item.method.len() as u32);
271 full_body.put(item.method);
272 full_body.put_u32(compressed_params.len() as u32); full_body.put(compressed_params);
274 } else {
275 let (compressed_body, body_compressed) = self.compressor.compress(&item.body)?;
277 was_compressed = body_compressed;
278 full_body.put(compressed_body);
279 }
280
281 let final_body = if let Some(ref enc) = self.encryptor {
284 let data_to_encrypt = if is_request {
285 full_body.freeze()
287 } else {
288 let mut with_len = BytesMut::with_capacity(4 + full_body.len());
290 with_len.put_u32(full_body.len() as u32);
291 with_len.put(full_body);
292 with_len.freeze()
293 };
294 let encrypted = enc.encrypt(&data_to_encrypt)?;
295 item.header.set_encrypted(true);
296 Bytes::from(encrypted)
297 } else {
298 item.header.set_encrypted(false);
299 full_body.freeze()
300 };
301
302 item.header.body_length = final_body.len() as u32;
304 item.header.set_compressed(was_compressed);
305
306 item.header.datacrc = RpcProtocolHeader::calculate_datacrc(&final_body);
309
310 let total_length = RpcProtocolHeader::HEADER_SIZE + final_body.len();
312 dst.reserve(total_length);
313
314 item.header.encode(dst);
316
317 dst.put(final_body);
319
320 Ok(())
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_codec_roundtrip() {
330 let mut codec = RpcCodec::new();
331
332 let method = Bytes::from("test.method");
333 let body = Bytes::from("hello");
334 let msg = RpcMessage::new_request(1, method.clone(), body.clone());
335
336 let mut buf = BytesMut::new();
337 codec.encode(msg, &mut buf).unwrap();
338
339 let decoded = codec.decode(&mut buf).unwrap().unwrap();
340 assert_eq!(decoded.header.sequence, 1);
341 assert_eq!(decoded.method, method);
342 assert_eq!(decoded.body, body);
343 }
344
345 #[test]
346 fn test_codec_with_encryption_roundtrip() {
347 let key = b"12345678";
348 let mut codec = RpcCodec::with_encryption(key).unwrap();
349
350 let method = Bytes::from("test.method");
351 let body = Bytes::from("hello encrypted world");
352 let msg = RpcMessage::new_request(1, method.clone(), body.clone());
353
354 let mut buf = BytesMut::new();
355 codec.encode(msg, &mut buf).unwrap();
356
357 let mut peek_buf = buf.clone();
359 let header = RpcProtocolHeader::decode(&mut peek_buf).unwrap();
360 assert!(header.is_encrypted());
361
362 let decoded = codec.decode(&mut buf).unwrap().unwrap();
363 assert_eq!(decoded.header.sequence, 1);
364 assert_eq!(decoded.method, method);
365 assert_eq!(decoded.body, body);
366 }
367
368 #[test]
369 fn test_codec_encrypted_response() {
370 let key = b"testkey!";
371 let mut codec = RpcCodec::with_encryption(key).unwrap();
372
373 let body = Bytes::from("response data");
375 let msg = RpcMessage::new_response(42, body.clone());
376
377 let mut buf = BytesMut::new();
378 codec.encode(msg, &mut buf).unwrap();
379
380 let decoded = codec.decode(&mut buf).unwrap().unwrap();
381 assert_eq!(decoded.header.sequence, 42);
382 assert!(decoded.method.is_empty());
383 assert_eq!(decoded.body, body);
384 }
385
386 #[test]
387 fn test_codec_encryption_key_mismatch() {
388 let key1 = b"12345678";
389 let key2 = b"87654321";
390
391 let mut encoder = RpcCodec::with_encryption(key1).unwrap();
392 let mut decoder = RpcCodec::with_encryption(key2).unwrap();
393
394 let method = Bytes::from("test.method");
395 let body = Bytes::from("hello");
396 let msg = RpcMessage::new_request(1, method, body);
397
398 let mut buf = BytesMut::new();
399 encoder.encode(msg, &mut buf).unwrap();
400
401 let result = decoder.decode(&mut buf);
403 assert!(result.is_err() || result.unwrap().is_none());
406 }
407
408 #[test]
409 fn test_codec_encrypted_without_key_configured() {
410 let key = b"12345678";
412 let mut encoder = RpcCodec::with_encryption(key).unwrap();
413
414 let method = Bytes::from("test.method");
415 let body = Bytes::from("hello");
416 let msg = RpcMessage::new_request(1, method, body);
417
418 let mut buf = BytesMut::new();
419 encoder.encode(msg, &mut buf).unwrap();
420
421 let mut decoder = RpcCodec::new();
423 let result = decoder.decode(&mut buf);
424 assert!(result.is_err());
425
426 if let Err(ZusError::Encryption(msg)) = result {
427 assert!(msg.contains("no encryption key configured"));
428 } else {
429 panic!("Expected Encryption error");
430 }
431 }
432
433 #[test]
434 fn test_codec_set_encryption_key() {
435 let mut codec = RpcCodec::new();
436 assert!(!codec.is_encryption_enabled());
437
438 codec.set_encryption_key(Some(b"12345678")).unwrap();
439 assert!(codec.is_encryption_enabled());
440
441 codec.set_encryption_key(None).unwrap();
442 assert!(!codec.is_encryption_enabled());
443 }
444
445 #[test]
446 fn test_codec_full_config() {
447 let compressor = Compressor::new();
448 let key = b"mykey123";
449
450 let codec = RpcCodec::with_full_config(5 * 1024 * 1024, compressor, Some(key)).unwrap();
451 assert!(codec.is_encryption_enabled());
452
453 let codec2 = RpcCodec::with_full_config(5 * 1024 * 1024, Compressor::new(), None).unwrap();
454 assert!(!codec2.is_encryption_enabled());
455 }
456}