Skip to main content

rustgate/
protocol.rs

1use serde::{Deserialize, Serialize};
2
3/// Top-level WebSocket text message envelope.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(tag = "kind", rename_all = "snake_case")]
6pub enum WsTextMessage {
7    Command(Command),
8    Response(CommandResponse),
9    Control(ControlMessage),
10}
11
12/// Command sent from server to client.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(tag = "type", rename_all = "snake_case")]
15pub enum Command {
16    /// Start a SOCKS5 proxy listener on the client.
17    Socks { tunnel_id: u32, port: u16 },
18    /// Create a reverse TCP tunnel: server binds remote_port,
19    /// forwards connections back to client's local_target.
20    ReverseTunnel {
21        tunnel_id: u32,
22        remote_port: u16,
23        local_target: String,
24    },
25    /// Ping/keepalive.
26    Ping { seq: u64 },
27    /// Request client to shut down a specific tunnel.
28    StopTunnel { tunnel_id: u32 },
29}
30
31/// Response from client to server.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(tag = "status", rename_all = "snake_case")]
34pub enum CommandResponse {
35    Ok {
36        tunnel_id: Option<u32>,
37        message: Option<String>,
38    },
39    /// SOCKS listener successfully bound — authorizes the tunnel on the server.
40    SocksReady {
41        tunnel_id: u32,
42    },
43    /// Reverse tunnel target validated (client confirmed local_target is reachable).
44    ReverseTunnelReady {
45        tunnel_id: u32,
46    },
47    Error {
48        tunnel_id: Option<u32>,
49        message: String,
50    },
51    Pong {
52        seq: u64,
53    },
54}
55
56/// Control messages for channel lifecycle (sent by both sides).
57#[derive(Debug, Clone, Serialize, Deserialize)]
58#[serde(tag = "type", rename_all = "snake_case")]
59pub enum ControlMessage {
60    /// A new data channel has been opened.
61    ChannelOpen {
62        channel_id: u32,
63        tunnel_id: u32,
64        /// For SOCKS: the destination the remote side should connect to.
65        target: Option<String>,
66    },
67    /// The channel is ready for data transfer.
68    ChannelReady { channel_id: u32 },
69    /// The channel has been closed.
70    ChannelClose { channel_id: u32 },
71}
72
73/// Frame tunnel data with a 4-byte channel ID header for binary WS messages.
74pub fn frame_tunnel_data(channel_id: u32, payload: &[u8]) -> Vec<u8> {
75    let mut frame = Vec::with_capacity(4 + payload.len());
76    frame.extend_from_slice(&channel_id.to_be_bytes());
77    frame.extend_from_slice(payload);
78    frame
79}
80
81/// Parse a binary WS message into (channel_id, payload).
82pub fn parse_tunnel_data(data: &[u8]) -> Option<(u32, &[u8])> {
83    if data.len() < 4 {
84        return None;
85    }
86    let channel_id = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
87    Some((channel_id, &data[4..]))
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_command_serde_roundtrip() {
96        let msg = WsTextMessage::Command(Command::Socks { tunnel_id: 1, port: 1080 });
97        let json = serde_json::to_string(&msg).unwrap();
98        let parsed: WsTextMessage = serde_json::from_str(&json).unwrap();
99        match parsed {
100            WsTextMessage::Command(Command::Socks { tunnel_id, port }) => {
101                assert_eq!(tunnel_id, 1);
102                assert_eq!(port, 1080);
103            }
104            _ => panic!("unexpected variant"),
105        }
106    }
107
108    #[test]
109    fn test_response_serde_roundtrip() {
110        let msg = WsTextMessage::Response(CommandResponse::Ok {
111            tunnel_id: Some(1),
112            message: None,
113        });
114        let json = serde_json::to_string(&msg).unwrap();
115        let parsed: WsTextMessage = serde_json::from_str(&json).unwrap();
116        match parsed {
117            WsTextMessage::Response(CommandResponse::Ok { tunnel_id, .. }) => {
118                assert_eq!(tunnel_id, Some(1));
119            }
120            _ => panic!("unexpected variant"),
121        }
122    }
123
124    #[test]
125    fn test_control_serde_roundtrip() {
126        let msg = WsTextMessage::Control(ControlMessage::ChannelOpen {
127            channel_id: 3,
128            tunnel_id: 1,
129            target: Some("example.com:443".into()),
130        });
131        let json = serde_json::to_string(&msg).unwrap();
132        let parsed: WsTextMessage = serde_json::from_str(&json).unwrap();
133        match parsed {
134            WsTextMessage::Control(ControlMessage::ChannelOpen {
135                channel_id,
136                tunnel_id,
137                target,
138            }) => {
139                assert_eq!(channel_id, 3);
140                assert_eq!(tunnel_id, 1);
141                assert_eq!(target.as_deref(), Some("example.com:443"));
142            }
143            _ => panic!("unexpected variant"),
144        }
145    }
146
147    #[test]
148    fn test_frame_parse_roundtrip() {
149        let data = b"hello world";
150        let framed = frame_tunnel_data(42, data);
151        let (channel_id, payload) = parse_tunnel_data(&framed).unwrap();
152        assert_eq!(channel_id, 42);
153        assert_eq!(payload, data);
154    }
155
156    #[test]
157    fn test_parse_tunnel_data_too_short() {
158        assert!(parse_tunnel_data(&[0, 1, 2]).is_none());
159        assert!(parse_tunnel_data(&[]).is_none());
160    }
161}