ttpkit_utils/
io.rs

1//! IO utilities.
2
3use std::{
4    future::Future,
5    io::{self, IoSlice},
6    pin::Pin,
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use bytes::Bytes;
12use futures::{FutureExt, channel::oneshot};
13use tokio::{
14    io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf},
15    time::Sleep,
16};
17
18/// Connection builder.
19pub struct ConnectionBuilder {
20    read_timeout: Option<Duration>,
21    write_timeout: Option<Duration>,
22}
23
24impl ConnectionBuilder {
25    /// Create a new connection builder.
26    #[inline]
27    const fn new() -> Self {
28        Self {
29            read_timeout: None,
30            write_timeout: None,
31        }
32    }
33
34    /// Set read timeout.
35    #[inline]
36    pub const fn read_timeout(mut self, timeout: Option<Duration>) -> Self {
37        self.read_timeout = timeout;
38        self
39    }
40
41    /// Set write timeout.
42    #[inline]
43    pub const fn write_timeout(mut self, timeout: Option<Duration>) -> Self {
44        self.write_timeout = timeout;
45        self
46    }
47
48    /// Build the connection.
49    pub fn build<IO>(self, io: IO) -> Connection<IO> {
50        let context = ConnectionContext::new(self.read_timeout, self.write_timeout);
51
52        Connection {
53            inner: io,
54            buffer: PrependBuffer::new(),
55            context: Box::pin(context),
56        }
57    }
58}
59
60pin_project_lite::pin_project! {
61    /// Connection abstraction.
62    pub struct Connection<IO> {
63        #[pin]
64        inner: IO,
65        buffer: PrependBuffer,
66        context: Pin<Box<ConnectionContext>>,
67    }
68}
69
70impl Connection<()> {
71    /// Get a connection builder.
72    #[inline]
73    pub const fn builder() -> ConnectionBuilder {
74        ConnectionBuilder::new()
75    }
76}
77
78impl<IO> Connection<IO> {
79    /// Prepend given data and return it on the next poll.
80    #[inline]
81    pub fn prepend(mut self, item: Bytes) -> Self {
82        self.buffer.prepend(item);
83        self
84    }
85}
86
87impl<IO> Connection<IO>
88where
89    IO: AsyncRead + AsyncWrite,
90{
91    /// Split the connection into a reader half and a writer half.
92    pub fn split(mut self) -> (ConnectionReader<IO>, ConnectionWriter<IO>) {
93        let buffer = self.buffer.take();
94
95        let (r, w) = tokio::io::split(self);
96
97        let reader = ConnectionReader { inner: r, buffer };
98
99        let writer = ConnectionWriter { inner: w };
100
101        (reader, writer)
102    }
103}
104
105impl<IO> Connection<IO>
106where
107    IO: AsyncRead + AsyncWrite + Send + 'static,
108{
109    /// Repurpose the underlying connection.
110    ///
111    /// This method can be used for HTTP protocol upgrades.
112    pub fn upgrade(self) -> Upgraded {
113        Upgraded {
114            inner: Box::pin(self.inner),
115            buffer: self.buffer,
116        }
117    }
118}
119
120impl<IO> AsyncRead for Connection<IO>
121where
122    IO: AsyncRead,
123{
124    fn poll_read(
125        self: Pin<&mut Self>,
126        cx: &mut Context<'_>,
127        buf: &mut ReadBuf<'_>,
128    ) -> Poll<io::Result<()>> {
129        let this = self.project();
130
131        if !this.buffer.is_empty() {
132            // always read the data from the internal buffer first
133            this.buffer.read(buf);
134
135            return Poll::Ready(Ok(()));
136        }
137
138        let res = this.inner.poll_read(cx, buf);
139
140        if res.is_ready() {
141            this.context.as_mut().reset_read_timeout();
142        } else {
143            this.context.as_mut().check_read_timeout(cx)?;
144        }
145
146        res
147    }
148}
149
150impl<IO> AsyncWrite for Connection<IO>
151where
152    IO: AsyncWrite,
153{
154    fn poll_write(
155        self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        buf: &[u8],
158    ) -> Poll<io::Result<usize>> {
159        let this = self.project();
160
161        let res = this.inner.poll_write(cx, buf);
162
163        if res.is_ready() {
164            this.context.as_mut().reset_write_timeout();
165        } else {
166            this.context.as_mut().check_write_timeout(cx)?;
167        }
168
169        res
170    }
171
172    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173        let this = self.project();
174
175        let res = this.inner.poll_flush(cx);
176
177        if res.is_ready() {
178            this.context.as_mut().reset_write_timeout();
179        } else {
180            this.context.as_mut().check_write_timeout(cx)?;
181        }
182
183        res
184    }
185
186    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
187        let this = self.project();
188
189        let res = this.inner.poll_shutdown(cx);
190
191        if res.is_ready() {
192            this.context.as_mut().reset_write_timeout();
193        } else {
194            this.context.as_mut().check_write_timeout(cx)?;
195        }
196
197        res
198    }
199
200    fn poll_write_vectored(
201        self: Pin<&mut Self>,
202        cx: &mut Context<'_>,
203        bufs: &[IoSlice<'_>],
204    ) -> Poll<io::Result<usize>> {
205        let this = self.project();
206
207        let res = this.inner.poll_write_vectored(cx, bufs);
208
209        if res.is_ready() {
210            this.context.as_mut().reset_write_timeout();
211        } else {
212            this.context.as_mut().check_write_timeout(cx)?;
213        }
214
215        res
216    }
217
218    #[inline]
219    fn is_write_vectored(&self) -> bool {
220        self.inner.is_write_vectored()
221    }
222}
223
224pin_project_lite::pin_project! {
225    /// Helper struct to avoid duplicating non-generic functions.
226    struct ConnectionContext {
227        read_timeout: Option<Duration>,
228        write_timeout: Option<Duration>,
229        #[pin]
230        read_timeout_delay: Option<Sleep>,
231        #[pin]
232        write_timeout_delay: Option<Sleep>,
233    }
234}
235
236impl ConnectionContext {
237    /// Create a new connection context.
238    #[inline]
239    const fn new(read_timeout: Option<Duration>, write_timeout: Option<Duration>) -> Self {
240        Self {
241            read_timeout,
242            write_timeout,
243            read_timeout_delay: None,
244            write_timeout_delay: None,
245        }
246    }
247
248    /// Reset read timeout.
249    #[inline]
250    fn reset_read_timeout(self: Pin<&mut Self>) {
251        let mut this = self.project();
252
253        this.read_timeout_delay.set(None);
254    }
255
256    /// Check if the read timeout has elapsed.
257    fn check_read_timeout(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Result<()> {
258        let mut this = self.project();
259
260        if let Some(timeout) = *this.read_timeout {
261            if this.read_timeout_delay.is_none() {
262                this.read_timeout_delay
263                    .set(Some(tokio::time::sleep(timeout)));
264            }
265
266            if let Some(timeout) = this.read_timeout_delay.as_pin_mut() {
267                if timeout.poll(cx).is_ready() {
268                    return Err(io::Error::new(io::ErrorKind::TimedOut, "read timeout"));
269                }
270            }
271        }
272
273        Ok(())
274    }
275
276    /// Reset write timeout.
277    #[inline]
278    fn reset_write_timeout(self: Pin<&mut Self>) {
279        let mut this = self.project();
280
281        this.write_timeout_delay.set(None);
282    }
283
284    /// Check if the write timeout has elapsed.
285    fn check_write_timeout(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Result<()> {
286        let mut this = self.project();
287
288        if let Some(timeout) = *this.write_timeout {
289            if this.write_timeout_delay.is_none() {
290                this.write_timeout_delay
291                    .set(Some(tokio::time::sleep(timeout)));
292            }
293
294            if let Some(timeout) = this.write_timeout_delay.as_pin_mut() {
295                if timeout.poll(cx).is_ready() {
296                    return Err(io::Error::new(io::ErrorKind::TimedOut, "write timeout"));
297                }
298            }
299        }
300
301        Ok(())
302    }
303}
304
305/// Reader half of a connection.
306pub struct ConnectionReader<IO> {
307    inner: ReadHalf<Connection<IO>>,
308    buffer: PrependBuffer,
309}
310
311impl<IO> ConnectionReader<IO> {
312    /// Prepend given data and return it on the next poll.
313    #[inline]
314    pub fn prepend(mut self, item: Bytes) -> Self {
315        self.buffer.prepend(item);
316        self
317    }
318}
319
320impl<IO> ConnectionReader<IO>
321where
322    IO: Unpin,
323{
324    /// Join again with the other half of the connection.
325    pub fn join(self, writer: ConnectionWriter<IO>) -> Connection<IO> {
326        let mut connection = self.inner.unsplit(writer.inner);
327
328        connection.buffer = self.buffer;
329        connection
330    }
331}
332
333impl<IO> AsyncRead for ConnectionReader<IO>
334where
335    IO: AsyncRead,
336{
337    fn poll_read(
338        mut self: Pin<&mut Self>,
339        cx: &mut Context<'_>,
340        buf: &mut ReadBuf<'_>,
341    ) -> Poll<io::Result<()>> {
342        if !self.buffer.is_empty() {
343            // always read the data from the internal buffer first
344            self.buffer.read(buf);
345
346            return Poll::Ready(Ok(()));
347        }
348
349        let pinned = Pin::new(&mut self.inner);
350
351        pinned.poll_read(cx, buf)
352    }
353}
354
355/// Writer half of a connection.
356pub struct ConnectionWriter<IO> {
357    inner: WriteHalf<Connection<IO>>,
358}
359
360impl<IO> AsyncWrite for ConnectionWriter<IO>
361where
362    IO: AsyncWrite,
363{
364    #[inline]
365    fn poll_write(
366        mut self: Pin<&mut Self>,
367        cx: &mut Context<'_>,
368        buf: &[u8],
369    ) -> Poll<io::Result<usize>> {
370        AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf)
371    }
372
373    #[inline]
374    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
375        AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx)
376    }
377
378    #[inline]
379    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
380        AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx)
381    }
382
383    #[inline]
384    fn poll_write_vectored(
385        mut self: Pin<&mut Self>,
386        cx: &mut Context<'_>,
387        bufs: &[IoSlice<'_>],
388    ) -> Poll<io::Result<usize>> {
389        AsyncWrite::poll_write_vectored(Pin::new(&mut self.inner), cx, bufs)
390    }
391
392    #[inline]
393    fn is_write_vectored(&self) -> bool {
394        self.inner.is_write_vectored()
395    }
396}
397
398/// Helper trait.
399trait AsyncReadWrite: AsyncRead + AsyncWrite {}
400
401impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite {}
402
403/// Upgraded connection.
404pub struct Upgraded {
405    inner: Pin<Box<dyn AsyncReadWrite + Send>>,
406    buffer: PrependBuffer,
407}
408
409impl AsyncRead for Upgraded {
410    fn poll_read(
411        mut self: Pin<&mut Self>,
412        cx: &mut Context<'_>,
413        buf: &mut ReadBuf<'_>,
414    ) -> Poll<io::Result<()>> {
415        if !self.buffer.is_empty() {
416            // always read the data from the internal buffer first
417            self.buffer.read(buf);
418
419            return Poll::Ready(Ok(()));
420        }
421
422        let pinned = Pin::new(&mut self.inner);
423
424        pinned.poll_read(cx, buf)
425    }
426}
427
428impl AsyncWrite for Upgraded {
429    #[inline]
430    fn poll_write(
431        mut self: Pin<&mut Self>,
432        cx: &mut Context<'_>,
433        buf: &[u8],
434    ) -> Poll<io::Result<usize>> {
435        AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf)
436    }
437
438    #[inline]
439    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
440        AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx)
441    }
442
443    #[inline]
444    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
445        AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx)
446    }
447
448    #[inline]
449    fn poll_write_vectored(
450        mut self: Pin<&mut Self>,
451        cx: &mut Context<'_>,
452        bufs: &[IoSlice<'_>],
453    ) -> Poll<io::Result<usize>> {
454        AsyncWrite::poll_write_vectored(Pin::new(&mut self.inner), cx, bufs)
455    }
456
457    #[inline]
458    fn is_write_vectored(&self) -> bool {
459        self.inner.is_write_vectored()
460    }
461}
462
463/// Future upgraded connection.
464pub struct UpgradeFuture {
465    inner: oneshot::Receiver<Upgraded>,
466}
467
468impl UpgradeFuture {
469    /// Create a new future-request pair for a new connection upgrade.
470    pub fn new() -> (Self, UpgradeRequest) {
471        let (tx, rx) = oneshot::channel();
472
473        let tx = UpgradeRequest { inner: tx };
474        let rx = Self { inner: rx };
475
476        (rx, tx)
477    }
478}
479
480impl Future for UpgradeFuture {
481    type Output = io::Result<Upgraded>;
482
483    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
484        self.inner
485            .poll_unpin(cx)
486            .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))
487    }
488}
489
490/// Connection upgrade request.
491pub struct UpgradeRequest {
492    inner: oneshot::Sender<Upgraded>,
493}
494
495impl UpgradeRequest {
496    /// Resolve the request.
497    pub fn resolve(self, connection: Upgraded) {
498        let _ = self.inner.send(connection);
499    }
500}
501
502/// Connection prepend buffer.
503struct PrependBuffer {
504    inner: Vec<Bytes>,
505}
506
507impl PrependBuffer {
508    /// Create a new prepend buffer.
509    #[inline]
510    const fn new() -> Self {
511        Self { inner: Vec::new() }
512    }
513
514    /// Prepend given data.
515    fn prepend(&mut self, item: Bytes) {
516        if !item.is_empty() {
517            self.inner.push(item);
518        }
519    }
520
521    /// Read data from the buffer into a given `ReadBuf`.
522    fn read(&mut self, buf: &mut ReadBuf<'_>) {
523        if let Some(chunk) = self.inner.last_mut() {
524            let available = chunk.len();
525
526            let take = available.min(buf.remaining());
527
528            buf.put_slice(&chunk.split_to(take));
529
530            if chunk.is_empty() {
531                self.inner.pop();
532            }
533        }
534    }
535
536    /// Take the buffered data.
537    #[inline]
538    fn take(&mut self) -> Self {
539        Self {
540            inner: std::mem::take(&mut self.inner),
541        }
542    }
543
544    /// Check if the buffer is empty.
545    #[inline]
546    fn is_empty(&self) -> bool {
547        self.inner.is_empty()
548    }
549}