Skip to main content

tokio_simplified/
lib.rs

1extern crate futures_promises;
2extern crate tokio;
3
4use futures::sync::mpsc::Sender;
5use futures::{future::Future, sink::Sink, stream::Stream, sync::mpsc::channel};
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9use futures_promises::promises::{Promise, PromiseHandle};
10
11#[cfg(not(any(feature = "big_channels", feature = "very_big_channels")))]
12const CHANNEL_SIZE: usize = 16;
13#[cfg(all(feature = "big_channels", not(feature = "very_big_channels")))]
14const CHANNEL_SIZE: usize = 64;
15#[cfg(feature = "very_big_channels")]
16const CHANNEL_SIZE: usize = 256;
17
18/// A simple interface to interact with a tokio sink.
19///
20/// Should always be constructed by a call to some IoManager's get_writer().
21pub struct IoWriter<SendType> {
22    tx: futures::sync::mpsc::Sender<SendType>,
23}
24
25impl<SendType> Clone for IoWriter<SendType> {
26    fn clone(&self) -> Self {
27        IoWriter {
28            tx: self.tx.clone(),
29        }
30    }
31}
32
33impl<SendType> IoWriter<SendType>
34where
35    SendType: std::marker::Send + 'static,
36{
37    pub fn new<SinkType>(sink: SinkType) -> Self
38    where
39        SinkType: Sink<SinkItem = SendType> + Send + 'static,
40    {
41        let (tx, sink_rx) = channel::<<SinkType as Sink>::SinkItem>(CHANNEL_SIZE);
42        let sink_task = sink_rx.forward(sink.sink_map_err(|_| ())).map(|_| ());
43        tokio::spawn(sink_task);
44        IoWriter { tx }
45    }
46
47    /// Forwards the frame to the tokio sink associated with the IoManager that build this instance.
48    pub fn write<T: Into<SendType>>(&mut self, frame: T) -> PromiseHandle<()> {
49        let promise = Promise::new();
50        let handle = promise.get_handle();
51        tokio::spawn(self.tx.clone().send(frame.into()).then(move |result| {
52            match result {
53                Ok(_) => promise.resolve(()),
54                Err(e) => {
55                    promise.reject(format!("{}", e));
56                }
57            };
58            Ok::<(), ()>(())
59        }));
60        handle
61    }
62}
63
64pub trait Filter<SendType, ReceiveType>:
65    FnMut(ReceiveType, &mut IoWriter<SendType>) -> Option<ReceiveType> + std::marker::Send + 'static
66{
67}
68
69impl<T, SendType, ReceiveType> Filter<SendType, ReceiveType> for T where
70    T: FnMut(ReceiveType, &mut IoWriter<SendType>) -> Option<ReceiveType>
71        + std::marker::Send
72        + 'static
73{
74}
75
76pub trait ErrorHandler<ErrorType>: FnMut(ErrorType) + std::marker::Send + 'static {}
77
78impl<T, ErrorType> ErrorHandler<ErrorType> for T where
79    T: FnMut(ErrorType) + std::marker::Send + 'static
80{
81}
82
83/// A builder for `IoManager`, and the only way to build one since the constructors have been deleted.
84pub struct IoManagerBuilder<
85    SinkType,
86    StreamType,
87    BF = (fn(
88        <StreamType as Stream>::Item,
89        &mut IoWriter<<SinkType as Sink>::SinkItem>,
90    ) -> Option<<StreamType as Stream>::Item>),
91    BEH = (fn(<StreamType as Stream>::Error)),
92> where
93    SinkType: Sink,
94    StreamType: Stream,
95    BF: FnMut(
96            <StreamType as Stream>::Item,
97            &mut IoWriter<<SinkType as Sink>::SinkItem>,
98        ) -> Option<<StreamType as Stream>::Item>
99        + std::marker::Send
100        + 'static,
101    BEH: FnMut(<StreamType as Stream>::Error) + std::marker::Send + 'static,
102{
103    sink: SinkType,
104    stream: StreamType,
105    filter: Option<BF>,
106    error_handler: Option<BEH>,
107}
108
109type DefaultFilterType<SinkType, StreamType> = (fn(
110    <StreamType as Stream>::Item,
111    &mut IoWriter<<SinkType as Sink>::SinkItem>,
112) -> Option<<StreamType as Stream>::Item>);
113type DefaultErrorHandlerType<StreamType> = (fn(<StreamType as Stream>::Error));
114
115impl<SinkType, StreamType> IoManagerBuilder<SinkType, StreamType>
116where
117    SinkType: Sink + Send + 'static,
118    StreamType: Stream + Send + 'static,
119    <StreamType as Stream>::Item: Send + Clone + 'static,
120    <StreamType as Stream>::Error: Send,
121    <SinkType as Sink>::SinkItem: Send + 'static,
122{
123    /// Creates a builder for `IoManager`.
124    pub fn new(
125        sink: SinkType,
126        stream: StreamType,
127    ) -> IoManagerBuilder<
128        SinkType,
129        StreamType,
130        DefaultFilterType<SinkType, StreamType>,
131        DefaultErrorHandlerType<StreamType>,
132    > {
133        IoManagerBuilder {
134            sink,
135            stream,
136            filter: None,
137            error_handler: None,
138        }
139    }
140}
141
142impl<SinkType, StreamType, FilterType, ErrorHandlerType>
143    IoManagerBuilder<SinkType, StreamType, FilterType, ErrorHandlerType>
144where
145    SinkType: Sink + Send + 'static,
146    StreamType: Stream + Send + 'static,
147    <StreamType as Stream>::Item: Send + Clone + 'static,
148    <StreamType as Stream>::Error: Send,
149    <SinkType as Sink>::SinkItem: Send + 'static,
150    FilterType: Filter<<SinkType as Sink>::SinkItem, <StreamType as Stream>::Item>,
151    ErrorHandlerType: ErrorHandler<<StreamType as Stream>::Error>,
152{
153    /// Adds a filter to the `IoManager` builder.
154    /// Filters are static in this library. If you need to be able to change the filter without
155    /// droping the sink and steram passed to this instance, you should probably use Box to encapsulate your filter,
156    /// and then whatever you need to make it all thread safe for when you'll need to modify it.
157    /// Type inference should still work, which is nice.
158    pub fn with_filter<NewFilterType>(
159        self,
160        filter: NewFilterType,
161    ) -> IoManagerBuilder<SinkType, StreamType, NewFilterType, ErrorHandlerType>
162    where
163        NewFilterType: Filter<<SinkType as Sink>::SinkItem, <StreamType as Stream>::Item>,
164    {
165        IoManagerBuilder {
166            sink: self.sink,
167            stream: self.stream,
168            filter: Some(filter),
169            error_handler: self.error_handler,
170        }
171    }
172
173    /// Similar to `with_filter`, only for error handling.
174    /// Tip: if you want to be able to catch end of streams with this API,
175    /// you may want your Codec to implement `decode_eof()` and throw an error at the last moment,
176    /// such as this:
177    /// ```rust
178    /// fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
179    ///     let decode = self.decode(buf);
180    ///     match decode {
181    ///         Ok(None) => Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted)),
182    ///         _ => decode,
183    ///     }
184    /// }
185    /// ```
186    pub fn with_error_handler<NewErrorHandlerType>(
187        self,
188        handler: NewErrorHandlerType,
189    ) -> IoManagerBuilder<SinkType, StreamType, FilterType, NewErrorHandlerType>
190    where
191        NewErrorHandlerType: ErrorHandler<<StreamType as Stream>::Error>,
192    {
193        IoManagerBuilder {
194            sink: self.sink,
195            stream: self.stream,
196            filter: self.filter,
197            error_handler: Some(handler),
198        }
199    }
200
201    pub fn build(self) -> IoManager<SinkType::SinkItem, StreamType::Item> {
202        IoManager::<SinkType::SinkItem, StreamType::Item>::constructor(
203            self.sink,
204            self.stream,
205            self.filter,
206            self.error_handler,
207        )
208    }
209}
210
211/// A simplified interface to interact with tokio's streams and sinks.
212///
213/// Allows easy subscription to the stream's frames, and easy sending to the sink.
214#[derive(Clone)]
215pub struct IoManager<SendType, ReceiveType = SendType> {
216    tx: futures::sync::mpsc::Sender<SendType>,
217    subscribers: Arc<Mutex<HashMap<u32, futures::sync::mpsc::Sender<ReceiveType>>>>,
218    next_handle: Arc<Mutex<u32>>,
219}
220
221impl<SendType, ReceiveType> IoManager<SendType, ReceiveType> {
222    fn constructor<SinkType, StreamType, F, EH>(
223        sink: SinkType,
224        stream: StreamType,
225        mut filter: Option<F>,
226        error_handler: Option<EH>,
227    ) -> IoManager<SinkType::SinkItem, StreamType::Item>
228    where
229        SinkType: Sink + Send + 'static,
230        StreamType: Stream + Send + 'static,
231        <StreamType as Stream>::Item: Send + Clone + 'static,
232        <StreamType as Stream>::Error: Send,
233        <SinkType as Sink>::SinkItem: Send + 'static,
234        F: Filter<SinkType::SinkItem, StreamType::Item>,
235        EH: ErrorHandler<StreamType::Error>,
236    {
237        let (sink_tx, sink_rx) = channel::<<SinkType as Sink>::SinkItem>(CHANNEL_SIZE);
238        let sink_task = sink_rx.forward(sink.sink_map_err(|_| ())).map(|_| ());
239        tokio::spawn(sink_task);
240        let mut filter_writer = IoWriter {
241            tx: sink_tx.clone(),
242        };
243
244        let subscribers = Arc::new(Mutex::new(HashMap::<
245            u32,
246            futures::sync::mpsc::Sender<<StreamType as Stream>::Item>,
247        >::new()));
248        let stream_subscribers_reference = subscribers.clone();
249        let stream_task = stream
250            .for_each(move |frame| {
251                let frame = match &mut filter {
252                    None => Some(frame),
253                    Some(function) => function(frame, &mut filter_writer),
254                };
255                match frame {
256                    Some(frame) => {
257                        for (_handle, tx) in stream_subscribers_reference.lock().unwrap().iter_mut()
258                        {
259                            match tx.start_send(frame.clone()) {
260                                Ok(_) => {}
261                                Err(error) => {
262                                    eprintln!("Stream Subscriber Error: {}", error);
263                                }
264                            };
265                        }
266                    }
267                    None => {}
268                }
269                Ok(())
270            })
271            .map_err(|e| match error_handler {
272                None => (),
273                Some(mut handler) => handler(e),
274            });
275        tokio::spawn(stream_task);
276        IoManager {
277            tx: sink_tx,
278            subscribers,
279            next_handle: Arc::new(Mutex::new(0)),
280        }
281    }
282
283    /// `subscriber` will receive any data polled from the internal stream.
284    pub fn subscribe_mpsc_sender(
285        &self,
286        subscriber: futures::sync::mpsc::Sender<ReceiveType>,
287    ) -> u32 {
288        let mut map = self.subscribers.lock().unwrap();
289        let mut handle_guard = self.next_handle.lock().unwrap();
290        let handle = handle_guard.clone();
291        *handle_guard += 1;
292        map.insert(handle.clone(), subscriber);
293        handle
294    }
295
296    /// `callback` will be called for each `frame` polled from the internal stream.
297    pub fn on_receive<F>(&self, callback: F) -> u32
298    where
299        F: FnMut(ReceiveType) -> Result<(), ()> + std::marker::Send + 'static,
300        ReceiveType: std::marker::Send + 'static,
301    {
302        let (tx, rx) = channel::<ReceiveType>(CHANNEL_SIZE);
303        let on_frame = rx.for_each(callback).map(|_| ());
304        tokio::spawn(on_frame);
305        self.subscribe_mpsc_sender(tx)
306    }
307
308    /// Removes the callback with `key`handle. `key` should be a value returned by either
309    /// `on_receive()` or `subscribe_mpsc_sender()`.
310    ///
311    /// Returns the `mpsc::Sender` that used to be notified upon new frames, just in case.
312    pub fn extract_callback(&self, key: &u32) -> Option<Sender<ReceiveType>> {
313        let mut map = self.subscribers.lock().unwrap();
314        map.remove(key)
315    }
316
317    /// Returns an `IoWriter` that will forward data to the associated tokio sink.
318    pub fn get_writer(&self) -> IoWriter<SendType> {
319        IoWriter {
320            tx: self.tx.clone(),
321        }
322    }
323}
324
325/// Inspired by bkwilliams, these aliases will probably stay, but you shouldn't rely too much on them
326pub mod silly_aliases {
327    pub type DoWhenever<T, U> = crate::IoManager<T, U>;
328    pub type PushWhenever<T> = crate::IoWriter<T>;
329}