1use std::sync::Arc;
5
6use futures::stream::BoxStream;
7use futures::{Stream, StreamExt};
8use smol::block_on;
9
10pub use crate::runtime::pool::CurrentThreadWorkerPool;
11use crate::runtime::{BlockingRuntime, Executor, Handle};
12
13#[derive(Clone, Default)]
27pub struct CurrentThreadRuntime {
28 executor: Arc<smol::Executor<'static>>,
29}
30
31impl CurrentThreadRuntime {
32 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn new_pool(&self) -> CurrentThreadWorkerPool {
45 CurrentThreadWorkerPool::new(self.executor.clone())
46 }
47
48 pub fn block_on_stream_thread_safe<F, S, R>(&self, f: F) -> ThreadSafeIterator<R>
56 where
57 F: FnOnce(Handle) -> S,
58 S: Stream<Item = R> + Send + 'static,
59 R: Send + 'static,
60 {
61 let stream = f(self.handle());
62
63 let (result_tx, result_rx) = kanal::bounded_async(1);
67 self.executor
68 .spawn(async move {
69 futures::pin_mut!(stream);
70 while let Some(item) = stream.next().await {
71 if let Err(e) = result_tx.send(item).await {
73 log::trace!("all receivers dropped, stopping stream: {}", e);
74 break;
75 }
76 }
77 })
78 .detach();
79
80 ThreadSafeIterator {
81 executor: self.executor.clone(),
82 results: result_rx,
83 }
84 }
85}
86
87impl BlockingRuntime for CurrentThreadRuntime {
88 type BlockingIterator<'a, R: 'a> = CurrentThreadIterator<'a, R>;
89
90 fn handle(&self) -> Handle {
91 let executor: Arc<dyn Executor> = self.executor.clone();
92 Handle::new(Arc::downgrade(&executor))
93 }
94
95 fn block_on<Fut, R>(&self, fut: Fut) -> R
96 where
97 Fut: Future<Output = R>,
98 {
99 block_on(self.executor.run(fut))
100 }
101
102 fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
103 where
104 S: Stream<Item = R> + Send + 'a,
105 R: Send + 'a,
106 {
107 CurrentThreadIterator {
108 executor: self.executor.clone(),
109 stream: stream.boxed(),
110 }
111 }
112}
113
114pub struct CurrentThreadIterator<'a, T> {
116 executor: Arc<smol::Executor<'static>>,
117 stream: BoxStream<'a, T>,
118}
119
120impl<T> Iterator for CurrentThreadIterator<'_, T> {
121 type Item = T;
122
123 fn next(&mut self) -> Option<Self::Item> {
124 block_on(self.executor.run(self.stream.next()))
125 }
126}
127
128pub struct ThreadSafeIterator<T> {
130 executor: Arc<smol::Executor<'static>>,
131 results: kanal::AsyncReceiver<T>,
132}
133
134impl<T> Clone for ThreadSafeIterator<T> {
136 fn clone(&self) -> Self {
137 Self {
138 executor: self.executor.clone(),
139 results: self.results.clone(),
140 }
141 }
142}
143
144impl<T> Iterator for ThreadSafeIterator<T> {
145 type Item = T;
146
147 fn next(&mut self) -> Option<Self::Item> {
148 block_on(self.executor.run(self.results.recv())).ok()
149 }
150}
151
152#[allow(clippy::if_then_some_else_none)] #[cfg(test)]
154mod tests {
155 use std::sync::atomic::{AtomicUsize, Ordering};
156 use std::sync::{Arc, Barrier};
157 use std::thread;
158 use std::time::Duration;
159
160 use futures::{StreamExt, stream};
161 use parking_lot::Mutex;
162
163 use super::*;
164
165 #[test]
166 fn test_worker_thread() {
167 let runtime = CurrentThreadRuntime::new();
168
169 let value = Arc::new(AtomicUsize::new(0));
171 let value2 = value.clone();
172 runtime
173 .handle()
174 .spawn(async move {
175 value2.store(42, Ordering::SeqCst);
176 })
177 .detach();
178
179 assert_eq!(value.load(Ordering::SeqCst), 0);
181
182 let pool = runtime.new_pool();
184 assert_eq!(value.load(Ordering::SeqCst), 0);
185
186 pool.set_workers(1);
188 for _ in 0..10 {
189 if value.load(Ordering::SeqCst) == 42 {
190 break;
191 }
192 thread::sleep(Duration::from_millis(10));
193 }
194 assert_eq!(value.load(Ordering::SeqCst), 42);
195 }
196
197 #[test]
198 fn test_block_on_stream_single_thread() {
199 let mut iter =
200 CurrentThreadRuntime::new().block_on_stream(stream::iter(vec![1, 2, 3, 4, 5]).boxed());
201
202 assert_eq!(iter.next(), Some(1));
203 assert_eq!(iter.next(), Some(2));
204 assert_eq!(iter.next(), Some(3));
205 assert_eq!(iter.next(), Some(4));
206 assert_eq!(iter.next(), Some(5));
207 assert_eq!(iter.next(), None);
208 }
209
210 #[test]
211 fn test_block_on_stream_multiple_threads() {
212 let counter = Arc::new(AtomicUsize::new(0));
213 let num_threads = 4;
214 let items_per_thread = 25;
215 let total_items = 100;
216
217 let iter = CurrentThreadRuntime::new()
218 .block_on_stream_thread_safe(|_h| stream::iter(0..total_items).boxed());
219
220 let barrier = Arc::new(Barrier::new(num_threads));
221 let results = Arc::new(Mutex::new(Vec::new()));
222
223 let threads: Vec<_> = (0..num_threads)
224 .map(|_| {
225 let mut iter = iter.clone();
226 let counter = counter.clone();
227 let barrier = barrier.clone();
228 let results = results.clone();
229
230 thread::spawn(move || {
231 barrier.wait();
232 let mut local_results = Vec::new();
233
234 for _ in 0..items_per_thread {
235 if let Some(item) = iter.next() {
236 counter.fetch_add(1, Ordering::SeqCst);
237 local_results.push(item);
238 }
239 }
240
241 results.lock().push(local_results);
242 })
243 })
244 .collect();
245
246 for thread in threads {
247 thread.join().unwrap();
248 }
249
250 assert_eq!(counter.load(Ordering::SeqCst), total_items);
251
252 let all_results = results.lock();
253 let mut collected: Vec<_> = all_results.iter().flatten().copied().collect();
254 collected.sort();
255 assert_eq!(collected, (0..total_items).collect::<Vec<_>>());
256 }
257
258 #[test]
259 fn test_block_on_stream_concurrent_clone_and_drive() {
260 let num_items = 50;
261 let num_threads = 3;
262
263 let iter = CurrentThreadRuntime::new().block_on_stream_thread_safe(|h| {
264 stream::unfold(0, move |state| {
265 let h = h.clone();
266 async move {
267 if state < num_items {
268 h.spawn_cpu(move || {
269 thread::sleep(Duration::from_micros(10));
270 state
271 })
272 .await;
273 Some((state, state + 1))
274 } else {
275 None
276 }
277 }
278 })
279 });
280
281 let collected = Arc::new(Mutex::new(Vec::new()));
282 let barrier = Arc::new(Barrier::new(num_threads));
283
284 let threads: Vec<_> = (0..num_threads)
285 .map(|thread_id| {
286 let iter = iter.clone();
287 let collected = collected.clone();
288 let barrier = barrier.clone();
289
290 thread::spawn(move || {
291 barrier.wait();
292 let mut local_items = Vec::new();
293
294 for item in iter {
295 local_items.push((thread_id, item));
296 if local_items.len() >= 5 {
297 break;
298 }
299 }
300
301 collected.lock().extend(local_items);
302 })
303 })
304 .collect();
305
306 for thread in threads {
307 thread.join().unwrap();
308 }
309
310 let results = collected.lock();
311 let mut values: Vec<_> = results.iter().map(|(_, v)| *v).collect();
312 values.sort();
313 values.dedup();
314
315 assert!(values.len() >= 5);
316 assert!(values.iter().all(|&v| v < num_items));
317 }
318
319 #[test]
320 fn test_block_on_stream_async_work() {
321 let runtime = CurrentThreadRuntime::new();
322 let handle = runtime.handle();
323 let iter = runtime.block_on_stream({
324 stream::unfold((handle, 0), |(h, state)| async move {
325 if state < 10 {
326 let value = h
327 .spawn(async move { futures::future::ready(state * 2).await })
328 .await;
329 Some((value, (h, state + 1)))
330 } else {
331 None
332 }
333 })
334 });
335
336 let results: Vec<_> = iter.collect();
337 assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
338 }
339
340 #[test]
341 fn test_block_on_stream_drop_receivers_early() {
342 let counter = Arc::new(AtomicUsize::new(0));
343 let c = counter.clone();
344
345 let mut iter = CurrentThreadRuntime::new().block_on_stream({
346 stream::unfold(0, move |state| {
347 let c = c.clone();
348 async move {
349 (state < 100).then(|| {
350 c.fetch_add(1, Ordering::SeqCst);
351 (state, state + 1)
352 })
353 }
354 })
355 .boxed()
356 });
357
358 assert_eq!(iter.next(), Some(0));
359 assert_eq!(iter.next(), Some(1));
360 assert_eq!(iter.next(), Some(2));
361
362 drop(iter);
363
364 let final_count = counter.load(Ordering::SeqCst);
365 assert!(
366 final_count < 100,
367 "Stream should stop when all receivers are dropped"
368 );
369 }
370
371 #[test]
372 fn test_block_on_stream_interleaved_access() {
373 let barrier = Arc::new(Barrier::new(2));
374 let iter = CurrentThreadRuntime::new()
375 .block_on_stream_thread_safe(|_h| stream::iter(0..20).boxed());
376
377 let iter1 = iter.clone();
378 let iter2 = iter;
379 let barrier1 = barrier.clone();
380 let barrier2 = barrier;
381
382 let thread1 = thread::spawn(move || {
383 let mut iter = iter1;
384 let mut results = Vec::new();
385 barrier1.wait();
386
387 for _ in 0..5 {
388 if let Some(val) = iter.next() {
389 results.push(val);
390 thread::sleep(Duration::from_micros(50));
391 }
392 }
393 results
394 });
395
396 let thread2 = thread::spawn(move || {
397 let mut iter = iter2;
398 let mut results = Vec::new();
399 barrier2.wait();
400
401 for _ in 0..5 {
402 if let Some(val) = iter.next() {
403 results.push(val);
404 thread::sleep(Duration::from_micros(50));
405 }
406 }
407 results
408 });
409
410 let results1 = thread1.join().unwrap();
411 let results2 = thread2.join().unwrap();
412
413 let mut all_results = results1;
414 all_results.extend(results2);
415 all_results.sort();
416
417 assert_eq!(all_results, (0..10).collect::<Vec<_>>());
418
419 for i in 0..10 {
420 assert_eq!(all_results.iter().filter(|&&x| x == i).count(), 1);
421 }
422 }
423
424 #[test]
425 fn test_block_on_stream_stress_test() {
426 let num_threads = 10;
427 let num_items = 1000;
428
429 let iter = CurrentThreadRuntime::new()
430 .block_on_stream_thread_safe(|_h| stream::iter(0..num_items).boxed());
431
432 let received = Arc::new(Mutex::new(Vec::new()));
433 let barrier = Arc::new(Barrier::new(num_threads));
434
435 let threads: Vec<_> = (0..num_threads)
436 .map(|_| {
437 let iter = iter.clone();
438 let received = received.clone();
439 let barrier = barrier.clone();
440
441 thread::spawn(move || {
442 barrier.wait();
443 for val in iter {
444 received.lock().push(val);
445 }
446 })
447 })
448 .collect();
449
450 for thread in threads {
451 thread.join().unwrap();
452 }
453
454 let mut results = received.lock().clone();
455 results.sort();
456
457 assert_eq!(results.len(), num_items);
458 assert_eq!(results, (0..num_items).collect::<Vec<_>>());
459 }
460}