1use crate::error::{Result, TunnelError};
11use serde::{Deserialize, Serialize};
12use uuid::Uuid;
13
14pub const PROTOCOL_VERSION: u8 = 1;
16
17pub const MAX_MESSAGE_SIZE: usize = 65536;
19
20pub const HEADER_SIZE: usize = 5;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(u8)]
26pub enum MessageType {
27 Auth = 0x01,
29 AuthOk = 0x02,
31 AuthFail = 0x03,
33 Register = 0x10,
35 RegisterOk = 0x11,
37 RegisterFail = 0x12,
39 Connect = 0x20,
41 ConnectAck = 0x21,
43 ConnectFail = 0x22,
45 Heartbeat = 0x30,
47 HeartbeatAck = 0x31,
49 Unregister = 0x40,
51 Disconnect = 0x41,
53}
54
55impl TryFrom<u8> for MessageType {
56 type Error = TunnelError;
57
58 fn try_from(value: u8) -> Result<Self> {
59 match value {
60 0x01 => Ok(Self::Auth),
61 0x02 => Ok(Self::AuthOk),
62 0x03 => Ok(Self::AuthFail),
63 0x10 => Ok(Self::Register),
64 0x11 => Ok(Self::RegisterOk),
65 0x12 => Ok(Self::RegisterFail),
66 0x20 => Ok(Self::Connect),
67 0x21 => Ok(Self::ConnectAck),
68 0x22 => Ok(Self::ConnectFail),
69 0x30 => Ok(Self::Heartbeat),
70 0x31 => Ok(Self::HeartbeatAck),
71 0x40 => Ok(Self::Unregister),
72 0x41 => Ok(Self::Disconnect),
73 _ => Err(TunnelError::protocol(format!(
74 "unknown message type: 0x{value:02x}"
75 ))),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
82#[serde(rename_all = "lowercase")]
83pub enum ServiceProtocol {
84 #[default]
86 Tcp,
87 Udp,
89}
90
91impl ServiceProtocol {
92 #[must_use]
94 pub const fn to_byte(self) -> u8 {
95 match self {
96 Self::Tcp => 0,
97 Self::Udp => 1,
98 }
99 }
100
101 pub fn from_byte(byte: u8) -> Result<Self> {
107 match byte {
108 0 => Ok(Self::Tcp),
109 1 => Ok(Self::Udp),
110 _ => Err(TunnelError::protocol(format!(
111 "unknown protocol type: {byte}"
112 ))),
113 }
114 }
115}
116
117#[derive(Debug, Clone, PartialEq, Eq)]
119pub enum Message {
120 Auth {
122 token: String,
124 client_id: Uuid,
126 },
127
128 AuthOk {
130 tunnel_id: Uuid,
132 },
133
134 AuthFail {
136 reason: String,
138 },
139
140 Register {
142 name: String,
144 protocol: ServiceProtocol,
146 local_port: u16,
148 remote_port: u16,
150 },
151
152 RegisterOk {
154 service_id: Uuid,
156 },
157
158 RegisterFail {
160 reason: String,
162 },
163
164 Connect {
166 service_id: Uuid,
168 connection_id: Uuid,
170 client_addr: String,
172 },
173
174 ConnectAck {
176 connection_id: Uuid,
178 },
179
180 ConnectFail {
182 connection_id: Uuid,
184 reason: String,
186 },
187
188 Heartbeat {
190 timestamp: u64,
192 },
193
194 HeartbeatAck {
196 timestamp: u64,
198 },
199
200 Unregister {
202 service_id: Uuid,
204 },
205
206 Disconnect {
208 reason: String,
210 },
211}
212
213impl Message {
214 #[must_use]
216 pub const fn message_type(&self) -> MessageType {
217 match self {
218 Self::Auth { .. } => MessageType::Auth,
219 Self::AuthOk { .. } => MessageType::AuthOk,
220 Self::AuthFail { .. } => MessageType::AuthFail,
221 Self::Register { .. } => MessageType::Register,
222 Self::RegisterOk { .. } => MessageType::RegisterOk,
223 Self::RegisterFail { .. } => MessageType::RegisterFail,
224 Self::Connect { .. } => MessageType::Connect,
225 Self::ConnectAck { .. } => MessageType::ConnectAck,
226 Self::ConnectFail { .. } => MessageType::ConnectFail,
227 Self::Heartbeat { .. } => MessageType::Heartbeat,
228 Self::HeartbeatAck { .. } => MessageType::HeartbeatAck,
229 Self::Unregister { .. } => MessageType::Unregister,
230 Self::Disconnect { .. } => MessageType::Disconnect,
231 }
232 }
233
234 #[must_use]
236 #[allow(clippy::cast_possible_truncation, clippy::match_same_arms)]
237 pub fn encode(&self) -> Vec<u8> {
238 let mut payload = Vec::new();
239
240 match self {
241 Self::Auth { token, client_id } => {
242 let token_bytes = token.as_bytes();
244 payload.extend_from_slice(&(token_bytes.len() as u16).to_be_bytes());
245 payload.extend_from_slice(token_bytes);
246 payload.extend_from_slice(client_id.as_bytes());
247 }
248
249 Self::AuthOk { tunnel_id } => {
250 payload.extend_from_slice(tunnel_id.as_bytes());
252 }
253
254 Self::AuthFail { reason } => {
255 let reason_bytes = reason.as_bytes();
257 payload.extend_from_slice(&(reason_bytes.len() as u16).to_be_bytes());
258 payload.extend_from_slice(reason_bytes);
259 }
260
261 Self::Register {
262 name,
263 protocol,
264 local_port,
265 remote_port,
266 } => {
267 let name_bytes = name.as_bytes();
269 payload.push(name_bytes.len() as u8);
270 payload.extend_from_slice(name_bytes);
271 payload.push(protocol.to_byte());
272 payload.extend_from_slice(&local_port.to_be_bytes());
273 payload.extend_from_slice(&remote_port.to_be_bytes());
274 }
275
276 Self::RegisterOk { service_id } => {
277 payload.extend_from_slice(service_id.as_bytes());
279 }
280
281 Self::RegisterFail { reason } => {
282 let reason_bytes = reason.as_bytes();
284 payload.extend_from_slice(&(reason_bytes.len() as u16).to_be_bytes());
285 payload.extend_from_slice(reason_bytes);
286 }
287
288 Self::Connect {
289 service_id,
290 connection_id,
291 client_addr,
292 } => {
293 payload.extend_from_slice(service_id.as_bytes());
295 payload.extend_from_slice(connection_id.as_bytes());
296 let addr_bytes = client_addr.as_bytes();
297 payload.extend_from_slice(&(addr_bytes.len() as u16).to_be_bytes());
298 payload.extend_from_slice(addr_bytes);
299 }
300
301 Self::ConnectAck { connection_id } => {
302 payload.extend_from_slice(connection_id.as_bytes());
304 }
305
306 Self::ConnectFail {
307 connection_id,
308 reason,
309 } => {
310 payload.extend_from_slice(connection_id.as_bytes());
312 let reason_bytes = reason.as_bytes();
313 payload.extend_from_slice(&(reason_bytes.len() as u16).to_be_bytes());
314 payload.extend_from_slice(reason_bytes);
315 }
316
317 Self::Heartbeat { timestamp } | Self::HeartbeatAck { timestamp } => {
318 payload.extend_from_slice(×tamp.to_be_bytes());
320 }
321
322 Self::Unregister { service_id } => {
323 payload.extend_from_slice(service_id.as_bytes());
325 }
326
327 Self::Disconnect { reason } => {
328 let reason_bytes = reason.as_bytes();
330 payload.extend_from_slice(&(reason_bytes.len() as u16).to_be_bytes());
331 payload.extend_from_slice(reason_bytes);
332 }
333 }
334
335 let msg_type = self.message_type() as u8;
337 let payload_len = payload.len() as u32;
338
339 let mut result = Vec::with_capacity(HEADER_SIZE + payload.len());
340 result.push(msg_type);
341 result.extend_from_slice(&payload_len.to_be_bytes());
342 result.extend_from_slice(&payload);
343
344 result
345 }
346
347 pub fn decode(bytes: &[u8]) -> Result<(Self, usize)> {
358 if bytes.len() < HEADER_SIZE {
359 return Err(TunnelError::protocol(format!(
360 "message too short: {} bytes, need at least {}",
361 bytes.len(),
362 HEADER_SIZE
363 )));
364 }
365
366 let msg_type = MessageType::try_from(bytes[0])?;
367 let payload_len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
368
369 if payload_len > MAX_MESSAGE_SIZE - HEADER_SIZE {
370 return Err(TunnelError::protocol(format!(
371 "payload too large: {payload_len} bytes, max {}",
372 MAX_MESSAGE_SIZE - HEADER_SIZE
373 )));
374 }
375
376 let total_len = HEADER_SIZE + payload_len;
377 if bytes.len() < total_len {
378 return Err(TunnelError::protocol(format!(
379 "incomplete message: have {} bytes, need {}",
380 bytes.len(),
381 total_len
382 )));
383 }
384
385 let payload = &bytes[HEADER_SIZE..total_len];
386 let message = Self::decode_payload(msg_type, payload)?;
387
388 Ok((message, total_len))
389 }
390
391 #[allow(clippy::too_many_lines)]
393 fn decode_payload(msg_type: MessageType, payload: &[u8]) -> Result<Self> {
394 match msg_type {
395 MessageType::Auth => Self::decode_auth(payload),
396 MessageType::AuthOk => Self::decode_auth_ok(payload),
397 MessageType::AuthFail => Self::decode_auth_fail(payload),
398 MessageType::Register => Self::decode_register(payload),
399 MessageType::RegisterOk => Self::decode_register_ok(payload),
400 MessageType::RegisterFail => Self::decode_register_fail(payload),
401 MessageType::Connect => Self::decode_connect(payload),
402 MessageType::ConnectAck => Self::decode_connect_ack(payload),
403 MessageType::ConnectFail => Self::decode_connect_fail(payload),
404 MessageType::Heartbeat => Self::decode_heartbeat(payload),
405 MessageType::HeartbeatAck => Self::decode_heartbeat_ack(payload),
406 MessageType::Unregister => Self::decode_unregister(payload),
407 MessageType::Disconnect => Self::decode_disconnect(payload),
408 }
409 }
410
411 fn decode_auth(payload: &[u8]) -> Result<Self> {
412 if payload.len() < 2 {
414 return Err(TunnelError::protocol(
415 "Auth: payload too short for token length",
416 ));
417 }
418 let token_len = u16::from_be_bytes([payload[0], payload[1]]) as usize;
419 if payload.len() < 2 + token_len + 16 {
420 return Err(TunnelError::protocol("Auth: payload too short"));
421 }
422 let token = String::from_utf8(payload[2..2 + token_len].to_vec())
423 .map_err(|e| TunnelError::protocol(format!("Auth: invalid token UTF-8: {e}")))?;
424 let client_id = Uuid::from_slice(&payload[2 + token_len..2 + token_len + 16])
425 .map_err(|e| TunnelError::protocol(format!("Auth: invalid client_id: {e}")))?;
426 Ok(Self::Auth { token, client_id })
427 }
428
429 fn decode_auth_ok(payload: &[u8]) -> Result<Self> {
430 if payload.len() < 16 {
432 return Err(TunnelError::protocol("AuthOk: payload too short"));
433 }
434 let tunnel_id = Uuid::from_slice(&payload[..16])
435 .map_err(|e| TunnelError::protocol(format!("AuthOk: invalid tunnel_id: {e}")))?;
436 Ok(Self::AuthOk { tunnel_id })
437 }
438
439 fn decode_auth_fail(payload: &[u8]) -> Result<Self> {
440 if payload.len() < 2 {
442 return Err(TunnelError::protocol(
443 "AuthFail: payload too short for reason length",
444 ));
445 }
446 let reason_len = u16::from_be_bytes([payload[0], payload[1]]) as usize;
447 if payload.len() < 2 + reason_len {
448 return Err(TunnelError::protocol(
449 "AuthFail: payload too short for reason",
450 ));
451 }
452 let reason = String::from_utf8(payload[2..2 + reason_len].to_vec())
453 .map_err(|e| TunnelError::protocol(format!("AuthFail: invalid reason UTF-8: {e}")))?;
454 Ok(Self::AuthFail { reason })
455 }
456
457 fn decode_register(payload: &[u8]) -> Result<Self> {
458 if payload.is_empty() {
460 return Err(TunnelError::protocol(
461 "Register: payload too short for name length",
462 ));
463 }
464 let name_len = payload[0] as usize;
465 if payload.len() < 1 + name_len + 1 + 2 + 2 {
466 return Err(TunnelError::protocol("Register: payload too short"));
467 }
468 let name = String::from_utf8(payload[1..=name_len].to_vec())
469 .map_err(|e| TunnelError::protocol(format!("Register: invalid name UTF-8: {e}")))?;
470 let protocol = ServiceProtocol::from_byte(payload[1 + name_len])?;
471 let local_port = u16::from_be_bytes([payload[2 + name_len], payload[3 + name_len]]);
472 let remote_port = u16::from_be_bytes([payload[4 + name_len], payload[5 + name_len]]);
473 Ok(Self::Register {
474 name,
475 protocol,
476 local_port,
477 remote_port,
478 })
479 }
480
481 fn decode_register_ok(payload: &[u8]) -> Result<Self> {
482 if payload.len() < 16 {
484 return Err(TunnelError::protocol("RegisterOk: payload too short"));
485 }
486 let service_id = Uuid::from_slice(&payload[..16])
487 .map_err(|e| TunnelError::protocol(format!("RegisterOk: invalid service_id: {e}")))?;
488 Ok(Self::RegisterOk { service_id })
489 }
490
491 fn decode_register_fail(payload: &[u8]) -> Result<Self> {
492 if payload.len() < 2 {
494 return Err(TunnelError::protocol(
495 "RegisterFail: payload too short for reason length",
496 ));
497 }
498 let reason_len = u16::from_be_bytes([payload[0], payload[1]]) as usize;
499 if payload.len() < 2 + reason_len {
500 return Err(TunnelError::protocol(
501 "RegisterFail: payload too short for reason",
502 ));
503 }
504 let reason = String::from_utf8(payload[2..2 + reason_len].to_vec()).map_err(|e| {
505 TunnelError::protocol(format!("RegisterFail: invalid reason UTF-8: {e}"))
506 })?;
507 Ok(Self::RegisterFail { reason })
508 }
509
510 fn decode_connect(payload: &[u8]) -> Result<Self> {
511 if payload.len() < 16 + 16 + 2 {
513 return Err(TunnelError::protocol("Connect: payload too short"));
514 }
515 let service_id = Uuid::from_slice(&payload[..16])
516 .map_err(|e| TunnelError::protocol(format!("Connect: invalid service_id: {e}")))?;
517 let connection_id = Uuid::from_slice(&payload[16..32])
518 .map_err(|e| TunnelError::protocol(format!("Connect: invalid connection_id: {e}")))?;
519 let addr_len = u16::from_be_bytes([payload[32], payload[33]]) as usize;
520 if payload.len() < 34 + addr_len {
521 return Err(TunnelError::protocol(
522 "Connect: payload too short for client_addr",
523 ));
524 }
525 let client_addr = String::from_utf8(payload[34..34 + addr_len].to_vec()).map_err(|e| {
526 TunnelError::protocol(format!("Connect: invalid client_addr UTF-8: {e}"))
527 })?;
528 Ok(Self::Connect {
529 service_id,
530 connection_id,
531 client_addr,
532 })
533 }
534
535 fn decode_connect_ack(payload: &[u8]) -> Result<Self> {
536 if payload.len() < 16 {
538 return Err(TunnelError::protocol("ConnectAck: payload too short"));
539 }
540 let connection_id = Uuid::from_slice(&payload[..16]).map_err(|e| {
541 TunnelError::protocol(format!("ConnectAck: invalid connection_id: {e}"))
542 })?;
543 Ok(Self::ConnectAck { connection_id })
544 }
545
546 fn decode_connect_fail(payload: &[u8]) -> Result<Self> {
547 if payload.len() < 16 + 2 {
549 return Err(TunnelError::protocol("ConnectFail: payload too short"));
550 }
551 let connection_id = Uuid::from_slice(&payload[..16]).map_err(|e| {
552 TunnelError::protocol(format!("ConnectFail: invalid connection_id: {e}"))
553 })?;
554 let reason_len = u16::from_be_bytes([payload[16], payload[17]]) as usize;
555 if payload.len() < 18 + reason_len {
556 return Err(TunnelError::protocol(
557 "ConnectFail: payload too short for reason",
558 ));
559 }
560 let reason = String::from_utf8(payload[18..18 + reason_len].to_vec()).map_err(|e| {
561 TunnelError::protocol(format!("ConnectFail: invalid reason UTF-8: {e}"))
562 })?;
563 Ok(Self::ConnectFail {
564 connection_id,
565 reason,
566 })
567 }
568
569 fn decode_heartbeat(payload: &[u8]) -> Result<Self> {
570 if payload.len() < 8 {
572 return Err(TunnelError::protocol("Heartbeat: payload too short"));
573 }
574 let timestamp = u64::from_be_bytes([
575 payload[0], payload[1], payload[2], payload[3], payload[4], payload[5], payload[6],
576 payload[7],
577 ]);
578 Ok(Self::Heartbeat { timestamp })
579 }
580
581 fn decode_heartbeat_ack(payload: &[u8]) -> Result<Self> {
582 if payload.len() < 8 {
584 return Err(TunnelError::protocol("HeartbeatAck: payload too short"));
585 }
586 let timestamp = u64::from_be_bytes([
587 payload[0], payload[1], payload[2], payload[3], payload[4], payload[5], payload[6],
588 payload[7],
589 ]);
590 Ok(Self::HeartbeatAck { timestamp })
591 }
592
593 fn decode_unregister(payload: &[u8]) -> Result<Self> {
594 if payload.len() < 16 {
596 return Err(TunnelError::protocol("Unregister: payload too short"));
597 }
598 let service_id = Uuid::from_slice(&payload[..16])
599 .map_err(|e| TunnelError::protocol(format!("Unregister: invalid service_id: {e}")))?;
600 Ok(Self::Unregister { service_id })
601 }
602
603 fn decode_disconnect(payload: &[u8]) -> Result<Self> {
604 if payload.len() < 2 {
606 return Err(TunnelError::protocol(
607 "Disconnect: payload too short for reason length",
608 ));
609 }
610 let reason_len = u16::from_be_bytes([payload[0], payload[1]]) as usize;
611 if payload.len() < 2 + reason_len {
612 return Err(TunnelError::protocol(
613 "Disconnect: payload too short for reason",
614 ));
615 }
616 let reason = String::from_utf8(payload[2..2 + reason_len].to_vec())
617 .map_err(|e| TunnelError::protocol(format!("Disconnect: invalid reason UTF-8: {e}")))?;
618 Ok(Self::Disconnect { reason })
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625
626 fn roundtrip(msg: &Message) {
628 let encoded = msg.encode();
629 let (decoded, consumed) = Message::decode(&encoded).expect("decode failed");
630 assert_eq!(consumed, encoded.len(), "consumed bytes mismatch");
631 assert_eq!(&decoded, msg, "roundtrip mismatch");
632 }
633
634 #[test]
635 fn test_auth_roundtrip() {
636 roundtrip(&Message::Auth {
637 token: "tun_abc123".to_string(),
638 client_id: Uuid::new_v4(),
639 });
640
641 roundtrip(&Message::Auth {
643 token: String::new(),
644 client_id: Uuid::nil(),
645 });
646
647 roundtrip(&Message::Auth {
649 token: "a".repeat(1000),
650 client_id: Uuid::new_v4(),
651 });
652 }
653
654 #[test]
655 fn test_auth_ok_roundtrip() {
656 roundtrip(&Message::AuthOk {
657 tunnel_id: Uuid::new_v4(),
658 });
659
660 roundtrip(&Message::AuthOk {
661 tunnel_id: Uuid::nil(),
662 });
663 }
664
665 #[test]
666 fn test_auth_fail_roundtrip() {
667 roundtrip(&Message::AuthFail {
668 reason: "invalid token".to_string(),
669 });
670
671 roundtrip(&Message::AuthFail {
672 reason: String::new(),
673 });
674
675 roundtrip(&Message::AuthFail {
676 reason: "x".repeat(500),
677 });
678 }
679
680 #[test]
681 fn test_register_roundtrip() {
682 roundtrip(&Message::Register {
683 name: "ssh".to_string(),
684 protocol: ServiceProtocol::Tcp,
685 local_port: 22,
686 remote_port: 2222,
687 });
688
689 roundtrip(&Message::Register {
690 name: "game".to_string(),
691 protocol: ServiceProtocol::Udp,
692 local_port: 27015,
693 remote_port: 0, });
695
696 roundtrip(&Message::Register {
697 name: "a".repeat(255), protocol: ServiceProtocol::Tcp,
699 local_port: 65535,
700 remote_port: 65535,
701 });
702 }
703
704 #[test]
705 fn test_register_ok_roundtrip() {
706 roundtrip(&Message::RegisterOk {
707 service_id: Uuid::new_v4(),
708 });
709 }
710
711 #[test]
712 fn test_register_fail_roundtrip() {
713 roundtrip(&Message::RegisterFail {
714 reason: "port already in use".to_string(),
715 });
716 }
717
718 #[test]
719 fn test_connect_roundtrip() {
720 roundtrip(&Message::Connect {
721 service_id: Uuid::new_v4(),
722 connection_id: Uuid::new_v4(),
723 client_addr: "192.168.1.100:54321".to_string(),
724 });
725
726 roundtrip(&Message::Connect {
727 service_id: Uuid::new_v4(),
728 connection_id: Uuid::new_v4(),
729 client_addr: "[::1]:8080".to_string(),
730 });
731 }
732
733 #[test]
734 fn test_connect_ack_roundtrip() {
735 roundtrip(&Message::ConnectAck {
736 connection_id: Uuid::new_v4(),
737 });
738 }
739
740 #[test]
741 fn test_connect_fail_roundtrip() {
742 roundtrip(&Message::ConnectFail {
743 connection_id: Uuid::new_v4(),
744 reason: "connection refused".to_string(),
745 });
746 }
747
748 #[test]
749 fn test_heartbeat_roundtrip() {
750 roundtrip(&Message::Heartbeat {
751 timestamp: 1_705_320_000_000,
752 });
753
754 roundtrip(&Message::Heartbeat { timestamp: 0 });
755
756 roundtrip(&Message::Heartbeat {
757 timestamp: u64::MAX,
758 });
759 }
760
761 #[test]
762 fn test_heartbeat_ack_roundtrip() {
763 roundtrip(&Message::HeartbeatAck {
764 timestamp: 1_705_320_000_000,
765 });
766 }
767
768 #[test]
769 fn test_unregister_roundtrip() {
770 roundtrip(&Message::Unregister {
771 service_id: Uuid::new_v4(),
772 });
773 }
774
775 #[test]
776 fn test_disconnect_roundtrip() {
777 roundtrip(&Message::Disconnect {
778 reason: "server shutdown".to_string(),
779 });
780 }
781
782 #[test]
783 fn test_message_type_discriminants() {
784 assert_eq!(
785 Message::Auth {
786 token: String::new(),
787 client_id: Uuid::nil()
788 }
789 .message_type(),
790 MessageType::Auth
791 );
792 assert_eq!(
793 Message::AuthOk {
794 tunnel_id: Uuid::nil()
795 }
796 .message_type(),
797 MessageType::AuthOk
798 );
799 assert_eq!(
800 Message::AuthFail {
801 reason: String::new()
802 }
803 .message_type(),
804 MessageType::AuthFail
805 );
806 assert_eq!(
807 Message::Register {
808 name: String::new(),
809 protocol: ServiceProtocol::Tcp,
810 local_port: 0,
811 remote_port: 0
812 }
813 .message_type(),
814 MessageType::Register
815 );
816 assert_eq!(
817 Message::RegisterOk {
818 service_id: Uuid::nil()
819 }
820 .message_type(),
821 MessageType::RegisterOk
822 );
823 assert_eq!(
824 Message::RegisterFail {
825 reason: String::new()
826 }
827 .message_type(),
828 MessageType::RegisterFail
829 );
830 assert_eq!(
831 Message::Connect {
832 service_id: Uuid::nil(),
833 connection_id: Uuid::nil(),
834 client_addr: String::new()
835 }
836 .message_type(),
837 MessageType::Connect
838 );
839 assert_eq!(
840 Message::ConnectAck {
841 connection_id: Uuid::nil()
842 }
843 .message_type(),
844 MessageType::ConnectAck
845 );
846 assert_eq!(
847 Message::ConnectFail {
848 connection_id: Uuid::nil(),
849 reason: String::new()
850 }
851 .message_type(),
852 MessageType::ConnectFail
853 );
854 assert_eq!(
855 Message::Heartbeat { timestamp: 0 }.message_type(),
856 MessageType::Heartbeat
857 );
858 assert_eq!(
859 Message::HeartbeatAck { timestamp: 0 }.message_type(),
860 MessageType::HeartbeatAck
861 );
862 assert_eq!(
863 Message::Unregister {
864 service_id: Uuid::nil()
865 }
866 .message_type(),
867 MessageType::Unregister
868 );
869 assert_eq!(
870 Message::Disconnect {
871 reason: String::new()
872 }
873 .message_type(),
874 MessageType::Disconnect
875 );
876 }
877
878 #[test]
879 fn test_message_type_from_u8() {
880 assert_eq!(MessageType::try_from(0x01).unwrap(), MessageType::Auth);
881 assert_eq!(MessageType::try_from(0x02).unwrap(), MessageType::AuthOk);
882 assert_eq!(MessageType::try_from(0x03).unwrap(), MessageType::AuthFail);
883 assert_eq!(MessageType::try_from(0x10).unwrap(), MessageType::Register);
884 assert_eq!(
885 MessageType::try_from(0x11).unwrap(),
886 MessageType::RegisterOk
887 );
888 assert_eq!(
889 MessageType::try_from(0x12).unwrap(),
890 MessageType::RegisterFail
891 );
892 assert_eq!(MessageType::try_from(0x20).unwrap(), MessageType::Connect);
893 assert_eq!(
894 MessageType::try_from(0x21).unwrap(),
895 MessageType::ConnectAck
896 );
897 assert_eq!(
898 MessageType::try_from(0x22).unwrap(),
899 MessageType::ConnectFail
900 );
901 assert_eq!(MessageType::try_from(0x30).unwrap(), MessageType::Heartbeat);
902 assert_eq!(
903 MessageType::try_from(0x31).unwrap(),
904 MessageType::HeartbeatAck
905 );
906 assert_eq!(
907 MessageType::try_from(0x40).unwrap(),
908 MessageType::Unregister
909 );
910 assert_eq!(
911 MessageType::try_from(0x41).unwrap(),
912 MessageType::Disconnect
913 );
914
915 assert!(MessageType::try_from(0xFF).is_err());
917 assert!(MessageType::try_from(0x00).is_err());
918 }
919
920 #[test]
921 fn test_service_protocol_roundtrip() {
922 assert_eq!(
923 ServiceProtocol::from_byte(ServiceProtocol::Tcp.to_byte()).unwrap(),
924 ServiceProtocol::Tcp
925 );
926 assert_eq!(
927 ServiceProtocol::from_byte(ServiceProtocol::Udp.to_byte()).unwrap(),
928 ServiceProtocol::Udp
929 );
930 assert!(ServiceProtocol::from_byte(0xFF).is_err());
931 }
932
933 #[test]
934 fn test_decode_too_short() {
935 assert!(Message::decode(&[]).is_err());
937 assert!(Message::decode(&[0x01]).is_err());
938 assert!(Message::decode(&[0x01, 0x00, 0x00, 0x00]).is_err());
939 }
940
941 #[test]
942 fn test_decode_incomplete_payload() {
943 let bytes = [0x01, 0x00, 0x00, 0x00, 0x20, 0x00]; assert!(Message::decode(&bytes).is_err());
946 }
947
948 #[test]
949 fn test_decode_invalid_message_type() {
950 let bytes = [0xFF, 0x00, 0x00, 0x00, 0x00]; assert!(Message::decode(&bytes).is_err());
952 }
953
954 #[test]
955 fn test_decode_payload_too_large() {
956 let bytes = [0x01, 0xFF, 0xFF, 0xFF, 0xFF]; assert!(Message::decode(&bytes).is_err());
959 }
960
961 #[test]
962 fn test_header_size_constant() {
963 let msg = Message::Heartbeat { timestamp: 0 };
965 let encoded = msg.encode();
966 assert_eq!(encoded.len(), HEADER_SIZE + 8);
968 }
969
970 #[test]
971 fn test_multiple_messages_in_buffer() {
972 let msg1 = Message::Heartbeat { timestamp: 100 };
973 let msg2 = Message::HeartbeatAck { timestamp: 100 };
974
975 let mut buffer = msg1.encode();
976 buffer.extend_from_slice(&msg2.encode());
977
978 let (decoded1, consumed1) = Message::decode(&buffer).unwrap();
980 assert_eq!(decoded1, msg1);
981
982 let (decoded2, consumed2) = Message::decode(&buffer[consumed1..]).unwrap();
984 assert_eq!(decoded2, msg2);
985
986 assert_eq!(consumed1 + consumed2, buffer.len());
987 }
988}