webull_rs/utils/
cache.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, Instant};
5
6/// Cache entry.
7#[derive(Debug, Clone)]
8struct CacheEntry<T> {
9    /// Cached value
10    value: T,
11
12    /// When the entry was created
13    created_at: Instant,
14
15    /// Time-to-live for the entry
16    ttl: Duration,
17}
18
19impl<T> CacheEntry<T> {
20    /// Create a new cache entry.
21    fn new(value: T, ttl: Duration) -> Self {
22        Self {
23            value,
24            created_at: Instant::now(),
25            ttl,
26        }
27    }
28
29    /// Check if the entry is expired.
30    fn is_expired(&self) -> bool {
31        self.created_at.elapsed() > self.ttl
32    }
33}
34
35/// Cache key.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37struct CacheKey {
38    /// Method (GET, POST, etc.)
39    method: String,
40
41    /// URL
42    url: String,
43
44    /// Query parameters
45    query: Option<String>,
46
47    /// Request body
48    body: Option<String>,
49}
50
51impl CacheKey {
52    /// Create a new cache key.
53    fn new(method: &str, url: &str, query: Option<&str>, body: Option<&str>) -> Self {
54        Self {
55            method: method.to_string(),
56            url: url.to_string(),
57            query: query.map(|s| s.to_string()),
58            body: body.map(|s| s.to_string()),
59        }
60    }
61}
62
63/// Response cache.
64pub struct ResponseCache<T: Clone + Send + Sync> {
65    /// Cached responses
66    cache: Mutex<HashMap<CacheKey, CacheEntry<T>>>,
67
68    /// Default time-to-live for cache entries
69    default_ttl: Duration,
70
71    /// Maximum number of entries in the cache
72    max_entries: usize,
73}
74
75impl<T: Clone + Send + Sync> ResponseCache<T> {
76    /// Create a new response cache.
77    pub fn new(default_ttl: Duration, max_entries: usize) -> Self {
78        Self {
79            cache: Mutex::new(HashMap::new()),
80            default_ttl,
81            max_entries,
82        }
83    }
84
85    /// Get a cached response.
86    pub fn get(&self, method: &str, url: &str, query: Option<&str>, body: Option<&str>) -> Option<T> {
87        let key = CacheKey::new(method, url, query, body);
88        let mut cache = self.cache.lock().unwrap();
89
90        if let Some(entry) = cache.get(&key) {
91            if entry.is_expired() {
92                // Remove expired entry
93                cache.remove(&key);
94                None
95            } else {
96                // Return cached value
97                Some(entry.value.clone())
98            }
99        } else {
100            None
101        }
102    }
103
104    /// Store a response in the cache.
105    pub fn set(&self, method: &str, url: &str, query: Option<&str>, body: Option<&str>, value: T, ttl: Option<Duration>) {
106        let key = CacheKey::new(method, url, query, body);
107        let ttl = ttl.unwrap_or(self.default_ttl);
108        let entry = CacheEntry::new(value, ttl);
109
110        let mut cache = self.cache.lock().unwrap();
111
112        // Check if we need to evict entries
113        if cache.len() >= self.max_entries {
114            // Remove expired entries first
115            let expired_keys: Vec<_> = cache.iter()
116                .filter(|(_, entry)| entry.is_expired())
117                .map(|(key, _)| key.clone())
118                .collect();
119
120            for key in expired_keys {
121                cache.remove(&key);
122            }
123
124            // If we still need to evict entries, remove the oldest ones
125            if cache.len() >= self.max_entries {
126                // Get all entries
127                let entries: Vec<_> = cache.iter().collect();
128
129                // Sort by creation time
130                let mut sorted_entries: Vec<_> = entries.iter().collect();
131                sorted_entries.sort_by_key(|(_, entry)| entry.created_at);
132
133                // Calculate how many to remove
134                let to_remove = entries.len() - self.max_entries + 1;
135
136                // Remove the oldest entries
137                let keys_to_remove: Vec<_> = sorted_entries.iter().take(to_remove).map(|(k, _)| (*k).clone()).collect();
138                for key in keys_to_remove {
139                    cache.remove(&key);
140                }
141            }
142        }
143
144        // Add the new entry
145        cache.insert(key, entry);
146    }
147
148    /// Clear the cache.
149    pub fn clear(&self) {
150        let mut cache = self.cache.lock().unwrap();
151        cache.clear();
152    }
153
154    /// Remove expired entries from the cache.
155    pub fn cleanup(&self) {
156        let mut cache = self.cache.lock().unwrap();
157        let expired_keys: Vec<_> = cache.iter()
158            .filter(|(_, entry)| entry.is_expired())
159            .map(|(key, _)| key.clone())
160            .collect();
161
162        for key in expired_keys {
163            cache.remove(&key);
164        }
165    }
166}
167
168/// Cache manager for API responses.
169pub struct CacheManager {
170    /// Response caches for different types
171    caches: Mutex<HashMap<String, Arc<dyn Any + Send + Sync>>>,
172}
173
174impl CacheManager {
175    /// Create a new cache manager.
176    pub fn new() -> Self {
177        Self {
178            caches: Mutex::new(HashMap::new()),
179        }
180    }
181
182    /// Get a cache for a specific type.
183    pub fn get_cache<T: Clone + Send + Sync + 'static>(&self, name: &str) -> Arc<ResponseCache<T>> {
184        let mut caches = self.caches.lock().unwrap();
185
186        // Check if the cache exists
187        if let Some(cache) = caches.get(name) {
188            // Try to downcast to the correct type
189            if let Some(typed_cache) = cache.clone().downcast_arc::<ResponseCache<T>>().ok() {
190                return typed_cache;
191            }
192        }
193
194        // Create a new cache
195        let cache = Arc::new(ResponseCache::<T> {
196            cache: Mutex::new(HashMap::new()),
197            default_ttl: Duration::from_secs(60),
198            max_entries: 1000,
199        });
200
201        // Store the cache
202        caches.insert(name.to_string(), cache.clone() as Arc<dyn Any + Send + Sync>);
203
204        cache
205    }
206
207    /// Clear all caches.
208    pub fn clear_all(&self) {
209        let mut caches = self.caches.lock().unwrap();
210        caches.clear();
211    }
212}
213
214use std::any::{Any, TypeId};
215
216/// Extension trait for Arc<dyn Any + Send + Sync>.
217trait ArcAnyExt {
218    /// Downcast to a specific type.
219    fn downcast_arc<T: 'static>(self) -> Result<Arc<T>, Self> where Self: Sized;
220}
221
222impl ArcAnyExt for Arc<dyn Any + Send + Sync> {
223    fn downcast_arc<T: 'static>(self) -> Result<Arc<T>, Self> {
224        if (*self).type_id() == TypeId::of::<T>() {
225            // SAFETY: We just checked that the type is correct
226            let ptr = Arc::into_raw(self) as *const T;
227            // SAFETY: We're creating a new Arc from the raw pointer
228            Ok(unsafe { Arc::from_raw(ptr) })
229        } else {
230            Err(self)
231        }
232    }
233}