webull_rs/utils/
cache.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, Instant};
5
6/// Cache entry.
7#[derive(Debug, Clone)]
8struct CacheEntry<T> {
9    /// Cached value
10    value: T,
11
12    /// When the entry was created
13    created_at: Instant,
14
15    /// Time-to-live for the entry
16    ttl: Duration,
17}
18
19impl<T> CacheEntry<T> {
20    /// Create a new cache entry.
21    fn new(value: T, ttl: Duration) -> Self {
22        Self {
23            value,
24            created_at: Instant::now(),
25            ttl,
26        }
27    }
28
29    /// Check if the entry is expired.
30    fn is_expired(&self) -> bool {
31        self.created_at.elapsed() > self.ttl
32    }
33}
34
35/// Cache key.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37struct CacheKey {
38    /// Method (GET, POST, etc.)
39    method: String,
40
41    /// URL
42    url: String,
43
44    /// Query parameters
45    query: Option<String>,
46
47    /// Request body
48    body: Option<String>,
49}
50
51impl CacheKey {
52    /// Create a new cache key.
53    fn new(method: &str, url: &str, query: Option<&str>, body: Option<&str>) -> Self {
54        Self {
55            method: method.to_string(),
56            url: url.to_string(),
57            query: query.map(|s| s.to_string()),
58            body: body.map(|s| s.to_string()),
59        }
60    }
61}
62
63/// Response cache.
64pub struct ResponseCache<T: Clone + Send + Sync> {
65    /// Cached responses
66    cache: Mutex<HashMap<CacheKey, CacheEntry<T>>>,
67
68    /// Default time-to-live for cache entries
69    default_ttl: Duration,
70
71    /// Maximum number of entries in the cache
72    max_entries: usize,
73}
74
75impl<T: Clone + Send + Sync> ResponseCache<T> {
76    /// Create a new response cache.
77    pub fn new(default_ttl: Duration, max_entries: usize) -> Self {
78        Self {
79            cache: Mutex::new(HashMap::new()),
80            default_ttl,
81            max_entries,
82        }
83    }
84
85    /// Get a cached response.
86    pub fn get(
87        &self,
88        method: &str,
89        url: &str,
90        query: Option<&str>,
91        body: Option<&str>,
92    ) -> Option<T> {
93        let key = CacheKey::new(method, url, query, body);
94        let mut cache = self.cache.lock().unwrap();
95
96        if let Some(entry) = cache.get(&key) {
97            if entry.is_expired() {
98                // Remove expired entry
99                cache.remove(&key);
100                None
101            } else {
102                // Return cached value
103                Some(entry.value.clone())
104            }
105        } else {
106            None
107        }
108    }
109
110    /// Store a response in the cache.
111    pub fn set(
112        &self,
113        method: &str,
114        url: &str,
115        query: Option<&str>,
116        body: Option<&str>,
117        value: T,
118        ttl: Option<Duration>,
119    ) {
120        let key = CacheKey::new(method, url, query, body);
121        let ttl = ttl.unwrap_or(self.default_ttl);
122        let entry = CacheEntry::new(value, ttl);
123
124        let mut cache = self.cache.lock().unwrap();
125
126        // Check if we need to evict entries
127        if cache.len() >= self.max_entries {
128            // Remove expired entries first
129            let expired_keys: Vec<_> = cache
130                .iter()
131                .filter(|(_, entry)| entry.is_expired())
132                .map(|(key, _)| key.clone())
133                .collect();
134
135            for key in expired_keys {
136                cache.remove(&key);
137            }
138
139            // If we still need to evict entries, remove the oldest ones
140            if cache.len() >= self.max_entries {
141                // Get all entries
142                let entries: Vec<_> = cache.iter().collect();
143
144                // Sort by creation time
145                let mut sorted_entries: Vec<_> = entries.iter().collect();
146                sorted_entries.sort_by_key(|(_, entry)| entry.created_at);
147
148                // Calculate how many to remove
149                let to_remove = entries.len() - self.max_entries + 1;
150
151                // Remove the oldest entries
152                let keys_to_remove: Vec<_> = sorted_entries
153                    .iter()
154                    .take(to_remove)
155                    .map(|(k, _)| (*k).clone())
156                    .collect();
157                for key in keys_to_remove {
158                    cache.remove(&key);
159                }
160            }
161        }
162
163        // Add the new entry
164        cache.insert(key, entry);
165    }
166
167    /// Clear the cache.
168    pub fn clear(&self) {
169        let mut cache = self.cache.lock().unwrap();
170        cache.clear();
171    }
172
173    /// Remove expired entries from the cache.
174    pub fn cleanup(&self) {
175        let mut cache = self.cache.lock().unwrap();
176        let expired_keys: Vec<_> = cache
177            .iter()
178            .filter(|(_, entry)| entry.is_expired())
179            .map(|(key, _)| key.clone())
180            .collect();
181
182        for key in expired_keys {
183            cache.remove(&key);
184        }
185    }
186}
187
188/// Cache manager for API responses.
189pub struct CacheManager {
190    /// Response caches for different types
191    caches: Mutex<HashMap<String, Arc<dyn Any + Send + Sync>>>,
192}
193
194impl CacheManager {
195    /// Create a new cache manager.
196    pub fn new() -> Self {
197        Self {
198            caches: Mutex::new(HashMap::new()),
199        }
200    }
201
202    /// Get a cache for a specific type.
203    pub fn get_cache<T: Clone + Send + Sync + 'static>(&self, name: &str) -> Arc<ResponseCache<T>> {
204        let mut caches = self.caches.lock().unwrap();
205
206        // Check if the cache exists
207        if let Some(cache) = caches.get(name) {
208            // Try to downcast to the correct type
209            if let Some(typed_cache) = cache.clone().downcast_arc::<ResponseCache<T>>().ok() {
210                return typed_cache;
211            }
212        }
213
214        // Create a new cache
215        let cache = Arc::new(ResponseCache::<T> {
216            cache: Mutex::new(HashMap::new()),
217            default_ttl: Duration::from_secs(60),
218            max_entries: 1000,
219        });
220
221        // Store the cache
222        caches.insert(
223            name.to_string(),
224            cache.clone() as Arc<dyn Any + Send + Sync>,
225        );
226
227        cache
228    }
229
230    /// Clear all caches.
231    pub fn clear_all(&self) {
232        let mut caches = self.caches.lock().unwrap();
233        caches.clear();
234    }
235}
236
237use std::any::{Any, TypeId};
238
239/// Extension trait for Arc<dyn Any + Send + Sync>.
240trait ArcAnyExt {
241    /// Downcast to a specific type.
242    fn downcast_arc<T: 'static>(self) -> Result<Arc<T>, Self>
243    where
244        Self: Sized;
245}
246
247impl ArcAnyExt for Arc<dyn Any + Send + Sync> {
248    fn downcast_arc<T: 'static>(self) -> Result<Arc<T>, Self> {
249        if (*self).type_id() == TypeId::of::<T>() {
250            // SAFETY: We just checked that the type is correct
251            let ptr = Arc::into_raw(self) as *const T;
252            // SAFETY: We're creating a new Arc from the raw pointer
253            Ok(unsafe { Arc::from_raw(ptr) })
254        } else {
255            Err(self)
256        }
257    }
258}