tls_api_rustls/
rustls_utils.rs

1use rustls::ClientConnection;
2use rustls::ServerConnection;
3use rustls::StreamOwned;
4use std::fmt::Arguments;
5use std::io;
6use std::io::IoSlice;
7use std::io::IoSliceMut;
8use std::io::Read;
9use std::io::Write;
10use tls_api::async_as_sync::WriteShutdown;
11
12pub enum RustlsSessionRef<'a> {
13    Client(&'a ClientConnection),
14    Server(&'a ServerConnection),
15}
16
17/// Merge client and server stream into single interface
18pub(crate) enum RustlsStream<S: Read + Write> {
19    Server(StreamOwned<ServerConnection, S>),
20    Client(StreamOwned<ClientConnection, S>),
21}
22
23impl<S: Read + Write> RustlsStream<S> {
24    pub fn session(&self) -> RustlsSessionRef {
25        match self {
26            RustlsStream::Server(s) => RustlsSessionRef::Server(&s.conn),
27            RustlsStream::Client(s) => RustlsSessionRef::Client(&s.conn),
28        }
29    }
30}
31
32impl<S: Read + Write> RustlsStream<S> {
33    pub fn get_socket_mut(&mut self) -> &mut S {
34        match self {
35            RustlsStream::Server(s) => s.get_mut(),
36            RustlsStream::Client(s) => s.get_mut(),
37        }
38    }
39
40    pub fn get_socket_ref(&self) -> &S {
41        match self {
42            RustlsStream::Server(s) => s.get_ref(),
43            RustlsStream::Client(s) => s.get_ref(),
44        }
45    }
46
47    pub fn is_handshaking(&self) -> bool {
48        match self {
49            RustlsStream::Server(s) => s.conn.is_handshaking(),
50            RustlsStream::Client(s) => s.conn.is_handshaking(),
51        }
52    }
53
54    pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
55        match self {
56            RustlsStream::Server(s) => s.conn.complete_io(&mut s.sock),
57            RustlsStream::Client(s) => s.conn.complete_io(&mut s.sock),
58        }
59    }
60
61    pub fn get_alpn_protocol(&self) -> Option<&[u8]> {
62        match self {
63            RustlsStream::Server(s) => s.conn.alpn_protocol(),
64            RustlsStream::Client(s) => s.conn.alpn_protocol(),
65        }
66    }
67}
68
69impl<S: Read + Write> Write for RustlsStream<S> {
70    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
71        match self {
72            RustlsStream::Server(s) => s.write(buf),
73            RustlsStream::Client(s) => s.write(buf),
74        }
75    }
76
77    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
78        match self {
79            RustlsStream::Server(s) => s.write_vectored(bufs),
80            RustlsStream::Client(s) => s.write_vectored(bufs),
81        }
82    }
83
84    fn flush(&mut self) -> io::Result<()> {
85        match self {
86            RustlsStream::Server(s) => s.flush(),
87            RustlsStream::Client(s) => s.flush(),
88        }
89    }
90
91    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
92        match self {
93            RustlsStream::Server(s) => s.write_all(buf),
94            RustlsStream::Client(s) => s.write_all(buf),
95        }
96    }
97
98    fn write_fmt(&mut self, fmt: Arguments<'_>) -> io::Result<()> {
99        match self {
100            RustlsStream::Server(s) => s.write_fmt(fmt),
101            RustlsStream::Client(s) => s.write_fmt(fmt),
102        }
103    }
104}
105
106impl<S: Read + Write> WriteShutdown for RustlsStream<S> {
107    fn shutdown(&mut self) -> Result<(), io::Error> {
108        match self {
109            RustlsStream::Server(s) => s.conn.send_close_notify(),
110            RustlsStream::Client(s) => s.conn.send_close_notify(),
111        }
112        self.flush()?;
113        Ok(())
114    }
115}
116
117impl<S: Read + Write> Read for RustlsStream<S> {
118    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
119        match self {
120            RustlsStream::Server(s) => s.read(buf),
121            RustlsStream::Client(s) => s.read(buf),
122        }
123    }
124
125    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
126        match self {
127            RustlsStream::Server(s) => s.read_vectored(bufs),
128            RustlsStream::Client(s) => s.read_vectored(bufs),
129        }
130    }
131
132    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
133        match self {
134            RustlsStream::Server(s) => s.read_to_end(buf),
135            RustlsStream::Client(s) => s.read_to_end(buf),
136        }
137    }
138
139    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
140        match self {
141            RustlsStream::Server(s) => s.read_to_string(buf),
142            RustlsStream::Client(s) => s.read_to_string(buf),
143        }
144    }
145
146    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
147        match self {
148            RustlsStream::Server(s) => s.read_exact(buf),
149            RustlsStream::Client(s) => s.read_exact(buf),
150        }
151    }
152}