Skip to main content

whois_service/
cache.rs

1use crate::{config::Config, WhoisResponse, errors::WhoisError};
2use moka::future::Cache;
3use std::{sync::Arc, time::Duration, future::Future};
4use tracing::debug;
5
6pub struct CacheService {
7    cache: Cache<String, Arc<WhoisResponse>>,
8}
9
10impl CacheService {
11    /// Create a new cache service with the given configuration.
12    /// 
13    /// Note: This cannot fail - moka cache creation is infallible.
14    #[must_use]
15    pub fn new(config: Arc<Config>) -> Self {
16        let cache = Cache::builder()
17            .max_capacity(config.cache_max_entries)
18            .time_to_live(Duration::from_secs(config.cache_ttl_seconds))
19            .build();
20
21        Self { cache }
22    }
23
24    /// Get a cached response for a domain.
25    ///
26    /// Returns `Some(response)` with `cached=true` if found, `None` otherwise.
27    pub async fn get(&self, domain: &str) -> Option<WhoisResponse> {
28        let key = Self::normalize_domain(domain);
29
30        match self.cache.get(&key).await {
31            Some(cached_response) => {
32                debug!("Cache hit for domain: {}", domain);
33                // Create a new response with cached=true
34                // This avoids mutating the cached Arc
35                Some(WhoisResponse {
36                    cached: true,
37                    ..(*cached_response).clone()
38                })
39            },
40            None => {
41                debug!("Cache miss for domain: {}", domain);
42                None
43            }
44        }
45    }
46
47    /// Store a response in the cache.
48    pub async fn set(&self, domain: &str, response: &WhoisResponse) {
49        let key = Self::normalize_domain(domain);
50        self.cache.insert(key, Arc::new(response.clone())).await;
51        debug!("Cached response for domain: {}", domain);
52    }
53
54    /// Get cached response or fetch with automatic deduplication.
55    ///
56    /// If multiple concurrent requests for the same domain arrive, only ONE
57    /// fetch operation will be executed. All other requests will wait for and
58    /// share the same result. This prevents thundering herd problems and
59    /// reduces load on WHOIS/RDAP servers.
60    ///
61    /// # Arguments
62    /// * `domain` - The domain to lookup
63    /// * `fetch_fn` - Async function that performs the actual lookup
64    ///
65    /// # Returns
66    /// WhoisResponse with `cached` field set appropriately
67    pub async fn get_or_fetch<F, Fut>(
68        &self,
69        domain: &str,
70        fetch_fn: F,
71    ) -> Result<WhoisResponse, WhoisError>
72    where
73        F: FnOnce() -> Fut,
74        Fut: Future<Output = Result<WhoisResponse, WhoisError>>,
75    {
76        let key = Self::normalize_domain(domain);
77
78        // Check if already cached
79        if let Some(cached) = self.cache.get(&key).await {
80            debug!("Cache hit for domain: {}", domain);
81            return Ok(WhoisResponse {
82                cached: true,
83                ..(*cached).clone()
84            });
85        }
86
87        // Not cached - perform fetch
88        debug!("Cache miss - executing fetch for domain: {}", domain);
89        let mut response = fetch_fn().await?;
90        response.cached = false;
91
92        // Store in cache
93        self.cache.insert(key, Arc::new(response.clone())).await;
94
95        Ok(response)
96    }
97
98    /// Normalize domain for consistent cache keys.
99    /// Domain is already lowercased by ValidatedDomain, just handle trailing dot.
100    fn normalize_domain(domain: &str) -> String {
101        // Remove trailing dot if present (common in DNS contexts)
102        // Domain is already trimmed and lowercased by ValidatedDomain
103        if let Some(stripped) = domain.strip_suffix('.') {
104            stripped.to_string()
105        } else {
106            domain.to_string()
107        }
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::ParsedWhoisData;
115
116    fn create_test_config() -> Arc<Config> {
117        Arc::new(Config {
118            port: 3000,
119            whois_timeout_seconds: 30,
120            max_response_size: 1024 * 1024,
121            cache_ttl_seconds: 3600,
122            cache_max_entries: 100,
123            max_referrals: 5,
124            discovery_timeout_seconds: 10,
125            concurrent_whois_queries: 4,
126            buffer_pool_size: 10,
127            buffer_size: 4096,
128        })
129    }
130
131    fn create_test_response(domain: &str) -> WhoisResponse {
132        WhoisResponse {
133            domain: domain.to_string(),
134            whois_server: "whois.test.com".to_string(),
135            raw_data: "test data".to_string(),
136            parsed_data: Some(ParsedWhoisData {
137                registrar: Some("Test Registrar".to_string()),
138                creation_date: Some("2020-01-01".to_string()),
139                expiration_date: Some("2030-01-01".to_string()),
140                updated_date: Some("2024-01-01".to_string()),
141                name_servers: vec!["ns1.test.com".to_string()],
142                status: vec!["ok".to_string()],
143                registrant_name: None,
144                registrant_email: None,
145                admin_email: None,
146                tech_email: None,
147                created_ago: Some(1000),
148                updated_ago: Some(100),
149                expires_in: Some(2000),
150            }),
151            cached: false,
152            query_time_ms: 100,
153            parsing_analysis: None,
154        }
155    }
156
157    #[tokio::test]
158    async fn test_cache_creation() {
159        let config = create_test_config();
160        let cache = CacheService::new(config);
161        // Should not panic or error
162        drop(cache);
163    }
164
165    #[tokio::test]
166    async fn test_cache_miss() {
167        let config = create_test_config();
168        let cache = CacheService::new(config);
169
170        let result = cache.get("example.com").await;
171        assert!(result.is_none());
172    }
173
174    #[tokio::test]
175    async fn test_cache_hit() {
176        let config = create_test_config();
177        let cache = CacheService::new(config);
178
179        let response = create_test_response("example.com");
180        cache.set("example.com", &response).await;
181
182        let cached = cache.get("example.com").await;
183        assert!(cached.is_some());
184
185        let cached = cached.unwrap();
186        assert_eq!(cached.domain, "example.com");
187        assert_eq!(cached.cached, true);
188        assert_eq!(cached.whois_server, "whois.test.com");
189    }
190
191    #[tokio::test]
192    async fn test_cache_normalization() {
193        let config = create_test_config();
194        let cache = CacheService::new(config);
195
196        // Store with trailing dot
197        let response = create_test_response("example.com.");
198        cache.set("example.com.", &response).await;
199
200        // Retrieve without trailing dot (should hit)
201        let cached = cache.get("example.com").await;
202        assert!(cached.is_some());
203
204        // Retrieve with trailing dot (should also hit)
205        let cached = cache.get("example.com.").await;
206        assert!(cached.is_some());
207    }
208
209    #[tokio::test]
210    async fn test_cache_get_or_fetch_miss() {
211        let config = create_test_config();
212        let cache = CacheService::new(config);
213
214        let mut fetch_count = 0;
215
216        let result = cache
217            .get_or_fetch("example.com", || async {
218                fetch_count += 1;
219                Ok(create_test_response("example.com"))
220            })
221            .await;
222
223        assert!(result.is_ok());
224        assert_eq!(fetch_count, 1);
225
226        let response = result.unwrap();
227        assert_eq!(response.cached, false);
228        assert_eq!(response.domain, "example.com");
229    }
230
231    #[tokio::test]
232    async fn test_cache_get_or_fetch_hit() {
233        let config = create_test_config();
234        let cache = CacheService::new(config);
235
236        // Pre-populate cache
237        let response = create_test_response("example.com");
238        cache.set("example.com", &response).await;
239
240        let mut fetch_count = 0;
241
242        let result = cache
243            .get_or_fetch("example.com", || async {
244                fetch_count += 1;
245                Ok(create_test_response("example.com"))
246            })
247            .await;
248
249        assert!(result.is_ok());
250        assert_eq!(fetch_count, 0); // Should NOT fetch
251
252        let response = result.unwrap();
253        assert_eq!(response.cached, true);
254    }
255
256    #[tokio::test]
257    async fn test_cache_get_or_fetch_error() {
258        let config = create_test_config();
259        let cache = CacheService::new(config);
260
261        let result = cache
262            .get_or_fetch("example.com", || async {
263                Err(crate::errors::WhoisError::Internal("Test error".to_string()))
264            })
265            .await;
266
267        assert!(result.is_err());
268    }
269
270    #[tokio::test]
271    async fn test_cache_multiple_domains() {
272        let config = create_test_config();
273        let cache = CacheService::new(config);
274
275        // Cache multiple domains
276        cache.set("example.com", &create_test_response("example.com")).await;
277        cache.set("test.com", &create_test_response("test.com")).await;
278        cache.set("demo.org", &create_test_response("demo.org")).await;
279
280        // All should be retrievable
281        assert!(cache.get("example.com").await.is_some());
282        assert!(cache.get("test.com").await.is_some());
283        assert!(cache.get("demo.org").await.is_some());
284
285        // Non-existent should miss
286        assert!(cache.get("notcached.com").await.is_none());
287    }
288
289    #[test]
290    fn test_normalize_domain_trailing_dot() {
291        assert_eq!(CacheService::normalize_domain("example.com."), "example.com");
292        assert_eq!(CacheService::normalize_domain("example.com"), "example.com");
293        assert_eq!(CacheService::normalize_domain("test.co.uk."), "test.co.uk");
294        assert_eq!(CacheService::normalize_domain("test.co.uk"), "test.co.uk");
295    }
296}