common/
cache.rs

1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::error::Error;
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::{Arc, Mutex};
7
8use tokio::sync::Mutex as AsyncMutex;
9use tokio::time::{Duration, Instant};
10
11use crate::io::Stream;
12
13pub trait CacheKey: Eq + Hash + Debug {}
14impl<T: Eq + Hash + Debug> CacheKey for T {}
15
16pub trait StreamCreator: Fn() -> Result<Stream, Box<dyn Error>> + Send + Sync + 'static {}
17impl<T: Fn() -> Result<Stream, Box<dyn Error>> + Send + Sync + 'static> StreamCreator for T {}
18
19struct CacheEntry {
20    stream: Arc<AsyncMutex<Stream>>,
21    last_activity: Instant,
22}
23
24const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(60);
25
26pub struct StreamsCache<F: StreamCreator, K: CacheKey> {
27    new_stream_creator: F,
28    entries: Mutex<HashMap<K, CacheEntry>>,
29    idle_entry_timeout: Duration,
30    cleanup_interval: Duration,
31    last_cleanup: Mutex<Instant>,
32}
33
34impl<F: StreamCreator, K: CacheKey> StreamsCache<F, K> {
35    pub fn new(
36        new_stream_creator: F,
37        idle_entry_timeout: Duration,
38        cleanup_interval: Duration,
39    ) -> Self {
40        Self {
41            new_stream_creator,
42            entries: Mutex::new(HashMap::new()),
43            idle_entry_timeout,
44            cleanup_interval,
45            last_cleanup: Mutex::new(Instant::now()),
46        }
47    }
48
49    pub fn with_default_cleanup_duration(
50        new_stream_creator: F,
51        idle_entry_timeout: Duration,
52    ) -> Self {
53        StreamsCache::new(
54            new_stream_creator,
55            idle_entry_timeout,
56            DEFAULT_CLEANUP_INTERVAL,
57        )
58    }
59
60    pub fn get(&self, key: K, now: Instant) -> Result<Arc<AsyncMutex<Stream>>, Box<dyn Error>> {
61        let res = self.get_or_create_stream(key, now);
62        let mut last_cleanup = self.last_cleanup.lock().unwrap();
63        if now - *last_cleanup > self.cleanup_interval {
64            self.cleanup_old_idle_streams(now);
65            *last_cleanup = now;
66        };
67        res
68    }
69
70    fn get_or_create_stream(
71        &self,
72        key: K,
73        now: Instant,
74    ) -> Result<Arc<AsyncMutex<Stream>>, Box<dyn Error>> {
75        let new_stream_creator = &self.new_stream_creator;
76        let mut entries = self.entries.lock().unwrap();
77        log::debug!("got key {:?}, entries size is {}", key, entries.len());
78        let entry = match entries.entry(key) {
79            Entry::Occupied(o) => o.into_mut(),
80            Entry::Vacant(v) => v.insert(CacheEntry {
81                stream: Arc::new(AsyncMutex::new(new_stream_creator()?)),
82                last_activity: now,
83            }),
84        };
85        entry.last_activity = now;
86        Ok(entry.stream.clone())
87    }
88
89    fn cleanup_old_idle_streams(&self, now: Instant) {
90        let mut entries = self.entries.lock().unwrap();
91        entries.retain(|_, v| now - v.last_activity < self.idle_entry_timeout);
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use tokio::time::Duration;
98    use tokio_test::io::Builder;
99
100    use super::*;
101
102    #[test]
103    fn stream_cache_get_new_stream_creator_failed() -> Result<(), Box<dyn Error>> {
104        let cache = StreamsCache::new(
105            || Err(String::from("bla").into()),
106            Duration::from_secs(3 * 60),
107            Duration::from_secs(60),
108        );
109        let res = cache.get("bla", Instant::now());
110        assert!(res.is_err());
111        Ok(())
112    }
113
114    #[test]
115    fn stream_cache_get_new_stream_success() -> Result<(), Box<dyn Error>> {
116        let cache = StreamsCache::new(
117            || Ok(Stream::new(Builder::new().build(), Builder::new().build())),
118            Duration::from_secs(3 * 60),
119            Duration::from_secs(60),
120        );
121
122        let now = Instant::now();
123        cache.get("bla", now)?;
124
125        let entries = cache.entries.lock().unwrap();
126        assert_eq!(entries.len(), 1);
127        assert!(entries.contains_key("bla"));
128        assert_eq!(entries.get("bla").unwrap().last_activity, now);
129        let last_cleanup = cache.last_cleanup.lock().unwrap();
130        assert_ne!(*last_cleanup, now);
131        Ok(())
132    }
133
134    #[test]
135    fn stream_cache_get_new_stream_another_exists() -> Result<(), Box<dyn Error>> {
136        let t1 = Instant::now();
137        let cache = StreamsCache::new(
138            || Ok(Stream::new(Builder::new().build(), Builder::new().build())),
139            Duration::from_secs(3 * 60),
140            Duration::from_secs(60),
141        );
142        {
143            let mut entries = cache.entries.lock().unwrap();
144            let entry = CacheEntry {
145                stream: Arc::new(AsyncMutex::new(Stream::new(
146                    Builder::new().build(),
147                    Builder::new().build(),
148                ))),
149                last_activity: t1,
150            };
151            entries.insert("bli", entry);
152        }
153
154        let mut t2 = t1.clone();
155        t2 += Duration::from_secs(1);
156        cache.get("bla", t2)?;
157
158        let entries = cache.entries.lock().unwrap();
159        assert_eq!(entries.len(), 2);
160        assert!(entries.contains_key("bla"));
161        assert_eq!(entries.get("bla").unwrap().last_activity, t2);
162        assert!(entries.contains_key("bli"));
163        assert_eq!(entries.get("bli").unwrap().last_activity, t1);
164        let last_cleanup = cache.last_cleanup.lock().unwrap();
165        assert_ne!(*last_cleanup, t2);
166        Ok(())
167    }
168
169    #[test]
170    fn stream_cache_get_existing_stream_success() -> Result<(), Box<dyn Error>> {
171        let mut now = Instant::now();
172        let cache = StreamsCache::new(
173            || Err(String::from("bla").into()),
174            Duration::from_secs(3 * 60),
175            Duration::from_secs(60),
176        );
177        {
178            let mut entries = cache.entries.lock().unwrap();
179            let entry = CacheEntry {
180                stream: Arc::new(AsyncMutex::new(Stream::new(
181                    Builder::new().build(),
182                    Builder::new().build(),
183                ))),
184                last_activity: now,
185            };
186            entries.insert("bla", entry);
187        }
188
189        now += Duration::from_secs(1);
190        cache.get("bla", now)?;
191
192        let entries = cache.entries.lock().unwrap();
193        assert_eq!(entries.len(), 1);
194        assert!(entries.contains_key("bla"));
195        assert_eq!(entries.get("bla").unwrap().last_activity, now);
196        let last_cleanup = cache.last_cleanup.lock().unwrap();
197        assert_ne!(*last_cleanup, now);
198        Ok(())
199    }
200
201    #[test]
202    fn stream_cache_get_auto_cleanup() -> Result<(), Box<dyn Error>> {
203        let mut now = Instant::now();
204        let cache = StreamsCache::new(
205            || Ok(Stream::new(Builder::new().build(), Builder::new().build())),
206            Duration::from_secs(1),
207            Duration::from_secs(1),
208        );
209        {
210            let mut entries = cache.entries.lock().unwrap();
211            let entry = CacheEntry {
212                stream: Arc::new(AsyncMutex::new(Stream::new(
213                    Builder::new().build(),
214                    Builder::new().build(),
215                ))),
216                last_activity: now,
217            };
218            entries.insert("bli", entry);
219        }
220
221        now += Duration::from_secs(5);
222        cache.get("bla", now)?;
223
224        let entries = cache.entries.lock().unwrap();
225        assert_eq!(entries.len(), 1);
226        assert!(entries.contains_key("bla"));
227        assert_eq!(entries.get("bla").unwrap().last_activity, now);
228        let last_cleanup = cache.last_cleanup.lock().unwrap();
229        assert_eq!(*last_cleanup, now);
230        Ok(())
231    }
232}