1
2
3
4pub mod vhost {
15 use std::io::{self, Cursor, Read, Write};
16 use std::net::TcpStream;
17 use std::sync::{Arc, Mutex};
18
19 pub fn new(mut stream: TcpStream) -> Result<SharedConn, io::Error> {
20 let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new())));
21
22 let mut buf: [u8; 1024] = [0_u8; 1024];
24 let n = stream.read(&mut buf)?;
25 if n > 0 {
26 let mut buffer = buffer.lock().unwrap();
27 buffer.get_mut().extend_from_slice(&buf[..n]);
28 }
29
30 let sni = parse_sni(&buf, n)?;
31
32 Ok(SharedConn {
33 stream,
34 buffer,
35 sni,
36 })
37 }
38
39 pub struct SharedConn {
40 pub stream: TcpStream,
41 buffer: Arc<Mutex<Cursor<Vec<u8>>>>,
42
43 sni: String,
44 }
45
46 impl SharedConn {
47 pub fn get_sni(&self) -> String {
48 self.sni.clone()
49 }
50 }
51
52 impl Read for SharedConn {
53 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
54 let mut buffer = self.buffer.lock().unwrap();
55 if buffer.position() < buffer.get_ref().len() as u64 {
56 buffer.read(buf)
57 } else {
58 self.stream.read(buf)
59 }
60 }
61 }
62
63 impl Write for SharedConn {
64 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
65 self.stream.write(buf)
66 }
67
68 fn flush(&mut self) -> io::Result<()> {
70 self.stream.flush()
72 }
73 }
74
75 fn parse_sni(buf: &[u8], n: usize) -> Result<String, io::Error> {
76 if n < 42 {
78 return Err(io::Error::new(
79 io::ErrorKind::Other,
80 "tls handshake is too short",
81 ));
82 }
83
84 let mut m: String = "".to_string();
85
86 let session_id_len = buf[43] as usize;
89 if n < 44 + session_id_len {
90 return Err(io::Error::new(
91 io::ErrorKind::Other,
92 "tls handshake is too short",
93 ));
94 }
95
96 let mut cur = 44 + session_id_len;
97 if n < cur + 2 {
98 return Err(io::Error::new(
99 io::ErrorKind::Other,
100 "tls handshake is too short",
101 ));
102 }
103
104 let cipher_suites_len = ((buf[cur] as usize) << 8 | buf[cur + 1] as usize) as usize;
105 if n < cur + 2 + cipher_suites_len {
106 return Err(io::Error::new(
107 io::ErrorKind::Other,
108 "tls handshake is too short",
109 ));
110 }
111 cur = cur + 2 + cipher_suites_len;
112
113 let compression_methods_len = buf[cur] as usize;
114 if n < cur + 3 + cipher_suites_len + compression_methods_len {
115 return Err(io::Error::new(
116 io::ErrorKind::Other,
117 "tls handshake is too short",
118 ));
119 }
120
121 cur = cur + 1 + compression_methods_len;
122
123 let extension_len = (buf[cur] as usize) << 8 | (buf[cur + 1] as usize);
124 if n < cur + extension_len {
125 return Err(io::Error::new(
126 io::ErrorKind::Other,
127 "tls handshake is too short",
128 ));
129 }
130
131 cur = cur + 2;
132
133 let mut ext_cur = 0;
134 while ext_cur < extension_len {
135 let ext_type = (buf[cur] as u16) << 8 | buf[cur + 1] as u16;
136 let ext_len = (buf[cur + 2] as usize) << 8 | buf[cur + 3] as usize;
137 if ext_type == 0 {
138 m = String::from_utf8(buf[cur + 9..cur + 4 + ext_len].to_vec()).unwrap();
139 break;
140 }
141 cur += 4 + ext_len;
142 ext_cur += 4 + ext_len;
143 }
144
145 Ok(m)
146 }
147}
148
149#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_parse_sni() {
156 use std::net::TcpListener;
158 let listener = TcpListener::bind("0.0.0.0:443").unwrap();
159 let (stream, _) = listener.accept().unwrap();
160 let tls_conn = vhost::new(stream).unwrap();
161 let sni = tls_conn.get_sni();
162 assert_eq!(sni, "www.baidu.com");
164
165 }
168}