shadow_tls/
server.rs

1use std::{
2    borrow::Cow,
3    collections::VecDeque,
4    ptr::{copy, copy_nonoverlapping},
5    rc::Rc,
6    sync::Arc,
7};
8
9use anyhow::bail;
10use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
11use local_sync::oneshot::Sender;
12use monoio::{
13    buf::{IoBuf, IoBufMut, Slice, SliceMut},
14    io::{
15        AsyncReadRent, AsyncReadRentExt, AsyncWriteRent, AsyncWriteRentExt, PrefixedReadIo,
16        Splitable,
17    },
18    net::TcpStream,
19};
20
21use crate::{
22    helper_v2::{
23        copy_with_application_data, copy_without_application_data, ErrGroup, FirstRetGroup,
24        FutureOrOutput, HashedWriteStream, HmacHandler, HMAC_SIZE_V2,
25    },
26    util::{
27        bind_with_pretty_error, copy_bidirectional, copy_until_eof, kdf, mod_tcp_conn, prelude::*,
28        support_tls13, verified_relay, xor_slice, CursorExt, Hmac, V3Mode,
29    },
30    WildcardSNI,
31};
32
33/// ShadowTlsServer.
34#[derive(Clone)]
35pub struct ShadowTlsServer<LA, TA> {
36    listen_addr: Arc<LA>,
37    target_addr: Arc<TA>,
38    tls_addr: Arc<TlsAddrs>,
39    password: Arc<String>,
40    nodelay: bool,
41    v3: V3Mode,
42}
43
44#[derive(Clone, Debug, PartialEq)]
45pub struct TlsAddrs {
46    dispatch: rustc_hash::FxHashMap<String, String>,
47    fallback: String,
48    wildcard_sni: WildcardSNI,
49}
50
51impl TlsAddrs {
52    fn find<'a>(&'a self, key: Option<&str>, auth: bool) -> Cow<'a, str> {
53        match key {
54            Some(k) => match self.dispatch.get(k) {
55                Some(v) => Cow::Borrowed(v),
56                None => match self.wildcard_sni {
57                    WildcardSNI::Authed if auth => Cow::Owned(format!("{k}:443")),
58                    WildcardSNI::All => Cow::Owned(format!("{k}:443")),
59                    _ => Cow::Borrowed(&self.fallback),
60                },
61            },
62            None => Cow::Borrowed(&self.fallback),
63        }
64    }
65
66    fn is_empty(&self) -> bool {
67        self.dispatch.is_empty()
68    }
69
70    pub fn set_wildcard_sni(&mut self, wildcard_sni: WildcardSNI) {
71        self.wildcard_sni = wildcard_sni;
72    }
73}
74
75impl TryFrom<&str> for TlsAddrs {
76    type Error = anyhow::Error;
77
78    fn try_from(arg: &str) -> Result<Self, Self::Error> {
79        let mut rev_parts = arg.split(';').rev();
80        let fallback = rev_parts
81            .next()
82            .and_then(|x| if x.trim().is_empty() { None } else { Some(x) })
83            .ok_or_else(|| anyhow::anyhow!("empty server addrs"))?;
84        let fallback = if !fallback.contains(':') {
85            format!("{fallback}:443")
86        } else {
87            fallback.to_string()
88        };
89
90        let mut dispatch = rustc_hash::FxHashMap::default();
91        for p in rev_parts {
92            let mut p = p.trim().split(':').rev();
93            let mut port = Cow::<'static, str>::Borrowed("443");
94            let maybe_port = p
95                .next()
96                .ok_or_else(|| anyhow::anyhow!("empty part found in server addrs"))?;
97            let host = if maybe_port.parse::<u16>().is_ok() {
98                // there is a port at the end
99                port = maybe_port.into();
100                p.next()
101                    .ok_or_else(|| anyhow::anyhow!("no host found in server addrs part"))?
102            } else {
103                maybe_port
104            };
105            let key = match p.next() {
106                Some(key) => key,
107                None => host,
108            };
109            if p.next().is_some() {
110                bail!("unrecognized server addrs part");
111            }
112            if dispatch
113                .insert(key.to_string(), format!("{host}:{port}"))
114                .is_some()
115            {
116                bail!("duplicate server addrs part found");
117            }
118        }
119        Ok(TlsAddrs {
120            dispatch,
121            fallback,
122            wildcard_sni: Default::default(),
123        })
124    }
125}
126
127impl std::fmt::Display for TlsAddrs {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        write!(f, "(wildcard-sni:{})", self.wildcard_sni)?;
130        for (k, v) in self.dispatch.iter() {
131            write!(f, "{k}->{v};")?;
132        }
133        write!(f, "fallback->{}", self.fallback)
134    }
135}
136
137impl<LA, TA> ShadowTlsServer<LA, TA> {
138    pub fn new(
139        listen_addr: LA,
140        target_addr: TA,
141        tls_addr: TlsAddrs,
142        password: String,
143        nodelay: bool,
144        v3: V3Mode,
145    ) -> Self {
146        Self {
147            listen_addr: Arc::new(listen_addr),
148            target_addr: Arc::new(target_addr),
149            tls_addr: Arc::new(tls_addr),
150            password: Arc::new(password),
151            nodelay,
152            v3,
153        }
154    }
155}
156
157impl<LA, TA> ShadowTlsServer<LA, TA> {
158    /// Serve a raw connection.
159    pub async fn serve(self) -> anyhow::Result<()>
160    where
161        LA: std::net::ToSocketAddrs + 'static,
162        TA: std::net::ToSocketAddrs + 'static,
163    {
164        let listener = bind_with_pretty_error(self.listen_addr.as_ref())?;
165        let shared = Rc::new(self);
166        loop {
167            match listener.accept().await {
168                Ok((mut conn, addr)) => {
169                    tracing::info!("Accepted a connection from {addr}");
170                    let server = shared.clone();
171                    mod_tcp_conn(&mut conn, true, shared.nodelay);
172                    monoio::spawn(async move {
173                        let _ = match server.v3.enabled() {
174                            false => server.relay_v2(conn).await,
175                            true => server.relay_v3(conn).await,
176                        };
177                        tracing::info!("Relay for {addr} finished");
178                    });
179                }
180                Err(e) => {
181                    tracing::error!("Accept failed: {e}");
182                }
183            }
184        }
185    }
186
187    /// Main relay for V2 protocol.
188    async fn relay_v2(&self, in_stream: TcpStream) -> anyhow::Result<()>
189    where
190        TA: std::net::ToSocketAddrs,
191    {
192        // wrap in_stream with hash layer
193        let mut in_stream = HashedWriteStream::new(in_stream, self.password.as_bytes())?;
194        let mut hmac = in_stream.hmac_handler();
195
196        // read and extract server name
197        // if there is only one fallback server, skip it
198        let (prefix, server_name) = match self.tls_addr.is_empty() {
199            true => (Vec::new(), None),
200            false => extract_sni_v2(&mut in_stream).await?,
201        };
202        let mut prefixed_io = PrefixedReadIo::new(&mut in_stream, std::io::Cursor::new(prefix));
203        tracing::debug!("server name extracted from SNI extention: {server_name:?}");
204
205        // choose handshake server addr and connect
206        let server_name = server_name.and_then(|s| String::from_utf8(s).ok());
207        let addr = self
208            .tls_addr
209            .find(server_name.as_ref().map(AsRef::as_ref), true);
210        let mut out_stream = TcpStream::connect(addr.as_ref()).await?;
211        mod_tcp_conn(&mut out_stream, true, self.nodelay);
212        tracing::debug!("handshake server connected: {addr}");
213
214        // copy stage 1
215        let (mut out_r, mut out_w) = out_stream.split();
216        let (mut in_r, mut in_w) = prefixed_io.split();
217        let (switch, cp) = FirstRetGroup::new(
218            copy_until_handshake_finished(&mut in_r, &mut out_w, &hmac),
219            Box::pin(copy_until_eof(&mut out_r, &mut in_w)),
220        )
221        .await?;
222        hmac.disable();
223        tracing::debug!("handshake finished, switch: {switch:?}");
224
225        // copy stage 2
226        match switch {
227            SwitchResult::Switch(data_left) => {
228                drop(cp);
229                let mut in_stream = in_stream.into_inner();
230                let (mut in_r, mut in_w) = in_stream.split();
231
232                // connect our data server
233                let _ = out_stream.shutdown().await;
234                drop(out_stream);
235                let mut data_stream = TcpStream::connect(self.target_addr.as_ref()).await?;
236                mod_tcp_conn(&mut data_stream, true, self.nodelay);
237                tracing::debug!("data server connected, start relay");
238                let (mut data_r, mut data_w) = data_stream.split();
239                let (result, _) = data_w.write(data_left).await;
240                result?;
241                ErrGroup::new(
242                    copy_with_application_data::<0, _, _>(&mut data_r, &mut in_w, None),
243                    copy_without_application_data(&mut in_r, &mut data_w),
244                )
245                .await?;
246            }
247            SwitchResult::DirectProxy => match cp {
248                FutureOrOutput::Future(cp) => {
249                    ErrGroup::new(cp, copy_until_eof(in_r, out_w)).await?;
250                }
251                FutureOrOutput::Output(_) => {
252                    copy_until_eof(in_r, out_w).await?;
253                }
254            },
255        }
256        Ok(())
257    }
258
259    /// Main relay for V3 protocol.
260    async fn relay_v3(&self, mut in_stream: TcpStream) -> anyhow::Result<()>
261    where
262        TA: std::net::ToSocketAddrs,
263    {
264        // stage 1.1: read and validate client hello
265        let first_client_frame = read_exact_frame(&mut in_stream).await?;
266        let (client_hello_pass, sni) = verified_extract_sni(&first_client_frame, &self.password);
267
268        // connect handshake server
269        let server_name = sni.and_then(|s| String::from_utf8(s).ok());
270        let addr = self
271            .tls_addr
272            .find(server_name.as_ref().map(AsRef::as_ref), client_hello_pass);
273        let mut handshake_stream = TcpStream::connect(addr.as_ref()).await?;
274        mod_tcp_conn(&mut handshake_stream, true, self.nodelay);
275        tracing::debug!("handshake server connected: {addr}");
276        tracing::trace!("ClientHello frame {first_client_frame:?}");
277        let (res, _) = handshake_stream.write_all(first_client_frame).await;
278        res?;
279        if !client_hello_pass {
280            // if client verify failed, bidirectional copy and return
281            tracing::warn!("ClientHello verify failed, will work as a SNI proxy");
282            copy_bidirectional(&mut in_stream, &mut handshake_stream).await;
283            return Ok(());
284        }
285        tracing::debug!("ClientHello verify success");
286
287        // stage 1.2: read server hello and extract server random from it
288        let first_server_frame = read_exact_frame(&mut handshake_stream).await?;
289        let (res, first_server_frame) = in_stream.write_all(first_server_frame).await;
290        res?;
291        let server_random = match extract_server_random(&first_server_frame) {
292            Some(sr) => sr,
293            None => {
294                // we cannot extract server random, bidirectional copy and return
295                tracing::warn!("ServerRandom extract failed, will copy bidirectional");
296                copy_bidirectional(&mut in_stream, &mut handshake_stream).await;
297                return Ok(());
298            }
299        };
300        tracing::debug!("Client authenticated. ServerRandom extracted: {server_random:?}");
301
302        let use_tls13 = support_tls13(&first_server_frame);
303        if self.v3.strict() && !use_tls13 {
304            tracing::error!(
305                "V3 strict enabled and TLS 1.3 is not supported, will copy bidirectional"
306            );
307            copy_bidirectional(&mut in_stream, &mut handshake_stream).await;
308            return Ok(());
309        }
310
311        // stage 1.3.1: create HMAC_ServerRandomC and HMAC_ServerRandom
312        let mut hmac_sr_c = Hmac::new(&self.password, (&server_random, b"C"));
313        let hmac_sr_s = Hmac::new(&self.password, (&server_random, b"S"));
314        let mut hmac_sr = Hmac::new(&self.password, (&server_random, &[]));
315
316        // stage 1.3.2: copy ShadowTLS Client -> Handshake Server until hamc matches
317        // stage 1.3.3: copy and modify Handshake Server -> ShadowTLS Client until 1.3.2 stops
318        let pure_data = {
319            let (mut c_read, mut c_write) = in_stream.split();
320            let (mut h_read, mut h_write) = handshake_stream.split();
321            let (mut sender, mut recevier) = local_sync::oneshot::channel::<()>();
322            let key = kdf(&self.password, &server_random);
323            let (maybe_pure, _) = monoio::join!(
324                async {
325                    let r =
326                        copy_by_frame_until_hmac_matches(&mut c_read, &mut h_write, &mut hmac_sr_c)
327                            .await;
328                    recevier.close();
329                    if r.is_err() {
330                        let _ = h_write.shutdown().await;
331                    }
332                    r
333                },
334                async {
335                    let r = copy_by_frame_with_modification(
336                        &mut h_read,
337                        &mut c_write,
338                        &mut hmac_sr,
339                        &key,
340                        &mut sender,
341                    )
342                    .await;
343                    if r.is_err() {
344                        let _ = c_write.shutdown().await;
345                    }
346                }
347            );
348            maybe_pure?
349        };
350        tracing::debug!("handshake relay finished");
351
352        // early drop useless resources
353        drop(handshake_stream);
354        drop(first_server_frame);
355
356        // stage 2.2: copy ShadowTLS Client -> Data Server
357        // stage 2.3: copy Data Server -> ShadowTLS Client
358        let mut data_stream = TcpStream::connect(self.target_addr.as_ref()).await?;
359        mod_tcp_conn(&mut data_stream, true, self.nodelay);
360        let (res, _) = data_stream.write_all(pure_data).await;
361        res?;
362        verified_relay(
363            data_stream,
364            in_stream,
365            hmac_sr_s,
366            hmac_sr_c,
367            None,
368            !use_tls13,
369        )
370        .await;
371        Ok(())
372    }
373}
374
375/// A helper struct for doing source switching.
376///
377/// Only used by V2 protocol.
378enum SwitchResult {
379    Switch(Vec<u8>),
380    DirectProxy,
381}
382
383impl std::fmt::Debug for SwitchResult {
384    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385        match self {
386            Self::Switch(_) => write!(f, "Switch"),
387            Self::DirectProxy => write!(f, "DirectProxy"),
388        }
389    }
390}
391
392/// Copy until handshake finished.
393/// We use HMAC to check if handshake finished.
394///
395/// Only used by V2 protocol.
396async fn copy_until_handshake_finished<R, W>(
397    mut read_half: R,
398    mut write_half: W,
399    hmac: &HmacHandler,
400) -> std::io::Result<SwitchResult>
401where
402    R: AsyncReadRent,
403    W: AsyncWriteRent,
404{
405    // We maintain 2 state to make sure current session is in an tls session.
406    // This is essential for preventing active detection.
407    let mut has_seen_change_cipher_spec = false;
408    let mut has_seen_handshake = false;
409
410    // header_buf is used to read handshake frame header, will be a fixed size buffer.
411    let mut header_buf = vec![0_u8; TLS_HEADER_SIZE].into_boxed_slice();
412    let mut header_read_len = 0;
413    let mut header_write_len = 0;
414    // data_buf is used to read and write data, and can be expanded.
415    let mut data_hmac_buf = vec![0_u8; HMAC_SIZE_V2].into_boxed_slice();
416    let mut data_buf = vec![0_u8; 2048];
417    let mut application_data_count: usize = 0;
418
419    let mut hashes = VecDeque::with_capacity(10);
420    loop {
421        let header_buf_slice = SliceMut::new(header_buf, header_read_len, TLS_HEADER_SIZE);
422        let (res, header_buf_slice_) = read_half.read(header_buf_slice).await;
423        header_buf = header_buf_slice_.into_inner();
424        let read_len = res?;
425        header_read_len += read_len;
426
427        // If EOF, close write half.
428        if read_len == 0 {
429            let _ = write_half.shutdown().await;
430            return Err(std::io::ErrorKind::UnexpectedEof.into());
431        }
432
433        // We have to relay data now no matter header is enough or not.
434        let header_buf_slice_w = Slice::new(header_buf, header_write_len, header_read_len);
435        let (res, header_buf_slice_w_) = write_half.write_all(header_buf_slice_w).await;
436        header_buf = header_buf_slice_w_.into_inner();
437        header_write_len += res?;
438
439        if header_read_len != TLS_HEADER_SIZE {
440            // Here we have not got enough data to parse header.
441            // continue to read.
442            continue;
443        }
444
445        // Now header has been read and redirected successfully.
446        // We should clear header status.
447        header_read_len = 0;
448        header_write_len = 0;
449
450        // Parse length.
451        let mut size: [u8; 2] = Default::default();
452        size.copy_from_slice(&header_buf[3..5]);
453        let data_size = u16::from_be_bytes(size) as usize;
454        tracing::debug!(
455            "read header with type {} and length {}",
456            header_buf[0],
457            data_size
458        );
459
460        // Check data type, if not app data we want, we can forward it directly(in streaming way).
461        if header_buf[0] != APPLICATION_DATA
462            || !has_seen_handshake
463            || !has_seen_change_cipher_spec
464            || data_size < HMAC_SIZE_V2
465        {
466            // The first packet must be handshake.
467            // Also, every packet's version must be valid.
468            let valid = (has_seen_handshake || header_buf[0] == HANDSHAKE)
469                && header_buf[1] == TLS_MAJOR
470                && (header_buf[2] == TLS_MINOR.0 || header_buf[2] == TLS_MINOR.1);
471            if header_buf[0] == CHANGE_CIPHER_SPEC {
472                has_seen_change_cipher_spec = true;
473            }
474            if header_buf[0] == HANDSHAKE {
475                has_seen_handshake = true;
476            }
477            // Copy data.
478            let mut to_copy = data_size;
479            while to_copy != 0 {
480                let max_read = data_buf.capacity().min(to_copy);
481                let buf = SliceMut::new(data_buf, 0, max_read);
482                let (read_res, buf) = read_half.read(buf).await;
483
484                // if EOF, close write half.
485                let read_len = read_res?;
486                if read_len == 0 {
487                    let _ = write_half.shutdown().await;
488                    return Err(std::io::ErrorKind::UnexpectedEof.into());
489                }
490
491                let buf = buf.into_inner().slice(0..read_len);
492                let (write_res, buf) = write_half.write_all(buf).await;
493                to_copy -= write_res?;
494                data_buf = buf.into_inner();
495            }
496            tracing::debug!("copied data with length {:?}", data_size);
497            if !valid {
498                tracing::debug!("early invalid tls: header {:?}", &header_buf[..3]);
499                return Ok(SwitchResult::DirectProxy);
500            }
501            continue;
502        }
503
504        // Here we need to check hmac.
505        // We have to read and copy the maybe_hmac.
506        // Note: Send this 8 byte to remote does not matters:
507        // If the data is sent by our authorized client, the handshake server must within
508        // a tls session. So it must read exact that length data and then process it.
509        // For this reason, sending 8 byte hmac will not cause the handshake server
510        // shuting down the connection.
511        // If the data in sent by an attacker, we must behaves like a tcp proxy so it seems
512        // we are the handshake server.
513        let mut hmac_read_len = 0;
514        while hmac_read_len < HMAC_SIZE_V2 {
515            let buf = SliceMut::new(data_hmac_buf, hmac_read_len, HMAC_SIZE_V2);
516            let (res, buf_) = read_half.read(buf).await;
517            // if EOF, close write half.
518            let read_len = res?;
519            if read_len == 0 {
520                let _ = write_half.shutdown().await;
521                return Err(std::io::ErrorKind::UnexpectedEof.into());
522            }
523
524            let buf = Slice::new(buf_.into_inner(), hmac_read_len, hmac_read_len + read_len);
525            let (write_res, buf_) = write_half.write_all(buf).await;
526            write_res?;
527            hmac_read_len += read_len;
528            data_hmac_buf = buf_.into_inner();
529        }
530
531        // Now hmac has been read and copied.
532        // If hmac matches, we need to read current data and return.
533        let hash = hmac.hash();
534        let mut hash_trim = [0; HMAC_SIZE_V2];
535        unsafe { copy_nonoverlapping(hash.as_ptr(), hash_trim.as_mut_ptr(), HMAC_SIZE_V2) };
536        tracing::debug!("hmac calculated: {hash_trim:?}");
537        if hashes.len() + 1 > hashes.capacity() {
538            hashes.pop_front();
539        }
540        hashes.push_back(hash_trim);
541        unsafe {
542            copy_nonoverlapping(data_hmac_buf.as_ptr(), hash_trim.as_mut_ptr(), HMAC_SIZE_V2)
543        };
544        if hashes.contains(&hash_trim) {
545            tracing::debug!("hmac matches");
546            let pure_data = vec![0; data_size - HMAC_SIZE_V2];
547            let (read_res, pure_data) = read_half.read_exact(pure_data).await;
548            read_res?;
549            return Ok(SwitchResult::Switch(pure_data));
550        }
551
552        // Now hmac does not match. We have to acc the counter and do copy.
553        application_data_count += 1;
554        let mut to_copy = data_size - HMAC_SIZE_V2;
555        while to_copy != 0 {
556            let max_read = data_buf.capacity().min(to_copy);
557            let buf = SliceMut::new(data_buf, 0, max_read);
558            let (read_res, buf) = read_half.read(buf).await;
559
560            // if EOF, close write half.
561            let read_len = read_res?;
562            if read_len == 0 {
563                let _ = write_half.shutdown().await;
564                return Err(std::io::ErrorKind::UnexpectedEof.into());
565            }
566
567            let buf = buf.into_inner().slice(0..read_len);
568            let (write_res, buf) = write_half.write_all(buf).await;
569            to_copy -= write_res?;
570            data_buf = buf.into_inner();
571        }
572
573        if application_data_count > 3 {
574            tracing::debug!("hmac not matches after 3 times, fallback to direct");
575            return Ok(SwitchResult::DirectProxy);
576        }
577    }
578}
579
580/// Read from connection and parse the frame.
581/// Return consumed data and SNI.
582///
583/// Only used by V2 protocol.
584async fn extract_sni_v2<R: AsyncReadRent>(mut r: R) -> std::io::Result<(Vec<u8>, Option<Vec<u8>>)> {
585    macro_rules! read_ok {
586        ($res: expr, $data: expr) => {
587            match $res {
588                Ok(r) => r,
589                Err(_) => {
590                    return Ok(($data, None));
591                }
592            }
593        };
594    }
595
596    let header = vec![0; TLS_HEADER_SIZE];
597    let (res, header) = r.read_exact(header).await;
598    res?;
599
600    // validate header and fail fast
601    if header[0] != HANDSHAKE
602        || header[1] != TLS_MAJOR
603        || (header[2] != TLS_MINOR.0 && header[2] != TLS_MINOR.1)
604    {
605        return Ok((header, None));
606    }
607
608    // read tls frame length
609    let mut size: [u8; 2] = Default::default();
610    size.copy_from_slice(&header[3..5]);
611    let data_size = u16::from_be_bytes(size);
612    tracing::debug!("read handshake length {}", data_size);
613
614    // read tls frame
615    let mut data = vec![0; data_size as usize + TLS_HEADER_SIZE];
616    unsafe { copy_nonoverlapping(header.as_ptr(), data.as_mut_ptr(), TLS_HEADER_SIZE) };
617    let (res, data_slice) = r.read_exact(data.slice_mut(TLS_HEADER_SIZE..)).await;
618    res?;
619
620    // validate client hello
621    let data_slice: SliceMut<Vec<u8>> = data_slice;
622    let data = data_slice.into_inner();
623    let mut cursor = std::io::Cursor::new(&data[TLS_HEADER_SIZE..]);
624    if read_ok!(cursor.read_u8(), data) != CLIENT_HELLO {
625        tracing::debug!("first packet is not client hello");
626        return Ok((data, None));
627    }
628    // length[0] must be 0
629    if read_ok!(cursor.read_u8(), data) != 0 {
630        tracing::debug!("client hello length first byte is not zero");
631        return Ok((data, None));
632    }
633    // client hello length[1..=2]
634    let prot_size = read_ok!(cursor.read_u16::<BigEndian>(), data);
635    if prot_size + 4 > data_size {
636        tracing::debug!("invalid client hello length");
637        return Ok((data, None));
638    }
639    // reset cursor with new smaller length limit
640    let mut cursor = std::io::Cursor::new(
641        &data[TLS_HMAC_HEADER_SIZE..TLS_HMAC_HEADER_SIZE + prot_size as usize],
642    );
643    // skip 2 byte version
644    read_ok!(cursor.read_u16::<BigEndian>(), data);
645    // skip 32 byte random
646    read_ok!(cursor.skip(TLS_RANDOM_SIZE), data);
647    // skip session id
648    read_ok!(cursor.skip_by_u8(), data);
649    // skip cipher suites
650    read_ok!(cursor.skip_by_u16(), data);
651    // skip compression method
652    read_ok!(cursor.skip_by_u8(), data);
653    // skip ext length
654    read_ok!(cursor.read_u16::<BigEndian>(), data);
655
656    loop {
657        let ext_type = read_ok!(cursor.read_u16::<BigEndian>(), data);
658        if ext_type != SNI_EXT_TYPE {
659            read_ok!(cursor.skip_by_u16(), data);
660            continue;
661        }
662        tracing::debug!("found server_name extension");
663        let _ext_len = read_ok!(cursor.read_u16::<BigEndian>(), data);
664        let _sni_len = read_ok!(cursor.read_u16::<BigEndian>(), data);
665        // must be host_name
666        if read_ok!(cursor.read_u8(), data) != 0 {
667            return Ok((data, None));
668        }
669        let sni = Some(read_ok!(cursor.read_by_u16(), data));
670        return Ok((data, sni));
671    }
672}
673
674/// Read a single frame and return Vec.
675///
676/// Only used by V3 protocol.
677async fn read_exact_frame(r: impl AsyncReadRent) -> std::io::Result<Vec<u8>> {
678    read_exact_frame_into(r, Vec::new()).await
679}
680
681/// Read a single frame into given Vec.
682///
683/// Only used by V3 protocol.
684async fn read_exact_frame_into(
685    mut r: impl AsyncReadRent,
686    mut buffer: Vec<u8>,
687) -> std::io::Result<Vec<u8>> {
688    unsafe { buffer.set_len(0) };
689    buffer.reserve(TLS_HEADER_SIZE);
690    let (res, header) = r.read_exact(buffer.slice_mut(..TLS_HEADER_SIZE)).await;
691    res?;
692    let mut buffer = header.into_inner();
693
694    // read tls frame length
695    let mut size: [u8; 2] = Default::default();
696    size.copy_from_slice(&buffer[3..5]);
697    let data_size = u16::from_be_bytes(size) as usize;
698
699    // read tls frame body
700    buffer.reserve(data_size);
701    let (res, data_slice) = r
702        .read_exact(buffer.slice_mut(TLS_HEADER_SIZE..TLS_HEADER_SIZE + data_size))
703        .await;
704    res?;
705
706    Ok(data_slice.into_inner())
707}
708
709/// Parse frame, verify it and extract SNI.
710/// Return is_pass and Option<SNI>.
711/// It requires &mut but it is meant for doing operation inplace.
712/// It does not modify the data.
713///
714/// Only used by V3 protocol.
715fn verified_extract_sni(frame: &[u8], password: &str) -> (bool, Option<Vec<u8>>) {
716    // 5 frame header + 1 handshake type + 3 length + 2 version + 32 random + 1 session id len + 32 session id
717    const MIN_LEN: usize = TLS_HEADER_SIZE + 1 + 3 + 2 + TLS_RANDOM_SIZE + 1 + TLS_SESSION_ID_SIZE;
718    const HMAC_IDX: usize = SESSION_ID_LEN_IDX + 1 + TLS_SESSION_ID_SIZE - HMAC_SIZE;
719    const ZERO4B: [u8; HMAC_SIZE] = [0; HMAC_SIZE];
720
721    if frame.len() < SESSION_ID_LEN_IDX || frame[0] != HANDSHAKE || frame[5] != CLIENT_HELLO {
722        return (false, None);
723    }
724
725    let pass = if frame.len() < MIN_LEN || frame[SESSION_ID_LEN_IDX] != TLS_SESSION_ID_SIZE as u8 {
726        false
727    } else {
728        let mut hmac = Hmac::new(password, (&[], &[]));
729        hmac.update(&frame[TLS_HEADER_SIZE..HMAC_IDX]);
730        hmac.update(&ZERO4B);
731        hmac.update(&frame[HMAC_IDX + HMAC_SIZE..]);
732        hmac.finalize() == frame[HMAC_IDX..HMAC_IDX + HMAC_SIZE]
733    };
734
735    let mut cursor = std::io::Cursor::new(&frame[SESSION_ID_LEN_IDX..]);
736    macro_rules! read_ok {
737        ($res: expr) => {
738            match $res {
739                Ok(r) => r,
740                Err(_) => {
741                    return (pass, None);
742                }
743            }
744        };
745    }
746
747    // skip session id
748    read_ok!(cursor.skip_by_u8());
749    // skip cipher suites
750    read_ok!(cursor.skip_by_u16());
751    // skip compression method
752    read_ok!(cursor.skip_by_u8());
753    // skip ext length
754    read_ok!(cursor.read_u16::<BigEndian>());
755
756    loop {
757        let ext_type = read_ok!(cursor.read_u16::<BigEndian>());
758        if ext_type != SNI_EXT_TYPE {
759            read_ok!(cursor.skip_by_u16());
760            continue;
761        }
762        tracing::debug!("found server_name extension");
763        let _ext_len = read_ok!(cursor.read_u16::<BigEndian>());
764        let _sni_len = read_ok!(cursor.read_u16::<BigEndian>());
765        // must be host_name
766        if read_ok!(cursor.read_u8()) != 0 {
767            return (pass, None);
768        }
769        let sni = Some(read_ok!(cursor.read_by_u16()));
770        return (pass, sni);
771    }
772}
773
774/// Parse given frame and extract ServerRandom.
775/// Return Option<ServerRandom>.
776///
777/// Only used by V3 protocol.
778fn extract_server_random(frame: &[u8]) -> Option<[u8; TLS_RANDOM_SIZE]> {
779    // 5 frame header + 1 handshake type + 3 length + 2 version + 32 random
780    const MIN_LEN: usize = TLS_HEADER_SIZE + 1 + 3 + 2 + TLS_RANDOM_SIZE;
781
782    if frame.len() < MIN_LEN || frame[0] != HANDSHAKE || frame[5] != SERVER_HELLO {
783        return None;
784    }
785
786    let mut server_random = [0; TLS_RANDOM_SIZE];
787    unsafe {
788        copy_nonoverlapping(
789            frame.as_ptr().add(SERVER_RANDOM_IDX),
790            server_random.as_mut_ptr(),
791            TLS_RANDOM_SIZE,
792        )
793    };
794
795    Some(server_random)
796}
797
798/// Copy frame by frame until a appdata frame matches hmac.
799/// Return the matched pure data(without header).
800///
801/// Only used by V3 protocol.
802async fn copy_by_frame_until_hmac_matches(
803    mut read: impl AsyncReadRent,
804    mut write: impl AsyncWriteRent,
805    hmac: &mut Hmac,
806) -> std::io::Result<Vec<u8>> {
807    let mut g_buffer = Vec::new();
808
809    loop {
810        let buffer = read_exact_frame_into(&mut read, g_buffer).await?;
811        if buffer.len() > 9 && buffer[0] == APPLICATION_DATA {
812            // check hmac
813            let mut tmp_hmac = hmac.to_owned();
814            tmp_hmac.update(&buffer[TLS_HMAC_HEADER_SIZE..]);
815            let h = tmp_hmac.finalize();
816
817            if buffer[TLS_HEADER_SIZE..TLS_HMAC_HEADER_SIZE] == h {
818                hmac.update(&buffer[TLS_HMAC_HEADER_SIZE..]);
819                hmac.update(&buffer[TLS_HEADER_SIZE..TLS_HMAC_HEADER_SIZE]);
820                return Ok(buffer[TLS_HMAC_HEADER_SIZE..].to_vec());
821            }
822        }
823
824        let (res, buffer) = write.write_all(buffer).await;
825        res?;
826        g_buffer = buffer;
827    }
828}
829
830/// Copy frame by frame.
831/// Modify appdata frame:
832/// 1. Cycle XOR xor data.
833/// 2. Calculate HMAC and insert before the frame data.
834///
835/// Only used by V3 protocol.
836async fn copy_by_frame_with_modification(
837    mut read: impl AsyncReadRent,
838    mut write: impl AsyncWriteRent,
839    hmac: &mut Hmac,
840    xor: &[u8],
841    stop: &mut Sender<()>,
842) -> std::io::Result<()> {
843    let mut g_buffer = Vec::new();
844    let stop = stop.closed();
845    monoio::pin!(stop);
846
847    loop {
848        monoio::select! {
849            // this function can be stopped by a channel when reading.
850            _ = &mut stop => {
851                return Ok(());
852            },
853            buffer_res = read_exact_frame_into(&mut read, g_buffer) => {
854                let mut buffer = buffer_res?;
855                // Note: if we get frame, it is guaranteed valid.
856                if buffer[0] == APPLICATION_DATA {
857                    // do modification: xor data, add 4-byte hmac, update tls frame length
858                    xor_slice(&mut buffer[TLS_HEADER_SIZE..], xor);
859                    hmac.update(&buffer[TLS_HEADER_SIZE..]);
860                    let hash = hmac.finalize();
861                    buffer.extend_from_slice(&hash);
862                    unsafe {
863                        copy(buffer.as_ptr().add(TLS_HEADER_SIZE), buffer.as_mut_ptr().add(TLS_HMAC_HEADER_SIZE), buffer.len() - TLS_HMAC_HEADER_SIZE);
864                        copy_nonoverlapping(hash.as_ptr(), buffer.as_mut_ptr().add(TLS_HEADER_SIZE), HMAC_SIZE);
865                    }
866
867                    let mut size: [u8; 2] = Default::default();
868                    size.copy_from_slice(&buffer[3..5]);
869                    let data_size = u16::from_be_bytes(size);
870                    // Normally it does not overflow.
871                    let data_size = data_size.wrapping_add(HMAC_SIZE as u16);
872                    (&mut buffer[3..5]).write_u16::<BigEndian>(data_size).unwrap();
873                }
874
875                // writing is not cancelable
876                let (res, buffer) = write.write_all(buffer).await;
877                res?;
878                g_buffer = buffer;
879            }
880        }
881    }
882}
883
884#[cfg(test)]
885mod tests {
886    use super::*;
887
888    fn to_map<K: Into<String>, V: Into<String>>(
889        kvs: Vec<(K, V)>,
890    ) -> rustc_hash::FxHashMap<String, String> {
891        kvs.into_iter().map(|(k, v)| (k.into(), v.into())).collect()
892    }
893
894    macro_rules! map {
895        [] => {rustc_hash::FxHashMap::<String, String>::default()};
896        [$($k:expr => $v:expr),*] => {to_map(vec![$(($k.to_owned(), $v.to_owned())), *])};
897        [$($k:expr => $v:expr,)*] => {to_map(vec![$(($k.to_owned(), $v.to_owned())), *])};
898    }
899
900    macro_rules! s {
901        ($v:expr) => {
902            $v.to_string()
903        };
904    }
905
906    #[test]
907    fn parse_tls_addrs() {
908        assert_eq!(
909            TlsAddrs::try_from("google.com").unwrap(),
910            TlsAddrs {
911                dispatch: map![],
912                fallback: s!("google.com:443"),
913                wildcard_sni: Default::default(),
914            }
915        );
916        assert_eq!(
917            TlsAddrs::try_from("feishu.cn;cloudflare.com:1.1.1.1:80;google.com").unwrap(),
918            TlsAddrs {
919                dispatch: map![
920                    "feishu.cn" => "feishu.cn:443",
921                    "cloudflare.com" => "1.1.1.1:80",
922                ],
923                fallback: s!("google.com:443"),
924                wildcard_sni: Default::default(),
925            }
926        );
927        assert_eq!(
928            TlsAddrs::try_from("captive.apple.com;feishu.cn:80").unwrap(),
929            TlsAddrs {
930                dispatch: map![
931                    "captive.apple.com" => "captive.apple.com:443",
932                ],
933                fallback: s!("feishu.cn:80"),
934                wildcard_sni: Default::default(),
935            }
936        );
937    }
938}