Skip to main content

tycho_core/block_strider/subscriber/
futures.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5
6use anyhow::Result;
7use futures_util::StreamExt;
8use futures_util::stream::FuturesUnordered;
9use tycho_util::futures::JoinTask;
10
11#[derive(Clone)]
12pub struct DelayedTasks {
13    inner: Arc<DelayedTasksInner>,
14}
15
16impl DelayedTasks {
17    pub fn new() -> (DelayedTasksSpawner, Self) {
18        let inner = Arc::new(DelayedTasksInner {
19            state: Mutex::new(DelayedTasksState::BeforeSpawn {
20                make_fns: Vec::new(),
21            }),
22        });
23        let handle = DelayedTasksSpawner {
24            inner: inner.clone(),
25        };
26        (handle, Self { inner })
27    }
28
29    pub fn spawn<F, Fut>(&self, f: F) -> Result<()>
30    where
31        F: FnOnce() -> Fut + Send + 'static,
32        Fut: Future<Output = Result<()>> + Send + 'static,
33    {
34        let mut inner = self.inner.state.lock().unwrap();
35        match &mut *inner {
36            DelayedTasksState::BeforeSpawn { make_fns } => {
37                make_fns.push(Box::new(move || JoinTask::new(f())));
38                Ok(())
39            }
40            DelayedTasksState::AfterSpawn { tasks } => {
41                tasks.push(JoinTask::new(f()));
42                Ok(())
43            }
44            DelayedTasksState::Closed => anyhow::bail!("delayed tasks context closed"),
45        }
46    }
47}
48
49pub struct DelayedTasksSpawner {
50    inner: Arc<DelayedTasksInner>,
51}
52
53impl DelayedTasksSpawner {
54    pub fn spawn(self) -> DelayedTasksJoinHandle {
55        {
56            let mut state = self.inner.state.lock().unwrap();
57            let make_fns = match &mut *state {
58                DelayedTasksState::BeforeSpawn { make_fns } => std::mem::take(make_fns),
59                DelayedTasksState::AfterSpawn { .. } | DelayedTasksState::Closed => {
60                    unreachable!("spawn can only be called once");
61                }
62            };
63            *state = DelayedTasksState::AfterSpawn {
64                tasks: make_fns.into_iter().map(|f| f()).collect(),
65            }
66        };
67
68        DelayedTasksJoinHandle { inner: self.inner }
69    }
70}
71
72pub struct DelayedTasksJoinHandle {
73    inner: Arc<DelayedTasksInner>,
74}
75
76impl DelayedTasksJoinHandle {
77    pub async fn join(self) -> Result<()> {
78        let mut tasks = {
79            let mut state = self.inner.state.lock().unwrap();
80            match std::mem::replace(&mut *state, DelayedTasksState::Closed) {
81                DelayedTasksState::AfterSpawn { tasks } => tasks,
82                DelayedTasksState::BeforeSpawn { .. } | DelayedTasksState::Closed => {
83                    unreachable!("join can only be called once");
84                }
85            }
86        };
87
88        while let Some(res) = tasks.next().await {
89            res?;
90        }
91        Ok(())
92    }
93}
94
95struct DelayedTasksInner {
96    state: Mutex<DelayedTasksState>,
97}
98
99enum DelayedTasksState {
100    BeforeSpawn {
101        make_fns: Vec<MakeTaskFn>,
102    },
103    AfterSpawn {
104        tasks: FuturesUnordered<JoinTask<Result<()>>>,
105    },
106    Closed,
107}
108
109type MakeTaskFn = Box<dyn FnOnce() -> JoinTask<Result<()>> + Send + 'static>;
110
111pin_project_lite::pin_project! {
112    pub struct OptionPrepareFut<F> {
113        #[pin]
114        inner: Option<F>,
115    }
116}
117
118impl<F> From<Option<F>> for OptionPrepareFut<F> {
119    #[inline]
120    fn from(inner: Option<F>) -> Self {
121        Self { inner }
122    }
123}
124
125impl<F, T, E> Future for OptionPrepareFut<F>
126where
127    F: Future<Output = Result<T, E>>,
128{
129    type Output = Result<Option<T>, E>;
130
131    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132        match self.project().inner.as_pin_mut() {
133            Some(f) => match f.poll(cx) {
134                Poll::Ready(Ok(res)) => Poll::Ready(Ok(Some(res))),
135                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
136                Poll::Pending => Poll::Pending,
137            },
138            None => Poll::Ready(Ok(None)),
139        }
140    }
141}
142
143pin_project_lite::pin_project! {
144    pub struct OptionHandleFut<F> {
145        #[pin]
146        inner: Option<F>,
147    }
148}
149
150impl<F> From<Option<F>> for OptionHandleFut<F> {
151    #[inline]
152    fn from(inner: Option<F>) -> Self {
153        Self { inner }
154    }
155}
156
157impl<F, T, E> Future for OptionHandleFut<F>
158where
159    F: Future<Output = Result<T, E>>,
160    T: Default,
161{
162    type Output = F::Output;
163
164    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
165        match self.project().inner.as_pin_mut() {
166            Some(f) => f.poll(cx),
167            None => Poll::Ready(Ok(T::default())),
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use std::pin::pin;
175
176    use futures_util::FutureExt;
177
178    use super::*;
179
180    #[tokio::test]
181    async fn delayed_tasks() -> anyhow::Result<()> {
182        Ok(())
183    }
184
185    #[tokio::test]
186    async fn option_futures() {
187        type NoopFut = futures_util::future::Ready<Result<(), ()>>;
188
189        // Prepare
190        let resolved = OptionPrepareFut::from(None::<NoopFut>);
191        assert_eq!(resolved.now_or_never().unwrap(), Ok(None));
192
193        let mut resolved = pin!(OptionPrepareFut::from(Some(async {
194            tokio::task::yield_now().await;
195            Ok::<_, ()>(())
196        })));
197        assert_eq!(futures_util::poll!(&mut resolved), Poll::Pending);
198        assert_eq!(
199            futures_util::poll!(&mut resolved),
200            Poll::Ready(Ok(Some(())))
201        );
202
203        // Handle
204        let resolved = OptionHandleFut::from(None::<NoopFut>);
205        assert_eq!(resolved.now_or_never().unwrap(), Ok(()));
206
207        let mut resolved = pin!(OptionHandleFut::from(Some(async {
208            tokio::task::yield_now().await;
209            Ok::<_, ()>(())
210        })));
211        assert_eq!(futures_util::poll!(&mut resolved), Poll::Pending);
212        assert_eq!(futures_util::poll!(&mut resolved), Poll::Ready(Ok(())));
213    }
214}