1use crate::{CacheKey, CachePolicy, CacheStorage, PutHandle, StoredEntry};
30use futures_lite::{AsyncRead, AsyncWrite};
31use moka::{future::Cache, ops::compute::Op};
32use std::{
33 fmt::{self, Debug, Formatter},
34 io,
35 pin::Pin,
36 sync::Arc,
37 task::{Context, Poll},
38 time::Duration,
39};
40use trillium_http::{Body, BodySource, Headers};
41
42const DEFAULT_MAX_CAPACITY_BYTES: u64 = 256 * 1024 * 1024;
43
44type Bucket = Arc<[Variant]>;
47
48#[derive(Clone)]
49struct Variant {
50 policy: Arc<CachePolicy>,
51 body: Arc<[u8]>,
52 trailers: Option<Headers>,
53}
54
55impl Debug for Variant {
56 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
57 f.debug_struct("Variant")
58 .field("body_len", &self.body.len())
59 .field("has_trailers", &self.trailers.is_some())
60 .finish_non_exhaustive()
61 }
62}
63
64#[derive(Clone)]
76pub struct InMemoryStorage {
77 cache: Cache<CacheKey, Bucket>,
78 max_capacity_bytes: Option<u64>,
79 time_to_idle: Option<Duration>,
80 time_to_live: Option<Duration>,
81}
82
83impl Debug for InMemoryStorage {
84 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
85 f.debug_struct("InMemoryStorage")
86 .field("entry_count", &self.cache.entry_count())
87 .field("weighted_size", &self.cache.weighted_size())
88 .field("max_capacity_bytes", &self.max_capacity_bytes)
89 .field("time_to_idle", &self.time_to_idle)
90 .field("time_to_live", &self.time_to_live)
91 .finish_non_exhaustive()
92 }
93}
94
95impl Default for InMemoryStorage {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl InMemoryStorage {
102 pub fn new() -> Self {
105 Self {
106 cache: build_cache(Some(DEFAULT_MAX_CAPACITY_BYTES), None, None),
107 max_capacity_bytes: Some(DEFAULT_MAX_CAPACITY_BYTES),
108 time_to_idle: None,
109 time_to_live: None,
110 }
111 }
112
113 pub fn with_max_capacity_bytes(mut self, bytes: u64) -> Self {
117 self.max_capacity_bytes = Some(bytes);
118 self.rebuild();
119 self
120 }
121
122 pub fn unbounded(mut self) -> Self {
126 self.max_capacity_bytes = None;
127 self.rebuild();
128 self
129 }
130
131 pub fn with_time_to_idle(mut self, duration: Duration) -> Self {
134 self.time_to_idle = Some(duration);
135 self.rebuild();
136 self
137 }
138
139 pub fn with_time_to_live(mut self, duration: Duration) -> Self {
147 self.time_to_live = Some(duration);
148 self.rebuild();
149 self
150 }
151
152 pub fn entry_count(&self) -> u64 {
157 self.cache.entry_count()
158 }
159
160 pub fn weighted_size(&self) -> u64 {
165 self.cache.weighted_size()
166 }
167
168 pub async fn run_pending_tasks(&self) {
173 self.cache.run_pending_tasks().await;
174 }
175
176 fn rebuild(&mut self) {
179 self.cache = build_cache(
180 self.max_capacity_bytes,
181 self.time_to_idle,
182 self.time_to_live,
183 );
184 }
185}
186
187fn build_cache(
188 max_capacity_bytes: Option<u64>,
189 time_to_idle: Option<Duration>,
190 time_to_live: Option<Duration>,
191) -> Cache<CacheKey, Bucket> {
192 let mut builder = Cache::<CacheKey, Bucket>::builder().weigher(weigh_bucket);
193 if let Some(cap) = max_capacity_bytes {
194 builder = builder.max_capacity(cap);
195 }
196 if let Some(tti) = time_to_idle {
197 builder = builder.time_to_idle(tti);
198 }
199 if let Some(ttl) = time_to_live {
200 builder = builder.time_to_live(ttl);
201 }
202 builder.build()
203}
204
205fn weigh_bucket(_key: &CacheKey, bucket: &Bucket) -> u32 {
206 let total: u64 = bucket.iter().map(|v| v.body.len() as u64).sum();
207 u32::try_from(total).unwrap_or(u32::MAX)
208}
209
210#[derive(Clone)]
213pub struct InMemoryEntry {
214 variant: Variant,
215 cache: Cache<CacheKey, Bucket>,
216 key: CacheKey,
217}
218
219impl Debug for InMemoryEntry {
220 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
221 f.debug_struct("InMemoryEntry")
222 .field("key", &self.key)
223 .field("variant", &self.variant)
224 .finish_non_exhaustive()
225 }
226}
227
228impl StoredEntry for InMemoryEntry {
229 fn policy(&self) -> &CachePolicy {
230 &self.variant.policy
231 }
232
233 async fn refresh_policy(&mut self, new_policy: CachePolicy) -> io::Result<()> {
234 let new_arc = Arc::new(new_policy);
235 self.variant.policy = Arc::clone(&new_arc);
239
240 self.cache
241 .entry(self.key.clone())
242 .and_compute_with(|maybe_entry| async move {
243 let Some(entry) = maybe_entry else {
244 return Op::Nop;
245 };
246 let bucket = entry.into_value();
247 let mut updated = false;
248 let new_variants: Vec<Variant> = bucket
249 .iter()
250 .map(|v| {
251 if !updated && v.policy.same_variant_as(&new_arc) {
252 updated = true;
253 Variant {
254 policy: Arc::clone(&new_arc),
255 body: Arc::clone(&v.body),
256 trailers: v.trailers.clone(),
257 }
258 } else {
259 v.clone()
260 }
261 })
262 .collect();
263 if updated {
264 Op::Put(Arc::from(new_variants.into_boxed_slice()))
265 } else {
266 Op::Nop
267 }
268 })
269 .await;
270 Ok(())
271 }
272
273 async fn open(self) -> io::Result<Body> {
274 let Variant { body, trailers, .. } = self.variant;
275 let len = u64::try_from(body.len()).ok();
276 let source = ReplayBodySource {
277 body,
278 position: 0,
279 trailers,
280 };
281 Ok(Body::new_with_trailers(source, len))
282 }
283}
284
285struct ReplayBodySource {
287 body: Arc<[u8]>,
288 position: usize,
289 trailers: Option<Headers>,
290}
291
292impl AsyncRead for ReplayBodySource {
293 fn poll_read(
294 mut self: Pin<&mut Self>,
295 _cx: &mut Context<'_>,
296 buf: &mut [u8],
297 ) -> Poll<io::Result<usize>> {
298 let remaining = self.body.len() - self.position;
299 let n = remaining.min(buf.len());
300 if n > 0 {
301 buf[..n].copy_from_slice(&self.body[self.position..self.position + n]);
302 self.position += n;
303 }
304 Poll::Ready(Ok(n))
305 }
306}
307
308impl BodySource for ReplayBodySource {
309 fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
310 self.get_mut().trailers.take()
311 }
312}
313
314impl CacheStorage for InMemoryStorage {
315 type PutHandle = InMemoryPutHandle;
316 type StoredEntry = InMemoryEntry;
317
318 async fn get(&self, key: &CacheKey) -> Vec<Self::StoredEntry> {
319 let Some(bucket) = self.cache.get(key).await else {
320 return Vec::new();
321 };
322 bucket
323 .iter()
324 .map(|variant| InMemoryEntry {
325 variant: variant.clone(),
326 cache: self.cache.clone(),
327 key: key.clone(),
328 })
329 .collect()
330 }
331
332 async fn put(&self, key: CacheKey, policy: CachePolicy) -> io::Result<Self::PutHandle> {
333 Ok(InMemoryPutHandle {
334 cache: self.cache.clone(),
335 key,
336 policy,
337 buffer: Vec::new(),
338 })
339 }
340
341 async fn invalidate(&self, key: &CacheKey) {
342 self.cache.invalidate(key).await;
343 }
344}
345
346#[derive(Debug)]
352pub struct InMemoryPutHandle {
353 cache: Cache<CacheKey, Bucket>,
354 key: CacheKey,
355 policy: CachePolicy,
356 buffer: Vec<u8>,
357}
358
359impl AsyncWrite for InMemoryPutHandle {
360 fn poll_write(
361 mut self: Pin<&mut Self>,
362 _cx: &mut Context<'_>,
363 buf: &[u8],
364 ) -> Poll<io::Result<usize>> {
365 self.buffer.extend_from_slice(buf);
366 Poll::Ready(Ok(buf.len()))
367 }
368
369 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
370 Poll::Ready(Ok(()))
371 }
372
373 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
374 Poll::Ready(Ok(()))
375 }
376}
377
378impl PutHandle for InMemoryPutHandle {
379 async fn finalize(self, trailers: Option<Headers>) -> io::Result<()> {
380 let Self {
381 cache,
382 key,
383 policy,
384 buffer,
385 } = self;
386 let new_variant = Variant {
387 policy: Arc::new(policy),
388 body: Arc::from(buffer.into_boxed_slice()),
389 trailers,
390 };
391
392 cache
393 .entry(key)
394 .and_upsert_with(|maybe_entry| async move {
395 let mut variants: Vec<Variant> = match maybe_entry {
396 Some(entry) => entry.into_value().to_vec(),
397 None => Vec::new(),
398 };
399 variants.retain(|v| !v.policy.same_variant_as(&new_variant.policy));
400 variants.push(new_variant);
401 Arc::from(variants.into_boxed_slice())
402 })
403 .await;
404 Ok(())
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use crate::test_helpers::*;
412 use futures_lite::{AsyncReadExt, AsyncWriteExt};
413 use std::time::SystemTime;
414 use trillium_client::Conn;
415 use trillium_http::{KnownHeaderName::*, Method, Status};
416 use trillium_testing::{TestResult, harness, test};
417
418 fn key() -> CacheKey {
419 CacheKey::new(Method::Get, "http://example.com/".parse().unwrap())
420 }
421
422 async fn store(storage: &InMemoryStorage, conn: &Conn, body: &[u8]) {
423 let policy = policy_from(conn, SystemTime::now(), private_cache());
424 let mut handle = storage.put(key(), policy).await.unwrap();
425 handle.write_all(body).await.unwrap();
426 handle.finalize(None).await.unwrap();
427 }
428
429 async fn read_body(entry: InMemoryEntry) -> Vec<u8> {
430 let mut body = entry.open().await.unwrap();
431 let mut buf = Vec::new();
432 body.read_to_end(&mut buf).await.unwrap();
433 buf
434 }
435
436 #[test(harness)]
437 async fn get_missing_key_returns_empty() -> TestResult {
438 let storage = InMemoryStorage::new();
439 assert!(storage.get(&key()).await.is_empty());
440 Ok(())
441 }
442
443 #[test(harness)]
444 async fn put_then_get_returns_entry() -> TestResult {
445 let storage = InMemoryStorage::new();
446 let conn = exchange(
447 Method::Get,
448 &[],
449 Status::Ok,
450 &[(CacheControl, "max-age=600")],
451 );
452 store(&storage, &conn, b"hello").await;
453 let result = storage.get(&key()).await;
454 assert_eq!(result.len(), 1);
455 assert_eq!(read_body(result[0].clone()).await, b"hello");
456 Ok(())
457 }
458
459 #[test(harness)]
460 async fn put_with_same_vary_replaces() -> TestResult {
461 let storage = InMemoryStorage::new();
462 let conn = exchange(
463 Method::Get,
464 &[(AcceptEncoding, "gzip")],
465 Status::Ok,
466 &[(CacheControl, "max-age=600"), (Vary, "Accept-Encoding")],
467 );
468 store(&storage, &conn, b"v1").await;
469 store(&storage, &conn, b"v2").await;
470 let result = storage.get(&key()).await;
471 assert_eq!(result.len(), 1);
472 assert_eq!(read_body(result[0].clone()).await, b"v2");
473 Ok(())
474 }
475
476 #[test(harness)]
477 async fn put_with_different_vary_appends() -> TestResult {
478 let storage = InMemoryStorage::new();
479 let gzip = exchange(
480 Method::Get,
481 &[(AcceptEncoding, "gzip")],
482 Status::Ok,
483 &[(CacheControl, "max-age=600"), (Vary, "Accept-Encoding")],
484 );
485 let br = exchange(
486 Method::Get,
487 &[(AcceptEncoding, "br")],
488 Status::Ok,
489 &[(CacheControl, "max-age=600"), (Vary, "Accept-Encoding")],
490 );
491 store(&storage, &gzip, b"gz").await;
492 store(&storage, &br, b"br").await;
493 let result = storage.get(&key()).await;
494 assert_eq!(result.len(), 2);
495 Ok(())
496 }
497
498 #[test(harness)]
499 async fn invalidate_removes_all_entries_for_key() -> TestResult {
500 let storage = InMemoryStorage::new();
501 let conn = exchange(
502 Method::Get,
503 &[],
504 Status::Ok,
505 &[(CacheControl, "max-age=600")],
506 );
507 store(&storage, &conn, b"x").await;
508 storage.run_pending_tasks().await;
509 assert_eq!(storage.entry_count(), 1);
510 storage.invalidate(&key()).await;
511 assert!(storage.get(&key()).await.is_empty());
512 storage.run_pending_tasks().await;
513 assert_eq!(storage.entry_count(), 0);
514 Ok(())
515 }
516
517 #[test(harness)]
518 async fn invalidate_does_not_touch_other_keys() -> TestResult {
519 let storage = InMemoryStorage::new();
520 let conn = exchange(
521 Method::Get,
522 &[],
523 Status::Ok,
524 &[(CacheControl, "max-age=600")],
525 );
526 let key_a = CacheKey::new(Method::Get, "http://a.example/".parse().unwrap());
527 let key_b = CacheKey::new(Method::Get, "http://b.example/".parse().unwrap());
528 {
529 let policy_a = policy_from(&conn, SystemTime::now(), private_cache());
530 let mut h = storage.put(key_a.clone(), policy_a).await.unwrap();
531 h.write_all(b"a").await.unwrap();
532 h.finalize(None).await.unwrap();
533 }
534 {
535 let policy_b = policy_from(&conn, SystemTime::now(), private_cache());
536 let mut h = storage.put(key_b.clone(), policy_b).await.unwrap();
537 h.write_all(b"b").await.unwrap();
538 h.finalize(None).await.unwrap();
539 }
540 storage.invalidate(&key_a).await;
541 assert!(storage.get(&key_a).await.is_empty());
542 assert_eq!(storage.get(&key_b).await.len(), 1);
543 Ok(())
544 }
545
546 #[test(harness)]
547 async fn drop_put_handle_without_finalize_discards() -> TestResult {
548 let storage = InMemoryStorage::new();
549 let conn = exchange(
550 Method::Get,
551 &[],
552 Status::Ok,
553 &[(CacheControl, "max-age=600")],
554 );
555 let policy = policy_from(&conn, SystemTime::now(), private_cache());
556 let mut handle = storage.put(key(), policy).await.unwrap();
557 handle.write_all(b"partial").await.unwrap();
558 drop(handle);
559 assert!(storage.get(&key()).await.is_empty());
560 Ok(())
561 }
562
563 #[test(harness)]
564 async fn refresh_policy_updates_storage() -> TestResult {
565 let storage = InMemoryStorage::new();
566 let conn = exchange(
567 Method::Get,
568 &[],
569 Status::Ok,
570 &[(CacheControl, "max-age=600")],
571 );
572 store(&storage, &conn, b"body").await;
573
574 let mut entries = storage.get(&key()).await;
575 let original_time = entries[0].policy().response_time;
576 let refreshed = exchange(
577 Method::Get,
578 &[],
579 Status::Ok,
580 &[(CacheControl, "max-age=1200")],
581 );
582 let new_policy = policy_from(
583 &refreshed,
584 original_time + Duration::from_secs(100),
585 private_cache(),
586 );
587 entries[0].refresh_policy(new_policy).await.unwrap();
588
589 let fresh = storage.get(&key()).await;
590 assert_eq!(fresh.len(), 1);
591 assert_ne!(fresh[0].policy().response_time, original_time);
592 Ok(())
593 }
594
595 #[test(harness)]
597 async fn size_cap_evicts_old_entries() -> TestResult {
598 let storage = InMemoryStorage::new().with_max_capacity_bytes(1024);
600 let conn = exchange(
601 Method::Get,
602 &[],
603 Status::Ok,
604 &[(CacheControl, "max-age=600")],
605 );
606 let body = vec![b'x'; 600];
607 for i in 0..10 {
608 let key = CacheKey::new(
609 Method::Get,
610 format!("http://example.com/{i}").parse().unwrap(),
611 );
612 let policy = policy_from(&conn, SystemTime::now(), private_cache());
613 let mut h = storage.put(key, policy).await.unwrap();
614 h.write_all(&body).await.unwrap();
615 h.finalize(None).await.unwrap();
616 }
617 storage.run_pending_tasks().await;
618 assert!(
619 storage.weighted_size() <= 1024,
620 "weighted size {} should be within cap of 1024",
621 storage.weighted_size()
622 );
623 Ok(())
624 }
625}