Skip to main content

workflow_core/
lookup.rs

1//!
2//! [`LookupHandler`] provides ability to queue multiple async requests for the same key
3//! into a group of futures that resolve upon request completion.
4//!
5//! This functionality is useful when a client may be making multiple requests
6//! for data that is not available and may need to be fetched over a transport
7//! that may take time (such as network I/O). Each async request for the same
8//! key will get queued into a set of futures all of which will resolve once
9//! the initial request is resolved.
10//!
11
12#![allow(unused)]
13
14use crate::channel::*;
15use std::collections::HashMap;
16use std::hash::Hash;
17use std::sync::Arc;
18use std::sync::Mutex;
19use std::sync::atomic::{AtomicUsize, Ordering};
20
21/// Custom result type used by [`LookupHandler`]
22pub type LookupResult<V, E> = std::result::Result<V, E>;
23/// Outcome of queuing a lookup request, indicating whether the caller initiated
24/// a new lookup or joined an already-pending one. Both variants carry a receiver
25/// that resolves once the lookup completes.
26pub enum RequestType<V, E> {
27    /// No lookup for this key was pending; the caller is responsible for performing it.
28    New(Receiver<LookupResult<V, E>>),
29    /// A lookup for this key is already in progress; the caller merely awaits its result.
30    Pending(Receiver<LookupResult<V, E>>),
31}
32
33/// List of channel senders awaiting for the same key lookup.
34pub type SenderList<V, E> = Vec<Sender<LookupResult<V, E>>>;
35
36///
37/// [`LookupHandler`] provides ability to queue multiple async requests for the same key
38/// into a group of futures that resolve upon request completion.
39///
40/// To use [`LookupHandler`], you need to create a custom lookup function. The example below
41/// declares a function `lookup()` that uses [`LookupHandler`] to queue requests
42/// and if there are no pending requests (request is new) performs the actual
43/// request by calling `lookup_impl()`. The [`LookupHandler::complete()`] will
44/// resolve all pending futures for the specific key.
45///
46/// Example:
47/// ```ignore
48/// ...
49/// pub lookup_handler : LookupHandler<Pubkey,Arc<Data>,Error>
50/// ...
51/// async fn lookup(&self, pubkey:&Pubkey) -> Result<Option<Arc<Data>>> {
52///     let request_type = self.lookup_handler.queue(pubkey).await;
53///     let result = match request_type {
54///         RequestType::New(receiver) => {
55///             // execute the actual lookup
56///             let response = self.lookup_impl(pubkey).await;
57///             // signal completion for all awaiting futures
58///             lookup_handler.complete(pubkey, response).await;
59///             // this request is queued like all the others
60///             // so wait for your own notification as well
61///             receiver.recv().await?
62///         },
63///         RequestType::Pending(receiver) => {
64///             receiver.recv().await?
65///         }
66///     }
67/// };
68/// ```
69pub struct LookupHandler<K, V, E> {
70    /// Map of in-flight lookups, associating each key with the senders awaiting its result.
71    pub map: Arc<Mutex<HashMap<K, SenderList<V, E>>>>,
72    pending: AtomicUsize,
73}
74
75/// Default trait for the LookupHandler
76impl<K, V, E> Default for LookupHandler<K, V, E>
77where
78    V: Clone,
79    K: Clone + Eq + Hash + std::fmt::Debug,
80    E: Clone,
81{
82    fn default() -> Self {
83        LookupHandler::new()
84    }
85}
86
87impl<K, V, E> LookupHandler<K, V, E>
88where
89    V: Clone,
90    K: Clone + Eq + Hash + std::fmt::Debug,
91    E: Clone,
92{
93    /// Create a new instance of the LookupHandler
94    pub fn new() -> Self {
95        LookupHandler {
96            map: Arc::new(Mutex::new(HashMap::new())),
97            pending: AtomicUsize::new(0),
98        }
99    }
100
101    /// Returns the total number of pending requests
102    pub fn pending(&self) -> usize {
103        self.pending.load(Ordering::SeqCst)
104    }
105
106    /// Queue the request for key `K`. Returns [`RequestType::New`] if
107    /// no other requests for the same key are pending and [`RequestType::Pending`]
108    /// if there are pending requests. Both [`RequestType`] values contain a [[`async_channel::Receiver`]]
109    /// that can be listened to for lookup completion. Lookup completion
110    /// can be signaled by [`LookupHandler::complete()`]
111    pub async fn queue(&self, key: &K) -> RequestType<V, E> {
112        let mut pending = self.map.lock().unwrap();
113        let (sender, receiver) = oneshot::<LookupResult<V, E>>();
114
115        if let Some(list) = pending.get_mut(key) {
116            list.push(sender);
117            RequestType::Pending(receiver)
118        } else {
119            pending.insert(key.clone(), vec![sender]);
120            self.pending.fetch_add(1, Ordering::Relaxed);
121            RequestType::New(receiver)
122        }
123    }
124
125    /// Signal the lookup completion for key `K` by supplying a [`LookupResult`]
126    /// with a resulting value `V` or an error `E`.
127    pub async fn complete(&self, key: &K, result: LookupResult<V, E>) {
128        let list = { self.map.lock().unwrap().remove(key) };
129
130        if let Some(list) = list {
131            self.pending.fetch_sub(1, Ordering::Relaxed);
132            for sender in list {
133                sender
134                    .send(result.clone())
135                    .await
136                    .expect("Unable to complete lookup result");
137            }
138        } else {
139            panic!("Lookup handler failure while processing key {key:?}")
140        }
141    }
142}
143
144#[cfg(not(target_arch = "bpf"))]
145#[cfg(any(test, feature = "test"))]
146mod tests {
147    use super::LookupHandler;
148    use super::RequestType;
149    use std::sync::Arc;
150    use std::sync::Mutex;
151    use std::sync::PoisonError;
152    use std::time::Duration;
153
154    use crate::task::sleep;
155    use futures::join;
156    use std::collections::HashMap;
157    use workflow_core::channel::RecvError;
158
159    #[derive(thiserror::Error, Debug, Clone)]
160    pub enum Error {
161        #[error("{0}")]
162        String(String),
163    }
164
165    impl<T> From<PoisonError<T>> for Error {
166        fn from(_: PoisonError<T>) -> Self {
167            Error::String("PoisonError".to_string())
168        }
169    }
170
171    impl From<RecvError> for Error {
172        fn from(_: RecvError) -> Self {
173            Error::String("RecvError".to_string())
174        }
175    }
176
177    type Result<T> = std::result::Result<T, Error>;
178
179    #[derive(Debug, Eq, PartialEq)]
180    enum RequestTypeTest {
181        New = 0,
182        Pending = 1,
183    }
184
185    struct LookupHandlerTest {
186        pub lookup_handler: LookupHandler<u32, Option<u32>, Error>,
187        pub map: Arc<Mutex<HashMap<u32, u32>>>,
188        pub request_types: Arc<Mutex<Vec<RequestTypeTest>>>,
189    }
190
191    impl LookupHandlerTest {
192        pub fn new() -> Self {
193            Self {
194                lookup_handler: LookupHandler::new(),
195                map: Arc::new(Mutex::new(HashMap::new())),
196                request_types: Arc::new(Mutex::new(Vec::new())),
197            }
198        }
199
200        pub fn insert(self: &Arc<Self>, key: u32, value: u32) -> Result<()> {
201            let mut map = self.map.lock()?;
202            map.insert(key, value);
203            Ok(())
204        }
205
206        pub async fn lookup_remote_impl(self: &Arc<Self>, key: &u32) -> Result<Option<u32>> {
207            // println!("[lh] lookup sleep...");
208            sleep(Duration::from_millis(100)).await;
209            // println!("[lh] lookup wake...");
210            let map = self.map.lock()?;
211            Ok(map.get(key).cloned())
212        }
213
214        pub async fn lookup_handler_request(self: &Arc<Self>, key: &u32) -> Result<Option<u32>> {
215            let request_type = self.lookup_handler.queue(key).await;
216            match request_type {
217                RequestType::New(receiver) => {
218                    self.request_types
219                        .lock()
220                        .unwrap()
221                        .push(RequestTypeTest::New);
222                    // println!("[lh] new request");
223                    let response = self.lookup_remote_impl(key).await;
224                    // println!("[lh] completing initial request");
225                    self.lookup_handler.complete(key, response).await;
226                    receiver.recv().await?
227                }
228                RequestType::Pending(receiver) => {
229                    self.request_types
230                        .lock()
231                        .unwrap()
232                        .push(RequestTypeTest::Pending);
233                    // println!("[lh] pending request");
234                    receiver.recv().await?
235                }
236            }
237        }
238    }
239
240    pub async fn lookup_handler_test() -> Result<()> {
241        let lht = Arc::new(LookupHandlerTest::new());
242        lht.insert(0xc0fee, 0xdecaf)?;
243
244        let v0 = lht.lookup_handler_request(&0xc0fee);
245        let v1 = lht.lookup_handler_request(&0xc0fee);
246        let v2 = lht.lookup_handler_request(&0xc0fee);
247        let f = join!(v0, v1, v2);
248
249        println!("[lh] results: {:?}", f);
250        let f = (
251            f.0.unwrap().unwrap(),
252            f.1.unwrap().unwrap(),
253            f.2.unwrap().unwrap(),
254        );
255        assert_eq!(f, (0xdecaf, 0xdecaf, 0xdecaf));
256
257        let request_types = lht.request_types.lock().unwrap();
258        println!("[lh] request types: {:?}", request_types);
259        assert_eq!(
260            request_types[..],
261            [
262                RequestTypeTest::New,
263                RequestTypeTest::Pending,
264                RequestTypeTest::Pending
265            ]
266        );
267        println!("all looks good ... 😎");
268
269        Ok(())
270    }
271
272    #[cfg(not(any(target_arch = "wasm32", target_arch = "bpf")))]
273    #[cfg(test)]
274    mod native_tests {
275        use super::*;
276
277        #[tokio::test]
278        pub async fn lookup_handler_test() -> Result<()> {
279            super::lookup_handler_test().await
280        }
281    }
282}