1use std::sync::Arc;
5
6use futures::stream::BoxStream;
7use futures::{Stream, StreamExt};
8use smol::block_on;
9
10use crate::runtime::{BlockingRuntime, Executor, Handle};
11
12#[derive(Clone, Default)]
16pub struct CurrentThreadRuntime {
17 executor: Arc<smol::Executor<'static>>,
18}
19
20impl BlockingRuntime for CurrentThreadRuntime {
21 type BlockingIterator<'a, R: 'a> = CurrentThreadIterator<'a, R>;
22
23 fn handle(&self) -> Handle {
24 let executor: Arc<dyn Executor> = self.executor.clone();
25 Handle::new(Arc::downgrade(&executor))
26 }
27
28 fn block_on<F, Fut, R>(&self, f: F) -> R
29 where
30 F: FnOnce(Handle) -> Fut,
31 Fut: Future<Output = R>,
32 {
33 block_on(self.executor.run(f(self.handle())))
34 }
35
36 fn block_on_stream<'a, F, S, R>(&self, f: F) -> Self::BlockingIterator<'a, R>
37 where
38 F: FnOnce(Handle) -> S,
39 S: Stream<Item = R> + Send + 'a,
40 R: Send + 'a,
41 {
42 CurrentThreadIterator {
43 executor: self.executor.clone(),
44 stream: f(self.handle()).boxed(),
45 }
46 }
47}
48
49impl CurrentThreadRuntime {
50 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn block_on_stream_thread_safe<F, S, R>(&self, f: F) -> ThreadSafeIterator<R>
63 where
64 F: FnOnce(Handle) -> S,
65 S: Stream<Item = R> + Send + 'static,
66 R: Send + 'static,
67 {
68 let stream = f(self.handle());
69
70 let (result_tx, result_rx) = kanal::bounded_async(1);
74 self.executor
75 .spawn(async move {
76 futures::pin_mut!(stream);
77 while let Some(item) = stream.next().await {
78 if let Err(e) = result_tx.send(item).await {
80 log::trace!("all receivers dropped, stopping stream: {}", e);
81 break;
82 }
83 }
84 })
85 .detach();
86
87 ThreadSafeIterator {
88 executor: self.executor.clone(),
89 results: result_rx,
90 }
91 }
92}
93
94pub struct CurrentThreadIterator<'a, T> {
96 executor: Arc<smol::Executor<'static>>,
97 stream: BoxStream<'a, T>,
98}
99
100impl<T> Iterator for CurrentThreadIterator<'_, T> {
101 type Item = T;
102
103 fn next(&mut self) -> Option<Self::Item> {
104 block_on(self.executor.run(self.stream.next()))
105 }
106}
107
108pub struct ThreadSafeIterator<T> {
110 executor: Arc<smol::Executor<'static>>,
111 results: kanal::AsyncReceiver<T>,
112}
113
114impl<T> Clone for ThreadSafeIterator<T> {
116 fn clone(&self) -> Self {
117 Self {
118 executor: self.executor.clone(),
119 results: self.results.clone(),
120 }
121 }
122}
123
124impl<T> Iterator for ThreadSafeIterator<T> {
125 type Item = T;
126
127 fn next(&mut self) -> Option<Self::Item> {
128 block_on(self.executor.run(self.results.recv())).ok()
129 }
130}
131
132#[allow(clippy::if_then_some_else_none)] #[cfg(test)]
134mod tests {
135 use std::sync::atomic::{AtomicUsize, Ordering};
136 use std::sync::{Arc, Barrier};
137 use std::thread;
138 use std::time::Duration;
139
140 use futures::{StreamExt, stream};
141 use parking_lot::Mutex;
142
143 use super::*;
144
145 #[test]
146 fn test_block_on_stream_single_thread() {
147 let mut iter = CurrentThreadRuntime::new()
148 .block_on_stream(|_h| stream::iter(vec![1, 2, 3, 4, 5]).boxed());
149
150 assert_eq!(iter.next(), Some(1));
151 assert_eq!(iter.next(), Some(2));
152 assert_eq!(iter.next(), Some(3));
153 assert_eq!(iter.next(), Some(4));
154 assert_eq!(iter.next(), Some(5));
155 assert_eq!(iter.next(), None);
156 }
157
158 #[test]
159 fn test_block_on_stream_multiple_threads() {
160 let counter = Arc::new(AtomicUsize::new(0));
161 let num_threads = 4;
162 let items_per_thread = 25;
163 let total_items = 100;
164
165 let iter = CurrentThreadRuntime::new()
166 .block_on_stream_thread_safe(|_h| stream::iter(0..total_items).boxed());
167
168 let barrier = Arc::new(Barrier::new(num_threads));
169 let results = Arc::new(Mutex::new(Vec::new()));
170
171 let threads: Vec<_> = (0..num_threads)
172 .map(|_| {
173 let mut iter = iter.clone();
174 let counter = counter.clone();
175 let barrier = barrier.clone();
176 let results = results.clone();
177
178 thread::spawn(move || {
179 barrier.wait();
180 let mut local_results = Vec::new();
181
182 for _ in 0..items_per_thread {
183 if let Some(item) = iter.next() {
184 counter.fetch_add(1, Ordering::SeqCst);
185 local_results.push(item);
186 }
187 }
188
189 results.lock().push(local_results);
190 })
191 })
192 .collect();
193
194 for thread in threads {
195 thread.join().unwrap();
196 }
197
198 assert_eq!(counter.load(Ordering::SeqCst), total_items);
199
200 let all_results = results.lock();
201 let mut collected: Vec<_> = all_results.iter().flatten().copied().collect();
202 collected.sort();
203 assert_eq!(collected, (0..total_items).collect::<Vec<_>>());
204 }
205
206 #[test]
207 fn test_block_on_stream_concurrent_clone_and_drive() {
208 let num_items = 50;
209 let num_threads = 3;
210
211 let iter = CurrentThreadRuntime::new().block_on_stream_thread_safe(|h| {
212 stream::unfold(0, move |state| {
213 let h = h.clone();
214 async move {
215 if state < num_items {
216 h.spawn_cpu(move || {
217 thread::sleep(Duration::from_micros(10));
218 state
219 })
220 .await;
221 Some((state, state + 1))
222 } else {
223 None
224 }
225 }
226 })
227 });
228
229 let collected = Arc::new(Mutex::new(Vec::new()));
230 let barrier = Arc::new(Barrier::new(num_threads));
231
232 let threads: Vec<_> = (0..num_threads)
233 .map(|thread_id| {
234 let iter = iter.clone();
235 let collected = collected.clone();
236 let barrier = barrier.clone();
237
238 thread::spawn(move || {
239 barrier.wait();
240 let mut local_items = Vec::new();
241
242 for item in iter {
243 local_items.push((thread_id, item));
244 if local_items.len() >= 5 {
245 break;
246 }
247 }
248
249 collected.lock().extend(local_items);
250 })
251 })
252 .collect();
253
254 for thread in threads {
255 thread.join().unwrap();
256 }
257
258 let results = collected.lock();
259 let mut values: Vec<_> = results.iter().map(|(_, v)| *v).collect();
260 values.sort();
261 values.dedup();
262
263 assert!(values.len() >= 5);
264 assert!(values.iter().all(|&v| v < num_items));
265 }
266
267 #[test]
268 fn test_block_on_stream_async_work() {
269 let iter = CurrentThreadRuntime::new().block_on_stream(|h| {
270 stream::unfold((h, 0), |(h, state)| async move {
271 if state < 10 {
272 let value = h
273 .spawn(async move { futures::future::ready(state * 2).await })
274 .await;
275 Some((value, (h, state + 1)))
276 } else {
277 None
278 }
279 })
280 });
281
282 let results: Vec<_> = iter.collect();
283 assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
284 }
285
286 #[test]
287 fn test_block_on_stream_drop_receivers_early() {
288 let counter = Arc::new(AtomicUsize::new(0));
289 let c = counter.clone();
290
291 let mut iter = CurrentThreadRuntime::new().block_on_stream(|_h| {
292 stream::unfold(0, move |state| {
293 let c = c.clone();
294 async move {
295 (state < 100).then(|| {
296 c.fetch_add(1, Ordering::SeqCst);
297 (state, state + 1)
298 })
299 }
300 })
301 .boxed()
302 });
303
304 assert_eq!(iter.next(), Some(0));
305 assert_eq!(iter.next(), Some(1));
306 assert_eq!(iter.next(), Some(2));
307
308 drop(iter);
309
310 let final_count = counter.load(Ordering::SeqCst);
311 assert!(
312 final_count < 100,
313 "Stream should stop when all receivers are dropped"
314 );
315 }
316
317 #[test]
318 fn test_block_on_stream_interleaved_access() {
319 let barrier = Arc::new(Barrier::new(2));
320 let iter = CurrentThreadRuntime::new()
321 .block_on_stream_thread_safe(|_h| stream::iter(0..20).boxed());
322
323 let iter1 = iter.clone();
324 let iter2 = iter;
325 let barrier1 = barrier.clone();
326 let barrier2 = barrier;
327
328 let thread1 = thread::spawn(move || {
329 let mut iter = iter1;
330 let mut results = Vec::new();
331 barrier1.wait();
332
333 for _ in 0..5 {
334 if let Some(val) = iter.next() {
335 results.push(val);
336 thread::sleep(Duration::from_micros(50));
337 }
338 }
339 results
340 });
341
342 let thread2 = thread::spawn(move || {
343 let mut iter = iter2;
344 let mut results = Vec::new();
345 barrier2.wait();
346
347 for _ in 0..5 {
348 if let Some(val) = iter.next() {
349 results.push(val);
350 thread::sleep(Duration::from_micros(50));
351 }
352 }
353 results
354 });
355
356 let results1 = thread1.join().unwrap();
357 let results2 = thread2.join().unwrap();
358
359 let mut all_results = results1;
360 all_results.extend(results2);
361 all_results.sort();
362
363 assert_eq!(all_results, (0..10).collect::<Vec<_>>());
364
365 for i in 0..10 {
366 assert_eq!(all_results.iter().filter(|&&x| x == i).count(), 1);
367 }
368 }
369
370 #[test]
371 fn test_block_on_stream_stress_test() {
372 let num_threads = 10;
373 let num_items = 1000;
374
375 let iter = CurrentThreadRuntime::new()
376 .block_on_stream_thread_safe(|_h| stream::iter(0..num_items).boxed());
377
378 let received = Arc::new(Mutex::new(Vec::new()));
379 let barrier = Arc::new(Barrier::new(num_threads));
380
381 let threads: Vec<_> = (0..num_threads)
382 .map(|_| {
383 let iter = iter.clone();
384 let received = received.clone();
385 let barrier = barrier.clone();
386
387 thread::spawn(move || {
388 barrier.wait();
389 for val in iter {
390 received.lock().push(val);
391 }
392 })
393 })
394 .collect();
395
396 for thread in threads {
397 thread.join().unwrap();
398 }
399
400 let mut results = received.lock().clone();
401 results.sort();
402
403 assert_eq!(results.len(), num_items);
404 assert_eq!(results, (0..num_items).collect::<Vec<_>>());
405 }
406}