shadow_tls/
client.rs

1use std::{
2    ptr::{copy, copy_nonoverlapping},
3    rc::Rc,
4    sync::Arc,
5};
6
7use anyhow::bail;
8use byteorder::{BigEndian, WriteBytesExt};
9use monoio::{
10    buf::IoBufMut,
11    io::{AsyncReadRent, AsyncReadRentExt, AsyncWriteRent, AsyncWriteRentExt, Splitable},
12    net::TcpStream,
13};
14use monoio_rustls_fork_shadow_tls::TlsConnector;
15use rand::{prelude::Distribution, seq::SliceRandom, Rng};
16use rustls_fork_shadow_tls::{OwnedTrustAnchor, RootCertStore, ServerName};
17
18use crate::{
19    helper_v2::{copy_with_application_data, copy_without_application_data, HashedReadStream},
20    util::{
21        bind_with_pretty_error, kdf, mod_tcp_conn, prelude::*, support_tls13, verified_relay,
22        xor_slice, Hmac, V3Mode,
23    },
24};
25
26const FAKE_REQUEST_LENGTH_RANGE: (usize, usize) = (16, 64);
27
28/// ShadowTlsClient.
29#[derive(Clone)]
30pub struct ShadowTlsClient<LA, TA> {
31    listen_addr: Arc<LA>,
32    target_addr: Arc<TA>,
33    tls_connector: TlsConnector,
34    tls_names: Arc<TlsNames>,
35    password: Arc<String>,
36    nodelay: bool,
37    v3: V3Mode,
38}
39
40#[derive(Clone, Debug, PartialEq)]
41pub struct TlsNames(Vec<ServerName>);
42
43impl TlsNames {
44    #[inline]
45    pub fn random_choose(&self) -> &ServerName {
46        self.0.choose(&mut rand::thread_rng()).unwrap()
47    }
48}
49
50impl TryFrom<&str> for TlsNames {
51    type Error = anyhow::Error;
52
53    fn try_from(value: &str) -> Result<Self, Self::Error> {
54        let v: Result<Vec<_>, _> = value.trim().split(';').map(ServerName::try_from).collect();
55        let v = v.map_err(Into::into).and_then(|v| {
56            if v.is_empty() {
57                Err(anyhow::anyhow!("empty tls names"))
58            } else {
59                Ok(v)
60            }
61        })?;
62        Ok(Self(v))
63    }
64}
65
66impl std::fmt::Display for TlsNames {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(f, "{:?}", self.0)
69    }
70}
71
72#[derive(Default, Debug)]
73pub struct TlsExtConfig {
74    alpn: Option<Vec<Vec<u8>>>,
75}
76
77impl TlsExtConfig {
78    #[allow(unused)]
79    #[inline]
80    pub fn new(alpn: Option<Vec<Vec<u8>>>) -> TlsExtConfig {
81        TlsExtConfig { alpn }
82    }
83}
84
85impl From<Option<Vec<String>>> for TlsExtConfig {
86    fn from(maybe_alpns: Option<Vec<String>>) -> Self {
87        Self {
88            alpn: maybe_alpns.map(|alpns| alpns.into_iter().map(Into::into).collect()),
89        }
90    }
91}
92
93impl std::fmt::Display for TlsExtConfig {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self.alpn.as_ref() {
96            Some(alpns) => {
97                write!(f, "ALPN(Some(")?;
98                for alpn in alpns.iter() {
99                    write!(f, "{},", String::from_utf8_lossy(alpn))?;
100                }
101                write!(f, "))")?;
102            }
103            None => {
104                write!(f, "ALPN(None)")?;
105            }
106        }
107        Ok(())
108    }
109}
110
111impl<LA, TA> ShadowTlsClient<LA, TA> {
112    /// Create new ShadowTlsClient.
113    pub fn new(
114        listen_addr: LA,
115        target_addr: TA,
116        tls_names: TlsNames,
117        tls_ext_config: TlsExtConfig,
118        password: String,
119        nodelay: bool,
120        v3: V3Mode,
121    ) -> anyhow::Result<Self> {
122        let mut root_store = RootCertStore::empty();
123        root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
124            OwnedTrustAnchor::from_subject_spki_name_constraints(
125                ta.subject,
126                ta.spki,
127                ta.name_constraints,
128            )
129        }));
130        // TLS 1.2 and TLS 1.3 is enabled.
131        let mut tls_config = rustls_fork_shadow_tls::ClientConfig::builder()
132            .with_safe_defaults()
133            .with_root_certificates(root_store)
134            .with_no_client_auth();
135
136        // Set tls config
137        if let Some(alpn) = tls_ext_config.alpn {
138            tls_config.alpn_protocols = alpn;
139        }
140
141        let tls_connector = TlsConnector::from(tls_config);
142
143        Ok(Self {
144            listen_addr: Arc::new(listen_addr),
145            target_addr: Arc::new(target_addr),
146            tls_connector,
147            tls_names: Arc::new(tls_names),
148            password: Arc::new(password),
149            nodelay,
150            v3,
151        })
152    }
153
154    /// Serve a raw connection.
155    pub async fn serve(self) -> anyhow::Result<()>
156    where
157        LA: std::net::ToSocketAddrs + 'static,
158        TA: std::net::ToSocketAddrs + 'static,
159    {
160        let listener = bind_with_pretty_error(self.listen_addr.as_ref())?;
161        let shared = Rc::new(self);
162        loop {
163            match listener.accept().await {
164                Ok((mut conn, addr)) => {
165                    tracing::info!("Accepted a connection from {addr}");
166                    let client = shared.clone();
167                    mod_tcp_conn(&mut conn, true, shared.nodelay);
168                    monoio::spawn(async move {
169                        let _ = match client.v3.enabled() {
170                            false => client.relay_v2(conn).await,
171                            true => client.relay_v3(conn).await,
172                        };
173                        tracing::info!("Relay for {addr} finished");
174                    });
175                }
176                Err(e) => {
177                    tracing::error!("Accept failed: {e}");
178                }
179            }
180        }
181    }
182
183    /// Main relay for V2 protocol.
184    async fn relay_v2(&self, mut in_stream: TcpStream) -> anyhow::Result<()>
185    where
186        TA: std::net::ToSocketAddrs,
187    {
188        let (mut out_stream, hash, session) = self.connect_v2().await?;
189        let mut hash_8b = [0; 8];
190        unsafe { std::ptr::copy_nonoverlapping(hash.as_ptr(), hash_8b.as_mut_ptr(), 8) };
191        let (out_r, mut out_w) = out_stream.split();
192        let (mut in_r, mut in_w) = in_stream.split();
193        let mut session_filtered_out_r = crate::helper_v2::SessionFilterStream::new(session, out_r);
194        let (a, b) = monoio::join!(
195            copy_without_application_data(&mut session_filtered_out_r, &mut in_w),
196            copy_with_application_data(&mut in_r, &mut out_w, Some(hash_8b))
197        );
198        let (_, _) = (a?, b?);
199        Ok(())
200    }
201
202    /// Main relay for V3 protocol.
203    async fn relay_v3(&self, in_stream: TcpStream) -> anyhow::Result<()>
204    where
205        TA: std::net::ToSocketAddrs,
206    {
207        let mut stream = TcpStream::connect(self.target_addr.as_ref()).await?;
208        mod_tcp_conn(&mut stream, true, self.nodelay);
209        tracing::debug!("tcp connected, start handshaking");
210
211        // stage1: handshake with wrapper
212        let hamc_sr = Hmac::new(&self.password, (&[], &[]));
213        let stream = StreamWrapper::new(stream, &self.password);
214        let sni = self.tls_names.random_choose().clone();
215        let tls_stream = self
216            .tls_connector
217            .connect_with_session_id_generator(sni, stream, move |data| {
218                generate_session_id(&hamc_sr, data)
219            })
220            .await?;
221        tracing::debug!("handshake success");
222        let (stream, session) = tls_stream.into_parts();
223        let authorized = stream.authorized();
224        let tls13 = stream.tls13;
225        let maybe_srh = stream
226            .state()
227            .as_ref()
228            .map(|s| (s.server_random, s.hmac.to_owned()));
229        let stream = stream.into_inner();
230
231        // stage2:
232        if (maybe_srh.is_none() || !authorized) && self.v3.strict() {
233            tracing::warn!("V3 strict enabled: traffic hijacked or TLS1.3 is not supported");
234            let tls_stream = monoio_rustls_fork_shadow_tls::ClientTlsStream::new(stream, session);
235            if let Err(e) = fake_request(tls_stream).await {
236                bail!("traffic hijacked or TLS1.3 is not supported, fake request fail: {e}");
237            }
238            bail!("traffic hijacked or TLS1.3 is not supported, but fake request success");
239        }
240
241        drop(session);
242        let (sr, hmac_sr) = maybe_srh.unwrap();
243        tracing::debug!("Authorized, ServerRandom extracted: {sr:?}");
244        let hmac_sr_s = Hmac::new(&self.password, (&sr, b"S"));
245        let hmac_sr_c = Hmac::new(&self.password, (&sr, b"C"));
246
247        verified_relay(
248            in_stream,
249            stream,
250            hmac_sr_c,
251            hmac_sr_s,
252            Some(hmac_sr),
253            !tls13,
254        )
255        .await;
256        Ok(())
257    }
258
259    /// Connect remote, do handshaking and calculate HMAC.
260    ///
261    /// Only used by V2 protocol.
262    async fn connect_v2(
263        &self,
264    ) -> anyhow::Result<(
265        TcpStream,
266        [u8; 20],
267        rustls_fork_shadow_tls::ClientConnection,
268    )>
269    where
270        TA: std::net::ToSocketAddrs,
271    {
272        let mut stream = TcpStream::connect(self.target_addr.as_ref()).await?;
273        mod_tcp_conn(&mut stream, true, self.nodelay);
274        tracing::debug!("tcp connected, start handshaking");
275        let stream = HashedReadStream::new(stream, self.password.as_bytes())?;
276        let sni = self.tls_names.random_choose().clone();
277        let tls_stream = self.tls_connector.connect(sni, stream).await?;
278        let (io, session) = tls_stream.into_parts();
279        let hash = io.hash();
280        tracing::debug!("tls handshake finished, signed hmac: {:?}", hash);
281        let stream = io.into_inner();
282        Ok((stream, hash, session))
283    }
284}
285
286/// A wrapper for doing data extraction and modification.
287///
288/// Only used by V3 protocol.
289struct StreamWrapper<S> {
290    raw: S,
291    password: String,
292
293    read_buf: Option<Vec<u8>>,
294    read_pos: usize,
295
296    read_state: Option<State>,
297    read_authorized: bool,
298    tls13: bool,
299}
300
301#[derive(Clone)]
302struct State {
303    server_random: [u8; TLS_RANDOM_SIZE],
304    hmac: Hmac,
305    key: Vec<u8>,
306}
307
308impl<S> StreamWrapper<S> {
309    fn new(raw: S, password: &str) -> Self {
310        Self {
311            raw,
312            password: password.to_string(),
313
314            read_buf: Some(Vec::new()),
315            read_pos: 0,
316
317            read_state: None,
318            read_authorized: false,
319            tls13: false,
320        }
321    }
322
323    fn authorized(&self) -> bool {
324        self.read_authorized
325    }
326
327    fn state(&self) -> &Option<State> {
328        &self.read_state
329    }
330
331    fn into_inner(self) -> S {
332        self.raw
333    }
334}
335
336impl<S: AsyncReadRent> StreamWrapper<S> {
337    async fn feed_data(&mut self) -> std::io::Result<usize> {
338        let mut buf = self.read_buf.take().unwrap();
339
340        // read header
341        unsafe { buf.set_init(0) };
342        self.read_pos = 0;
343        buf.reserve(TLS_HEADER_SIZE);
344        let (res, buf) = self.raw.read_exact(buf.slice_mut(0..TLS_HEADER_SIZE)).await;
345        match res {
346            Ok(_) => (),
347            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
348                tracing::debug!("stream wrapper eof");
349                self.read_buf = Some(buf.into_inner());
350                return Ok(0);
351            }
352            Err(e) => {
353                tracing::error!("stream wrapper unable to read tls header: {e}");
354                self.read_buf = Some(buf.into_inner());
355                return Err(e);
356            }
357        }
358        let mut buf: Vec<u8> = buf.into_inner();
359        let mut size: [u8; 2] = Default::default();
360        size.copy_from_slice(&buf[3..5]);
361        let data_size = u16::from_be_bytes(size) as usize;
362
363        // read body
364        buf.reserve(data_size);
365        let (res, buf) = self
366            .raw
367            .read_exact(buf.slice_mut(TLS_HEADER_SIZE..TLS_HEADER_SIZE + data_size))
368            .await;
369        if let Err(e) = res {
370            self.read_buf = Some(buf.into_inner());
371            return Err(e);
372        }
373        let mut buf: Vec<u8> = buf.into_inner();
374
375        // do extraction and modification
376        match buf[0] {
377            HANDSHAKE => {
378                if buf.len() > SERVER_RANDOM_IDX + TLS_RANDOM_SIZE && buf[5] == SERVER_HELLO {
379                    // we can read server random
380                    let mut server_random = [0; TLS_RANDOM_SIZE];
381                    unsafe {
382                        copy_nonoverlapping(
383                            buf.as_ptr().add(SERVER_RANDOM_IDX),
384                            server_random.as_mut_ptr(),
385                            TLS_RANDOM_SIZE,
386                        )
387                    }
388                    tracing::debug!("ServerRandom extracted: {server_random:?}");
389                    let hmac = Hmac::new(&self.password, (&server_random, &[]));
390                    let key = kdf(&self.password, &server_random);
391                    self.read_state = Some(State {
392                        server_random,
393                        hmac,
394                        key,
395                    });
396                    self.tls13 = support_tls13(&buf);
397                }
398            }
399            APPLICATION_DATA => {
400                self.read_authorized = false;
401                if buf.len() > TLS_HMAC_HEADER_SIZE {
402                    if let Some(State { hmac, key, .. }) = self.read_state.as_mut() {
403                        hmac.update(&buf[TLS_HMAC_HEADER_SIZE..]);
404                        if hmac.finalize() == buf[TLS_HEADER_SIZE..TLS_HMAC_HEADER_SIZE] {
405                            xor_slice(&mut buf[TLS_HMAC_HEADER_SIZE..], key);
406                            unsafe {
407                                copy(
408                                    buf.as_ptr().add(TLS_HMAC_HEADER_SIZE),
409                                    buf.as_mut_ptr().add(5),
410                                    buf.len() - 9,
411                                )
412                            };
413                            (&mut buf[3..5])
414                                .write_u16::<BigEndian>(data_size as u16 - HMAC_SIZE as u16)
415                                .unwrap();
416                            unsafe { buf.set_init(buf.len() - HMAC_SIZE) };
417                            self.read_authorized = true;
418                        } else {
419                            tracing::debug!("app data verification failed");
420                        }
421                    }
422                }
423            }
424            _ => {}
425        }
426
427        // set buffer
428        let buf_len = buf.len();
429        self.read_buf = Some(buf);
430        Ok(buf_len)
431    }
432}
433
434impl<S: AsyncWriteRent> AsyncWriteRent for StreamWrapper<S> {
435    type WriteFuture<'a, T> = S::WriteFuture<'a, T> where
436    T: monoio::buf::IoBuf + 'a, Self: 'a;
437    type WritevFuture<'a, T>= S::WritevFuture<'a, T> where
438    T: monoio::buf::IoVecBuf + 'a, Self: 'a;
439    type FlushFuture<'a> = S::FlushFuture<'a> where Self: 'a;
440    type ShutdownFuture<'a> = S::ShutdownFuture<'a> where Self: 'a;
441
442    fn write<T: monoio::buf::IoBuf>(&mut self, buf: T) -> Self::WriteFuture<'_, T> {
443        self.raw.write(buf)
444    }
445    fn writev<T: monoio::buf::IoVecBuf>(&mut self, buf_vec: T) -> Self::WritevFuture<'_, T> {
446        self.raw.writev(buf_vec)
447    }
448    fn flush(&mut self) -> Self::FlushFuture<'_> {
449        self.raw.flush()
450    }
451    fn shutdown(&mut self) -> Self::ShutdownFuture<'_> {
452        self.raw.shutdown()
453    }
454}
455
456impl<S: AsyncReadRent> AsyncReadRent for StreamWrapper<S> {
457    type ReadFuture<'a, B> = impl std::future::Future<Output = monoio::BufResult<usize, B>> +'a where
458        B: monoio::buf::IoBufMut + 'a, S: 'a;
459    type ReadvFuture<'a, B> = impl std::future::Future<Output = monoio::BufResult<usize, B>> +'a where
460        B: monoio::buf::IoVecBufMut + 'a, S: 'a;
461
462    // uncancelable
463    fn read<T: monoio::buf::IoBufMut>(&mut self, mut buf: T) -> Self::ReadFuture<'_, T> {
464        async move {
465            loop {
466                let owned_buf = self.read_buf.as_mut().unwrap();
467                let data_len = owned_buf.len() - self.read_pos;
468                // there is enough data to copy
469                if data_len > 0 {
470                    let to_copy = buf.bytes_total().min(data_len);
471                    unsafe {
472                        copy_nonoverlapping(
473                            owned_buf.as_ptr().add(self.read_pos),
474                            buf.write_ptr(),
475                            to_copy,
476                        );
477                        buf.set_init(to_copy);
478                    };
479                    self.read_pos += to_copy;
480                    return (Ok(to_copy), buf);
481                }
482
483                // no data now
484                match self.feed_data().await {
485                    Ok(0) => return (Ok(0), buf),
486                    Ok(_) => continue,
487                    Err(e) => return (Err(e), buf),
488                }
489            }
490        }
491    }
492
493    fn readv<T: monoio::buf::IoVecBufMut>(&mut self, mut buf: T) -> Self::ReadvFuture<'_, T> {
494        async move {
495            let slice = match monoio::buf::IoVecWrapperMut::new(buf) {
496                Ok(slice) => slice,
497                Err(buf) => return (Ok(0), buf),
498            };
499
500            let (result, slice) = self.read(slice).await;
501            buf = slice.into_inner();
502            if let Ok(n) = result {
503                unsafe { buf.set_init(n) };
504            }
505            (result, buf)
506        }
507    }
508}
509
510/// Doing fake request.
511///
512/// Only used by V3 protocol.
513async fn fake_request(
514    mut stream: monoio_rustls_fork_shadow_tls::ClientTlsStream<TcpStream>,
515) -> std::io::Result<()> {
516    const HEADER: &[u8; 207] = b"GET / HTTP/1.1\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36\nAccept: gzip, deflate, br\nConnection: Close\nCookie: sessionid=";
517    let cnt =
518        rand::thread_rng().gen_range(FAKE_REQUEST_LENGTH_RANGE.0..FAKE_REQUEST_LENGTH_RANGE.1);
519    let mut buffer = Vec::with_capacity(cnt + HEADER.len() + 1);
520
521    buffer.extend_from_slice(HEADER);
522    rand::distributions::Alphanumeric
523        .sample_iter(rand::thread_rng())
524        .take(cnt)
525        .for_each(|c| buffer.push(c));
526    buffer.push(b'\n');
527
528    let (res, mut buf) = stream.write_all(buffer).await;
529    res?;
530    let _ = stream.shutdown().await;
531
532    // read until eof
533    loop {
534        let (res, b) = stream.read(buf).await;
535        buf = b;
536        if res? == 0 {
537            return Ok(());
538        }
539    }
540}
541
542/// Take a slice of tls frame[5..] and returns signed session id.
543///
544/// Only used by V3 protocol.
545fn generate_session_id(hmac: &Hmac, buf: &[u8]) -> [u8; TLS_SESSION_ID_SIZE] {
546    /// Note: SESSION_ID_START does not include 5 TLS_HEADER_SIZE.
547    const SESSION_ID_START: usize = 1 + 3 + 2 + TLS_RANDOM_SIZE + 1;
548
549    if buf.len() < SESSION_ID_START + TLS_SESSION_ID_SIZE {
550        tracing::warn!("unexpected client hello length");
551        return [0; TLS_SESSION_ID_SIZE];
552    }
553
554    let mut session_id = [0; TLS_SESSION_ID_SIZE];
555    rand::thread_rng().fill(&mut session_id[..TLS_SESSION_ID_SIZE - HMAC_SIZE]);
556    let mut hmac = hmac.to_owned();
557    hmac.update(&buf[0..SESSION_ID_START]);
558    hmac.update(&session_id);
559    hmac.update(&buf[SESSION_ID_START + TLS_SESSION_ID_SIZE..]);
560    let hmac_val = hmac.finalize();
561    unsafe {
562        copy_nonoverlapping(
563            hmac_val.as_ptr(),
564            session_id.as_mut_ptr().add(TLS_SESSION_ID_SIZE - HMAC_SIZE),
565            HMAC_SIZE,
566        )
567    }
568    tracing::debug!("ClientHello before sign: {buf:?}, session_id {session_id:?}");
569    session_id
570}