1use bytes::BytesMut;
7
8pub const HASH_LEN: usize = 56;
9pub const CRLF: &[u8; 2] = b"\r\n";
10
11pub const CMD_CONNECT: u8 = 0x01;
12pub const CMD_UDP_ASSOCIATE: u8 = 0x03;
13pub const CMD_MUX: u8 = 0x7f;
15
16pub const MAX_UDP_PAYLOAD: usize = 8 * 1024;
18pub const MAX_DOMAIN_LEN: usize = 255;
20
21pub const ATYP_IPV4: u8 = 0x01;
22pub const ATYP_DOMAIN: u8 = 0x03;
23pub const ATYP_IPV6: u8 = 0x04;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ParseError {
27 InvalidCrlf,
28 InvalidCommand,
29 InvalidAtyp,
30 InvalidDomainLen,
31 InvalidUtf8,
32 InvalidHashFormat,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum WriteError {
39 PayloadTooLarge,
41 DomainTooLong,
43 InvalidHashLen,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum ParseResult<T> {
55 Complete(T),
56 Incomplete(usize),
57 Invalid(ParseError),
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum HostRef<'a> {
62 Ipv4([u8; 4]),
63 Ipv6([u8; 16]),
64 Domain(&'a [u8]),
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct AddressRef<'a> {
69 pub host: HostRef<'a>,
70 pub port: u16,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct TrojanRequest<'a> {
75 pub hash: &'a [u8],
76 pub command: u8,
77 pub address: AddressRef<'a>,
78 pub header_len: usize,
79 pub payload: &'a [u8],
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct UdpPacket<'a> {
84 pub address: AddressRef<'a>,
85 pub length: usize,
86 pub packet_len: usize,
87 pub payload: &'a [u8],
88}
89
90#[inline]
92pub fn is_valid_hash(hash: &[u8]) -> bool {
93 hash.len() == HASH_LEN && hash.iter().all(|&b| b.is_ascii_hexdigit())
94}
95
96#[inline]
97pub fn parse_request(buf: &[u8]) -> ParseResult<TrojanRequest<'_>> {
98 if buf.len() < HASH_LEN {
99 return ParseResult::Incomplete(HASH_LEN);
100 }
101
102 let hash = &buf[..HASH_LEN];
103 if !is_valid_hash(hash) {
104 return ParseResult::Invalid(ParseError::InvalidHashFormat);
105 }
106 let mut offset = HASH_LEN;
107
108 if let Some(res) = expect_crlf(buf, offset) {
109 return res;
110 }
111 offset += 2;
112
113 if buf.len() < offset + 2 {
114 return ParseResult::Incomplete(offset + 2);
115 }
116 let command = buf[offset];
117 if command != CMD_CONNECT && command != CMD_UDP_ASSOCIATE {
118 return ParseResult::Invalid(ParseError::InvalidCommand);
119 }
120 let atyp = buf[offset + 1];
121 offset += 2;
122
123 let addr_res = parse_address(atyp, &buf[offset..]);
124 let (address, addr_len) = match addr_res {
125 ParseResult::Complete(v) => v,
126 ParseResult::Incomplete(n) => return ParseResult::Incomplete(offset + n),
127 ParseResult::Invalid(e) => return ParseResult::Invalid(e),
128 };
129 offset += addr_len;
130
131 if let Some(res) = expect_crlf(buf, offset) {
132 return res;
133 }
134 offset += 2;
135
136 ParseResult::Complete(TrojanRequest {
137 hash,
138 command,
139 address,
140 header_len: offset,
141 payload: &buf[offset..],
142 })
143}
144
145#[inline]
146pub fn parse_udp_packet(buf: &[u8]) -> ParseResult<UdpPacket<'_>> {
147 if buf.is_empty() {
148 return ParseResult::Incomplete(1);
149 }
150 let atyp = buf[0];
151 let addr_res = parse_address(atyp, &buf[1..]);
152 let (address, addr_len) = match addr_res {
153 ParseResult::Complete(v) => v,
154 ParseResult::Incomplete(n) => return ParseResult::Incomplete(1 + n),
155 ParseResult::Invalid(e) => return ParseResult::Invalid(e),
156 };
157
158 let mut offset = 1 + addr_len;
159 if buf.len() < offset + 2 {
160 return ParseResult::Incomplete(offset + 2);
161 }
162 let length = read_u16(&buf[offset..offset + 2]) as usize;
163 if buf.len() < offset + 4 {
164 return ParseResult::Incomplete(offset + 4);
165 }
166 if &buf[offset + 2..offset + 4] != CRLF {
167 return ParseResult::Invalid(ParseError::InvalidCrlf);
168 }
169 offset += 4;
170 if buf.len() < offset + length {
171 return ParseResult::Incomplete(offset + length);
172 }
173
174 ParseResult::Complete(UdpPacket {
175 address,
176 length,
177 packet_len: offset + length,
178 payload: &buf[offset..offset + length],
179 })
180}
181
182#[allow(clippy::cast_possible_truncation)]
188pub fn write_request_header(
189 buf: &mut BytesMut,
190 hash_hex: &[u8],
191 command: u8,
192 address: &AddressRef<'_>,
193) -> Result<(), WriteError> {
194 if hash_hex.len() != HASH_LEN {
195 return Err(WriteError::InvalidHashLen);
196 }
197 if let HostRef::Domain(d) = &address.host
198 && d.len() > MAX_DOMAIN_LEN
199 {
200 return Err(WriteError::DomainTooLong);
201 }
202 buf.extend_from_slice(hash_hex);
203 buf.extend_from_slice(CRLF);
204 buf.extend_from_slice(&[command, address_atyp(address)]);
205 write_address_unchecked(buf, address);
206 buf.extend_from_slice(CRLF);
207 Ok(())
208}
209
210#[allow(clippy::cast_possible_truncation)]
216pub fn write_udp_packet(
217 buf: &mut BytesMut,
218 address: &AddressRef<'_>,
219 payload: &[u8],
220) -> Result<(), WriteError> {
221 if payload.len() > u16::MAX as usize {
222 return Err(WriteError::PayloadTooLarge);
223 }
224 if let HostRef::Domain(d) = &address.host
225 && d.len() > MAX_DOMAIN_LEN
226 {
227 return Err(WriteError::DomainTooLong);
228 }
229 buf.extend_from_slice(&[address_atyp(address)]);
230 write_address_unchecked(buf, address);
231 buf.extend_from_slice(&(payload.len() as u16).to_be_bytes());
232 buf.extend_from_slice(CRLF);
233 buf.extend_from_slice(payload);
234 Ok(())
235}
236
237#[inline]
238fn expect_crlf<T>(buf: &[u8], offset: usize) -> Option<ParseResult<T>> {
239 if buf.len() < offset + 2 {
240 return Some(ParseResult::Incomplete(offset + 2));
241 }
242 if &buf[offset..offset + 2] != CRLF {
243 return Some(ParseResult::Invalid(ParseError::InvalidCrlf));
244 }
245 None
246}
247
248#[inline]
249fn parse_address<'a>(atyp: u8, buf: &'a [u8]) -> ParseResult<(AddressRef<'a>, usize)> {
250 match atyp {
251 ATYP_IPV4 => {
252 if buf.len() < 6 {
253 return ParseResult::Incomplete(6);
254 }
255 let host = HostRef::Ipv4([buf[0], buf[1], buf[2], buf[3]]);
256 let port = read_u16(&buf[4..6]);
257 ParseResult::Complete((AddressRef { host, port }, 6))
258 }
259 ATYP_DOMAIN => {
260 if buf.is_empty() {
261 return ParseResult::Incomplete(1);
262 }
263 let len = buf[0] as usize;
264 if len == 0 {
265 return ParseResult::Invalid(ParseError::InvalidDomainLen);
266 }
267 let need = 1 + len + 2;
268 if buf.len() < need {
269 return ParseResult::Incomplete(need);
270 }
271 let domain = &buf[1..1 + len];
272 if std::str::from_utf8(domain).is_err() {
273 return ParseResult::Invalid(ParseError::InvalidUtf8);
274 }
275 let port = read_u16(&buf[1 + len..1 + len + 2]);
276 ParseResult::Complete((
277 AddressRef {
278 host: HostRef::Domain(domain),
279 port,
280 },
281 need,
282 ))
283 }
284 ATYP_IPV6 => {
285 if buf.len() < 18 {
286 return ParseResult::Incomplete(18);
287 }
288 let mut ip = [0u8; 16];
289 ip.copy_from_slice(&buf[0..16]);
290 let port = read_u16(&buf[16..18]);
291 ParseResult::Complete((
292 AddressRef {
293 host: HostRef::Ipv6(ip),
294 port,
295 },
296 18,
297 ))
298 }
299 _ => ParseResult::Invalid(ParseError::InvalidAtyp),
300 }
301}
302
303#[allow(clippy::cast_possible_truncation)]
305fn write_address_unchecked(buf: &mut BytesMut, address: &AddressRef<'_>) {
306 match address.host {
307 HostRef::Ipv4(ip) => {
308 buf.extend_from_slice(&ip);
309 }
310 HostRef::Ipv6(ip) => {
311 buf.extend_from_slice(&ip);
312 }
313 HostRef::Domain(domain) => {
314 debug_assert!(domain.len() <= MAX_DOMAIN_LEN);
315 buf.extend_from_slice(&[domain.len() as u8]);
316 buf.extend_from_slice(domain);
317 }
318 }
319 buf.extend_from_slice(&address.port.to_be_bytes());
320}
321
322#[inline]
323fn address_atyp(address: &AddressRef<'_>) -> u8 {
324 match address.host {
325 HostRef::Ipv4(_) => ATYP_IPV4,
326 HostRef::Ipv6(_) => ATYP_IPV6,
327 HostRef::Domain(_) => ATYP_DOMAIN,
328 }
329}
330
331#[inline]
332fn read_u16(buf: &[u8]) -> u16 {
333 debug_assert!(buf.len() >= 2, "read_u16 requires at least 2 bytes");
334 u16::from_be_bytes([buf[0], buf[1]])
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 fn sample_hash() -> [u8; HASH_LEN] {
342 [b'a'; HASH_LEN]
343 }
344
345 #[test]
346 fn test_is_valid_hash() {
347 assert!(is_valid_hash(&[b'a'; HASH_LEN]));
349 assert!(is_valid_hash(
350 b"0123456789abcdef0123456789abcdef0123456789abcdef01234567"
351 ));
352
353 assert!(is_valid_hash(
355 b"0123456789ABCDEF0123456789abcdef0123456789abcdef01234567"
356 ));
357
358 assert!(!is_valid_hash(&[b'a'; HASH_LEN - 1]));
360 assert!(!is_valid_hash(&[b'a'; HASH_LEN + 1]));
361
362 let mut invalid = [b'a'; HASH_LEN];
364 invalid[0] = b'g';
365 assert!(!is_valid_hash(&invalid));
366 }
367
368 #[test]
369 fn parse_request_connect_ipv4() {
370 let addr = AddressRef {
371 host: HostRef::Ipv4([1, 2, 3, 4]),
372 port: 443,
373 };
374 let mut buf = BytesMut::new();
375 write_request_header(&mut buf, &sample_hash(), CMD_CONNECT, &addr).unwrap();
376 buf.extend_from_slice(b"hello");
377
378 let res = parse_request(&buf);
379 match res {
380 ParseResult::Complete(req) => {
381 assert_eq!(req.command, CMD_CONNECT);
382 assert_eq!(req.address, addr);
383 assert_eq!(req.payload, b"hello");
384 }
385 _ => panic!("unexpected parse result: {:?}", res),
386 }
387 }
388
389 #[test]
390 fn parse_request_invalid_hash() {
391 let addr = AddressRef {
392 host: HostRef::Ipv4([1, 2, 3, 4]),
393 port: 443,
394 };
395 let mut buf = BytesMut::new();
396 let mut invalid_hash = [b'a'; HASH_LEN];
398 invalid_hash[0] = b'g';
399 write_request_header(&mut buf, &invalid_hash, CMD_CONNECT, &addr).unwrap();
400
401 let res = parse_request(&buf);
402 assert_eq!(res, ParseResult::Invalid(ParseError::InvalidHashFormat));
403 }
404
405 #[test]
406 fn parse_request_incomplete() {
407 let data = vec![b'a'; HASH_LEN - 1];
408 assert_eq!(parse_request(&data), ParseResult::Incomplete(HASH_LEN));
409 }
410
411 #[test]
412 fn parse_udp_packet_ipv4() {
413 let addr = AddressRef {
414 host: HostRef::Ipv4([8, 8, 8, 8]),
415 port: 53,
416 };
417 let mut buf = BytesMut::new();
418 write_udp_packet(&mut buf, &addr, b"ping").unwrap();
419 let res = parse_udp_packet(&buf);
420 match res {
421 ParseResult::Complete(pkt) => {
422 assert_eq!(pkt.address, addr);
423 assert_eq!(pkt.payload, b"ping");
424 }
425 _ => panic!("unexpected parse result: {:?}", res),
426 }
427 }
428
429 #[test]
430 fn write_udp_packet_payload_too_large() {
431 let addr = AddressRef {
432 host: HostRef::Ipv4([8, 8, 8, 8]),
433 port: 53,
434 };
435 let mut buf = BytesMut::new();
436 let large_payload = vec![0u8; u16::MAX as usize + 1];
437 let res = write_udp_packet(&mut buf, &addr, &large_payload);
438 assert_eq!(res, Err(WriteError::PayloadTooLarge));
439 }
440
441 #[test]
442 fn write_request_header_domain_too_long() {
443 let long_domain = vec![b'a'; 256];
444 let addr = AddressRef {
445 host: HostRef::Domain(&long_domain),
446 port: 443,
447 };
448 let mut buf = BytesMut::new();
449 let res = write_request_header(&mut buf, &sample_hash(), CMD_CONNECT, &addr);
450 assert_eq!(res, Err(WriteError::DomainTooLong));
451 }
452
453 #[test]
454 fn write_request_header_invalid_hash_len() {
455 let addr = AddressRef {
456 host: HostRef::Ipv4([1, 2, 3, 4]),
457 port: 443,
458 };
459 let mut buf = BytesMut::new();
460 let short_hash = [b'a'; HASH_LEN - 1];
461 let res = write_request_header(&mut buf, &short_hash, CMD_CONNECT, &addr);
462 assert_eq!(res, Err(WriteError::InvalidHashLen));
463 }
464}