1use std::io::{self, Cursor, Read, Write};
2use std::net::TcpStream;
3use std::sync::{Arc, Mutex};
4
5
6pub(crate) struct SharedConn {
15 pub stream: TcpStream,
16 buffer: Arc<Mutex<Cursor<Vec<u8>>>>,
17
18 sni: String,
19}
20
21impl SharedConn {
22 pub fn new(mut stream: TcpStream) -> Result<SharedConn, std::io::Error> {
23 let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new())));
24
25 let mut buf: [u8; 1024] = [0_u8; 1024];
27 let n = stream.read(&mut buf)?;
28 if n > 0 {
29 let mut buffer = buffer.lock().unwrap();
30 buffer.get_mut().extend_from_slice(&buf[..n]);
31 }
32
33 let sni = parse_sni(&buf, n)?;
34
35 Ok(SharedConn {
36 stream,
37 buffer,
38 sni,
39 })
40 }
41
42 pub fn get_sni(&self) -> String {
43 self.sni.clone()
44 }
45}
46
47impl Read for SharedConn {
48 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
49 let mut buffer = self.buffer.lock().unwrap();
50 if buffer.position() < buffer.get_ref().len() as u64 {
51 buffer.read(buf)
52 } else {
53 self.stream.read(buf)
54 }
55 }
56}
57
58impl Write for SharedConn {
59 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
60 self.stream.write(buf)
61 }
62
63 fn flush(&mut self) -> io::Result<()> {
65 self.stream.flush()
67 }
68}
69
70fn parse_sni(buf: &[u8], n: usize) -> Result<String, io::Error> {
71 if n < 42 {
73 return Err(io::Error::new(
74 io::ErrorKind::Other,
75 "tls handshake is too short",
76 ));
77 }
78
79 let mut m: String = "".to_string();
80
81 let session_id_len = buf[43] as usize;
84 if n < 44 + session_id_len {
85 return Err(io::Error::new(
86 io::ErrorKind::Other,
87 "tls handshake is too short",
88 ));
89 }
90
91 let mut cur = 44 + session_id_len;
92 if n < cur + 2 {
93 return Err(io::Error::new(
94 io::ErrorKind::Other,
95 "tls handshake is too short",
96 ));
97 }
98
99 let cipher_suites_len = ((buf[cur] as usize) << 8 | buf[cur + 1] as usize) as usize;
100 if n < cur + 2 + cipher_suites_len {
101 return Err(io::Error::new(
102 io::ErrorKind::Other,
103 "tls handshake is too short",
104 ));
105 }
106 cur = cur + 2 + cipher_suites_len;
107
108 let compression_methods_len = buf[cur] as usize;
109 if n < cur + 3 + cipher_suites_len + compression_methods_len {
110 return Err(io::Error::new(
111 io::ErrorKind::Other,
112 "tls handshake is too short",
113 ));
114 }
115
116 cur = cur + 1 + compression_methods_len;
117
118 let extension_len = (buf[cur] as usize) << 8 | (buf[cur + 1] as usize);
119 if n < cur + extension_len {
120 return Err(io::Error::new(
121 io::ErrorKind::Other,
122 "tls handshake is too short",
123 ));
124 }
125
126 cur = cur + 2;
127
128 let mut ext_cur = 0;
129 while ext_cur < extension_len {
130 let ext_type = (buf[cur] as u16) << 8 | buf[cur + 1] as u16;
131 let ext_len = (buf[cur + 2] as usize) << 8 | buf[cur + 3] as usize;
132 if ext_type == 0 {
133 m = String::from_utf8(buf[cur + 9..cur + 4 + ext_len].to_vec()).unwrap();
134 break;
135 }
136 cur += 4 + ext_len;
137 ext_cur += 4 + ext_len;
138 }
139
140 Ok(m)
141}
142
143#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn test_parse_sni() {
150 use std::net::TcpListener;
152 let listener = TcpListener::bind("0.0.0.0:443").unwrap();
153 let (stream, _) = listener.accept().unwrap();
154 let tls_conn = SharedConn::new(stream).unwrap();
155 let sni = tls_conn.get_sni();
156 assert_eq!(sni, "www.baidu.com");
158
159 }
162}