task_executor/
local_builder.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3use std::marker::Unpin;
4use std::ops::{Deref, DerefMut};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use futures::channel::mpsc;
9
10use super::{assert_future, LocalExecutor, LocalSpawner, LocalTaskType};
11
12impl<T: ?Sized> LocalSpawnExt for T where T: futures::Future {}
13
14pub trait LocalSpawnExt: futures::Future {
15    #[inline]
16    fn spawn<Tx, G>(self, exec: &LocalExecutor<Tx, G>) -> LocalSpawner<Self, Tx, G, ()>
17        where
18            Self: Sized + 'static,
19            Self::Output: 'static,
20            Tx: Clone + Unpin + futures::Sink<((), LocalTaskType)> + Sync + 'static,
21            G: Hash + Eq + Clone + Debug + Sync + 'static,
22    {
23        let f = LocalSpawner::new(exec, self, ());
24        assert_future::<_, _>(f)
25    }
26
27    #[inline]
28    fn spawn_with<Tx, G, D>(
29        self,
30        exec: &LocalExecutor<Tx, G, D>,
31        name: D,
32    ) -> LocalSpawner<Self, Tx, G, D>
33        where
34            Self: Sized + 'static,
35            Self::Output: 'static,
36            Tx: Clone + Unpin + futures::Sink<(D, LocalTaskType)> + Sync + 'static,
37            G: Hash + Eq + Clone + Debug + Sync + 'static,
38    {
39        let f = LocalSpawner::new(exec, self, name);
40        assert_future::<_, _>(f)
41    }
42}
43
44pub struct LocalBuilder {
45    workers: usize,
46    queue_max: usize,
47}
48
49impl Default for LocalBuilder {
50    fn default() -> Self {
51        Self {
52            workers: 100,
53            queue_max: 100_000,
54        }
55    }
56}
57
58impl LocalBuilder {
59    #[inline]
60    pub fn workers(mut self, workers: usize) -> Self {
61        self.workers = workers;
62        self
63    }
64
65    #[inline]
66    pub fn queue_max(mut self, queue_max: usize) -> Self {
67        self.queue_max = queue_max;
68        self
69    }
70
71    #[inline]
72    pub fn with_channel<Tx, Rx, D>(self, tx: Tx, rx: Rx) -> ChannelLocalBuilder<Tx, Rx, D>
73        where
74            Tx: Clone + futures::Sink<(D, LocalTaskType)> + Unpin + Sync + 'static,
75            Rx: futures::Stream<Item=(D, LocalTaskType)> + Unpin,
76    {
77        ChannelLocalBuilder {
78            builder: self,
79            tx,
80            rx,
81            _d: std::marker::PhantomData,
82        }
83    }
84
85    #[inline]
86    pub fn build(self) -> (LocalExecutor, impl futures::Future<Output=()>) {
87        let (tx, rx) = futures::channel::mpsc::channel(self.queue_max);
88        LocalExecutor::with_channel(self.workers, self.queue_max, SyncSender(tx), rx)
89    }
90}
91
92pub struct ChannelLocalBuilder<Tx, Rx, D> {
93    builder: LocalBuilder,
94    tx: Tx,
95    rx: Rx,
96    _d: std::marker::PhantomData<D>,
97}
98
99impl<Tx, Rx, D> ChannelLocalBuilder<Tx, Rx, D>
100    where
101        Tx: Clone + futures::Sink<(D, LocalTaskType)> + Unpin + Sync + 'static,
102        Rx: futures::Stream<Item=(D, LocalTaskType)> + Unpin,
103{
104    #[inline]
105    pub fn build(self) -> (LocalExecutor<Tx, (), D>, impl futures::Future<Output=()>) {
106        LocalExecutor::with_channel(
107            self.builder.workers,
108            self.builder.queue_max,
109            self.tx,
110            self.rx,
111        )
112    }
113
114    #[inline]
115    pub fn group(self) -> GroupChannelLocalBuilder<Tx, Rx, D> {
116        GroupChannelLocalBuilder { builder: self }
117    }
118}
119
120pub struct GroupChannelLocalBuilder<Tx, Rx, D> {
121    builder: ChannelLocalBuilder<Tx, Rx, D>,
122}
123
124impl<Tx, Rx, D> GroupChannelLocalBuilder<Tx, Rx, D>
125    where
126        Tx: Clone + futures::Sink<((), LocalTaskType)> + Unpin + Sync + 'static,
127        Rx: futures::Stream<Item=((), LocalTaskType)> + Unpin,
128{
129    #[inline]
130    pub fn build<G>(self) -> (LocalExecutor<Tx, G>, impl futures::Future<Output=()>)
131        where
132            G: Hash + Eq + Clone + Debug + Sync + 'static,
133    {
134        LocalExecutor::with_channel(
135            self.builder.builder.workers,
136            self.builder.builder.queue_max,
137            self.builder.tx,
138            self.builder.rx,
139        )
140    }
141}
142
143type DataType = ((), LocalTaskType);
144
145#[derive(Clone)]
146pub struct SyncSender(pub mpsc::Sender<DataType>);
147
148unsafe impl Sync for SyncSender {}
149
150impl Deref for SyncSender {
151    type Target = mpsc::Sender<DataType>;
152    #[inline]
153    fn deref(&self) -> &Self::Target {
154        &self.0
155    }
156}
157
158impl DerefMut for SyncSender {
159    #[inline]
160    fn deref_mut(&mut self) -> &mut Self::Target {
161        &mut self.0
162    }
163}
164
165impl futures::Sink<DataType> for SyncSender {
166    type Error = mpsc::SendError;
167
168    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169        self.0.poll_ready(cx)
170    }
171
172    fn start_send(mut self: Pin<&mut Self>, msg: DataType) -> Result<(), Self::Error> {
173        self.0.start_send(msg)
174    }
175
176    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        Pin::new(&mut self.0).poll_flush(cx)
178    }
179
180    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
181        Pin::new(&mut self.0).poll_close(cx)
182    }
183}