1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
use std::fmt; use std::future::Future; use std::pin::Pin; use std::sync::{Arc, Mutex}; use futures::channel::oneshot; use futures::future::{Fuse, FusedFuture, FutureExt}; use futures::ready; use futures::task::{Context, Poll}; use pin_project_lite::pin_project; pub enum StopReason<T> { HandleLost, Requested(T), } impl<T> fmt::Display for StopReason<T> where T: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { StopReason::HandleLost => write!(f, "handle lost"), StopReason::Requested(r) => write!(f, "requested with reason `{}`", r), } } } impl<T> fmt::Debug for StopReason<T> where T: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { StopReason::HandleLost => write!(f, "HandleLost"), StopReason::Requested(r) => write!(f, "Requested({:?})", r), } } } impl<T> Clone for StopReason<T> where T: Clone, { fn clone(&self) -> Self { match self { StopReason::HandleLost => StopReason::HandleLost, StopReason::Requested(r) => StopReason::Requested(r.clone()), } } } pub struct StopHandle<T> { inner: Arc<Mutex<Option<oneshot::Sender<T>>>>, } impl<T> fmt::Debug for StopHandle<T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "StopHandle") } } impl<T> Clone for StopHandle<T> { fn clone(&self) -> Self { StopHandle { inner: Arc::clone(&self.inner), } } } impl<T> StopHandle<T> { pub fn stop(&self, reason: T) { if let Some(tx) = self.inner.lock().unwrap().take() { let _ = tx.send(reason); } } } pin_project! { pub struct StopWait<T> { #[pin] inner: Fuse<oneshot::Receiver<T>>, } } impl<T> FusedFuture for StopWait<T> { fn is_terminated(&self) -> bool { self.inner.is_terminated() } } impl<T> Future for StopWait<T> { type Output = StopReason<T>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let r = match ready!(Future::poll(self.project().inner, cx)) { Err(_) => StopReason::HandleLost, Ok(reason) => StopReason::Requested(reason), }; Poll::Ready(r) } } pub fn stop_handle<T>() -> (StopHandle<T>, StopWait<T>) { let (tx, rx) = oneshot::channel(); let stop_handle = StopHandle { inner: Arc::new(Mutex::new(Some(tx))), }; let stop_wait = StopWait { inner: rx.fuse() }; (stop_handle, stop_wait) } #[cfg(test)] mod tests { use std::time::Duration; use matches::assert_matches; use tokio::time::delay_for; use super::*; #[derive(Debug)] pub enum TerminationReason { Manual, } #[tokio::test] async fn test_stop_handle() { let (stop_handle, stop_wait) = stop_handle(); tokio::spawn(async move { delay_for(Duration::from_secs(1)).await; stop_handle.stop(TerminationReason::Manual); }); let res = stop_wait.await; assert_matches!(res, StopReason::Requested(TerminationReason::Manual)); } }