Skip to main content

rustrade_integration/stream/ext/
indexed.rs

1use derive_more::Constructor;
2use futures::Stream;
3use pin_project::pin_project;
4use rustrade_instrument::index::error::IndexError;
5use std::{
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10/// Type that indexes data structures.
11///
12/// An example `Indexer` use case is "keying" an event: <br>
13/// Unindexed = MarketEvent<MarketDataInstrument, DataKind> <br>
14/// Indexed = MarketEvent<InstrumentIndex, DataKind>
15pub trait Indexer {
16    type Unindexed;
17    type Indexed;
18
19    /// Index the input.
20    fn index(&self, item: Self::Unindexed) -> Result<Self::Indexed, IndexError>;
21}
22
23/// Stream adapter that indexes items using an [`Indexer`].
24#[derive(Debug, Constructor)]
25#[pin_project]
26pub struct IndexedStream<Stream, Indexer> {
27    #[pin]
28    stream: Stream,
29    indexer: Indexer,
30}
31
32impl<St, Index> Stream for IndexedStream<St, Index>
33where
34    St: Stream,
35    Index: Indexer<Unindexed = St::Item>,
36{
37    type Item = Result<Index::Indexed, IndexError>;
38
39    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
40        let this = self.project();
41        match this.stream.poll_next(cx) {
42            Poll::Ready(Some(item)) => Poll::Ready(Some(this.indexer.index(item))),
43            Poll::Ready(None) => Poll::Ready(None),
44            Poll::Pending => Poll::Pending,
45        }
46    }
47}
48
49#[cfg(test)]
50#[allow(clippy::unwrap_used)] // Test code: panics on bad input are acceptable
51mod tests {
52    use super::*;
53    use crate::stream::ext::BarterStreamExt;
54    use futures::StreamExt;
55    use std::collections::HashMap;
56    use tokio::sync::mpsc;
57    use tokio_stream::wrappers::UnboundedReceiverStream;
58    use tokio_test::{assert_pending, assert_ready};
59
60    #[derive(Debug, Clone)]
61    struct UnindexedData {
62        key: String,
63        value: i32,
64    }
65
66    #[derive(Debug, Clone, PartialEq)]
67    struct IndexedData {
68        index: usize,
69        value: i32,
70    }
71
72    struct MapIndexer {
73        map: HashMap<String, usize>,
74    }
75
76    impl Indexer for MapIndexer {
77        type Unindexed = UnindexedData;
78        type Indexed = IndexedData;
79
80        fn index(&self, item: Self::Unindexed) -> Result<Self::Indexed, IndexError> {
81            self.map
82                .get(&item.key)
83                .map(|&index| IndexedData {
84                    index,
85                    value: item.value,
86                })
87                .ok_or_else(|| IndexError::InstrumentIndex(format!("key '{}' not found", item.key)))
88        }
89    }
90
91    #[tokio::test]
92    async fn test_indexed_stream() {
93        let waker = futures::task::noop_waker_ref();
94        let mut cx = Context::from_waker(waker);
95
96        let (tx, rx) = mpsc::unbounded_channel::<UnindexedData>();
97        let rx = UnboundedReceiverStream::new(rx);
98
99        let mut map = HashMap::new();
100        map.insert("a".to_string(), 0);
101        map.insert("b".to_string(), 1);
102        map.insert("c".to_string(), 2);
103
104        let mut stream = rx.with_index(MapIndexer { map });
105
106        assert_pending!(stream.poll_next_unpin(&mut cx));
107
108        tx.send(UnindexedData {
109            key: "a".to_string(),
110            value: 10,
111        })
112        .unwrap();
113        assert_eq!(
114            assert_ready!(stream.poll_next_unpin(&mut cx)),
115            Some(Ok(IndexedData {
116                index: 0,
117                value: 10
118            }))
119        );
120
121        tx.send(UnindexedData {
122            key: "b".to_string(),
123            value: 20,
124        })
125        .unwrap();
126        assert_eq!(
127            assert_ready!(stream.poll_next_unpin(&mut cx)),
128            Some(Ok(IndexedData {
129                index: 1,
130                value: 20
131            }))
132        );
133
134        drop(tx);
135        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
136    }
137}