1use errno::{set_errno, Errno};
5use s2n_tls::{
6 config::Config,
7 connection::{Builder, Connection},
8 enums::{Blinding, CallbackResult, Mode},
9 error::Error,
10};
11use std::{
12 fmt,
13 future::Future,
14 io,
15 os::raw::{c_int, c_void},
16 pin::Pin,
17 task::{
18 Context, Poll,
19 Poll::{Pending, Ready},
20 },
21};
22use tokio::{
23 io::{AsyncRead, AsyncWrite, ReadBuf},
24 time::{sleep, Duration, Sleep},
25};
26
27mod task;
29use task::waker::debug_assert_contract as debug_assert_waker_contract;
30
31macro_rules! ready {
32 ($x:expr) => {
33 match $x {
34 Ready(r) => r,
35 Pending => return Pending,
36 }
37 };
38}
39
40#[derive(Clone)]
41pub struct TlsAcceptor<B: Builder = Config>
42where
43 <B as Builder>::Output: Unpin,
44{
45 builder: B,
46}
47
48impl<B: Builder> TlsAcceptor<B>
49where
50 <B as Builder>::Output: Unpin,
51{
52 pub fn new(builder: B) -> Self {
53 TlsAcceptor { builder }
54 }
55
56 pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S, B::Output>, Error>
57 where
58 S: AsyncRead + AsyncWrite + Unpin,
59 {
60 let conn = self.builder.build_connection(Mode::Server)?;
61 TlsStream::open(conn, stream).await
62 }
63}
64
65#[derive(Clone)]
66pub struct TlsConnector<B: Builder = Config>
67where
68 <B as Builder>::Output: Unpin,
69{
70 builder: B,
71}
72
73impl<B: Builder> TlsConnector<B>
74where
75 <B as Builder>::Output: Unpin,
76{
77 pub fn new(builder: B) -> Self {
78 TlsConnector { builder }
79 }
80
81 pub async fn connect<S>(
82 &self,
83 domain: &str,
84 stream: S,
85 ) -> Result<TlsStream<S, B::Output>, Error>
86 where
87 S: AsyncRead + AsyncWrite + Unpin,
88 {
89 let mut conn = self.builder.build_connection(Mode::Client)?;
90 conn.as_mut().set_server_name(domain)?;
91 TlsStream::open(conn, stream).await
92 }
93}
94
95struct TlsHandshake<'a, S, C>
96where
97 C: AsRef<Connection> + AsMut<Connection> + Unpin,
98 S: AsyncRead + AsyncWrite + Unpin,
99{
100 tls: &'a mut TlsStream<S, C>,
101 error: Option<Error>,
102}
103
104impl<S, C> Future for TlsHandshake<'_, S, C>
105where
106 C: AsRef<Connection> + AsMut<Connection> + Unpin,
107 S: AsyncRead + AsyncWrite + Unpin,
108{
109 type Output = Result<(), Error>;
110
111 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
112 debug_assert_waker_contract(ctx, |ctx| {
113 let result = match self.error.take() {
118 Some(err) => Err(err),
119 None => {
120 let handshake_poll = self.tls.with_io(ctx, |context| {
121 let conn = context.get_mut().as_mut();
122 conn.poll_negotiate().map(|r| r.map(|_| ()))
123 });
124 ready!(handshake_poll)
125 }
126 };
127 match result {
134 Ok(r) => Ok(r).into(),
135 Err(e) if e.is_retryable() => Err(e).into(),
136 Err(e) => match Pin::new(&mut self.tls).poll_shutdown(ctx) {
137 Pending => {
138 self.error = Some(e);
139 Pending
140 }
141 Ready(_) => Err(e).into(),
142 },
143 }
144 })
145 }
146}
147
148pub struct TlsStream<S, C = Connection>
149where
150 C: AsRef<Connection> + AsMut<Connection> + Unpin,
151 S: AsyncRead + AsyncWrite + Unpin,
152{
153 conn: C,
154 stream: S,
155 blinding: Option<Pin<Box<Sleep>>>,
156 shutdown_error: Option<Error>,
157}
158
159impl<S, C> TlsStream<S, C>
160where
161 C: AsRef<Connection> + AsMut<Connection> + Unpin,
162 S: AsyncRead + AsyncWrite + Unpin,
163{
164 pub fn get_ref(&self) -> &S {
166 &self.stream
167 }
168
169 pub fn get_mut(&mut self) -> &mut S {
171 &mut self.stream
172 }
173
174 async fn open(conn: C, stream: S) -> Result<Self, Error> {
175 let mut tls = TlsStream {
176 conn,
177 stream,
178 blinding: None,
179 shutdown_error: None,
180 };
181 TlsHandshake {
182 tls: &mut tls,
183 error: None,
184 }
185 .await?;
186 Ok(tls)
187 }
188
189 fn with_io<F, R>(&mut self, ctx: &mut Context, action: F) -> Poll<Result<R, Error>>
190 where
191 F: FnOnce(Pin<&mut Self>) -> Poll<Result<R, Error>>,
192 {
193 unsafe {
198 let context = self as *mut Self as *mut c_void;
199
200 self.as_mut().set_receive_callback(Some(Self::recv_io_cb))?;
201 self.as_mut().set_send_callback(Some(Self::send_io_cb))?;
202 self.as_mut().set_receive_context(context)?;
203 self.as_mut().set_send_context(context)?;
204 self.as_mut().set_waker(Some(ctx.waker()))?;
205 self.as_mut().set_blinding(Blinding::SelfService)?;
206
207 let result = action(Pin::new(self));
208
209 self.as_mut().set_receive_callback(None)?;
210 self.as_mut().set_send_callback(None)?;
211 self.as_mut().set_receive_context(std::ptr::null_mut())?;
212 self.as_mut().set_send_context(std::ptr::null_mut())?;
213 self.as_mut().set_waker(None)?;
214 result
215 }
216 }
217
218 fn poll_io<F>(ctx: *mut c_void, action: F) -> c_int
219 where
220 F: FnOnce(Pin<&mut S>, &mut Context) -> Poll<Result<usize, std::io::Error>>,
221 {
222 debug_assert_ne!(ctx, std::ptr::null_mut());
223 let tls = unsafe { &mut *(ctx as *mut Self) };
224
225 let mut async_context = Context::from_waker(tls.conn.as_ref().waker().unwrap());
226 let stream = Pin::new(&mut tls.stream);
227
228 let res = debug_assert_waker_contract(&mut async_context, |async_context| {
229 action(stream, async_context)
230 });
231
232 match res {
233 Poll::Ready(Ok(len)) => len as c_int,
234 Poll::Pending => {
235 set_errno(Errno(libc::EWOULDBLOCK));
236 CallbackResult::Failure.into()
237 }
238 _ => CallbackResult::Failure.into(),
239 }
240 }
241
242 unsafe extern "C" fn recv_io_cb(ctx: *mut c_void, buf: *mut u8, len: u32) -> c_int {
243 Self::poll_io(ctx, |stream, async_context| {
244 let mut dest = ReadBuf::new(std::slice::from_raw_parts_mut(buf, len as usize));
245 stream
246 .poll_read(async_context, &mut dest)
247 .map_ok(|_| dest.filled().len())
248 })
249 }
250
251 unsafe extern "C" fn send_io_cb(ctx: *mut c_void, buf: *const u8, len: u32) -> c_int {
252 Self::poll_io(ctx, |stream, async_context| {
253 let src = std::slice::from_raw_parts(buf, len as usize);
254 stream.poll_write(async_context, src)
255 })
256 }
257
258 pub fn poll_blinding(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Error>> {
271 debug_assert_waker_contract(ctx, |ctx| {
272 let tls = self.get_mut();
273
274 if tls.blinding.is_none() {
275 let delay = tls.as_ref().remaining_blinding_delay()?;
276 if !delay.is_zero() {
277 let safety = Duration::from_millis(1);
280 tls.blinding = Some(Box::pin(sleep(delay.saturating_add(safety))));
281 }
282 };
283
284 if let Some(timer) = tls.blinding.as_mut() {
285 ready!(timer.as_mut().poll(ctx));
286 tls.blinding = None;
287 }
288
289 Poll::Ready(Ok(()))
290 })
291 }
292
293 pub async fn apply_blinding(&mut self) -> Result<(), Error> {
294 ApplyBlinding { stream: self }.await
295 }
296}
297
298impl<S, C> AsRef<Connection> for TlsStream<S, C>
299where
300 C: AsRef<Connection> + AsMut<Connection> + Unpin,
301 S: AsyncRead + AsyncWrite + Unpin,
302{
303 fn as_ref(&self) -> &Connection {
304 self.conn.as_ref()
305 }
306}
307
308impl<S, C> AsMut<Connection> for TlsStream<S, C>
309where
310 C: AsRef<Connection> + AsMut<Connection> + Unpin,
311 S: AsyncRead + AsyncWrite + Unpin,
312{
313 fn as_mut(&mut self) -> &mut Connection {
314 self.conn.as_mut()
315 }
316}
317
318impl<S, C> AsyncRead for TlsStream<S, C>
319where
320 C: AsRef<Connection> + AsMut<Connection> + Unpin,
321 S: AsyncRead + AsyncWrite + Unpin,
322{
323 fn poll_read(
324 self: Pin<&mut Self>,
325 ctx: &mut Context<'_>,
326 buf: &mut ReadBuf<'_>,
327 ) -> Poll<io::Result<()>> {
328 let tls = self.get_mut();
329 tls.with_io(ctx, |mut context| {
330 context
331 .conn
332 .as_mut()
333 .poll_recv_uninitialized(unsafe { buf.unfilled_mut() })
336 .map_ok(|size| {
337 unsafe {
338 buf.assume_init(size);
342 }
343 buf.advance(size);
344 })
345 })
346 .map_err(io::Error::from)
347 }
348}
349
350impl<S, C> AsyncWrite for TlsStream<S, C>
351where
352 C: AsRef<Connection> + AsMut<Connection> + Unpin,
353 S: AsyncRead + AsyncWrite + Unpin,
354{
355 fn poll_write(
356 self: Pin<&mut Self>,
357 ctx: &mut Context<'_>,
358 buf: &[u8],
359 ) -> Poll<io::Result<usize>> {
360 let tls = self.get_mut();
361 tls.with_io(ctx, |mut context| context.conn.as_mut().poll_send(buf))
362 .map_err(io::Error::from)
363 }
364
365 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
366 let tls = self.get_mut();
367
368 ready!(tls.with_io(ctx, |mut context| {
369 context.conn.as_mut().poll_flush().map(|r| r.map(|_| ()))
370 }))
371 .map_err(io::Error::from)?;
372
373 Pin::new(&mut tls.stream).poll_flush(ctx)
374 }
375
376 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
377 debug_assert_waker_contract(ctx, |ctx| {
378 ready!(self.as_mut().poll_blinding(ctx))?;
379
380 if self.shutdown_error.is_none() {
382 let result = ready!(self.as_mut().with_io(ctx, |mut context| {
383 context
384 .conn
385 .as_mut()
386 .poll_shutdown_send()
387 .map(|r| r.map(|_| ()))
388 }));
389 if let Err(error) = result {
390 self.shutdown_error = Some(error);
391 }
394 };
395
396 let tcp_result = ready!(Pin::new(&mut self.as_mut().stream).poll_shutdown(ctx));
397
398 if let Some(err) = self.shutdown_error.take() {
399 let next_error = Error::application("Shutdown called again after error".into());
405 self.shutdown_error = Some(next_error);
406
407 Ready(Err(io::Error::from(err)))
408 } else {
409 Ready(tcp_result)
410 }
411 })
412 }
413}
414
415impl<S, C> fmt::Debug for TlsStream<S, C>
416where
417 C: AsRef<Connection> + AsMut<Connection> + Unpin,
418 S: AsyncRead + AsyncWrite + Unpin,
419{
420 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
421 f.debug_struct("TlsStream")
422 .field("connection", self.as_ref())
423 .finish()
424 }
425}
426
427struct ApplyBlinding<'a, S, C>
428where
429 C: AsRef<Connection> + AsMut<Connection> + Unpin,
430 S: AsyncRead + AsyncWrite + Unpin,
431{
432 stream: &'a mut TlsStream<S, C>,
433}
434
435impl<S, C> Future for ApplyBlinding<'_, S, C>
436where
437 C: AsRef<Connection> + AsMut<Connection> + Unpin,
438 S: AsyncRead + AsyncWrite + Unpin,
439{
440 type Output = Result<(), Error>;
441
442 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
443 Pin::new(&mut *self.as_mut().stream).poll_blinding(ctx)
444 }
445}