1use std::future::Future;
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7#[cfg(feature = "early-data")]
8use std::task::Waker;
9use std::task::{Context, Poll};
10use std::{
11 io::{self, BufRead as _},
12 sync::Arc,
13};
14
15use rustls::{pki_types::ServerName, ClientConfig, ClientConnection};
16use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
17
18use crate::common::{IoSession, MidHandshake, Stream, TlsState};
19
20#[derive(Clone)]
22pub struct TlsConnector {
23 inner: Arc<ClientConfig>,
24 #[cfg(feature = "early-data")]
25 early_data: bool,
26}
27
28impl TlsConnector {
29 #[cfg(feature = "early-data")]
34 pub fn early_data(mut self, flag: bool) -> Self {
35 self.early_data = flag;
36 self
37 }
38
39 #[inline]
40 pub fn connect<IO>(&self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
41 where
42 IO: AsyncRead + AsyncWrite + Unpin,
43 {
44 self.connect_impl(domain, stream, None, |_| ())
45 }
46
47 #[inline]
48 pub fn connect_with<IO, F>(&self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
49 where
50 IO: AsyncRead + AsyncWrite + Unpin,
51 F: FnOnce(&mut ClientConnection),
52 {
53 self.connect_impl(domain, stream, None, f)
54 }
55
56 fn connect_impl<IO, F>(
57 &self,
58 domain: ServerName<'static>,
59 stream: IO,
60 alpn_protocols: Option<Vec<Vec<u8>>>,
61 f: F,
62 ) -> Connect<IO>
63 where
64 IO: AsyncRead + AsyncWrite + Unpin,
65 F: FnOnce(&mut ClientConnection),
66 {
67 let alpn = alpn_protocols.unwrap_or_else(|| self.inner.alpn_protocols.clone());
68 let mut session = match ClientConnection::new_with_alpn(self.inner.clone(), domain, alpn) {
69 Ok(session) => session,
70 Err(error) => {
71 return Connect(MidHandshake::Error {
72 io: stream,
73 error: io::Error::new(io::ErrorKind::Other, error),
76 });
77 }
78 };
79 f(&mut session);
80
81 Connect(MidHandshake::Handshaking(TlsStream {
82 io: stream,
83
84 #[cfg(not(feature = "early-data"))]
85 state: TlsState::Stream,
86
87 #[cfg(feature = "early-data")]
88 state: if self.early_data && session.early_data().is_some() {
89 TlsState::EarlyData(0, Vec::new())
90 } else {
91 TlsState::Stream
92 },
93
94 need_flush: false,
95
96 #[cfg(feature = "early-data")]
97 early_waker: None,
98
99 session,
100 }))
101 }
102
103 pub fn with_alpn(&self, alpn_protocols: Vec<Vec<u8>>) -> TlsConnectorWithAlpn<'_> {
104 TlsConnectorWithAlpn {
105 inner: self,
106 alpn_protocols,
107 }
108 }
109
110 pub fn config(&self) -> &Arc<ClientConfig> {
112 &self.inner
113 }
114}
115
116impl From<Arc<ClientConfig>> for TlsConnector {
117 fn from(inner: Arc<ClientConfig>) -> Self {
118 Self {
119 inner,
120 #[cfg(feature = "early-data")]
121 early_data: false,
122 }
123 }
124}
125
126pub struct TlsConnectorWithAlpn<'c> {
127 inner: &'c TlsConnector,
128 alpn_protocols: Vec<Vec<u8>>,
129}
130
131impl TlsConnectorWithAlpn<'_> {
132 #[inline]
133 pub fn connect<IO>(self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
134 where
135 IO: AsyncRead + AsyncWrite + Unpin,
136 {
137 self.inner
138 .connect_impl(domain, stream, Some(self.alpn_protocols), |_| ())
139 }
140
141 #[inline]
142 pub fn connect_with<IO, F>(self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
143 where
144 IO: AsyncRead + AsyncWrite + Unpin,
145 F: FnOnce(&mut ClientConnection),
146 {
147 self.inner
148 .connect_impl(domain, stream, Some(self.alpn_protocols), f)
149 }
150}
151
152pub struct Connect<IO>(MidHandshake<TlsStream<IO>>);
155
156impl<IO> Connect<IO> {
157 #[inline]
158 pub fn into_fallible(self) -> FallibleConnect<IO> {
159 FallibleConnect(self.0)
160 }
161
162 pub fn get_ref(&self) -> Option<&IO> {
163 match &self.0 {
164 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
165 MidHandshake::SendAlert { io, .. } => Some(io),
166 MidHandshake::Error { io, .. } => Some(io),
167 MidHandshake::End => None,
168 }
169 }
170
171 pub fn get_mut(&mut self) -> Option<&mut IO> {
172 match &mut self.0 {
173 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
174 MidHandshake::SendAlert { io, .. } => Some(io),
175 MidHandshake::Error { io, .. } => Some(io),
176 MidHandshake::End => None,
177 }
178 }
179}
180
181impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
182 type Output = io::Result<TlsStream<IO>>;
183
184 #[inline]
185 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
187 }
188}
189
190impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
191 type Output = Result<TlsStream<IO>, (io::Error, IO)>;
192
193 #[inline]
194 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
195 Pin::new(&mut self.0).poll(cx)
196 }
197}
198
199pub struct FallibleConnect<IO>(MidHandshake<TlsStream<IO>>);
201
202#[derive(Debug)]
205pub struct TlsStream<IO> {
206 pub(crate) io: IO,
207 pub(crate) session: ClientConnection,
208 pub(crate) state: TlsState,
209 pub(crate) need_flush: bool,
210
211 #[cfg(feature = "early-data")]
212 pub(crate) early_waker: Option<Waker>,
213}
214
215impl<IO> TlsStream<IO> {
216 #[inline]
217 pub fn get_ref(&self) -> (&IO, &ClientConnection) {
218 (&self.io, &self.session)
219 }
220
221 #[inline]
222 pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
223 (&mut self.io, &mut self.session)
224 }
225
226 #[inline]
227 pub fn into_inner(self) -> (IO, ClientConnection) {
228 (self.io, self.session)
229 }
230}
231
232#[cfg(unix)]
233impl<S> AsRawFd for TlsStream<S>
234where
235 S: AsRawFd,
236{
237 fn as_raw_fd(&self) -> RawFd {
238 self.get_ref().0.as_raw_fd()
239 }
240}
241
242#[cfg(windows)]
243impl<S> AsRawSocket for TlsStream<S>
244where
245 S: AsRawSocket,
246{
247 fn as_raw_socket(&self) -> RawSocket {
248 self.get_ref().0.as_raw_socket()
249 }
250}
251
252impl<IO> IoSession for TlsStream<IO> {
253 type Io = IO;
254 type Session = ClientConnection;
255
256 #[inline]
257 fn skip_handshake(&self) -> bool {
258 self.state.is_early_data()
259 }
260
261 #[inline]
262 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) {
263 (
264 &mut self.state,
265 &mut self.io,
266 &mut self.session,
267 &mut self.need_flush,
268 )
269 }
270
271 #[inline]
272 fn into_io(self) -> Self::Io {
273 self.io
274 }
275}
276
277#[cfg(feature = "early-data")]
278impl<IO> TlsStream<IO>
279where
280 IO: AsyncRead + AsyncWrite + Unpin,
281{
282 fn poll_early_data(&mut self, cx: &mut Context<'_>) {
283 if self
290 .early_waker
291 .as_ref()
292 .filter(|waker| cx.waker().will_wake(waker))
293 .is_none()
294 {
295 self.early_waker = Some(cx.waker().clone());
296 }
297 }
298}
299
300impl<IO> AsyncRead for TlsStream<IO>
301where
302 IO: AsyncRead + AsyncWrite + Unpin,
303{
304 fn poll_read(
305 mut self: Pin<&mut Self>,
306 cx: &mut Context<'_>,
307 buf: &mut ReadBuf<'_>,
308 ) -> Poll<io::Result<()>> {
309 let data = ready!(self.as_mut().poll_fill_buf(cx))?;
310 let len = data.len().min(buf.remaining());
311 buf.put_slice(&data[..len]);
312 self.consume(len);
313 Poll::Ready(Ok(()))
314 }
315}
316
317impl<IO> AsyncBufRead for TlsStream<IO>
318where
319 IO: AsyncRead + AsyncWrite + Unpin,
320{
321 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
322 match self.state {
323 #[cfg(feature = "early-data")]
324 TlsState::EarlyData(..) => {
325 self.get_mut().poll_early_data(cx);
326 Poll::Pending
327 }
328 TlsState::Stream | TlsState::WriteShutdown => {
329 let this = self.get_mut();
330 let stream =
331 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
332
333 match stream.poll_fill_buf(cx) {
334 Poll::Ready(Ok(buf)) => {
335 if buf.is_empty() {
336 this.state.shutdown_read();
337 }
338
339 Poll::Ready(Ok(buf))
340 }
341 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
342 this.state.shutdown_read();
343 Poll::Ready(Err(err))
344 }
345 output => output,
346 }
347 }
348 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])),
349 }
350 }
351
352 fn consume(mut self: Pin<&mut Self>, amt: usize) {
353 self.session.reader().consume(amt);
354 }
355}
356
357impl<IO> AsyncWrite for TlsStream<IO>
358where
359 IO: AsyncRead + AsyncWrite + Unpin,
360{
361 fn poll_write(
364 self: Pin<&mut Self>,
365 cx: &mut Context<'_>,
366 buf: &[u8],
367 ) -> Poll<io::Result<usize>> {
368 let this = self.get_mut();
369 let mut stream = Stream::new(&mut this.io, &mut this.session)
370 .set_eof(!this.state.readable())
371 .set_need_flush(this.need_flush);
372
373 #[cfg(feature = "early-data")]
374 {
375 let bufs = [io::IoSlice::new(buf)];
376 let written = poll_handle_early_data(
377 &mut this.state,
378 &mut stream,
379 &mut this.early_waker,
380 cx,
381 &bufs,
382 )?;
383 match written {
384 Poll::Ready(0) => {}
385 Poll::Ready(written) => return Poll::Ready(Ok(written)),
386 Poll::Pending => {
387 this.need_flush = stream.need_flush;
388 return Poll::Pending;
389 }
390 }
391 }
392
393 stream.as_mut_pin().poll_write(cx, buf)
394 }
395
396 fn poll_write_vectored(
399 self: Pin<&mut Self>,
400 cx: &mut Context<'_>,
401 bufs: &[io::IoSlice<'_>],
402 ) -> Poll<io::Result<usize>> {
403 let this = self.get_mut();
404 let mut stream = Stream::new(&mut this.io, &mut this.session)
405 .set_eof(!this.state.readable())
406 .set_need_flush(this.need_flush);
407
408 #[cfg(feature = "early-data")]
409 {
410 let written = poll_handle_early_data(
411 &mut this.state,
412 &mut stream,
413 &mut this.early_waker,
414 cx,
415 bufs,
416 )?;
417 match written {
418 Poll::Ready(0) => {}
419 Poll::Ready(written) => return Poll::Ready(Ok(written)),
420 Poll::Pending => {
421 this.need_flush = stream.need_flush;
422 return Poll::Pending;
423 }
424 }
425 }
426
427 stream.as_mut_pin().poll_write_vectored(cx, bufs)
428 }
429
430 #[inline]
431 fn is_write_vectored(&self) -> bool {
432 true
433 }
434
435 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
436 let this = self.get_mut();
437 let mut stream = Stream::new(&mut this.io, &mut this.session)
438 .set_eof(!this.state.readable())
439 .set_need_flush(this.need_flush);
440
441 #[cfg(feature = "early-data")]
442 {
443 let written = poll_handle_early_data(
444 &mut this.state,
445 &mut stream,
446 &mut this.early_waker,
447 cx,
448 &[],
449 )?;
450 if written.is_pending() {
451 this.need_flush = stream.need_flush;
452 return Poll::Pending;
453 }
454 }
455
456 stream.as_mut_pin().poll_flush(cx)
457 }
458
459 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
460 #[cfg(feature = "early-data")]
461 {
462 if matches!(self.state, TlsState::EarlyData(..)) {
464 ready!(self.as_mut().poll_flush(cx))?;
465 }
466 }
467
468 if self.state.writeable() {
469 self.session.send_close_notify();
470 self.state.shutdown_write();
471 }
472
473 let this = self.get_mut();
474 let mut stream =
475 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
476 stream.as_mut_pin().poll_shutdown(cx)
477 }
478}
479
480#[cfg(feature = "early-data")]
481fn poll_handle_early_data<IO>(
482 state: &mut TlsState,
483 stream: &mut Stream<IO, ClientConnection>,
484 early_waker: &mut Option<Waker>,
485 cx: &mut Context<'_>,
486 bufs: &[io::IoSlice<'_>],
487) -> Poll<io::Result<usize>>
488where
489 IO: AsyncRead + AsyncWrite + Unpin,
490{
491 if let TlsState::EarlyData(pos, data) = state {
492 use std::io::Write;
493
494 if let Some(mut early_data) = stream.session.early_data() {
496 let mut written = 0;
497
498 for buf in bufs {
499 if buf.is_empty() {
500 continue;
501 }
502
503 let len = match early_data.write(buf) {
504 Ok(0) => break,
505 Ok(n) => n,
506 Err(err) => return Poll::Ready(Err(err)),
507 };
508
509 written += len;
510 data.extend_from_slice(&buf[..len]);
511
512 if len < buf.len() {
513 break;
514 }
515 }
516
517 if written != 0 {
518 return Poll::Ready(Ok(written));
519 }
520 }
521
522 while stream.session.is_handshaking() {
524 ready!(stream.handshake(cx))?;
525 }
526
527 if !stream.session.is_early_data_accepted() {
529 while *pos < data.len() {
530 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
531 *pos += len;
532 }
533 }
534
535 *state = TlsState::Stream;
537
538 if let Some(waker) = early_waker.take() {
539 waker.wake();
540 }
541 }
542
543 Poll::Ready(Ok(0))
544}