yield_return/
iter.rs

1use std::{
2    future::Future,
3    iter::FusedIterator,
4    ops::{Deref, DerefMut},
5    pin::Pin,
6    sync::{Arc, Mutex},
7    task::{Context, Poll},
8};
9
10use futures_core::{FusedStream, Stream};
11
12use crate::utils::noop_waker;
13
14struct Sender<T>(Arc<Mutex<Option<T>>>);
15
16impl<T> Sender<T> {
17    #[track_caller]
18    fn set(&self, value: T) {
19        let mut guard = self.0.lock().unwrap();
20        assert!(guard.is_none(), "The result of `ret` is not await.");
21        *guard = Some(value);
22    }
23}
24
25impl<T> Future for Sender<T> {
26    type Output = ();
27    fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
28        if self.0.lock().unwrap().is_some() {
29            Poll::Pending
30        } else {
31            Poll::Ready(())
32        }
33    }
34}
35
36/// `Send` iterator context.
37///
38/// This type implements `Send`.
39pub struct IterContext<T>(Sender<T>);
40
41impl<T> IterContext<T>
42where
43    T: Send,
44{
45    /// Yields a single value. Similar to C#'s `yield return` or Python's `yield`.
46    #[track_caller]
47    pub fn ret(&mut self, value: T) -> impl Future<Output = ()> + Send + Sync + '_ {
48        self.0.set(value);
49        &mut self.0
50    }
51
52    /// Yields all values from an iterator. Similar to Python's `yield from` or JavaScript's `yield*`.
53    pub async fn ret_iter(&mut self, iter: impl IntoIterator<Item = T> + Send + Sync) {
54        for value in iter {
55            self.ret(value).await;
56        }
57    }
58}
59
60struct Data<'a, T> {
61    value: Arc<Mutex<Option<T>>>,
62    fut: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync + 'a>>>,
63}
64impl<T> Data<'_, T> {
65    fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<T>> {
66        let Some(fut) = &mut self.fut else {
67            return Poll::Ready(None);
68        };
69        let poll = fut.as_mut().poll(cx);
70        match poll {
71            Poll::Ready(_) => {
72                assert!(
73                    self.value.lock().unwrap().is_none(),
74                    "The result of `ret` is not await."
75                );
76                self.fut = None;
77                Poll::Ready(None)
78            }
79            Poll::Pending => {
80                if let Some(value) = self.value.lock().unwrap().take() {
81                    Poll::Ready(Some(value))
82                } else {
83                    Poll::Pending
84                }
85            }
86        }
87    }
88}
89
90/// `Send` iterator implemented using async functions.
91///
92/// This type implements `Send`.
93pub struct Iter<'a, T>(Data<'a, T>);
94
95impl<'a, T: 'a + Send> Iter<'a, T> {
96    /// Create an iterator from an asynchronous function.
97    ///
98    /// # Example
99    ///
100    /// ```
101    /// use yield_return::Yield;
102    /// let iter = Yield::new(|mut y| async move {
103    ///     y.ret(1).await;
104    ///     y.ret(2).await;
105    /// });
106    /// let list: Vec<_> = iter.collect();
107    /// assert_eq!(list, vec![1, 2]);
108    /// ```
109    pub fn new<Fut: Future<Output = ()> + Send + Sync + 'a>(
110        f: impl FnOnce(IterContext<T>) -> Fut,
111    ) -> Self {
112        let value = Arc::new(Mutex::new(None));
113        let cx = IterContext(Sender(value.clone()));
114        let fut: Pin<Box<dyn Future<Output = ()> + Send + Sync + 'a>> = Box::pin(f(cx));
115        let fut = Some(fut);
116        Self(Data { value, fut })
117    }
118}
119
120impl<T> Iterator for Iter<'_, T> {
121    type Item = T;
122    #[track_caller]
123    fn next(&mut self) -> Option<Self::Item> {
124        match self.0.poll_next(&mut Context::from_waker(&noop_waker())) {
125            Poll::Ready(value) => value,
126            Poll::Pending => panic!("`YieldContext::ret` is not called."),
127        }
128    }
129}
130impl<T> FusedIterator for Iter<'_, T> {}
131
132/// `Send` stream context.
133///
134/// This type implements `Send`.
135pub struct AsyncIterContext<T>(IterContext<T>);
136impl<T> Deref for AsyncIterContext<T> {
137    type Target = IterContext<T>;
138    fn deref(&self) -> &Self::Target {
139        &self.0
140    }
141}
142impl<T> DerefMut for AsyncIterContext<T> {
143    fn deref_mut(&mut self) -> &mut Self::Target {
144        &mut self.0
145    }
146}
147
148/// `Send` stream implemented using async functions.
149///
150/// This type implements `Send`.
151pub struct AsyncIter<'a, T>(Iter<'a, T>);
152
153impl<'a, T: Send + 'a> AsyncIter<'a, T> {
154    /// Create a stream from an asynchronous function.
155    ///
156    /// # Example
157    /// ```
158    /// use yield_return::AsyncIter;
159    /// # futures::executor::block_on(async {
160    /// let iter = AsyncIter::new(|mut y| async move {
161    ///     y.ret(1).await;
162    ///     y.ret(2).await;
163    /// });
164    /// let list: Vec<_> = futures::StreamExt::collect(iter).await;
165    /// assert_eq!(list, vec![1, 2]);
166    /// # });
167    /// ```
168    pub fn new<Fut: Future<Output = ()> + Send + Sync + 'a>(
169        f: impl FnOnce(AsyncIterContext<T>) -> Fut + Send + Sync,
170    ) -> Self {
171        Self(Iter::new(|cx| f(AsyncIterContext(cx))))
172    }
173}
174
175impl<T> Stream for AsyncIter<'_, T> {
176    type Item = T;
177
178    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179        self.0 .0.poll_next(cx)
180    }
181}
182impl<T> FusedStream for AsyncIter<'_, T> {
183    fn is_terminated(&self) -> bool {
184        self.0 .0.fut.is_none()
185    }
186}