rustrade_integration/stream/ext/
indexed.rs1use derive_more::Constructor;
2use futures::Stream;
3use pin_project::pin_project;
4use rustrade_instrument::index::error::IndexError;
5use std::{
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10pub trait Indexer {
16 type Unindexed;
17 type Indexed;
18
19 fn index(&self, item: Self::Unindexed) -> Result<Self::Indexed, IndexError>;
21}
22
23#[derive(Debug, Constructor)]
25#[pin_project]
26pub struct IndexedStream<Stream, Indexer> {
27 #[pin]
28 stream: Stream,
29 indexer: Indexer,
30}
31
32impl<St, Index> Stream for IndexedStream<St, Index>
33where
34 St: Stream,
35 Index: Indexer<Unindexed = St::Item>,
36{
37 type Item = Result<Index::Indexed, IndexError>;
38
39 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
40 let this = self.project();
41 match this.stream.poll_next(cx) {
42 Poll::Ready(Some(item)) => Poll::Ready(Some(this.indexer.index(item))),
43 Poll::Ready(None) => Poll::Ready(None),
44 Poll::Pending => Poll::Pending,
45 }
46 }
47}
48
49#[cfg(test)]
50#[allow(clippy::unwrap_used)] mod tests {
52 use super::*;
53 use crate::stream::ext::BarterStreamExt;
54 use futures::StreamExt;
55 use std::collections::HashMap;
56 use tokio::sync::mpsc;
57 use tokio_stream::wrappers::UnboundedReceiverStream;
58 use tokio_test::{assert_pending, assert_ready};
59
60 #[derive(Debug, Clone)]
61 struct UnindexedData {
62 key: String,
63 value: i32,
64 }
65
66 #[derive(Debug, Clone, PartialEq)]
67 struct IndexedData {
68 index: usize,
69 value: i32,
70 }
71
72 struct MapIndexer {
73 map: HashMap<String, usize>,
74 }
75
76 impl Indexer for MapIndexer {
77 type Unindexed = UnindexedData;
78 type Indexed = IndexedData;
79
80 fn index(&self, item: Self::Unindexed) -> Result<Self::Indexed, IndexError> {
81 self.map
82 .get(&item.key)
83 .map(|&index| IndexedData {
84 index,
85 value: item.value,
86 })
87 .ok_or_else(|| IndexError::InstrumentIndex(format!("key '{}' not found", item.key)))
88 }
89 }
90
91 #[tokio::test]
92 async fn test_indexed_stream() {
93 let waker = futures::task::noop_waker_ref();
94 let mut cx = Context::from_waker(waker);
95
96 let (tx, rx) = mpsc::unbounded_channel::<UnindexedData>();
97 let rx = UnboundedReceiverStream::new(rx);
98
99 let mut map = HashMap::new();
100 map.insert("a".to_string(), 0);
101 map.insert("b".to_string(), 1);
102 map.insert("c".to_string(), 2);
103
104 let mut stream = rx.with_index(MapIndexer { map });
105
106 assert_pending!(stream.poll_next_unpin(&mut cx));
107
108 tx.send(UnindexedData {
109 key: "a".to_string(),
110 value: 10,
111 })
112 .unwrap();
113 assert_eq!(
114 assert_ready!(stream.poll_next_unpin(&mut cx)),
115 Some(Ok(IndexedData {
116 index: 0,
117 value: 10
118 }))
119 );
120
121 tx.send(UnindexedData {
122 key: "b".to_string(),
123 value: 20,
124 })
125 .unwrap();
126 assert_eq!(
127 assert_ready!(stream.poll_next_unpin(&mut cx)),
128 Some(Ok(IndexedData {
129 index: 1,
130 value: 20
131 }))
132 );
133
134 drop(tx);
135 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
136 }
137}