tls_api_rustls/
rustls_utils.rs1use rustls::ClientConnection;
2use rustls::ServerConnection;
3use rustls::StreamOwned;
4use std::fmt::Arguments;
5use std::io;
6use std::io::IoSlice;
7use std::io::IoSliceMut;
8use std::io::Read;
9use std::io::Write;
10use tls_api::async_as_sync::WriteShutdown;
11
12pub enum RustlsSessionRef<'a> {
13 Client(&'a ClientConnection),
14 Server(&'a ServerConnection),
15}
16
17pub(crate) enum RustlsStream<S: Read + Write> {
19 Server(StreamOwned<ServerConnection, S>),
20 Client(StreamOwned<ClientConnection, S>),
21}
22
23impl<S: Read + Write> RustlsStream<S> {
24 pub fn session(&self) -> RustlsSessionRef {
25 match self {
26 RustlsStream::Server(s) => RustlsSessionRef::Server(&s.conn),
27 RustlsStream::Client(s) => RustlsSessionRef::Client(&s.conn),
28 }
29 }
30}
31
32impl<S: Read + Write> RustlsStream<S> {
33 pub fn get_socket_mut(&mut self) -> &mut S {
34 match self {
35 RustlsStream::Server(s) => s.get_mut(),
36 RustlsStream::Client(s) => s.get_mut(),
37 }
38 }
39
40 pub fn get_socket_ref(&self) -> &S {
41 match self {
42 RustlsStream::Server(s) => s.get_ref(),
43 RustlsStream::Client(s) => s.get_ref(),
44 }
45 }
46
47 pub fn is_handshaking(&self) -> bool {
48 match self {
49 RustlsStream::Server(s) => s.conn.is_handshaking(),
50 RustlsStream::Client(s) => s.conn.is_handshaking(),
51 }
52 }
53
54 pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
55 match self {
56 RustlsStream::Server(s) => s.conn.complete_io(&mut s.sock),
57 RustlsStream::Client(s) => s.conn.complete_io(&mut s.sock),
58 }
59 }
60
61 pub fn get_alpn_protocol(&self) -> Option<&[u8]> {
62 match self {
63 RustlsStream::Server(s) => s.conn.alpn_protocol(),
64 RustlsStream::Client(s) => s.conn.alpn_protocol(),
65 }
66 }
67}
68
69impl<S: Read + Write> Write for RustlsStream<S> {
70 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
71 match self {
72 RustlsStream::Server(s) => s.write(buf),
73 RustlsStream::Client(s) => s.write(buf),
74 }
75 }
76
77 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
78 match self {
79 RustlsStream::Server(s) => s.write_vectored(bufs),
80 RustlsStream::Client(s) => s.write_vectored(bufs),
81 }
82 }
83
84 fn flush(&mut self) -> io::Result<()> {
85 match self {
86 RustlsStream::Server(s) => s.flush(),
87 RustlsStream::Client(s) => s.flush(),
88 }
89 }
90
91 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
92 match self {
93 RustlsStream::Server(s) => s.write_all(buf),
94 RustlsStream::Client(s) => s.write_all(buf),
95 }
96 }
97
98 fn write_fmt(&mut self, fmt: Arguments<'_>) -> io::Result<()> {
99 match self {
100 RustlsStream::Server(s) => s.write_fmt(fmt),
101 RustlsStream::Client(s) => s.write_fmt(fmt),
102 }
103 }
104}
105
106impl<S: Read + Write> WriteShutdown for RustlsStream<S> {
107 fn shutdown(&mut self) -> Result<(), io::Error> {
108 match self {
109 RustlsStream::Server(s) => s.conn.send_close_notify(),
110 RustlsStream::Client(s) => s.conn.send_close_notify(),
111 }
112 self.flush()?;
113 Ok(())
114 }
115}
116
117impl<S: Read + Write> Read for RustlsStream<S> {
118 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
119 match self {
120 RustlsStream::Server(s) => s.read(buf),
121 RustlsStream::Client(s) => s.read(buf),
122 }
123 }
124
125 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
126 match self {
127 RustlsStream::Server(s) => s.read_vectored(bufs),
128 RustlsStream::Client(s) => s.read_vectored(bufs),
129 }
130 }
131
132 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
133 match self {
134 RustlsStream::Server(s) => s.read_to_end(buf),
135 RustlsStream::Client(s) => s.read_to_end(buf),
136 }
137 }
138
139 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
140 match self {
141 RustlsStream::Server(s) => s.read_to_string(buf),
142 RustlsStream::Client(s) => s.read_to_string(buf),
143 }
144 }
145
146 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
147 match self {
148 RustlsStream::Server(s) => s.read_exact(buf),
149 RustlsStream::Client(s) => s.read_exact(buf),
150 }
151 }
152}