1use std::{
2 ptr::{copy, copy_nonoverlapping},
3 rc::Rc,
4 sync::Arc,
5};
6
7use anyhow::bail;
8use byteorder::{BigEndian, WriteBytesExt};
9use monoio::{
10 buf::IoBufMut,
11 io::{AsyncReadRent, AsyncReadRentExt, AsyncWriteRent, AsyncWriteRentExt, Splitable},
12 net::TcpStream,
13};
14use monoio_rustls_fork_shadow_tls::TlsConnector;
15use rand::{prelude::Distribution, seq::SliceRandom, Rng};
16use rustls_fork_shadow_tls::{OwnedTrustAnchor, RootCertStore, ServerName};
17
18use crate::{
19 helper_v2::{copy_with_application_data, copy_without_application_data, HashedReadStream},
20 util::{
21 bind_with_pretty_error, kdf, mod_tcp_conn, prelude::*, support_tls13, verified_relay,
22 xor_slice, Hmac, V3Mode,
23 },
24};
25
26const FAKE_REQUEST_LENGTH_RANGE: (usize, usize) = (16, 64);
27
28#[derive(Clone)]
30pub struct ShadowTlsClient<LA, TA> {
31 listen_addr: Arc<LA>,
32 target_addr: Arc<TA>,
33 tls_connector: TlsConnector,
34 tls_names: Arc<TlsNames>,
35 password: Arc<String>,
36 nodelay: bool,
37 v3: V3Mode,
38}
39
40#[derive(Clone, Debug, PartialEq)]
41pub struct TlsNames(Vec<ServerName>);
42
43impl TlsNames {
44 #[inline]
45 pub fn random_choose(&self) -> &ServerName {
46 self.0.choose(&mut rand::thread_rng()).unwrap()
47 }
48}
49
50impl TryFrom<&str> for TlsNames {
51 type Error = anyhow::Error;
52
53 fn try_from(value: &str) -> Result<Self, Self::Error> {
54 let v: Result<Vec<_>, _> = value.trim().split(';').map(ServerName::try_from).collect();
55 let v = v.map_err(Into::into).and_then(|v| {
56 if v.is_empty() {
57 Err(anyhow::anyhow!("empty tls names"))
58 } else {
59 Ok(v)
60 }
61 })?;
62 Ok(Self(v))
63 }
64}
65
66impl std::fmt::Display for TlsNames {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "{:?}", self.0)
69 }
70}
71
72#[derive(Default, Debug)]
73pub struct TlsExtConfig {
74 alpn: Option<Vec<Vec<u8>>>,
75}
76
77impl TlsExtConfig {
78 #[allow(unused)]
79 #[inline]
80 pub fn new(alpn: Option<Vec<Vec<u8>>>) -> TlsExtConfig {
81 TlsExtConfig { alpn }
82 }
83}
84
85impl From<Option<Vec<String>>> for TlsExtConfig {
86 fn from(maybe_alpns: Option<Vec<String>>) -> Self {
87 Self {
88 alpn: maybe_alpns.map(|alpns| alpns.into_iter().map(Into::into).collect()),
89 }
90 }
91}
92
93impl std::fmt::Display for TlsExtConfig {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self.alpn.as_ref() {
96 Some(alpns) => {
97 write!(f, "ALPN(Some(")?;
98 for alpn in alpns.iter() {
99 write!(f, "{},", String::from_utf8_lossy(alpn))?;
100 }
101 write!(f, "))")?;
102 }
103 None => {
104 write!(f, "ALPN(None)")?;
105 }
106 }
107 Ok(())
108 }
109}
110
111impl<LA, TA> ShadowTlsClient<LA, TA> {
112 pub fn new(
114 listen_addr: LA,
115 target_addr: TA,
116 tls_names: TlsNames,
117 tls_ext_config: TlsExtConfig,
118 password: String,
119 nodelay: bool,
120 v3: V3Mode,
121 ) -> anyhow::Result<Self> {
122 let mut root_store = RootCertStore::empty();
123 root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
124 OwnedTrustAnchor::from_subject_spki_name_constraints(
125 ta.subject,
126 ta.spki,
127 ta.name_constraints,
128 )
129 }));
130 let mut tls_config = rustls_fork_shadow_tls::ClientConfig::builder()
132 .with_safe_defaults()
133 .with_root_certificates(root_store)
134 .with_no_client_auth();
135
136 if let Some(alpn) = tls_ext_config.alpn {
138 tls_config.alpn_protocols = alpn;
139 }
140
141 let tls_connector = TlsConnector::from(tls_config);
142
143 Ok(Self {
144 listen_addr: Arc::new(listen_addr),
145 target_addr: Arc::new(target_addr),
146 tls_connector,
147 tls_names: Arc::new(tls_names),
148 password: Arc::new(password),
149 nodelay,
150 v3,
151 })
152 }
153
154 pub async fn serve(self) -> anyhow::Result<()>
156 where
157 LA: std::net::ToSocketAddrs + 'static,
158 TA: std::net::ToSocketAddrs + 'static,
159 {
160 let listener = bind_with_pretty_error(self.listen_addr.as_ref())?;
161 let shared = Rc::new(self);
162 loop {
163 match listener.accept().await {
164 Ok((mut conn, addr)) => {
165 tracing::info!("Accepted a connection from {addr}");
166 let client = shared.clone();
167 mod_tcp_conn(&mut conn, true, shared.nodelay);
168 monoio::spawn(async move {
169 let _ = match client.v3.enabled() {
170 false => client.relay_v2(conn).await,
171 true => client.relay_v3(conn).await,
172 };
173 tracing::info!("Relay for {addr} finished");
174 });
175 }
176 Err(e) => {
177 tracing::error!("Accept failed: {e}");
178 }
179 }
180 }
181 }
182
183 async fn relay_v2(&self, mut in_stream: TcpStream) -> anyhow::Result<()>
185 where
186 TA: std::net::ToSocketAddrs,
187 {
188 let (mut out_stream, hash, session) = self.connect_v2().await?;
189 let mut hash_8b = [0; 8];
190 unsafe { std::ptr::copy_nonoverlapping(hash.as_ptr(), hash_8b.as_mut_ptr(), 8) };
191 let (out_r, mut out_w) = out_stream.split();
192 let (mut in_r, mut in_w) = in_stream.split();
193 let mut session_filtered_out_r = crate::helper_v2::SessionFilterStream::new(session, out_r);
194 let (a, b) = monoio::join!(
195 copy_without_application_data(&mut session_filtered_out_r, &mut in_w),
196 copy_with_application_data(&mut in_r, &mut out_w, Some(hash_8b))
197 );
198 let (_, _) = (a?, b?);
199 Ok(())
200 }
201
202 async fn relay_v3(&self, in_stream: TcpStream) -> anyhow::Result<()>
204 where
205 TA: std::net::ToSocketAddrs,
206 {
207 let mut stream = TcpStream::connect(self.target_addr.as_ref()).await?;
208 mod_tcp_conn(&mut stream, true, self.nodelay);
209 tracing::debug!("tcp connected, start handshaking");
210
211 let hamc_sr = Hmac::new(&self.password, (&[], &[]));
213 let stream = StreamWrapper::new(stream, &self.password);
214 let sni = self.tls_names.random_choose().clone();
215 let tls_stream = self
216 .tls_connector
217 .connect_with_session_id_generator(sni, stream, move |data| {
218 generate_session_id(&hamc_sr, data)
219 })
220 .await?;
221 tracing::debug!("handshake success");
222 let (stream, session) = tls_stream.into_parts();
223 let authorized = stream.authorized();
224 let tls13 = stream.tls13;
225 let maybe_srh = stream
226 .state()
227 .as_ref()
228 .map(|s| (s.server_random, s.hmac.to_owned()));
229 let stream = stream.into_inner();
230
231 if (maybe_srh.is_none() || !authorized) && self.v3.strict() {
233 tracing::warn!("V3 strict enabled: traffic hijacked or TLS1.3 is not supported");
234 let tls_stream = monoio_rustls_fork_shadow_tls::ClientTlsStream::new(stream, session);
235 if let Err(e) = fake_request(tls_stream).await {
236 bail!("traffic hijacked or TLS1.3 is not supported, fake request fail: {e}");
237 }
238 bail!("traffic hijacked or TLS1.3 is not supported, but fake request success");
239 }
240
241 drop(session);
242 let (sr, hmac_sr) = maybe_srh.unwrap();
243 tracing::debug!("Authorized, ServerRandom extracted: {sr:?}");
244 let hmac_sr_s = Hmac::new(&self.password, (&sr, b"S"));
245 let hmac_sr_c = Hmac::new(&self.password, (&sr, b"C"));
246
247 verified_relay(
248 in_stream,
249 stream,
250 hmac_sr_c,
251 hmac_sr_s,
252 Some(hmac_sr),
253 !tls13,
254 )
255 .await;
256 Ok(())
257 }
258
259 async fn connect_v2(
263 &self,
264 ) -> anyhow::Result<(
265 TcpStream,
266 [u8; 20],
267 rustls_fork_shadow_tls::ClientConnection,
268 )>
269 where
270 TA: std::net::ToSocketAddrs,
271 {
272 let mut stream = TcpStream::connect(self.target_addr.as_ref()).await?;
273 mod_tcp_conn(&mut stream, true, self.nodelay);
274 tracing::debug!("tcp connected, start handshaking");
275 let stream = HashedReadStream::new(stream, self.password.as_bytes())?;
276 let sni = self.tls_names.random_choose().clone();
277 let tls_stream = self.tls_connector.connect(sni, stream).await?;
278 let (io, session) = tls_stream.into_parts();
279 let hash = io.hash();
280 tracing::debug!("tls handshake finished, signed hmac: {:?}", hash);
281 let stream = io.into_inner();
282 Ok((stream, hash, session))
283 }
284}
285
286struct StreamWrapper<S> {
290 raw: S,
291 password: String,
292
293 read_buf: Option<Vec<u8>>,
294 read_pos: usize,
295
296 read_state: Option<State>,
297 read_authorized: bool,
298 tls13: bool,
299}
300
301#[derive(Clone)]
302struct State {
303 server_random: [u8; TLS_RANDOM_SIZE],
304 hmac: Hmac,
305 key: Vec<u8>,
306}
307
308impl<S> StreamWrapper<S> {
309 fn new(raw: S, password: &str) -> Self {
310 Self {
311 raw,
312 password: password.to_string(),
313
314 read_buf: Some(Vec::new()),
315 read_pos: 0,
316
317 read_state: None,
318 read_authorized: false,
319 tls13: false,
320 }
321 }
322
323 fn authorized(&self) -> bool {
324 self.read_authorized
325 }
326
327 fn state(&self) -> &Option<State> {
328 &self.read_state
329 }
330
331 fn into_inner(self) -> S {
332 self.raw
333 }
334}
335
336impl<S: AsyncReadRent> StreamWrapper<S> {
337 async fn feed_data(&mut self) -> std::io::Result<usize> {
338 let mut buf = self.read_buf.take().unwrap();
339
340 unsafe { buf.set_init(0) };
342 self.read_pos = 0;
343 buf.reserve(TLS_HEADER_SIZE);
344 let (res, buf) = self.raw.read_exact(buf.slice_mut(0..TLS_HEADER_SIZE)).await;
345 match res {
346 Ok(_) => (),
347 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
348 tracing::debug!("stream wrapper eof");
349 self.read_buf = Some(buf.into_inner());
350 return Ok(0);
351 }
352 Err(e) => {
353 tracing::error!("stream wrapper unable to read tls header: {e}");
354 self.read_buf = Some(buf.into_inner());
355 return Err(e);
356 }
357 }
358 let mut buf: Vec<u8> = buf.into_inner();
359 let mut size: [u8; 2] = Default::default();
360 size.copy_from_slice(&buf[3..5]);
361 let data_size = u16::from_be_bytes(size) as usize;
362
363 buf.reserve(data_size);
365 let (res, buf) = self
366 .raw
367 .read_exact(buf.slice_mut(TLS_HEADER_SIZE..TLS_HEADER_SIZE + data_size))
368 .await;
369 if let Err(e) = res {
370 self.read_buf = Some(buf.into_inner());
371 return Err(e);
372 }
373 let mut buf: Vec<u8> = buf.into_inner();
374
375 match buf[0] {
377 HANDSHAKE => {
378 if buf.len() > SERVER_RANDOM_IDX + TLS_RANDOM_SIZE && buf[5] == SERVER_HELLO {
379 let mut server_random = [0; TLS_RANDOM_SIZE];
381 unsafe {
382 copy_nonoverlapping(
383 buf.as_ptr().add(SERVER_RANDOM_IDX),
384 server_random.as_mut_ptr(),
385 TLS_RANDOM_SIZE,
386 )
387 }
388 tracing::debug!("ServerRandom extracted: {server_random:?}");
389 let hmac = Hmac::new(&self.password, (&server_random, &[]));
390 let key = kdf(&self.password, &server_random);
391 self.read_state = Some(State {
392 server_random,
393 hmac,
394 key,
395 });
396 self.tls13 = support_tls13(&buf);
397 }
398 }
399 APPLICATION_DATA => {
400 self.read_authorized = false;
401 if buf.len() > TLS_HMAC_HEADER_SIZE {
402 if let Some(State { hmac, key, .. }) = self.read_state.as_mut() {
403 hmac.update(&buf[TLS_HMAC_HEADER_SIZE..]);
404 if hmac.finalize() == buf[TLS_HEADER_SIZE..TLS_HMAC_HEADER_SIZE] {
405 xor_slice(&mut buf[TLS_HMAC_HEADER_SIZE..], key);
406 unsafe {
407 copy(
408 buf.as_ptr().add(TLS_HMAC_HEADER_SIZE),
409 buf.as_mut_ptr().add(5),
410 buf.len() - 9,
411 )
412 };
413 (&mut buf[3..5])
414 .write_u16::<BigEndian>(data_size as u16 - HMAC_SIZE as u16)
415 .unwrap();
416 unsafe { buf.set_init(buf.len() - HMAC_SIZE) };
417 self.read_authorized = true;
418 } else {
419 tracing::debug!("app data verification failed");
420 }
421 }
422 }
423 }
424 _ => {}
425 }
426
427 let buf_len = buf.len();
429 self.read_buf = Some(buf);
430 Ok(buf_len)
431 }
432}
433
434impl<S: AsyncWriteRent> AsyncWriteRent for StreamWrapper<S> {
435 type WriteFuture<'a, T> = S::WriteFuture<'a, T> where
436 T: monoio::buf::IoBuf + 'a, Self: 'a;
437 type WritevFuture<'a, T>= S::WritevFuture<'a, T> where
438 T: monoio::buf::IoVecBuf + 'a, Self: 'a;
439 type FlushFuture<'a> = S::FlushFuture<'a> where Self: 'a;
440 type ShutdownFuture<'a> = S::ShutdownFuture<'a> where Self: 'a;
441
442 fn write<T: monoio::buf::IoBuf>(&mut self, buf: T) -> Self::WriteFuture<'_, T> {
443 self.raw.write(buf)
444 }
445 fn writev<T: monoio::buf::IoVecBuf>(&mut self, buf_vec: T) -> Self::WritevFuture<'_, T> {
446 self.raw.writev(buf_vec)
447 }
448 fn flush(&mut self) -> Self::FlushFuture<'_> {
449 self.raw.flush()
450 }
451 fn shutdown(&mut self) -> Self::ShutdownFuture<'_> {
452 self.raw.shutdown()
453 }
454}
455
456impl<S: AsyncReadRent> AsyncReadRent for StreamWrapper<S> {
457 type ReadFuture<'a, B> = impl std::future::Future<Output = monoio::BufResult<usize, B>> +'a where
458 B: monoio::buf::IoBufMut + 'a, S: 'a;
459 type ReadvFuture<'a, B> = impl std::future::Future<Output = monoio::BufResult<usize, B>> +'a where
460 B: monoio::buf::IoVecBufMut + 'a, S: 'a;
461
462 fn read<T: monoio::buf::IoBufMut>(&mut self, mut buf: T) -> Self::ReadFuture<'_, T> {
464 async move {
465 loop {
466 let owned_buf = self.read_buf.as_mut().unwrap();
467 let data_len = owned_buf.len() - self.read_pos;
468 if data_len > 0 {
470 let to_copy = buf.bytes_total().min(data_len);
471 unsafe {
472 copy_nonoverlapping(
473 owned_buf.as_ptr().add(self.read_pos),
474 buf.write_ptr(),
475 to_copy,
476 );
477 buf.set_init(to_copy);
478 };
479 self.read_pos += to_copy;
480 return (Ok(to_copy), buf);
481 }
482
483 match self.feed_data().await {
485 Ok(0) => return (Ok(0), buf),
486 Ok(_) => continue,
487 Err(e) => return (Err(e), buf),
488 }
489 }
490 }
491 }
492
493 fn readv<T: monoio::buf::IoVecBufMut>(&mut self, mut buf: T) -> Self::ReadvFuture<'_, T> {
494 async move {
495 let slice = match monoio::buf::IoVecWrapperMut::new(buf) {
496 Ok(slice) => slice,
497 Err(buf) => return (Ok(0), buf),
498 };
499
500 let (result, slice) = self.read(slice).await;
501 buf = slice.into_inner();
502 if let Ok(n) = result {
503 unsafe { buf.set_init(n) };
504 }
505 (result, buf)
506 }
507 }
508}
509
510async fn fake_request(
514 mut stream: monoio_rustls_fork_shadow_tls::ClientTlsStream<TcpStream>,
515) -> std::io::Result<()> {
516 const HEADER: &[u8; 207] = b"GET / HTTP/1.1\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36\nAccept: gzip, deflate, br\nConnection: Close\nCookie: sessionid=";
517 let cnt =
518 rand::thread_rng().gen_range(FAKE_REQUEST_LENGTH_RANGE.0..FAKE_REQUEST_LENGTH_RANGE.1);
519 let mut buffer = Vec::with_capacity(cnt + HEADER.len() + 1);
520
521 buffer.extend_from_slice(HEADER);
522 rand::distributions::Alphanumeric
523 .sample_iter(rand::thread_rng())
524 .take(cnt)
525 .for_each(|c| buffer.push(c));
526 buffer.push(b'\n');
527
528 let (res, mut buf) = stream.write_all(buffer).await;
529 res?;
530 let _ = stream.shutdown().await;
531
532 loop {
534 let (res, b) = stream.read(buf).await;
535 buf = b;
536 if res? == 0 {
537 return Ok(());
538 }
539 }
540}
541
542fn generate_session_id(hmac: &Hmac, buf: &[u8]) -> [u8; TLS_SESSION_ID_SIZE] {
546 const SESSION_ID_START: usize = 1 + 3 + 2 + TLS_RANDOM_SIZE + 1;
548
549 if buf.len() < SESSION_ID_START + TLS_SESSION_ID_SIZE {
550 tracing::warn!("unexpected client hello length");
551 return [0; TLS_SESSION_ID_SIZE];
552 }
553
554 let mut session_id = [0; TLS_SESSION_ID_SIZE];
555 rand::thread_rng().fill(&mut session_id[..TLS_SESSION_ID_SIZE - HMAC_SIZE]);
556 let mut hmac = hmac.to_owned();
557 hmac.update(&buf[0..SESSION_ID_START]);
558 hmac.update(&session_id);
559 hmac.update(&buf[SESSION_ID_START + TLS_SESSION_ID_SIZE..]);
560 let hmac_val = hmac.finalize();
561 unsafe {
562 copy_nonoverlapping(
563 hmac_val.as_ptr(),
564 session_id.as_mut_ptr().add(TLS_SESSION_ID_SIZE - HMAC_SIZE),
565 HMAC_SIZE,
566 )
567 }
568 tracing::debug!("ClientHello before sign: {buf:?}, session_id {session_id:?}");
569 session_id
570}