veilid_tools/must_join_single_future.rs
1use super::*;
2
3use core::task::Poll;
4use futures_util::poll;
5
6#[derive(Debug)]
7struct MustJoinSingleFutureInner<T>
8where
9 T: 'static,
10{
11 locked: bool,
12 join_handle: Option<MustJoinHandle<T>>,
13}
14
15/// Spawns a single background processing task idempotently, possibly returning the return value of the previously executed background task
16/// This does not queue, just ensures that no more than a single copy of the task is running at a time, but allowing tasks to be retriggered
17#[derive(Debug, Clone)]
18pub struct MustJoinSingleFuture<T>
19where
20 T: 'static,
21{
22 inner: Arc<Mutex<MustJoinSingleFutureInner<T>>>,
23}
24
25impl<T> Default for MustJoinSingleFuture<T>
26where
27 T: 'static,
28{
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl<T> MustJoinSingleFuture<T>
35where
36 T: 'static,
37{
38 pub fn new() -> Self {
39 Self {
40 inner: Arc::new(Mutex::new(MustJoinSingleFutureInner {
41 locked: false,
42 join_handle: None,
43 })),
44 }
45 }
46
47 fn try_lock(&self) -> Result<Option<MustJoinHandle<T>>, ()> {
48 let mut inner = self.inner.lock();
49 if inner.locked {
50 // If already locked error out
51 return Err(());
52 }
53 inner.locked = true;
54 // If we got the lock, return what we have for a join handle if anything
55 Ok(inner.join_handle.take())
56 }
57
58 fn unlock(&self, jh: Option<MustJoinHandle<T>>) {
59 let mut inner = self.inner.lock();
60 assert!(inner.locked);
61 assert!(inner.join_handle.is_none());
62 inner.locked = false;
63 inner.join_handle = jh;
64 }
65
66 /// Check the result and take it if there is one
67 #[cfg_attr(feature = "tracing", instrument(level = "trace", skip_all))]
68 pub async fn check(&self) -> Result<Option<T>, ()> {
69 let mut out: Option<T> = None;
70
71 // See if we have a result we can return
72 let maybe_jh = match self.try_lock() {
73 Ok(v) => v,
74 Err(_) => {
75 // If we are already polling somewhere else, don't hand back a result
76 return Err(());
77 }
78 };
79 if let Some(mut jh) = maybe_jh {
80 // See if we finished, if so, return the value of the last execution
81 if let Poll::Ready(r) = poll!(&mut jh) {
82 out = Some(r);
83 // Task finished, unlock with nothing
84 self.unlock(None);
85 } else {
86 // Still running put the join handle back so we can check on it later
87 self.unlock(Some(jh));
88 }
89 } else {
90 // No task, unlock with nothing
91 self.unlock(None);
92 }
93
94 // Return the prior result if we have one
95 Ok(out)
96 }
97
98 /// Wait for the result and take it
99 #[cfg_attr(feature = "tracing", instrument(level = "trace", skip_all))]
100 pub async fn join(&self) -> Result<Option<T>, ()> {
101 let mut out: Option<T> = None;
102
103 // See if we have a result we can return
104 let maybe_jh = match self.try_lock() {
105 Ok(v) => v,
106 Err(_) => {
107 // If we are already polling somewhere else,
108 // that's an error because you can only join
109 // these things once
110 return Err(());
111 }
112 };
113 if let Some(jh) = maybe_jh {
114 // Wait for return value of the last execution
115 out = Some(jh.await);
116 // Task finished, unlock with nothing
117 } else {
118 // No task, unlock with nothing
119 }
120 self.unlock(None);
121
122 // Return the prior result if we have one
123 Ok(out)
124 }
125
126 // Possibly spawn the future possibly returning the value of the last execution
127 pub async fn single_spawn_local(
128 &self,
129 name: &str,
130 future: impl Future<Output = T> + 'static,
131 ) -> Result<(Option<T>, bool), ()> {
132 let mut out: Option<T> = None;
133
134 // See if we have a result we can return
135 let maybe_jh = match self.try_lock() {
136 Ok(v) => v,
137 Err(_) => {
138 // If we are already polling somewhere else, don't hand back a result
139 return Err(());
140 }
141 };
142 let mut run = true;
143
144 if let Some(mut jh) = maybe_jh {
145 // See if we finished, if so, return the value of the last execution
146 if let Poll::Ready(r) = poll!(&mut jh) {
147 out = Some(r);
148 // Task finished, unlock with a new task
149 } else {
150 // Still running, don't run again, unlock with the current join handle
151 run = false;
152 self.unlock(Some(jh));
153 }
154 }
155
156 // Run if we should do that
157 if run {
158 self.unlock(Some(spawn_local(name, future)));
159 }
160
161 // Return the prior result if we have one
162 Ok((out, run))
163 }
164}
165
166impl<T> MustJoinSingleFuture<T>
167where
168 T: 'static + Send,
169{
170 pub async fn single_spawn(
171 &self,
172 name: &str,
173 future: impl Future<Output = T> + Send + 'static,
174 ) -> Result<(Option<T>, bool), ()> {
175 let mut out: Option<T> = None;
176 // See if we have a result we can return
177 let maybe_jh = match self.try_lock() {
178 Ok(v) => v,
179 Err(_) => {
180 // If we are already polling somewhere else, don't hand back a result
181 return Err(());
182 }
183 };
184 let mut run = true;
185 if let Some(mut jh) = maybe_jh {
186 // See if we finished, if so, return the value of the last execution
187 if let Poll::Ready(r) = poll!(&mut jh) {
188 out = Some(r);
189 // Task finished, unlock with a new task
190 } else {
191 // Still running, don't run again, unlock with the current join handle
192 run = false;
193 self.unlock(Some(jh));
194 }
195 }
196 // Run if we should do that
197 if run {
198 self.unlock(Some(spawn(name, future)));
199 }
200 // Return the prior result if we have one
201 Ok((out, run))
202 }
203}