tycho_core/block_strider/subscriber/
futures.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5
6use anyhow::Result;
7use futures_util::StreamExt;
8use futures_util::stream::FuturesUnordered;
9use tycho_util::futures::JoinTask;
10
11#[derive(Clone)]
12pub struct DelayedTasks {
13 inner: Arc<DelayedTasksInner>,
14}
15
16impl DelayedTasks {
17 pub fn new() -> (DelayedTasksSpawner, Self) {
18 let inner = Arc::new(DelayedTasksInner {
19 state: Mutex::new(DelayedTasksState::BeforeSpawn {
20 make_fns: Vec::new(),
21 }),
22 });
23 let handle = DelayedTasksSpawner {
24 inner: inner.clone(),
25 };
26 (handle, Self { inner })
27 }
28
29 pub fn spawn<F, Fut>(&self, f: F) -> Result<()>
30 where
31 F: FnOnce() -> Fut + Send + 'static,
32 Fut: Future<Output = Result<()>> + Send + 'static,
33 {
34 let mut inner = self.inner.state.lock().unwrap();
35 match &mut *inner {
36 DelayedTasksState::BeforeSpawn { make_fns } => {
37 make_fns.push(Box::new(move || JoinTask::new(f())));
38 Ok(())
39 }
40 DelayedTasksState::AfterSpawn { tasks } => {
41 tasks.push(JoinTask::new(f()));
42 Ok(())
43 }
44 DelayedTasksState::Closed => anyhow::bail!("delayed tasks context closed"),
45 }
46 }
47}
48
49pub struct DelayedTasksSpawner {
50 inner: Arc<DelayedTasksInner>,
51}
52
53impl DelayedTasksSpawner {
54 pub fn spawn(self) -> DelayedTasksJoinHandle {
55 {
56 let mut state = self.inner.state.lock().unwrap();
57 let make_fns = match &mut *state {
58 DelayedTasksState::BeforeSpawn { make_fns } => std::mem::take(make_fns),
59 DelayedTasksState::AfterSpawn { .. } | DelayedTasksState::Closed => {
60 unreachable!("spawn can only be called once");
61 }
62 };
63 *state = DelayedTasksState::AfterSpawn {
64 tasks: make_fns.into_iter().map(|f| f()).collect(),
65 }
66 };
67
68 DelayedTasksJoinHandle { inner: self.inner }
69 }
70}
71
72pub struct DelayedTasksJoinHandle {
73 inner: Arc<DelayedTasksInner>,
74}
75
76impl DelayedTasksJoinHandle {
77 pub async fn join(self) -> Result<()> {
78 let mut tasks = {
79 let mut state = self.inner.state.lock().unwrap();
80 match std::mem::replace(&mut *state, DelayedTasksState::Closed) {
81 DelayedTasksState::AfterSpawn { tasks } => tasks,
82 DelayedTasksState::BeforeSpawn { .. } | DelayedTasksState::Closed => {
83 unreachable!("join can only be called once");
84 }
85 }
86 };
87
88 while let Some(res) = tasks.next().await {
89 res?;
90 }
91 Ok(())
92 }
93}
94
95struct DelayedTasksInner {
96 state: Mutex<DelayedTasksState>,
97}
98
99enum DelayedTasksState {
100 BeforeSpawn {
101 make_fns: Vec<MakeTaskFn>,
102 },
103 AfterSpawn {
104 tasks: FuturesUnordered<JoinTask<Result<()>>>,
105 },
106 Closed,
107}
108
109type MakeTaskFn = Box<dyn FnOnce() -> JoinTask<Result<()>> + Send + 'static>;
110
111pin_project_lite::pin_project! {
112 pub struct OptionPrepareFut<F> {
113 #[pin]
114 inner: Option<F>,
115 }
116}
117
118impl<F> From<Option<F>> for OptionPrepareFut<F> {
119 #[inline]
120 fn from(inner: Option<F>) -> Self {
121 Self { inner }
122 }
123}
124
125impl<F, T, E> Future for OptionPrepareFut<F>
126where
127 F: Future<Output = Result<T, E>>,
128{
129 type Output = Result<Option<T>, E>;
130
131 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132 match self.project().inner.as_pin_mut() {
133 Some(f) => match f.poll(cx) {
134 Poll::Ready(Ok(res)) => Poll::Ready(Ok(Some(res))),
135 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
136 Poll::Pending => Poll::Pending,
137 },
138 None => Poll::Ready(Ok(None)),
139 }
140 }
141}
142
143pin_project_lite::pin_project! {
144 pub struct OptionHandleFut<F> {
145 #[pin]
146 inner: Option<F>,
147 }
148}
149
150impl<F> From<Option<F>> for OptionHandleFut<F> {
151 #[inline]
152 fn from(inner: Option<F>) -> Self {
153 Self { inner }
154 }
155}
156
157impl<F, T, E> Future for OptionHandleFut<F>
158where
159 F: Future<Output = Result<T, E>>,
160 T: Default,
161{
162 type Output = F::Output;
163
164 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
165 match self.project().inner.as_pin_mut() {
166 Some(f) => f.poll(cx),
167 None => Poll::Ready(Ok(T::default())),
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use std::pin::pin;
175
176 use futures_util::FutureExt;
177
178 use super::*;
179
180 #[tokio::test]
181 async fn delayed_tasks() -> anyhow::Result<()> {
182 Ok(())
183 }
184
185 #[tokio::test]
186 async fn option_futures() {
187 type NoopFut = futures_util::future::Ready<Result<(), ()>>;
188
189 let resolved = OptionPrepareFut::from(None::<NoopFut>);
191 assert_eq!(resolved.now_or_never().unwrap(), Ok(None));
192
193 let mut resolved = pin!(OptionPrepareFut::from(Some(async {
194 tokio::task::yield_now().await;
195 Ok::<_, ()>(())
196 })));
197 assert_eq!(futures_util::poll!(&mut resolved), Poll::Pending);
198 assert_eq!(
199 futures_util::poll!(&mut resolved),
200 Poll::Ready(Ok(Some(())))
201 );
202
203 let resolved = OptionHandleFut::from(None::<NoopFut>);
205 assert_eq!(resolved.now_or_never().unwrap(), Ok(()));
206
207 let mut resolved = pin!(OptionHandleFut::from(Some(async {
208 tokio::task::yield_now().await;
209 Ok::<_, ()>(())
210 })));
211 assert_eq!(futures_util::poll!(&mut resolved), Poll::Pending);
212 assert_eq!(futures_util::poll!(&mut resolved), Poll::Ready(Ok(())));
213 }
214}