1use bytes::{Buf, BufMut, BytesMut};
2use serde::{de::DeserializeOwned, Serialize};
3use thiserror::Error;
4use tokio_util::codec::{Decoder, Encoder};
5
6const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
8
9#[derive(Debug, Error)]
11pub enum CodecError {
12 #[error("Frame too large: {0} bytes (max {MAX_FRAME_SIZE})")]
13 FrameTooLarge(usize),
14
15 #[error("IO error: {0}")]
16 Io(#[from] std::io::Error),
17
18 #[error("JSON error: {0}")]
19 Json(#[from] serde_json::Error),
20}
21
22pub struct TunnelCodec<T> {
32 _phantom: std::marker::PhantomData<T>,
33}
34
35impl<T> TunnelCodec<T> {
36 pub fn new() -> Self {
37 Self {
38 _phantom: std::marker::PhantomData,
39 }
40 }
41}
42
43impl<T> Default for TunnelCodec<T> {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl<T: DeserializeOwned> Decoder for TunnelCodec<T> {
50 type Item = T;
51 type Error = CodecError;
52
53 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
54 if src.len() < 4 {
56 return Ok(None);
57 }
58
59 let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
61
62 if length > MAX_FRAME_SIZE {
64 return Err(CodecError::FrameTooLarge(length));
65 }
66
67 let total_len = 4 + length;
69 if src.len() < total_len {
70 src.reserve(total_len - src.len());
72 return Ok(None);
73 }
74
75 src.advance(4);
77
78 let payload = src.split_to(length);
80
81 let message = serde_json::from_slice(&payload)?;
83 Ok(Some(message))
84 }
85}
86
87impl<T: Serialize> Encoder<T> for TunnelCodec<T> {
88 type Error = CodecError;
89
90 fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
91 let json = serde_json::to_vec(&item)?;
93
94 if json.len() > MAX_FRAME_SIZE {
96 return Err(CodecError::FrameTooLarge(json.len()));
97 }
98
99 dst.reserve(4 + json.len());
101 dst.put_u32(json.len() as u32);
102 dst.put_slice(&json);
103
104 Ok(())
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::messages::{ClientMessage, ServerMessage, TunnelType};
112
113 #[test]
114 fn test_roundtrip_client_message() {
115 let mut codec = TunnelCodec::<ClientMessage>::new();
116 let msg = ClientMessage::RequestTunnel {
117 subdomain: Some("test".to_string()),
118 tunnel_type: TunnelType::Http,
119 local_port: 8080,
120 };
121
122 let mut buf = BytesMut::new();
124 codec.encode(msg.clone(), &mut buf).unwrap();
125
126 let decoded = codec.decode(&mut buf).unwrap().unwrap();
128 match decoded {
129 ClientMessage::RequestTunnel {
130 subdomain,
131 tunnel_type,
132 local_port,
133 } => {
134 assert_eq!(subdomain, Some("test".to_string()));
135 assert_eq!(tunnel_type, TunnelType::Http);
136 assert_eq!(local_port, 8080);
137 }
138 _ => panic!("Wrong variant"),
139 }
140 }
141
142 #[test]
143 fn test_roundtrip_server_message() {
144 let mut codec = TunnelCodec::<ServerMessage>::new();
145 let msg = ServerMessage::HttpRequest {
146 stream_id: 42,
147 method: "GET".to_string(),
148 uri: "/api/test".to_string(),
149 headers: vec![("Host".to_string(), "example.com".to_string())],
150 body: vec![],
151 };
152
153 let mut buf = BytesMut::new();
155 codec.encode(msg, &mut buf).unwrap();
156
157 let decoded = codec.decode(&mut buf).unwrap().unwrap();
159 match decoded {
160 ServerMessage::HttpRequest {
161 stream_id,
162 method,
163 uri,
164 ..
165 } => {
166 assert_eq!(stream_id, 42);
167 assert_eq!(method, "GET");
168 assert_eq!(uri, "/api/test");
169 }
170 _ => panic!("Wrong variant"),
171 }
172 }
173
174 #[test]
175 fn test_partial_frame() {
176 let mut codec = TunnelCodec::<ClientMessage>::new();
177 let msg = ClientMessage::Ping { timestamp: 12345 };
178
179 let mut buf = BytesMut::new();
181 codec.encode(msg, &mut buf).unwrap();
182
183 let full_len = buf.len();
185 let mut partial = buf.split_to(full_len / 2);
186
187 assert!(codec.decode(&mut partial).unwrap().is_none());
189
190 partial.unsplit(buf);
192
193 let decoded = codec.decode(&mut partial).unwrap().unwrap();
195 match decoded {
196 ClientMessage::Ping { timestamp } => assert_eq!(timestamp, 12345),
197 _ => panic!("Wrong variant"),
198 }
199 }
200}