1use std::{
2 io,
3 io::Write,
4 net::{Shutdown, TcpStream},
5 sync::{Arc, Mutex, MutexGuard},
6};
7
8use rustls::Connection;
9
10mod buffer;
11pub use buffer::BufCfg;
12use buffer::Buffer;
13
14struct Shared {
15 stream: TcpStream,
16 connection: Mutex<Connection>,
17}
18
19pub struct ReadHalf {
20 shared: Arc<Shared>,
21 buf: Buffer,
22}
23
24impl io::Read for ReadHalf {
25 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
26 let mut connection = self.shared.connection.lock().unwrap();
27
28 while connection.wants_read() {
29 if self.buf.is_empty() {
30 drop(connection);
31
32 let bytes_read = self.buf.read_from(&mut &self.shared.stream)?;
33
34 connection = self.shared.connection.lock().unwrap();
35
36 if bytes_read == 0 {
37 break;
38 }
39 }
40
41 let bytes_read = connection.read_tls(&mut self.buf)?;
42 debug_assert_ne!(bytes_read, 0);
43
44 connection
45 .process_new_packets()
46 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
47 }
48
49 connection.reader().read(buf)
50 }
51}
52
53impl ReadHalf {
54 pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
55 self.shared.stream.shutdown(how)
56 }
57}
58
59pub struct WriteHalf {
60 shared: Arc<Shared>,
61 buf: Buffer,
62}
63
64impl WriteHalf {
65 pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
66 if how == Shutdown::Read {
67 return self.shared.stream.shutdown(Shutdown::Read);
68 }
69
70 let mut connection = self.shared.connection.lock().unwrap();
71 connection.send_close_notify();
72 let res = flush(&mut self.buf, &self.shared, connection);
73 self.shared.stream.shutdown(how)?;
74 res
75 }
76}
77
78fn wants_write_loop<'a>(
79 buf: &mut Buffer,
80 shared: &'a Shared,
81 mut connection: MutexGuard<'a, Connection>,
82) -> io::Result<MutexGuard<'a, Connection>> {
83 while connection.wants_write() {
84 while buf.is_full() {
85 drop(connection);
86
87 buf.write_to(&mut &shared.stream)?;
88
89 connection = shared.connection.lock().unwrap();
90 }
91
92 connection.write_tls(buf)?;
93 }
94
95 Ok(connection)
96}
97
98fn flush<'a>(
99 buf: &mut Buffer,
100 shared: &'a Shared,
101 mut connection: MutexGuard<'a, Connection>,
102) -> io::Result<()> {
103 connection.writer().flush()?;
104
105 let connection = wants_write_loop(buf, shared, connection)?;
106 std::mem::drop(connection);
107
108 while !buf.is_empty() {
109 buf.write_to(&mut &shared.stream)?;
110 }
111
112 Ok(())
113}
114
115impl io::Write for WriteHalf {
116 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
117 let connection = self.shared.connection.lock().unwrap();
118 let mut connection = wants_write_loop(&mut self.buf, &self.shared, connection)?;
119 connection.writer().write(buf)
120 }
121
122 fn flush(&mut self) -> io::Result<()> {
123 let connection = self.shared.connection.lock().unwrap();
124 flush(&mut self.buf, &self.shared, connection)
125 }
126}
127
128pub fn split<D1: Into<Vec<u8>>, D2: Into<Vec<u8>>>(
129 stream: TcpStream,
130 connection: Connection,
131 read_buf_cfg: BufCfg<D1>,
132 write_buf_cfg: BufCfg<D2>,
133) -> (ReadHalf, WriteHalf) {
134 assert!(!connection.is_handshaking());
135
136 let shared = Arc::new(Shared {
137 stream,
138 connection: Mutex::new(connection),
139 });
140
141 let read_half = ReadHalf {
142 shared: shared.clone(),
143 buf: Buffer::build_from(read_buf_cfg),
144 };
145
146 let write_half = WriteHalf {
147 shared,
148 buf: Buffer::build_from(write_buf_cfg),
149 };
150
151 (read_half, write_half)
152}