rust_vhost/
lib.rs

1
2
3
4///! vhost
5///!  vhost is to fetch sni info and return value is still available
6///
7///! # Example
8/// ```
9/// let tls_conn = new(conn);
10/// let sni = tls_conn.get_sni();
11/// assert!("google.com", sni);
12/// ```
13
14pub mod vhost {
15    use std::io::{self, Cursor, Read, Write};
16    use std::net::TcpStream;
17    use std::sync::{Arc, Mutex};
18
19    pub fn new(mut stream: TcpStream) -> Result<SharedConn, io::Error> {
20        let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new())));
21
22        // read tls handshake from stream, and then put data into buffer
23        let mut buf: [u8; 1024] = [0_u8; 1024];
24        let n = stream.read(&mut buf)?;
25        if n > 0 {
26            let mut buffer = buffer.lock().unwrap();
27            buffer.get_mut().extend_from_slice(&buf[..n]);
28        }
29
30        let sni = parse_sni(&buf, n)?;
31
32        Ok(SharedConn {
33            stream,
34            buffer,
35            sni,
36        })
37    }
38
39    pub struct SharedConn {
40        pub stream: TcpStream,
41        buffer: Arc<Mutex<Cursor<Vec<u8>>>>,
42
43        sni: String,
44    }
45
46    impl SharedConn {
47        pub fn get_sni(&self) -> String {
48            self.sni.clone()
49        }
50    }
51
52    impl Read for SharedConn {
53        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
54            let mut buffer = self.buffer.lock().unwrap();
55            if buffer.position() < buffer.get_ref().len() as u64 {
56                buffer.read(buf)
57            } else {
58                self.stream.read(buf)
59            }
60        }
61    }
62
63    impl Write for SharedConn {
64        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
65            self.stream.write(buf)
66        }
67
68        // 实现flush方法
69        fn flush(&mut self) -> io::Result<()> {
70            // 同样,这里简单地将标准输出的缓冲区刷新,实际应用中应根据需要进行操作
71            self.stream.flush()
72        }
73    }
74
75    fn parse_sni(buf: &[u8], n: usize) -> Result<String, io::Error> {
76        // 提取出 server name
77        if n < 42 {
78            return Err(io::Error::new(
79                io::ErrorKind::Other,
80                "tls handshake is too short",
81            ));
82        }
83
84        let mut m: String = "".to_string();
85
86        //m.vers = (buf[4] << 8 | buf[5]) as u16;
87
88        let session_id_len = buf[43] as usize;
89        if n < 44 + session_id_len {
90            return Err(io::Error::new(
91                io::ErrorKind::Other,
92                "tls handshake is too short",
93            ));
94        }
95
96        let mut cur = 44 + session_id_len;
97        if n < cur + 2 {
98            return Err(io::Error::new(
99                io::ErrorKind::Other,
100                "tls handshake is too short",
101            ));
102        }
103
104        let cipher_suites_len = ((buf[cur] as usize) << 8 | buf[cur + 1] as usize) as usize;
105        if n < cur + 2 + cipher_suites_len {
106            return Err(io::Error::new(
107                io::ErrorKind::Other,
108                "tls handshake is too short",
109            ));
110        }
111        cur = cur + 2 + cipher_suites_len;
112
113        let compression_methods_len = buf[cur] as usize;
114        if n < cur + 3 + cipher_suites_len + compression_methods_len {
115            return Err(io::Error::new(
116                io::ErrorKind::Other,
117                "tls handshake is too short",
118            ));
119        }
120
121        cur = cur + 1 + compression_methods_len;
122
123        let extension_len = (buf[cur] as usize) << 8 | (buf[cur + 1] as usize);
124        if n < cur + extension_len {
125            return Err(io::Error::new(
126                io::ErrorKind::Other,
127                "tls handshake is too short",
128            ));
129        }
130
131        cur = cur + 2;
132
133        let mut ext_cur = 0;
134        while ext_cur < extension_len {
135            let ext_type = (buf[cur] as u16) << 8 | buf[cur + 1] as u16;
136            let ext_len = (buf[cur + 2] as usize) << 8 | buf[cur + 3] as usize;
137            if ext_type == 0 {
138                m = String::from_utf8(buf[cur + 9..cur + 4 + ext_len].to_vec()).unwrap();
139                break;
140            }
141            cur += 4 + ext_len;
142            ext_cur += 4 + ext_len;
143        }
144
145        Ok(m)
146    }
147}
148
149// 为上面的代码添加测试
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_parse_sni() {
156        // 监听 443 端口,来获取 tls 握手信息
157        use std::net::TcpListener;
158        let listener = TcpListener::bind("0.0.0.0:443").unwrap();
159        let (stream, _) = listener.accept().unwrap();
160        let tls_conn = vhost::new(stream).unwrap();
161        let sni = tls_conn.get_sni();
162        // 添加  assert 确保 sni 为 www.baidu.com
163        assert_eq!(sni, "www.baidu.com");
164
165        // local test curl
166        //  curl -vv --resolve www.baidu.com:443:127.0.0.1 https://www.baidu.com
167    }
168}