1use std::{
2 borrow::Cow,
3 collections::VecDeque,
4 ptr::{copy, copy_nonoverlapping},
5 rc::Rc,
6 sync::Arc,
7};
8
9use anyhow::bail;
10use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
11use local_sync::oneshot::Sender;
12use monoio::{
13 buf::{IoBuf, IoBufMut, Slice, SliceMut},
14 io::{
15 AsyncReadRent, AsyncReadRentExt, AsyncWriteRent, AsyncWriteRentExt, PrefixedReadIo,
16 Splitable,
17 },
18 net::TcpStream,
19};
20
21use crate::{
22 helper_v2::{
23 copy_with_application_data, copy_without_application_data, ErrGroup, FirstRetGroup,
24 FutureOrOutput, HashedWriteStream, HmacHandler, HMAC_SIZE_V2,
25 },
26 util::{
27 bind_with_pretty_error, copy_bidirectional, copy_until_eof, kdf, mod_tcp_conn, prelude::*,
28 support_tls13, verified_relay, xor_slice, CursorExt, Hmac, V3Mode,
29 },
30 WildcardSNI,
31};
32
33#[derive(Clone)]
35pub struct ShadowTlsServer<LA, TA> {
36 listen_addr: Arc<LA>,
37 target_addr: Arc<TA>,
38 tls_addr: Arc<TlsAddrs>,
39 password: Arc<String>,
40 nodelay: bool,
41 v3: V3Mode,
42}
43
44#[derive(Clone, Debug, PartialEq)]
45pub struct TlsAddrs {
46 dispatch: rustc_hash::FxHashMap<String, String>,
47 fallback: String,
48 wildcard_sni: WildcardSNI,
49}
50
51impl TlsAddrs {
52 fn find<'a>(&'a self, key: Option<&str>, auth: bool) -> Cow<'a, str> {
53 match key {
54 Some(k) => match self.dispatch.get(k) {
55 Some(v) => Cow::Borrowed(v),
56 None => match self.wildcard_sni {
57 WildcardSNI::Authed if auth => Cow::Owned(format!("{k}:443")),
58 WildcardSNI::All => Cow::Owned(format!("{k}:443")),
59 _ => Cow::Borrowed(&self.fallback),
60 },
61 },
62 None => Cow::Borrowed(&self.fallback),
63 }
64 }
65
66 fn is_empty(&self) -> bool {
67 self.dispatch.is_empty()
68 }
69
70 pub fn set_wildcard_sni(&mut self, wildcard_sni: WildcardSNI) {
71 self.wildcard_sni = wildcard_sni;
72 }
73}
74
75impl TryFrom<&str> for TlsAddrs {
76 type Error = anyhow::Error;
77
78 fn try_from(arg: &str) -> Result<Self, Self::Error> {
79 let mut rev_parts = arg.split(';').rev();
80 let fallback = rev_parts
81 .next()
82 .and_then(|x| if x.trim().is_empty() { None } else { Some(x) })
83 .ok_or_else(|| anyhow::anyhow!("empty server addrs"))?;
84 let fallback = if !fallback.contains(':') {
85 format!("{fallback}:443")
86 } else {
87 fallback.to_string()
88 };
89
90 let mut dispatch = rustc_hash::FxHashMap::default();
91 for p in rev_parts {
92 let mut p = p.trim().split(':').rev();
93 let mut port = Cow::<'static, str>::Borrowed("443");
94 let maybe_port = p
95 .next()
96 .ok_or_else(|| anyhow::anyhow!("empty part found in server addrs"))?;
97 let host = if maybe_port.parse::<u16>().is_ok() {
98 port = maybe_port.into();
100 p.next()
101 .ok_or_else(|| anyhow::anyhow!("no host found in server addrs part"))?
102 } else {
103 maybe_port
104 };
105 let key = match p.next() {
106 Some(key) => key,
107 None => host,
108 };
109 if p.next().is_some() {
110 bail!("unrecognized server addrs part");
111 }
112 if dispatch
113 .insert(key.to_string(), format!("{host}:{port}"))
114 .is_some()
115 {
116 bail!("duplicate server addrs part found");
117 }
118 }
119 Ok(TlsAddrs {
120 dispatch,
121 fallback,
122 wildcard_sni: Default::default(),
123 })
124 }
125}
126
127impl std::fmt::Display for TlsAddrs {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 write!(f, "(wildcard-sni:{})", self.wildcard_sni)?;
130 for (k, v) in self.dispatch.iter() {
131 write!(f, "{k}->{v};")?;
132 }
133 write!(f, "fallback->{}", self.fallback)
134 }
135}
136
137impl<LA, TA> ShadowTlsServer<LA, TA> {
138 pub fn new(
139 listen_addr: LA,
140 target_addr: TA,
141 tls_addr: TlsAddrs,
142 password: String,
143 nodelay: bool,
144 v3: V3Mode,
145 ) -> Self {
146 Self {
147 listen_addr: Arc::new(listen_addr),
148 target_addr: Arc::new(target_addr),
149 tls_addr: Arc::new(tls_addr),
150 password: Arc::new(password),
151 nodelay,
152 v3,
153 }
154 }
155}
156
157impl<LA, TA> ShadowTlsServer<LA, TA> {
158 pub async fn serve(self) -> anyhow::Result<()>
160 where
161 LA: std::net::ToSocketAddrs + 'static,
162 TA: std::net::ToSocketAddrs + 'static,
163 {
164 let listener = bind_with_pretty_error(self.listen_addr.as_ref())?;
165 let shared = Rc::new(self);
166 loop {
167 match listener.accept().await {
168 Ok((mut conn, addr)) => {
169 tracing::info!("Accepted a connection from {addr}");
170 let server = shared.clone();
171 mod_tcp_conn(&mut conn, true, shared.nodelay);
172 monoio::spawn(async move {
173 let _ = match server.v3.enabled() {
174 false => server.relay_v2(conn).await,
175 true => server.relay_v3(conn).await,
176 };
177 tracing::info!("Relay for {addr} finished");
178 });
179 }
180 Err(e) => {
181 tracing::error!("Accept failed: {e}");
182 }
183 }
184 }
185 }
186
187 async fn relay_v2(&self, in_stream: TcpStream) -> anyhow::Result<()>
189 where
190 TA: std::net::ToSocketAddrs,
191 {
192 let mut in_stream = HashedWriteStream::new(in_stream, self.password.as_bytes())?;
194 let mut hmac = in_stream.hmac_handler();
195
196 let (prefix, server_name) = match self.tls_addr.is_empty() {
199 true => (Vec::new(), None),
200 false => extract_sni_v2(&mut in_stream).await?,
201 };
202 let mut prefixed_io = PrefixedReadIo::new(&mut in_stream, std::io::Cursor::new(prefix));
203 tracing::debug!("server name extracted from SNI extention: {server_name:?}");
204
205 let server_name = server_name.and_then(|s| String::from_utf8(s).ok());
207 let addr = self
208 .tls_addr
209 .find(server_name.as_ref().map(AsRef::as_ref), true);
210 let mut out_stream = TcpStream::connect(addr.as_ref()).await?;
211 mod_tcp_conn(&mut out_stream, true, self.nodelay);
212 tracing::debug!("handshake server connected: {addr}");
213
214 let (mut out_r, mut out_w) = out_stream.split();
216 let (mut in_r, mut in_w) = prefixed_io.split();
217 let (switch, cp) = FirstRetGroup::new(
218 copy_until_handshake_finished(&mut in_r, &mut out_w, &hmac),
219 Box::pin(copy_until_eof(&mut out_r, &mut in_w)),
220 )
221 .await?;
222 hmac.disable();
223 tracing::debug!("handshake finished, switch: {switch:?}");
224
225 match switch {
227 SwitchResult::Switch(data_left) => {
228 drop(cp);
229 let mut in_stream = in_stream.into_inner();
230 let (mut in_r, mut in_w) = in_stream.split();
231
232 let _ = out_stream.shutdown().await;
234 drop(out_stream);
235 let mut data_stream = TcpStream::connect(self.target_addr.as_ref()).await?;
236 mod_tcp_conn(&mut data_stream, true, self.nodelay);
237 tracing::debug!("data server connected, start relay");
238 let (mut data_r, mut data_w) = data_stream.split();
239 let (result, _) = data_w.write(data_left).await;
240 result?;
241 ErrGroup::new(
242 copy_with_application_data::<0, _, _>(&mut data_r, &mut in_w, None),
243 copy_without_application_data(&mut in_r, &mut data_w),
244 )
245 .await?;
246 }
247 SwitchResult::DirectProxy => match cp {
248 FutureOrOutput::Future(cp) => {
249 ErrGroup::new(cp, copy_until_eof(in_r, out_w)).await?;
250 }
251 FutureOrOutput::Output(_) => {
252 copy_until_eof(in_r, out_w).await?;
253 }
254 },
255 }
256 Ok(())
257 }
258
259 async fn relay_v3(&self, mut in_stream: TcpStream) -> anyhow::Result<()>
261 where
262 TA: std::net::ToSocketAddrs,
263 {
264 let first_client_frame = read_exact_frame(&mut in_stream).await?;
266 let (client_hello_pass, sni) = verified_extract_sni(&first_client_frame, &self.password);
267
268 let server_name = sni.and_then(|s| String::from_utf8(s).ok());
270 let addr = self
271 .tls_addr
272 .find(server_name.as_ref().map(AsRef::as_ref), client_hello_pass);
273 let mut handshake_stream = TcpStream::connect(addr.as_ref()).await?;
274 mod_tcp_conn(&mut handshake_stream, true, self.nodelay);
275 tracing::debug!("handshake server connected: {addr}");
276 tracing::trace!("ClientHello frame {first_client_frame:?}");
277 let (res, _) = handshake_stream.write_all(first_client_frame).await;
278 res?;
279 if !client_hello_pass {
280 tracing::warn!("ClientHello verify failed, will work as a SNI proxy");
282 copy_bidirectional(&mut in_stream, &mut handshake_stream).await;
283 return Ok(());
284 }
285 tracing::debug!("ClientHello verify success");
286
287 let first_server_frame = read_exact_frame(&mut handshake_stream).await?;
289 let (res, first_server_frame) = in_stream.write_all(first_server_frame).await;
290 res?;
291 let server_random = match extract_server_random(&first_server_frame) {
292 Some(sr) => sr,
293 None => {
294 tracing::warn!("ServerRandom extract failed, will copy bidirectional");
296 copy_bidirectional(&mut in_stream, &mut handshake_stream).await;
297 return Ok(());
298 }
299 };
300 tracing::debug!("Client authenticated. ServerRandom extracted: {server_random:?}");
301
302 let use_tls13 = support_tls13(&first_server_frame);
303 if self.v3.strict() && !use_tls13 {
304 tracing::error!(
305 "V3 strict enabled and TLS 1.3 is not supported, will copy bidirectional"
306 );
307 copy_bidirectional(&mut in_stream, &mut handshake_stream).await;
308 return Ok(());
309 }
310
311 let mut hmac_sr_c = Hmac::new(&self.password, (&server_random, b"C"));
313 let hmac_sr_s = Hmac::new(&self.password, (&server_random, b"S"));
314 let mut hmac_sr = Hmac::new(&self.password, (&server_random, &[]));
315
316 let pure_data = {
319 let (mut c_read, mut c_write) = in_stream.split();
320 let (mut h_read, mut h_write) = handshake_stream.split();
321 let (mut sender, mut recevier) = local_sync::oneshot::channel::<()>();
322 let key = kdf(&self.password, &server_random);
323 let (maybe_pure, _) = monoio::join!(
324 async {
325 let r =
326 copy_by_frame_until_hmac_matches(&mut c_read, &mut h_write, &mut hmac_sr_c)
327 .await;
328 recevier.close();
329 if r.is_err() {
330 let _ = h_write.shutdown().await;
331 }
332 r
333 },
334 async {
335 let r = copy_by_frame_with_modification(
336 &mut h_read,
337 &mut c_write,
338 &mut hmac_sr,
339 &key,
340 &mut sender,
341 )
342 .await;
343 if r.is_err() {
344 let _ = c_write.shutdown().await;
345 }
346 }
347 );
348 maybe_pure?
349 };
350 tracing::debug!("handshake relay finished");
351
352 drop(handshake_stream);
354 drop(first_server_frame);
355
356 let mut data_stream = TcpStream::connect(self.target_addr.as_ref()).await?;
359 mod_tcp_conn(&mut data_stream, true, self.nodelay);
360 let (res, _) = data_stream.write_all(pure_data).await;
361 res?;
362 verified_relay(
363 data_stream,
364 in_stream,
365 hmac_sr_s,
366 hmac_sr_c,
367 None,
368 !use_tls13,
369 )
370 .await;
371 Ok(())
372 }
373}
374
375enum SwitchResult {
379 Switch(Vec<u8>),
380 DirectProxy,
381}
382
383impl std::fmt::Debug for SwitchResult {
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 match self {
386 Self::Switch(_) => write!(f, "Switch"),
387 Self::DirectProxy => write!(f, "DirectProxy"),
388 }
389 }
390}
391
392async fn copy_until_handshake_finished<R, W>(
397 mut read_half: R,
398 mut write_half: W,
399 hmac: &HmacHandler,
400) -> std::io::Result<SwitchResult>
401where
402 R: AsyncReadRent,
403 W: AsyncWriteRent,
404{
405 let mut has_seen_change_cipher_spec = false;
408 let mut has_seen_handshake = false;
409
410 let mut header_buf = vec![0_u8; TLS_HEADER_SIZE].into_boxed_slice();
412 let mut header_read_len = 0;
413 let mut header_write_len = 0;
414 let mut data_hmac_buf = vec![0_u8; HMAC_SIZE_V2].into_boxed_slice();
416 let mut data_buf = vec![0_u8; 2048];
417 let mut application_data_count: usize = 0;
418
419 let mut hashes = VecDeque::with_capacity(10);
420 loop {
421 let header_buf_slice = SliceMut::new(header_buf, header_read_len, TLS_HEADER_SIZE);
422 let (res, header_buf_slice_) = read_half.read(header_buf_slice).await;
423 header_buf = header_buf_slice_.into_inner();
424 let read_len = res?;
425 header_read_len += read_len;
426
427 if read_len == 0 {
429 let _ = write_half.shutdown().await;
430 return Err(std::io::ErrorKind::UnexpectedEof.into());
431 }
432
433 let header_buf_slice_w = Slice::new(header_buf, header_write_len, header_read_len);
435 let (res, header_buf_slice_w_) = write_half.write_all(header_buf_slice_w).await;
436 header_buf = header_buf_slice_w_.into_inner();
437 header_write_len += res?;
438
439 if header_read_len != TLS_HEADER_SIZE {
440 continue;
443 }
444
445 header_read_len = 0;
448 header_write_len = 0;
449
450 let mut size: [u8; 2] = Default::default();
452 size.copy_from_slice(&header_buf[3..5]);
453 let data_size = u16::from_be_bytes(size) as usize;
454 tracing::debug!(
455 "read header with type {} and length {}",
456 header_buf[0],
457 data_size
458 );
459
460 if header_buf[0] != APPLICATION_DATA
462 || !has_seen_handshake
463 || !has_seen_change_cipher_spec
464 || data_size < HMAC_SIZE_V2
465 {
466 let valid = (has_seen_handshake || header_buf[0] == HANDSHAKE)
469 && header_buf[1] == TLS_MAJOR
470 && (header_buf[2] == TLS_MINOR.0 || header_buf[2] == TLS_MINOR.1);
471 if header_buf[0] == CHANGE_CIPHER_SPEC {
472 has_seen_change_cipher_spec = true;
473 }
474 if header_buf[0] == HANDSHAKE {
475 has_seen_handshake = true;
476 }
477 let mut to_copy = data_size;
479 while to_copy != 0 {
480 let max_read = data_buf.capacity().min(to_copy);
481 let buf = SliceMut::new(data_buf, 0, max_read);
482 let (read_res, buf) = read_half.read(buf).await;
483
484 let read_len = read_res?;
486 if read_len == 0 {
487 let _ = write_half.shutdown().await;
488 return Err(std::io::ErrorKind::UnexpectedEof.into());
489 }
490
491 let buf = buf.into_inner().slice(0..read_len);
492 let (write_res, buf) = write_half.write_all(buf).await;
493 to_copy -= write_res?;
494 data_buf = buf.into_inner();
495 }
496 tracing::debug!("copied data with length {:?}", data_size);
497 if !valid {
498 tracing::debug!("early invalid tls: header {:?}", &header_buf[..3]);
499 return Ok(SwitchResult::DirectProxy);
500 }
501 continue;
502 }
503
504 let mut hmac_read_len = 0;
514 while hmac_read_len < HMAC_SIZE_V2 {
515 let buf = SliceMut::new(data_hmac_buf, hmac_read_len, HMAC_SIZE_V2);
516 let (res, buf_) = read_half.read(buf).await;
517 let read_len = res?;
519 if read_len == 0 {
520 let _ = write_half.shutdown().await;
521 return Err(std::io::ErrorKind::UnexpectedEof.into());
522 }
523
524 let buf = Slice::new(buf_.into_inner(), hmac_read_len, hmac_read_len + read_len);
525 let (write_res, buf_) = write_half.write_all(buf).await;
526 write_res?;
527 hmac_read_len += read_len;
528 data_hmac_buf = buf_.into_inner();
529 }
530
531 let hash = hmac.hash();
534 let mut hash_trim = [0; HMAC_SIZE_V2];
535 unsafe { copy_nonoverlapping(hash.as_ptr(), hash_trim.as_mut_ptr(), HMAC_SIZE_V2) };
536 tracing::debug!("hmac calculated: {hash_trim:?}");
537 if hashes.len() + 1 > hashes.capacity() {
538 hashes.pop_front();
539 }
540 hashes.push_back(hash_trim);
541 unsafe {
542 copy_nonoverlapping(data_hmac_buf.as_ptr(), hash_trim.as_mut_ptr(), HMAC_SIZE_V2)
543 };
544 if hashes.contains(&hash_trim) {
545 tracing::debug!("hmac matches");
546 let pure_data = vec![0; data_size - HMAC_SIZE_V2];
547 let (read_res, pure_data) = read_half.read_exact(pure_data).await;
548 read_res?;
549 return Ok(SwitchResult::Switch(pure_data));
550 }
551
552 application_data_count += 1;
554 let mut to_copy = data_size - HMAC_SIZE_V2;
555 while to_copy != 0 {
556 let max_read = data_buf.capacity().min(to_copy);
557 let buf = SliceMut::new(data_buf, 0, max_read);
558 let (read_res, buf) = read_half.read(buf).await;
559
560 let read_len = read_res?;
562 if read_len == 0 {
563 let _ = write_half.shutdown().await;
564 return Err(std::io::ErrorKind::UnexpectedEof.into());
565 }
566
567 let buf = buf.into_inner().slice(0..read_len);
568 let (write_res, buf) = write_half.write_all(buf).await;
569 to_copy -= write_res?;
570 data_buf = buf.into_inner();
571 }
572
573 if application_data_count > 3 {
574 tracing::debug!("hmac not matches after 3 times, fallback to direct");
575 return Ok(SwitchResult::DirectProxy);
576 }
577 }
578}
579
580async fn extract_sni_v2<R: AsyncReadRent>(mut r: R) -> std::io::Result<(Vec<u8>, Option<Vec<u8>>)> {
585 macro_rules! read_ok {
586 ($res: expr, $data: expr) => {
587 match $res {
588 Ok(r) => r,
589 Err(_) => {
590 return Ok(($data, None));
591 }
592 }
593 };
594 }
595
596 let header = vec![0; TLS_HEADER_SIZE];
597 let (res, header) = r.read_exact(header).await;
598 res?;
599
600 if header[0] != HANDSHAKE
602 || header[1] != TLS_MAJOR
603 || (header[2] != TLS_MINOR.0 && header[2] != TLS_MINOR.1)
604 {
605 return Ok((header, None));
606 }
607
608 let mut size: [u8; 2] = Default::default();
610 size.copy_from_slice(&header[3..5]);
611 let data_size = u16::from_be_bytes(size);
612 tracing::debug!("read handshake length {}", data_size);
613
614 let mut data = vec![0; data_size as usize + TLS_HEADER_SIZE];
616 unsafe { copy_nonoverlapping(header.as_ptr(), data.as_mut_ptr(), TLS_HEADER_SIZE) };
617 let (res, data_slice) = r.read_exact(data.slice_mut(TLS_HEADER_SIZE..)).await;
618 res?;
619
620 let data_slice: SliceMut<Vec<u8>> = data_slice;
622 let data = data_slice.into_inner();
623 let mut cursor = std::io::Cursor::new(&data[TLS_HEADER_SIZE..]);
624 if read_ok!(cursor.read_u8(), data) != CLIENT_HELLO {
625 tracing::debug!("first packet is not client hello");
626 return Ok((data, None));
627 }
628 if read_ok!(cursor.read_u8(), data) != 0 {
630 tracing::debug!("client hello length first byte is not zero");
631 return Ok((data, None));
632 }
633 let prot_size = read_ok!(cursor.read_u16::<BigEndian>(), data);
635 if prot_size + 4 > data_size {
636 tracing::debug!("invalid client hello length");
637 return Ok((data, None));
638 }
639 let mut cursor = std::io::Cursor::new(
641 &data[TLS_HMAC_HEADER_SIZE..TLS_HMAC_HEADER_SIZE + prot_size as usize],
642 );
643 read_ok!(cursor.read_u16::<BigEndian>(), data);
645 read_ok!(cursor.skip(TLS_RANDOM_SIZE), data);
647 read_ok!(cursor.skip_by_u8(), data);
649 read_ok!(cursor.skip_by_u16(), data);
651 read_ok!(cursor.skip_by_u8(), data);
653 read_ok!(cursor.read_u16::<BigEndian>(), data);
655
656 loop {
657 let ext_type = read_ok!(cursor.read_u16::<BigEndian>(), data);
658 if ext_type != SNI_EXT_TYPE {
659 read_ok!(cursor.skip_by_u16(), data);
660 continue;
661 }
662 tracing::debug!("found server_name extension");
663 let _ext_len = read_ok!(cursor.read_u16::<BigEndian>(), data);
664 let _sni_len = read_ok!(cursor.read_u16::<BigEndian>(), data);
665 if read_ok!(cursor.read_u8(), data) != 0 {
667 return Ok((data, None));
668 }
669 let sni = Some(read_ok!(cursor.read_by_u16(), data));
670 return Ok((data, sni));
671 }
672}
673
674async fn read_exact_frame(r: impl AsyncReadRent) -> std::io::Result<Vec<u8>> {
678 read_exact_frame_into(r, Vec::new()).await
679}
680
681async fn read_exact_frame_into(
685 mut r: impl AsyncReadRent,
686 mut buffer: Vec<u8>,
687) -> std::io::Result<Vec<u8>> {
688 unsafe { buffer.set_len(0) };
689 buffer.reserve(TLS_HEADER_SIZE);
690 let (res, header) = r.read_exact(buffer.slice_mut(..TLS_HEADER_SIZE)).await;
691 res?;
692 let mut buffer = header.into_inner();
693
694 let mut size: [u8; 2] = Default::default();
696 size.copy_from_slice(&buffer[3..5]);
697 let data_size = u16::from_be_bytes(size) as usize;
698
699 buffer.reserve(data_size);
701 let (res, data_slice) = r
702 .read_exact(buffer.slice_mut(TLS_HEADER_SIZE..TLS_HEADER_SIZE + data_size))
703 .await;
704 res?;
705
706 Ok(data_slice.into_inner())
707}
708
709fn verified_extract_sni(frame: &[u8], password: &str) -> (bool, Option<Vec<u8>>) {
716 const MIN_LEN: usize = TLS_HEADER_SIZE + 1 + 3 + 2 + TLS_RANDOM_SIZE + 1 + TLS_SESSION_ID_SIZE;
718 const HMAC_IDX: usize = SESSION_ID_LEN_IDX + 1 + TLS_SESSION_ID_SIZE - HMAC_SIZE;
719 const ZERO4B: [u8; HMAC_SIZE] = [0; HMAC_SIZE];
720
721 if frame.len() < SESSION_ID_LEN_IDX || frame[0] != HANDSHAKE || frame[5] != CLIENT_HELLO {
722 return (false, None);
723 }
724
725 let pass = if frame.len() < MIN_LEN || frame[SESSION_ID_LEN_IDX] != TLS_SESSION_ID_SIZE as u8 {
726 false
727 } else {
728 let mut hmac = Hmac::new(password, (&[], &[]));
729 hmac.update(&frame[TLS_HEADER_SIZE..HMAC_IDX]);
730 hmac.update(&ZERO4B);
731 hmac.update(&frame[HMAC_IDX + HMAC_SIZE..]);
732 hmac.finalize() == frame[HMAC_IDX..HMAC_IDX + HMAC_SIZE]
733 };
734
735 let mut cursor = std::io::Cursor::new(&frame[SESSION_ID_LEN_IDX..]);
736 macro_rules! read_ok {
737 ($res: expr) => {
738 match $res {
739 Ok(r) => r,
740 Err(_) => {
741 return (pass, None);
742 }
743 }
744 };
745 }
746
747 read_ok!(cursor.skip_by_u8());
749 read_ok!(cursor.skip_by_u16());
751 read_ok!(cursor.skip_by_u8());
753 read_ok!(cursor.read_u16::<BigEndian>());
755
756 loop {
757 let ext_type = read_ok!(cursor.read_u16::<BigEndian>());
758 if ext_type != SNI_EXT_TYPE {
759 read_ok!(cursor.skip_by_u16());
760 continue;
761 }
762 tracing::debug!("found server_name extension");
763 let _ext_len = read_ok!(cursor.read_u16::<BigEndian>());
764 let _sni_len = read_ok!(cursor.read_u16::<BigEndian>());
765 if read_ok!(cursor.read_u8()) != 0 {
767 return (pass, None);
768 }
769 let sni = Some(read_ok!(cursor.read_by_u16()));
770 return (pass, sni);
771 }
772}
773
774fn extract_server_random(frame: &[u8]) -> Option<[u8; TLS_RANDOM_SIZE]> {
779 const MIN_LEN: usize = TLS_HEADER_SIZE + 1 + 3 + 2 + TLS_RANDOM_SIZE;
781
782 if frame.len() < MIN_LEN || frame[0] != HANDSHAKE || frame[5] != SERVER_HELLO {
783 return None;
784 }
785
786 let mut server_random = [0; TLS_RANDOM_SIZE];
787 unsafe {
788 copy_nonoverlapping(
789 frame.as_ptr().add(SERVER_RANDOM_IDX),
790 server_random.as_mut_ptr(),
791 TLS_RANDOM_SIZE,
792 )
793 };
794
795 Some(server_random)
796}
797
798async fn copy_by_frame_until_hmac_matches(
803 mut read: impl AsyncReadRent,
804 mut write: impl AsyncWriteRent,
805 hmac: &mut Hmac,
806) -> std::io::Result<Vec<u8>> {
807 let mut g_buffer = Vec::new();
808
809 loop {
810 let buffer = read_exact_frame_into(&mut read, g_buffer).await?;
811 if buffer.len() > 9 && buffer[0] == APPLICATION_DATA {
812 let mut tmp_hmac = hmac.to_owned();
814 tmp_hmac.update(&buffer[TLS_HMAC_HEADER_SIZE..]);
815 let h = tmp_hmac.finalize();
816
817 if buffer[TLS_HEADER_SIZE..TLS_HMAC_HEADER_SIZE] == h {
818 hmac.update(&buffer[TLS_HMAC_HEADER_SIZE..]);
819 hmac.update(&buffer[TLS_HEADER_SIZE..TLS_HMAC_HEADER_SIZE]);
820 return Ok(buffer[TLS_HMAC_HEADER_SIZE..].to_vec());
821 }
822 }
823
824 let (res, buffer) = write.write_all(buffer).await;
825 res?;
826 g_buffer = buffer;
827 }
828}
829
830async fn copy_by_frame_with_modification(
837 mut read: impl AsyncReadRent,
838 mut write: impl AsyncWriteRent,
839 hmac: &mut Hmac,
840 xor: &[u8],
841 stop: &mut Sender<()>,
842) -> std::io::Result<()> {
843 let mut g_buffer = Vec::new();
844 let stop = stop.closed();
845 monoio::pin!(stop);
846
847 loop {
848 monoio::select! {
849 _ = &mut stop => {
851 return Ok(());
852 },
853 buffer_res = read_exact_frame_into(&mut read, g_buffer) => {
854 let mut buffer = buffer_res?;
855 if buffer[0] == APPLICATION_DATA {
857 xor_slice(&mut buffer[TLS_HEADER_SIZE..], xor);
859 hmac.update(&buffer[TLS_HEADER_SIZE..]);
860 let hash = hmac.finalize();
861 buffer.extend_from_slice(&hash);
862 unsafe {
863 copy(buffer.as_ptr().add(TLS_HEADER_SIZE), buffer.as_mut_ptr().add(TLS_HMAC_HEADER_SIZE), buffer.len() - TLS_HMAC_HEADER_SIZE);
864 copy_nonoverlapping(hash.as_ptr(), buffer.as_mut_ptr().add(TLS_HEADER_SIZE), HMAC_SIZE);
865 }
866
867 let mut size: [u8; 2] = Default::default();
868 size.copy_from_slice(&buffer[3..5]);
869 let data_size = u16::from_be_bytes(size);
870 let data_size = data_size.wrapping_add(HMAC_SIZE as u16);
872 (&mut buffer[3..5]).write_u16::<BigEndian>(data_size).unwrap();
873 }
874
875 let (res, buffer) = write.write_all(buffer).await;
877 res?;
878 g_buffer = buffer;
879 }
880 }
881 }
882}
883
884#[cfg(test)]
885mod tests {
886 use super::*;
887
888 fn to_map<K: Into<String>, V: Into<String>>(
889 kvs: Vec<(K, V)>,
890 ) -> rustc_hash::FxHashMap<String, String> {
891 kvs.into_iter().map(|(k, v)| (k.into(), v.into())).collect()
892 }
893
894 macro_rules! map {
895 [] => {rustc_hash::FxHashMap::<String, String>::default()};
896 [$($k:expr => $v:expr),*] => {to_map(vec![$(($k.to_owned(), $v.to_owned())), *])};
897 [$($k:expr => $v:expr,)*] => {to_map(vec![$(($k.to_owned(), $v.to_owned())), *])};
898 }
899
900 macro_rules! s {
901 ($v:expr) => {
902 $v.to_string()
903 };
904 }
905
906 #[test]
907 fn parse_tls_addrs() {
908 assert_eq!(
909 TlsAddrs::try_from("google.com").unwrap(),
910 TlsAddrs {
911 dispatch: map![],
912 fallback: s!("google.com:443"),
913 wildcard_sni: Default::default(),
914 }
915 );
916 assert_eq!(
917 TlsAddrs::try_from("feishu.cn;cloudflare.com:1.1.1.1:80;google.com").unwrap(),
918 TlsAddrs {
919 dispatch: map![
920 "feishu.cn" => "feishu.cn:443",
921 "cloudflare.com" => "1.1.1.1:80",
922 ],
923 fallback: s!("google.com:443"),
924 wildcard_sni: Default::default(),
925 }
926 );
927 assert_eq!(
928 TlsAddrs::try_from("captive.apple.com;feishu.cn:80").unwrap(),
929 TlsAddrs {
930 dispatch: map![
931 "captive.apple.com" => "captive.apple.com:443",
932 ],
933 fallback: s!("feishu.cn:80"),
934 wildcard_sni: Default::default(),
935 }
936 );
937 }
938}