Skip to main content

swansong/implementation/
guarded.rs

1use super::Inner;
2use crate::Guard;
3use futures_core::Stream;
4use std::{
5    future::Future,
6    ops::{Deref, DerefMut},
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10};
11
12pin_project_lite::pin_project! {
13    /// Guarded is a convenient way to attach a [`Guard`] to another type.
14    ///
15    /// Guarded does not stop the wrapped type on shutdown, but will delay shutdown until it is
16    /// dropped. To stop the wrapped type, use
17    /// [`Swansong::interrupt`][crate::Swansong::interrupt]. To both stop the wrapped type and
18    /// also act as a guard, use [`Interrupt::guarded`][crate::Interrupt::guarded].
19    ///
20    /// Guarded implements Future, Stream, Clone, Debug, AsyncRead, and AsyncWrite when the wrapped
21    /// type also does.
22    ///
23    /// Guarded implements [`Deref`] and [`DerefMut`] to the wrapped type.
24    #[derive(Clone, Debug, PartialEq, Eq)]
25    pub struct Guarded<T> {
26        guard: Guard,
27        #[pin]
28        wrapped_type: T
29    }
30}
31
32impl<T> Guarded<T> {
33    pub(crate) fn new(inner: &Arc<Inner>, wrapped_type: T) -> Self {
34        Self {
35            guard: Guard::new(inner),
36            wrapped_type,
37        }
38    }
39
40    /// Transform this `Guarded<T>` into the inner `T`, dropping the [`Guard`] in the process.
41    ///
42    /// Doing this allows shutdown to proceed if no other guards exist and shutdown is initiated.
43    pub fn into_inner(self) -> T {
44        self.wrapped_type
45    }
46}
47
48impl<T: Future> Future for Guarded<T> {
49    type Output = T::Output;
50
51    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
52        self.project().wrapped_type.poll(cx)
53    }
54}
55
56impl<T: Stream> Stream for Guarded<T> {
57    type Item = T::Item;
58
59    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60        self.project().wrapped_type.poll_next(cx)
61    }
62
63    fn size_hint(&self) -> (usize, Option<usize>) {
64        self.wrapped_type.size_hint()
65    }
66}
67
68impl<T> Deref for Guarded<T> {
69    type Target = T;
70
71    fn deref(&self) -> &Self::Target {
72        &self.wrapped_type
73    }
74}
75
76impl<T> DerefMut for Guarded<T> {
77    fn deref_mut(&mut self) -> &mut Self::Target {
78        &mut self.wrapped_type
79    }
80}
81
82#[cfg(feature = "futures-io")]
83impl<T: futures_io::AsyncRead> futures_io::AsyncRead for Guarded<T> {
84    fn poll_read(
85        self: Pin<&mut Self>,
86        cx: &mut Context<'_>,
87        buf: &mut [u8],
88    ) -> Poll<std::io::Result<usize>> {
89        self.project().wrapped_type.poll_read(cx, buf)
90    }
91
92    fn poll_read_vectored(
93        self: Pin<&mut Self>,
94        cx: &mut Context<'_>,
95        bufs: &mut [std::io::IoSliceMut<'_>],
96    ) -> Poll<std::io::Result<usize>> {
97        self.project().wrapped_type.poll_read_vectored(cx, bufs)
98    }
99}
100
101#[cfg(feature = "futures-io")]
102impl<T: futures_io::AsyncWrite> futures_io::AsyncWrite for Guarded<T> {
103    fn poll_write(
104        self: Pin<&mut Self>,
105        cx: &mut Context<'_>,
106        buf: &[u8],
107    ) -> Poll<std::io::Result<usize>> {
108        self.project().wrapped_type.poll_write(cx, buf)
109    }
110
111    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
112        self.project().wrapped_type.poll_flush(cx)
113    }
114
115    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
116        self.project().wrapped_type.poll_close(cx)
117    }
118
119    fn poll_write_vectored(
120        self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122        bufs: &[std::io::IoSlice<'_>],
123    ) -> Poll<std::io::Result<usize>> {
124        self.project().wrapped_type.poll_write_vectored(cx, bufs)
125    }
126}
127
128#[cfg(feature = "futures-io")]
129impl<T: futures_io::AsyncBufRead> futures_io::AsyncBufRead for Guarded<T> {
130    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
131        self.project().wrapped_type.poll_fill_buf(cx)
132    }
133
134    fn consume(self: Pin<&mut Self>, amt: usize) {
135        self.project().wrapped_type.consume(amt);
136    }
137}
138
139#[cfg(feature = "tokio")]
140impl<T: tokio::io::AsyncRead> tokio::io::AsyncRead for Guarded<T> {
141    fn poll_read(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144        buf: &mut tokio::io::ReadBuf<'_>,
145    ) -> Poll<std::io::Result<()>> {
146        self.project().wrapped_type.poll_read(cx, buf)
147    }
148}
149#[cfg(feature = "tokio")]
150impl<T: tokio::io::AsyncWrite> tokio::io::AsyncWrite for Guarded<T> {
151    fn poll_write(
152        self: Pin<&mut Self>,
153        cx: &mut Context<'_>,
154        buf: &[u8],
155    ) -> Poll<Result<usize, std::io::Error>> {
156        self.project().wrapped_type.poll_write(cx, buf)
157    }
158
159    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
160        self.project().wrapped_type.poll_flush(cx)
161    }
162
163    fn poll_shutdown(
164        self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166    ) -> Poll<Result<(), std::io::Error>> {
167        self.project().wrapped_type.poll_shutdown(cx)
168    }
169
170    fn poll_write_vectored(
171        self: Pin<&mut Self>,
172        cx: &mut Context<'_>,
173        bufs: &[std::io::IoSlice<'_>],
174    ) -> Poll<Result<usize, std::io::Error>> {
175        self.project().wrapped_type.poll_write_vectored(cx, bufs)
176    }
177
178    fn is_write_vectored(&self) -> bool {
179        self.wrapped_type.is_write_vectored()
180    }
181}
182
183#[cfg(feature = "tokio")]
184impl<T: tokio::io::AsyncBufRead> tokio::io::AsyncBufRead for Guarded<T> {
185    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
186        self.project().wrapped_type.poll_fill_buf(cx)
187    }
188
189    fn consume(self: Pin<&mut Self>, amt: usize) {
190        self.project().wrapped_type.consume(amt);
191    }
192}