1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(tag = "kind", rename_all = "snake_case")]
6pub enum WsTextMessage {
7 Command(Command),
8 Response(CommandResponse),
9 Control(ControlMessage),
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(tag = "type", rename_all = "snake_case")]
15pub enum Command {
16 Socks { tunnel_id: u32, port: u16 },
18 ReverseTunnel {
21 tunnel_id: u32,
22 remote_port: u16,
23 local_target: String,
24 },
25 Ping { seq: u64 },
27 StopTunnel { tunnel_id: u32 },
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(tag = "status", rename_all = "snake_case")]
34pub enum CommandResponse {
35 Ok {
36 tunnel_id: Option<u32>,
37 message: Option<String>,
38 },
39 SocksReady {
41 tunnel_id: u32,
42 },
43 ReverseTunnelReady {
45 tunnel_id: u32,
46 },
47 Error {
48 tunnel_id: Option<u32>,
49 message: String,
50 },
51 Pong {
52 seq: u64,
53 },
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58#[serde(tag = "type", rename_all = "snake_case")]
59pub enum ControlMessage {
60 ChannelOpen {
62 channel_id: u32,
63 tunnel_id: u32,
64 target: Option<String>,
66 },
67 ChannelReady { channel_id: u32 },
69 ChannelClose { channel_id: u32 },
71}
72
73pub fn frame_tunnel_data(channel_id: u32, payload: &[u8]) -> Vec<u8> {
75 let mut frame = Vec::with_capacity(4 + payload.len());
76 frame.extend_from_slice(&channel_id.to_be_bytes());
77 frame.extend_from_slice(payload);
78 frame
79}
80
81pub fn parse_tunnel_data(data: &[u8]) -> Option<(u32, &[u8])> {
83 if data.len() < 4 {
84 return None;
85 }
86 let channel_id = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
87 Some((channel_id, &data[4..]))
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_command_serde_roundtrip() {
96 let msg = WsTextMessage::Command(Command::Socks { tunnel_id: 1, port: 1080 });
97 let json = serde_json::to_string(&msg).unwrap();
98 let parsed: WsTextMessage = serde_json::from_str(&json).unwrap();
99 match parsed {
100 WsTextMessage::Command(Command::Socks { tunnel_id, port }) => {
101 assert_eq!(tunnel_id, 1);
102 assert_eq!(port, 1080);
103 }
104 _ => panic!("unexpected variant"),
105 }
106 }
107
108 #[test]
109 fn test_response_serde_roundtrip() {
110 let msg = WsTextMessage::Response(CommandResponse::Ok {
111 tunnel_id: Some(1),
112 message: None,
113 });
114 let json = serde_json::to_string(&msg).unwrap();
115 let parsed: WsTextMessage = serde_json::from_str(&json).unwrap();
116 match parsed {
117 WsTextMessage::Response(CommandResponse::Ok { tunnel_id, .. }) => {
118 assert_eq!(tunnel_id, Some(1));
119 }
120 _ => panic!("unexpected variant"),
121 }
122 }
123
124 #[test]
125 fn test_control_serde_roundtrip() {
126 let msg = WsTextMessage::Control(ControlMessage::ChannelOpen {
127 channel_id: 3,
128 tunnel_id: 1,
129 target: Some("example.com:443".into()),
130 });
131 let json = serde_json::to_string(&msg).unwrap();
132 let parsed: WsTextMessage = serde_json::from_str(&json).unwrap();
133 match parsed {
134 WsTextMessage::Control(ControlMessage::ChannelOpen {
135 channel_id,
136 tunnel_id,
137 target,
138 }) => {
139 assert_eq!(channel_id, 3);
140 assert_eq!(tunnel_id, 1);
141 assert_eq!(target.as_deref(), Some("example.com:443"));
142 }
143 _ => panic!("unexpected variant"),
144 }
145 }
146
147 #[test]
148 fn test_frame_parse_roundtrip() {
149 let data = b"hello world";
150 let framed = frame_tunnel_data(42, data);
151 let (channel_id, payload) = parse_tunnel_data(&framed).unwrap();
152 assert_eq!(channel_id, 42);
153 assert_eq!(payload, data);
154 }
155
156 #[test]
157 fn test_parse_tunnel_data_too_short() {
158 assert!(parse_tunnel_data(&[0, 1, 2]).is_none());
159 assert!(parse_tunnel_data(&[]).is_none());
160 }
161}