vortex_io/runtime/
current.rs1use std::sync::Arc;
5
6use futures::Stream;
7use futures::StreamExt;
8use futures::stream::BoxStream;
9use smol::block_on;
10
11use crate::runtime::BlockingRuntime;
12use crate::runtime::Executor;
13use crate::runtime::Handle;
14pub use crate::runtime::pool::CurrentThreadWorkerPool;
15
16#[derive(Clone, Default)]
30pub struct CurrentThreadRuntime {
31 executor: Arc<smol::Executor<'static>>,
32}
33
34impl CurrentThreadRuntime {
35 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn new_pool(&self) -> CurrentThreadWorkerPool {
48 CurrentThreadWorkerPool::new(self.executor.clone())
49 }
50
51 pub fn block_on_stream_thread_safe<F, S, R>(&self, f: F) -> ThreadSafeIterator<R>
59 where
60 F: FnOnce(Handle) -> S,
61 S: Stream<Item = R> + Send + 'static,
62 R: Send + 'static,
63 {
64 let stream = f(self.handle());
65
66 let (result_tx, result_rx) = kanal::bounded_async(1);
70 self.executor
71 .spawn(async move {
72 futures::pin_mut!(stream);
73 while let Some(item) = stream.next().await {
74 if let Err(e) = result_tx.send(item).await {
76 tracing::trace!("all receivers dropped, stopping stream: {}", e);
77 break;
78 }
79 }
80 })
81 .detach();
82
83 ThreadSafeIterator {
84 executor: self.executor.clone(),
85 results: result_rx,
86 }
87 }
88}
89
90impl BlockingRuntime for CurrentThreadRuntime {
91 type BlockingIterator<'a, R: 'a> = CurrentThreadIterator<'a, R>;
92
93 fn handle(&self) -> Handle {
94 let executor: Arc<dyn Executor> = self.executor.clone();
95 Handle::new(Arc::downgrade(&executor))
96 }
97
98 fn block_on<Fut, R>(&self, fut: Fut) -> R
99 where
100 Fut: Future<Output = R>,
101 {
102 block_on(self.executor.run(fut))
103 }
104
105 fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
106 where
107 S: Stream<Item = R> + Send + 'a,
108 R: Send + 'a,
109 {
110 CurrentThreadIterator {
111 executor: self.executor.clone(),
112 stream: stream.boxed(),
113 }
114 }
115}
116
117pub struct CurrentThreadIterator<'a, T> {
119 executor: Arc<smol::Executor<'static>>,
120 stream: BoxStream<'a, T>,
121}
122
123impl<T> Iterator for CurrentThreadIterator<'_, T> {
124 type Item = T;
125
126 fn next(&mut self) -> Option<Self::Item> {
127 block_on(self.executor.run(self.stream.next()))
128 }
129}
130
131pub struct ThreadSafeIterator<T> {
133 executor: Arc<smol::Executor<'static>>,
134 results: kanal::AsyncReceiver<T>,
135}
136
137impl<T> Clone for ThreadSafeIterator<T> {
139 fn clone(&self) -> Self {
140 Self {
141 executor: self.executor.clone(),
142 results: self.results.clone(),
143 }
144 }
145}
146
147impl<T> Iterator for ThreadSafeIterator<T> {
148 type Item = T;
149
150 fn next(&mut self) -> Option<Self::Item> {
151 block_on(self.executor.run(self.results.recv())).ok()
152 }
153}
154
155#[allow(clippy::if_then_some_else_none)] #[cfg(test)]
157mod tests {
158 use std::sync::Arc;
159 use std::sync::Barrier;
160 use std::sync::atomic::AtomicUsize;
161 use std::sync::atomic::Ordering;
162 use std::thread;
163 use std::time::Duration;
164
165 use futures::StreamExt;
166 use futures::stream;
167 use parking_lot::Mutex;
168
169 use super::*;
170
171 #[test]
172 fn test_worker_thread() {
173 let runtime = CurrentThreadRuntime::new();
174
175 let value = Arc::new(AtomicUsize::new(0));
177 let value2 = value.clone();
178 runtime
179 .handle()
180 .spawn(async move {
181 value2.store(42, Ordering::SeqCst);
182 })
183 .detach();
184
185 assert_eq!(value.load(Ordering::SeqCst), 0);
187
188 let pool = runtime.new_pool();
190 assert_eq!(value.load(Ordering::SeqCst), 0);
191
192 pool.set_workers(1);
194 for _ in 0..10 {
195 if value.load(Ordering::SeqCst) == 42 {
196 break;
197 }
198 thread::sleep(Duration::from_millis(10));
199 }
200 assert_eq!(value.load(Ordering::SeqCst), 42);
201 }
202
203 #[test]
204 fn test_block_on_stream_single_thread() {
205 let mut iter =
206 CurrentThreadRuntime::new().block_on_stream(stream::iter(vec![1, 2, 3, 4, 5]).boxed());
207
208 assert_eq!(iter.next(), Some(1));
209 assert_eq!(iter.next(), Some(2));
210 assert_eq!(iter.next(), Some(3));
211 assert_eq!(iter.next(), Some(4));
212 assert_eq!(iter.next(), Some(5));
213 assert_eq!(iter.next(), None);
214 }
215
216 #[test]
217 fn test_block_on_stream_multiple_threads() {
218 let counter = Arc::new(AtomicUsize::new(0));
219 let num_threads = 4;
220 let items_per_thread = 25;
221 let total_items = 100;
222
223 let iter = CurrentThreadRuntime::new()
224 .block_on_stream_thread_safe(|_h| stream::iter(0..total_items).boxed());
225
226 let barrier = Arc::new(Barrier::new(num_threads));
227 let results = Arc::new(Mutex::new(Vec::new()));
228
229 let threads: Vec<_> = (0..num_threads)
230 .map(|_| {
231 let mut iter = iter.clone();
232 let counter = counter.clone();
233 let barrier = barrier.clone();
234 let results = results.clone();
235
236 thread::spawn(move || {
237 barrier.wait();
238 let mut local_results = Vec::new();
239
240 for _ in 0..items_per_thread {
241 if let Some(item) = iter.next() {
242 counter.fetch_add(1, Ordering::SeqCst);
243 local_results.push(item);
244 }
245 }
246
247 results.lock().push(local_results);
248 })
249 })
250 .collect();
251
252 for thread in threads {
253 thread.join().unwrap();
254 }
255
256 assert_eq!(counter.load(Ordering::SeqCst), total_items);
257
258 let all_results = results.lock();
259 let mut collected: Vec<_> = all_results.iter().flatten().copied().collect();
260 collected.sort();
261 assert_eq!(collected, (0..total_items).collect::<Vec<_>>());
262 }
263
264 #[test]
265 fn test_block_on_stream_concurrent_clone_and_drive() {
266 let num_items = 50;
267 let num_threads = 3;
268
269 let iter = CurrentThreadRuntime::new().block_on_stream_thread_safe(|h| {
270 stream::unfold(0, move |state| {
271 let h = h.clone();
272 async move {
273 if state < num_items {
274 h.spawn_cpu(move || {
275 thread::sleep(Duration::from_micros(10));
276 state
277 })
278 .await;
279 Some((state, state + 1))
280 } else {
281 None
282 }
283 }
284 })
285 });
286
287 let collected = Arc::new(Mutex::new(Vec::new()));
288 let barrier = Arc::new(Barrier::new(num_threads));
289
290 let threads: Vec<_> = (0..num_threads)
291 .map(|thread_id| {
292 let iter = iter.clone();
293 let collected = collected.clone();
294 let barrier = barrier.clone();
295
296 thread::spawn(move || {
297 barrier.wait();
298 let mut local_items = Vec::new();
299
300 for item in iter {
301 local_items.push((thread_id, item));
302 if local_items.len() >= 5 {
303 break;
304 }
305 }
306
307 collected.lock().extend(local_items);
308 })
309 })
310 .collect();
311
312 for thread in threads {
313 thread.join().unwrap();
314 }
315
316 let results = collected.lock();
317 let mut values: Vec<_> = results.iter().map(|(_, v)| *v).collect();
318 values.sort();
319 values.dedup();
320
321 assert!(values.len() >= 5);
322 assert!(values.iter().all(|&v| v < num_items));
323 }
324
325 #[test]
326 fn test_block_on_stream_async_work() {
327 let runtime = CurrentThreadRuntime::new();
328 let handle = runtime.handle();
329 let iter = runtime.block_on_stream({
330 stream::unfold((handle, 0), |(h, state)| async move {
331 if state < 10 {
332 let value = h
333 .spawn(async move { futures::future::ready(state * 2).await })
334 .await;
335 Some((value, (h, state + 1)))
336 } else {
337 None
338 }
339 })
340 });
341
342 let results: Vec<_> = iter.collect();
343 assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
344 }
345
346 #[test]
347 fn test_block_on_stream_drop_receivers_early() {
348 let counter = Arc::new(AtomicUsize::new(0));
349 let c = counter.clone();
350
351 let mut iter = CurrentThreadRuntime::new().block_on_stream({
352 stream::unfold(0, move |state| {
353 let c = c.clone();
354 async move {
355 (state < 100).then(|| {
356 c.fetch_add(1, Ordering::SeqCst);
357 (state, state + 1)
358 })
359 }
360 })
361 .boxed()
362 });
363
364 assert_eq!(iter.next(), Some(0));
365 assert_eq!(iter.next(), Some(1));
366 assert_eq!(iter.next(), Some(2));
367
368 drop(iter);
369
370 let final_count = counter.load(Ordering::SeqCst);
371 assert!(
372 final_count < 100,
373 "Stream should stop when all receivers are dropped"
374 );
375 }
376
377 #[test]
378 fn test_block_on_stream_interleaved_access() {
379 let barrier = Arc::new(Barrier::new(2));
380 let iter = CurrentThreadRuntime::new()
381 .block_on_stream_thread_safe(|_h| stream::iter(0..20).boxed());
382
383 let iter1 = iter.clone();
384 let iter2 = iter;
385 let barrier1 = barrier.clone();
386 let barrier2 = barrier;
387
388 let thread1 = thread::spawn(move || {
389 let mut iter = iter1;
390 let mut results = Vec::new();
391 barrier1.wait();
392
393 for _ in 0..5 {
394 if let Some(val) = iter.next() {
395 results.push(val);
396 thread::sleep(Duration::from_micros(50));
397 }
398 }
399 results
400 });
401
402 let thread2 = thread::spawn(move || {
403 let mut iter = iter2;
404 let mut results = Vec::new();
405 barrier2.wait();
406
407 for _ in 0..5 {
408 if let Some(val) = iter.next() {
409 results.push(val);
410 thread::sleep(Duration::from_micros(50));
411 }
412 }
413 results
414 });
415
416 let results1 = thread1.join().unwrap();
417 let results2 = thread2.join().unwrap();
418
419 let mut all_results = results1;
420 all_results.extend(results2);
421 all_results.sort();
422
423 assert_eq!(all_results, (0..10).collect::<Vec<_>>());
424
425 for i in 0..10 {
426 assert_eq!(all_results.iter().filter(|&&x| x == i).count(), 1);
427 }
428 }
429
430 #[test]
431 fn test_block_on_stream_stress_test() {
432 let num_threads = 10;
433 let num_items = 1000;
434
435 let iter = CurrentThreadRuntime::new()
436 .block_on_stream_thread_safe(|_h| stream::iter(0..num_items).boxed());
437
438 let received = Arc::new(Mutex::new(Vec::new()));
439 let barrier = Arc::new(Barrier::new(num_threads));
440
441 let threads: Vec<_> = (0..num_threads)
442 .map(|_| {
443 let iter = iter.clone();
444 let received = received.clone();
445 let barrier = barrier.clone();
446
447 thread::spawn(move || {
448 barrier.wait();
449 for val in iter {
450 received.lock().push(val);
451 }
452 })
453 })
454 .collect();
455
456 for thread in threads {
457 thread.join().unwrap();
458 }
459
460 let mut results = received.lock().clone();
461 results.sort();
462
463 assert_eq!(results.len(), num_items);
464 assert_eq!(results, (0..num_items).collect::<Vec<_>>());
465 }
466}