1use alloc::boxed::Box;
2use core::cell::{Cell, UnsafeCell};
3use core::future::Future;
4use core::marker::{PhantomData, PhantomPinned};
5#[cfg(feature = "std")]
6use core::ops::Deref;
7use core::ops::DerefMut;
8use core::pin::Pin;
9use core::ptr::NonNull;
10use core::task::{Context, Poll};
11
12use futures_core::Stream;
13use futures_sink::Sink;
14use pin_project_lite::pin_project;
15
16#[cfg(feature = "std")]
17use crate::LocalThread;
18
19#[cfg(feature = "std")]
20pub type DynSinkFn<'env, T, E> = Box<
36 dyn 'env
37 + Send
38 + for<'scope> FnMut(Pin<&'scope mut SinkInner<'scope, 'env, T>>) -> DynSinkFuture<'scope, E>,
39>;
40
41#[cfg(feature = "std")]
43pub type DynSinkFuture<'scope, E> = Pin<Box<dyn Future<Output = Result<(), E>> + Send + 'scope>>;
44
45#[cfg(feature = "std")]
46pin_project! {
47 #[must_use = "Sink will not do anything if not used"]
54 pub struct ScopedSink<'env, T, E> {
55 f: DynSinkFn<'env, T, E>,
56 inner: Option<DynSinkFuture<'env, E>>,
57
58 data: Pin<Box<SinkInner<'env, 'env, T>>>,
59 }
60}
61
62struct SinkInnerData<T> {
63 data: UnsafeCell<Option<T>>,
64 closed: Cell<bool>,
65
66 _pinned: PhantomPinned,
69}
70
71unsafe impl<T: Send> Send for SinkInnerData<T> {}
74unsafe impl<T: Send> Sync for SinkInnerData<T> {}
75
76impl<T> SinkInnerData<T> {
77 const fn new() -> Self {
78 Self {
79 data: UnsafeCell::new(None),
80 closed: Cell::new(false),
81 _pinned: PhantomPinned,
82 }
83 }
84}
85
86#[cfg(feature = "std")]
87pin_project! {
88 #[must_use = "SinkInner will not do anything if not used"]
105 pub struct SinkInner<'scope, 'env: 'scope, T> {
106 #[pin]
107 inner: LocalThread<SinkInnerData<T>>,
108
109 phantom: PhantomData<&'scope mut &'env T>,
110 }
111}
112
113#[cfg(feature = "std")]
114impl<'env, T: 'env, E: 'env> ScopedSink<'env, T, E> {
115 pub fn new_dyn(f: DynSinkFn<'env, T, E>) -> Self {
129 Self {
130 data: Box::pin(SinkInner {
131 inner: LocalThread::new(SinkInnerData::new()),
132
133 phantom: PhantomData,
134 }),
135
136 f,
137 inner: None,
138 }
139 }
140
141 pub fn new<F>(f: F) -> Self
171 where
172 for<'scope> F: 'env
173 + Send
174 + FnMut(
175 Pin<&'scope mut SinkInner<'scope, 'env, T>>,
176 ) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send + 'scope>>,
177 {
178 Self::new_dyn(Box::new(f))
179 }
180}
181
182impl<T> SinkInnerData<T> {
183 fn flush<E, U, F>(
184 &self,
185 cx: &mut Context<'_>,
186 fut: &mut Option<Pin<U>>,
187 mut f: F,
188 ) -> Poll<Result<(), E>>
189 where
190 U: DerefMut,
191 U::Target: Future<Output = Result<(), E>>,
192 F: FnMut() -> Pin<U>,
193 {
194 loop {
195 if unsafe { (*self.data.get()).is_none() } {
196 return Poll::Ready(Ok(()));
198 }
199
200 let fp = if let Some(v) = fut {
201 v
202 } else if self.closed.get() {
203 return Poll::Ready(Ok(()));
204 } else {
205 fut.get_or_insert_with(&mut f)
206 };
207
208 let Poll::Ready(v) = fp.as_mut().poll(cx) else {
209 break;
210 };
211
212 *fut = None;
214
215 if v.is_err() {
216 return Poll::Ready(v);
217 }
218
219 }
221
222 if unsafe { (*self.data.get()).is_some() } {
223 Poll::Pending
224 } else {
225 Poll::Ready(Ok(()))
226 }
227 }
228
229 fn send(&self, item: T) {
230 if self.closed.get() {
231 panic!("Sink is closed!");
232 }
233 let data = unsafe { &mut *self.data.get() };
234 if data.is_some() {
235 panic!("poll_ready() is not called yet!");
236 }
237
238 *data = Some(item);
239 }
240
241 fn close<E, U, F>(
242 &self,
243 cx: &mut Context<'_>,
244 fut: &mut Option<Pin<U>>,
245 f: F,
246 ) -> Poll<Result<(), E>>
247 where
248 U: DerefMut,
249 U::Target: Future<Output = Result<(), E>>,
250 F: FnMut() -> Pin<U>,
251 {
252 self.closed.set(true);
253
254 if unsafe { (*self.data.get()).is_some() } {
256 let ret = self.flush(cx, &mut *fut, f);
257 if let Poll::Ready(Err(_)) = ret {
258 return ret;
259 }
260 return match fut {
261 Some(_) => Poll::Pending,
263 None => ret,
264 };
265 }
266
267 let ret = match fut {
268 Some(p) => p.as_mut().poll(cx),
269 None => return Poll::Ready(Ok(())),
270 };
271 if ret.is_ready() {
272 *fut = None;
273 }
274 ret
275 }
276
277 fn next(&self) -> Poll<Option<T>> {
278 match unsafe { (*self.data.get()).take() } {
279 v @ Some(_) => Poll::Ready(v),
280 None if self.closed.get() => Poll::Ready(None),
281 None => Poll::Pending,
282 }
283 }
284}
285
286unsafe fn make_future<'a, T: 'a, R, F>(mut ptr: NonNull<T>, mut f: F) -> impl FnMut() -> R
287where
288 F: FnMut(Pin<&'a mut T>) -> R,
289{
290 move || f(Pin::new_unchecked(ptr.as_mut()))
291}
292
293#[cfg(feature = "std")]
294impl<'env, T: 'env, E: 'env> ScopedSink<'env, T, E> {
295 fn future_wrapper(
296 self: Pin<&mut Self>,
297 ) -> (
298 impl Deref<Target = SinkInnerData<T>> + '_,
299 &mut Option<DynSinkFuture<'env, E>>,
300 impl FnMut() -> DynSinkFuture<'env, E> + '_,
301 ) {
302 let this = self.project();
303 let f = unsafe {
306 make_future(
307 NonNull::from(this.data.as_mut().get_unchecked_mut()),
308 this.f,
309 )
310 };
311
312 (
313 this.data.as_ref().get_ref().inner.set_inner_ctx(),
314 this.inner,
315 f,
316 )
317 }
318}
319
320#[cfg(feature = "std")]
321impl<'env, T: 'env, E: 'env> Sink<T> for ScopedSink<'env, T, E> {
322 type Error = E;
323
324 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
325 self.poll_flush(cx)
326 }
327
328 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
329 let (data, fut, f) = self.future_wrapper();
330
331 data.flush(cx, fut, f)
332 }
333
334 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), E> {
335 self.data.inner.set_inner_ctx().send(item);
336 Ok(())
337 }
338
339 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
340 let (data, fut, f) = self.future_wrapper();
341
342 data.close(cx, fut, f)
343 }
344}
345
346#[cfg(feature = "std")]
347impl<'scope, 'env, T> Stream for SinkInner<'scope, 'env, T> {
348 type Item = T;
349
350 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
351 self.inner.get_inner().next()
352 }
353}
354
355pub type DynLocalSinkFn<'env, T, E> = Box<
371 dyn 'env
372 + for<'scope> FnMut(
373 Pin<&'scope mut LocalSinkInner<'scope, 'env, T>>,
374 ) -> DynLocalSinkFuture<'scope, E>,
375>;
376
377pub type DynLocalSinkFuture<'scope, E> = Pin<Box<dyn Future<Output = Result<(), E>> + 'scope>>;
379
380pin_project! {
381 #[must_use = "Sink will not do anything if not used"]
385 pub struct LocalScopedSink<'env, T, E> {
386 f: DynLocalSinkFn<'env, T, E>,
387 inner: Option<DynLocalSinkFuture<'env, E>>,
388
389 data: Pin<Box<LocalSinkInner<'env, 'env, T>>>,
390 }
391}
392
393pin_project! {
394 pub struct LocalSinkInner<'scope, 'env: 'scope, T> {
398 #[pin]
399 inner: SinkInnerData<T>,
400
401 phantom: PhantomData<(&'scope mut &'env T, *mut u8)>,
402 }
403}
404
405impl<'env, T: 'env, E: 'env> LocalScopedSink<'env, T, E> {
406 pub fn new_dyn(f: DynLocalSinkFn<'env, T, E>) -> Self {
420 Self {
421 data: Box::pin(LocalSinkInner {
422 inner: SinkInnerData::new(),
423
424 phantom: PhantomData,
425 }),
426
427 f,
428 inner: None,
429 }
430 }
431
432 pub fn new<F>(f: F) -> Self
462 where
463 for<'scope> F: 'env
464 + FnMut(Pin<&'scope mut LocalSinkInner<'scope, 'env, T>>) -> DynLocalSinkFuture<'scope, E>,
465 {
466 Self::new_dyn(Box::new(f))
467 }
468}
469
470impl<'env, T: 'env, E: 'env> LocalScopedSink<'env, T, E> {
471 fn future_wrapper(
472 self: Pin<&mut Self>,
473 ) -> (
474 &SinkInnerData<T>,
475 &mut Option<DynLocalSinkFuture<'env, E>>,
476 impl FnMut() -> DynLocalSinkFuture<'env, E> + '_,
477 ) {
478 let this = self.project();
479 let f = unsafe {
482 make_future(
483 NonNull::from(this.data.as_mut().get_unchecked_mut()),
484 this.f,
485 )
486 };
487
488 (&this.data.inner, this.inner, f)
489 }
490}
491
492impl<'env, T: 'env, E: 'env> Sink<T> for LocalScopedSink<'env, T, E> {
493 type Error = E;
494
495 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
496 self.poll_flush(cx)
497 }
498
499 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
500 let (data, fut, f) = self.future_wrapper();
501
502 data.flush(cx, fut, f)
503 }
504
505 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), E> {
506 self.data.inner.send(item);
507 Ok(())
508 }
509
510 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
511 let (data, fut, f) = self.future_wrapper();
512
513 data.close(cx, fut, f)
514 }
515}
516
517impl<'scope, 'env, T> Stream for LocalSinkInner<'scope, 'env, T> {
518 type Item = T;
519
520 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
521 self.inner.next()
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 use std::prelude::rust_2021::*;
530 use std::sync::atomic::{AtomicU8, Ordering};
531 use std::sync::Arc;
532 use std::time::Duration;
533
534 use anyhow::{bail, Error as AnyError, Result as AnyResult};
535 use futures_util::{SinkExt, StreamExt};
536 use num_integer::Roots as _;
537 use tokio::spawn;
538 use tokio::sync::mpsc::channel;
539 use tokio::task::yield_now;
540 use tokio::time::timeout;
541
542 async fn test_helper<F>(f: F) -> AnyResult<()>
543 where
544 F: Future<Output = AnyResult<()>> + Send,
545 {
546 match timeout(Duration::from_secs(5), f).await {
547 Ok(v) => v,
548 Err(_) => bail!("Time ran out"),
549 }
550 }
551 #[tokio::test]
576 async fn test_simple() -> AnyResult<()> {
577 let mut sink: ScopedSink<usize, AnyError> = ScopedSink::new(|_| Box::pin(async { Ok(()) }));
578
579 test_helper(async move {
580 println!("Closing");
581 sink.close().await?;
582
583 Ok(())
584 })
585 .await
586 }
587
588 #[tokio::test]
589 async fn test_send_one() -> AnyResult<()> {
590 let mut sink: ScopedSink<usize, AnyError> = ScopedSink::new(|mut src| {
591 Box::pin(async move {
592 println!("Starting sink");
593 while let Some(v) = src.next().await {
594 println!("Value: {v}");
595 }
596 println!("Stopping sink");
597
598 Ok(())
599 })
600 });
601
602 test_helper(async move {
603 sink.feed(1).await?;
604
605 println!("Closing");
606 sink.close().await?;
607
608 Ok(())
609 })
610 .await
611 }
612
613 #[tokio::test]
614 async fn test_send_many() -> AnyResult<()> {
615 let mut sink: ScopedSink<usize, AnyError> = ScopedSink::new(|mut src| {
616 Box::pin(async move {
617 println!("Starting sink");
618 while let Some(v) = src.next().await {
619 println!("Value: {v}");
620 }
621 println!("Stopping sink");
622
623 Ok(())
624 })
625 });
626
627 test_helper(async move {
628 for i in 0..10 {
629 println!("Sending: {i}");
630 sink.feed(i).await?;
631 }
632
633 println!("Closing");
634 sink.close().await?;
635
636 Ok(())
637 })
638 .await
639 }
640
641 #[tokio::test]
642 async fn test_send_yield() -> AnyResult<()> {
643 let mut sink: ScopedSink<usize, AnyError> = ScopedSink::new(|mut src| {
644 Box::pin(async move {
645 println!("Starting sink");
646 while let Some(v) = src.next().await {
647 println!("Value: {v}");
648 for _ in 0..5 {
649 yield_now().await;
650 }
651 }
652 println!("Stopping sink");
653
654 Ok(())
655 })
656 });
657
658 test_helper(async move {
659 for i in 0..10 {
660 println!("Sending: {i}");
661 sink.feed(i).await?;
662 }
663
664 println!("Closing");
665 sink.close().await?;
666
667 Ok(())
668 })
669 .await
670 }
671
672 #[tokio::test]
673 async fn test_send_yield2() -> AnyResult<()> {
674 let mut sink: ScopedSink<usize, AnyError> = ScopedSink::new(|mut src| {
675 Box::pin(async move {
676 println!("Starting sink");
677 while let Some(v) = src.next().await {
678 println!("Value: {v}");
679 for _ in 0..3 {
680 yield_now().await;
681 }
682 }
683 println!("Stopping sink");
684
685 Ok(())
686 })
687 });
688
689 test_helper(async move {
690 for i in 0..10 {
691 println!("Sending: {i}");
692 sink.feed(i).await?;
693
694 for _ in 0..5 {
695 yield_now().await;
696 }
697 }
698
699 println!("Closing");
700 sink.close().await?;
701
702 Ok(())
703 })
704 .await
705 }
706
707 #[tokio::test]
708 async fn test_send_many_flush() -> AnyResult<()> {
709 let mut sink: ScopedSink<usize, AnyError> = ScopedSink::new(|mut src| {
710 Box::pin(async move {
711 println!("Starting sink");
712 while let Some(v) = src.next().await {
713 println!("Value: {v}");
714 }
715 println!("Stopping sink");
716
717 Ok(())
718 })
719 });
720
721 test_helper(async move {
722 for i in 0..10 {
723 println!("Sending: {i}");
724 sink.feed(i).await?;
725 }
726
727 println!("Flushing");
728 sink.flush().await?;
729
730 for i in 10..20 {
731 println!("Sending: {i}");
732 sink.feed(i).await?;
733 }
734
735 println!("Closing");
736 sink.close().await?;
737
738 Ok(())
739 })
740 .await
741 }
742
743 #[tokio::test]
744 async fn test_return_then_receive() -> AnyResult<()> {
745 let v = Arc::new(AtomicU8::new(0));
746 let mut sink: ScopedSink<usize, AnyError> = ScopedSink::new(move |mut src| {
747 let v = v.clone();
748 Box::pin(async move {
749 let mut v_ = v.load(Ordering::SeqCst);
750 v_ = if v_ == 8 {
751 assert_eq!(src.next().await, Some(1));
753 0
754 } else {
755 v_ + 1
756 };
757 v.store(v_, Ordering::SeqCst);
758
759 Ok(())
760 })
761 });
762
763 test_helper(async move {
764 for _ in 0..10 {
765 println!("Sending");
766 sink.feed(1).await?;
767 }
768
769 println!("Closing");
770 sink.close().await?;
771
772 Ok(())
773 })
774 .await
775 }
776
777 #[tokio::test]
778 async fn test_double_scoped() -> AnyResult<()> {
779 let mut sink: ScopedSink<usize, AnyError> = ScopedSink::new(|mut src| {
780 Box::pin(async move {
781 let mut sink2: ScopedSink<usize, AnyError> = ScopedSink::new(|mut src| {
784 Box::pin(async move {
785 println!("Value: {}", src.next().await.unwrap());
786 Ok(())
787 })
788 });
789
790 while let Some(v) = src.next().await {
791 sink2.feed(v).await?;
792 sink2.feed(v).await?;
793 }
794 sink2.close().await?;
795
796 Ok(())
797 })
798 });
799
800 test_helper(async move {
801 for i in 0..10 {
802 sink.feed(i).await?;
803 }
804 sink.close().await?;
805
806 Ok(())
807 })
808 .await
809 }
810
811 #[tokio::test]
812 async fn test_spawn_mpsc() -> AnyResult<()> {
813 fn is_prime(v: u64) -> bool {
814 (2..v.sqrt()).all(|i| v % i != 0)
815 }
816
817 let (s1, mut r1) = channel::<u64>(4);
818 let (s2, mut r2) = channel::<u64>(4);
819
820 let mut sink = ScopedSink::new(move |mut stream| {
821 let s1 = s1.clone();
822 let s2 = s2.clone();
823
824 Box::pin(async move {
825 let Some(v) = stream.next().await else {
826 return Ok(());
827 };
828
829 if is_prime(v) {
830 s1.send(v).await
831 } else {
832 s2.send(v).await
833 }
834 })
835 });
836
837 let mut handles = Vec::new();
838
839 handles.push(spawn(test_helper(async move {
840 while let Some(v) = r1.recv().await {
841 assert!(is_prime(v));
842 for _ in 0..v.sqrt() {
843 yield_now().await
844 }
845 }
846
847 Ok(())
848 })));
849
850 handles.push(spawn(test_helper(async move {
851 while let Some(v) = r2.recv().await {
852 assert!(!is_prime(v));
853 for _ in 0..v.sqrt() {
854 yield_now().await
855 }
856 }
857
858 Ok(())
859 })));
860
861 handles.push(spawn(test_helper(async move {
862 for i in 1..1000 {
863 sink.feed(i).await?;
864 }
865 sink.close().await?;
866
867 Ok(())
868 })));
869
870 let mut has_error = false;
871 for f in handles {
872 if let Err(e) = f.await? {
873 eprintln!("{e:?}");
874 has_error = true;
875 }
876 }
877
878 if has_error {
879 bail!("Some error has happened");
880 }
881
882 Ok(())
883 }
884}