rustls_split/
lib.rs

1use std::{
2    io,
3    io::Write,
4    net::{Shutdown, TcpStream},
5    sync::{Arc, Mutex, MutexGuard},
6};
7
8use rustls::Connection;
9
10mod buffer;
11pub use buffer::BufCfg;
12use buffer::Buffer;
13
14struct Shared {
15    stream: TcpStream,
16    connection: Mutex<Connection>,
17}
18
19pub struct ReadHalf {
20    shared: Arc<Shared>,
21    buf: Buffer,
22}
23
24impl io::Read for ReadHalf {
25    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
26        let mut connection = self.shared.connection.lock().unwrap();
27
28        while connection.wants_read() {
29            if self.buf.is_empty() {
30                drop(connection);
31
32                let bytes_read = self.buf.read_from(&mut &self.shared.stream)?;
33
34                connection = self.shared.connection.lock().unwrap();
35
36                if bytes_read == 0 {
37                    break;
38                }
39            }
40
41            let bytes_read = connection.read_tls(&mut self.buf)?;
42            debug_assert_ne!(bytes_read, 0);
43
44            connection
45                .process_new_packets()
46                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
47        }
48
49        connection.reader().read(buf)
50    }
51}
52
53impl ReadHalf {
54    pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
55        self.shared.stream.shutdown(how)
56    }
57}
58
59pub struct WriteHalf {
60    shared: Arc<Shared>,
61    buf: Buffer,
62}
63
64impl WriteHalf {
65    pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
66        if how == Shutdown::Read {
67            return self.shared.stream.shutdown(Shutdown::Read);
68        }
69
70        let mut connection = self.shared.connection.lock().unwrap();
71        connection.send_close_notify();
72        let res = flush(&mut self.buf, &self.shared, connection);
73        self.shared.stream.shutdown(how)?;
74        res
75    }
76}
77
78fn wants_write_loop<'a>(
79    buf: &mut Buffer,
80    shared: &'a Shared,
81    mut connection: MutexGuard<'a, Connection>,
82) -> io::Result<MutexGuard<'a, Connection>> {
83    while connection.wants_write() {
84        while buf.is_full() {
85            drop(connection);
86
87            buf.write_to(&mut &shared.stream)?;
88
89            connection = shared.connection.lock().unwrap();
90        }
91
92        connection.write_tls(buf)?;
93    }
94
95    Ok(connection)
96}
97
98fn flush<'a>(
99    buf: &mut Buffer,
100    shared: &'a Shared,
101    mut connection: MutexGuard<'a, Connection>,
102) -> io::Result<()> {
103    connection.writer().flush()?;
104
105    let connection = wants_write_loop(buf, shared, connection)?;
106    std::mem::drop(connection);
107
108    while !buf.is_empty() {
109        buf.write_to(&mut &shared.stream)?;
110    }
111
112    Ok(())
113}
114
115impl io::Write for WriteHalf {
116    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
117        let connection = self.shared.connection.lock().unwrap();
118        let mut connection = wants_write_loop(&mut self.buf, &self.shared, connection)?;
119        connection.writer().write(buf)
120    }
121
122    fn flush(&mut self) -> io::Result<()> {
123        let connection = self.shared.connection.lock().unwrap();
124        flush(&mut self.buf, &self.shared, connection)
125    }
126}
127
128pub fn split<D1: Into<Vec<u8>>, D2: Into<Vec<u8>>>(
129    stream: TcpStream,
130    connection: Connection,
131    read_buf_cfg: BufCfg<D1>,
132    write_buf_cfg: BufCfg<D2>,
133) -> (ReadHalf, WriteHalf) {
134    assert!(!connection.is_handshaking());
135
136    let shared = Arc::new(Shared {
137        stream,
138        connection: Mutex::new(connection),
139    });
140
141    let read_half = ReadHalf {
142        shared: shared.clone(),
143        buf: Buffer::build_from(read_buf_cfg),
144    };
145
146    let write_half = WriteHalf {
147        shared,
148        buf: Buffer::build_from(write_buf_cfg),
149    };
150
151    (read_half, write_half)
152}