1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::error::Error;
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::{Arc, Mutex};
7
8use tokio::sync::Mutex as AsyncMutex;
9use tokio::time::{Duration, Instant};
10
11use crate::io::Stream;
12
13pub trait CacheKey: Eq + Hash + Debug {}
14impl<T: Eq + Hash + Debug> CacheKey for T {}
15
16pub trait StreamCreator: Fn() -> Result<Stream, Box<dyn Error>> + Send + Sync + 'static {}
17impl<T: Fn() -> Result<Stream, Box<dyn Error>> + Send + Sync + 'static> StreamCreator for T {}
18
19struct CacheEntry {
20 stream: Arc<AsyncMutex<Stream>>,
21 last_activity: Instant,
22}
23
24const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(60);
25
26pub struct StreamsCache<F: StreamCreator, K: CacheKey> {
27 new_stream_creator: F,
28 entries: Mutex<HashMap<K, CacheEntry>>,
29 idle_entry_timeout: Duration,
30 cleanup_interval: Duration,
31 last_cleanup: Mutex<Instant>,
32}
33
34impl<F: StreamCreator, K: CacheKey> StreamsCache<F, K> {
35 pub fn new(
36 new_stream_creator: F,
37 idle_entry_timeout: Duration,
38 cleanup_interval: Duration,
39 ) -> Self {
40 Self {
41 new_stream_creator,
42 entries: Mutex::new(HashMap::new()),
43 idle_entry_timeout,
44 cleanup_interval,
45 last_cleanup: Mutex::new(Instant::now()),
46 }
47 }
48
49 pub fn with_default_cleanup_duration(
50 new_stream_creator: F,
51 idle_entry_timeout: Duration,
52 ) -> Self {
53 StreamsCache::new(
54 new_stream_creator,
55 idle_entry_timeout,
56 DEFAULT_CLEANUP_INTERVAL,
57 )
58 }
59
60 pub fn get(&self, key: K, now: Instant) -> Result<Arc<AsyncMutex<Stream>>, Box<dyn Error>> {
61 let res = self.get_or_create_stream(key, now);
62 let mut last_cleanup = self.last_cleanup.lock().unwrap();
63 if now - *last_cleanup > self.cleanup_interval {
64 self.cleanup_old_idle_streams(now);
65 *last_cleanup = now;
66 };
67 res
68 }
69
70 fn get_or_create_stream(
71 &self,
72 key: K,
73 now: Instant,
74 ) -> Result<Arc<AsyncMutex<Stream>>, Box<dyn Error>> {
75 let new_stream_creator = &self.new_stream_creator;
76 let mut entries = self.entries.lock().unwrap();
77 log::debug!("got key {:?}, entries size is {}", key, entries.len());
78 let entry = match entries.entry(key) {
79 Entry::Occupied(o) => o.into_mut(),
80 Entry::Vacant(v) => v.insert(CacheEntry {
81 stream: Arc::new(AsyncMutex::new(new_stream_creator()?)),
82 last_activity: now,
83 }),
84 };
85 entry.last_activity = now;
86 Ok(entry.stream.clone())
87 }
88
89 fn cleanup_old_idle_streams(&self, now: Instant) {
90 let mut entries = self.entries.lock().unwrap();
91 entries.retain(|_, v| now - v.last_activity < self.idle_entry_timeout);
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use tokio::time::Duration;
98 use tokio_test::io::Builder;
99
100 use super::*;
101
102 #[test]
103 fn stream_cache_get_new_stream_creator_failed() -> Result<(), Box<dyn Error>> {
104 let cache = StreamsCache::new(
105 || Err(String::from("bla").into()),
106 Duration::from_secs(3 * 60),
107 Duration::from_secs(60),
108 );
109 let res = cache.get("bla", Instant::now());
110 assert!(res.is_err());
111 Ok(())
112 }
113
114 #[test]
115 fn stream_cache_get_new_stream_success() -> Result<(), Box<dyn Error>> {
116 let cache = StreamsCache::new(
117 || Ok(Stream::new(Builder::new().build(), Builder::new().build())),
118 Duration::from_secs(3 * 60),
119 Duration::from_secs(60),
120 );
121
122 let now = Instant::now();
123 cache.get("bla", now)?;
124
125 let entries = cache.entries.lock().unwrap();
126 assert_eq!(entries.len(), 1);
127 assert!(entries.contains_key("bla"));
128 assert_eq!(entries.get("bla").unwrap().last_activity, now);
129 let last_cleanup = cache.last_cleanup.lock().unwrap();
130 assert_ne!(*last_cleanup, now);
131 Ok(())
132 }
133
134 #[test]
135 fn stream_cache_get_new_stream_another_exists() -> Result<(), Box<dyn Error>> {
136 let t1 = Instant::now();
137 let cache = StreamsCache::new(
138 || Ok(Stream::new(Builder::new().build(), Builder::new().build())),
139 Duration::from_secs(3 * 60),
140 Duration::from_secs(60),
141 );
142 {
143 let mut entries = cache.entries.lock().unwrap();
144 let entry = CacheEntry {
145 stream: Arc::new(AsyncMutex::new(Stream::new(
146 Builder::new().build(),
147 Builder::new().build(),
148 ))),
149 last_activity: t1,
150 };
151 entries.insert("bli", entry);
152 }
153
154 let mut t2 = t1.clone();
155 t2 += Duration::from_secs(1);
156 cache.get("bla", t2)?;
157
158 let entries = cache.entries.lock().unwrap();
159 assert_eq!(entries.len(), 2);
160 assert!(entries.contains_key("bla"));
161 assert_eq!(entries.get("bla").unwrap().last_activity, t2);
162 assert!(entries.contains_key("bli"));
163 assert_eq!(entries.get("bli").unwrap().last_activity, t1);
164 let last_cleanup = cache.last_cleanup.lock().unwrap();
165 assert_ne!(*last_cleanup, t2);
166 Ok(())
167 }
168
169 #[test]
170 fn stream_cache_get_existing_stream_success() -> Result<(), Box<dyn Error>> {
171 let mut now = Instant::now();
172 let cache = StreamsCache::new(
173 || Err(String::from("bla").into()),
174 Duration::from_secs(3 * 60),
175 Duration::from_secs(60),
176 );
177 {
178 let mut entries = cache.entries.lock().unwrap();
179 let entry = CacheEntry {
180 stream: Arc::new(AsyncMutex::new(Stream::new(
181 Builder::new().build(),
182 Builder::new().build(),
183 ))),
184 last_activity: now,
185 };
186 entries.insert("bla", entry);
187 }
188
189 now += Duration::from_secs(1);
190 cache.get("bla", now)?;
191
192 let entries = cache.entries.lock().unwrap();
193 assert_eq!(entries.len(), 1);
194 assert!(entries.contains_key("bla"));
195 assert_eq!(entries.get("bla").unwrap().last_activity, now);
196 let last_cleanup = cache.last_cleanup.lock().unwrap();
197 assert_ne!(*last_cleanup, now);
198 Ok(())
199 }
200
201 #[test]
202 fn stream_cache_get_auto_cleanup() -> Result<(), Box<dyn Error>> {
203 let mut now = Instant::now();
204 let cache = StreamsCache::new(
205 || Ok(Stream::new(Builder::new().build(), Builder::new().build())),
206 Duration::from_secs(1),
207 Duration::from_secs(1),
208 );
209 {
210 let mut entries = cache.entries.lock().unwrap();
211 let entry = CacheEntry {
212 stream: Arc::new(AsyncMutex::new(Stream::new(
213 Builder::new().build(),
214 Builder::new().build(),
215 ))),
216 last_activity: now,
217 };
218 entries.insert("bli", entry);
219 }
220
221 now += Duration::from_secs(5);
222 cache.get("bla", now)?;
223
224 let entries = cache.entries.lock().unwrap();
225 assert_eq!(entries.len(), 1);
226 assert!(entries.contains_key("bla"));
227 assert_eq!(entries.get("bla").unwrap().last_activity, now);
228 let last_cleanup = cache.last_cleanup.lock().unwrap();
229 assert_eq!(*last_cleanup, now);
230 Ok(())
231 }
232}