tower_memlim/
service.rs

1use tokio::time::{Instant, Sleep};
2use tower_service::Service;
3
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use crate::{
12    error::{BoxError, MemCheckFailure},
13    future::ResponseFuture,
14    memory::{AvailableMemory, Threshold},
15};
16
17/// Enforces a limit on the underlying service when a memory threshold is met.
18#[derive(Debug)]
19pub struct MemoryLimit<T, M>
20where
21    M: AvailableMemory,
22{
23    inner: T,
24    threshold: Threshold,
25    mem_checker: M,
26    retry_interval: Duration,
27    is_ready: bool,
28    sleep: Pin<Box<Sleep>>,
29}
30
31impl<T, M> MemoryLimit<T, M>
32where
33    M: AvailableMemory,
34{
35    /// Create a new memory limiter.
36    pub fn new(inner: T, threshold: Threshold, mem_checker: M, retry_interval: Duration) -> Self {
37        Self {
38            inner,
39            threshold,
40            mem_checker,
41            retry_interval,
42            is_ready: false,
43            sleep: Box::pin(tokio::time::sleep(retry_interval)),
44        }
45    }
46
47    /// Get a reference to the inner service
48    pub fn get_ref(&self) -> &T {
49        &self.inner
50    }
51
52    /// Get a mutable reference to the inner service
53    pub fn get_mut(&mut self) -> &mut T {
54        &mut self.inner
55    }
56
57    /// Consume `self`, returning the inner service
58    pub fn into_inner(self) -> T {
59        self.inner
60    }
61}
62
63impl<S, Request, M> Service<Request> for MemoryLimit<S, M>
64where
65    S: Service<Request>,
66    M: AvailableMemory,
67    S::Error: Into<BoxError>,
68{
69    type Response = S::Response;
70    type Error = BoxError;
71    type Future = ResponseFuture<S::Future>;
72
73    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
74        // Check current memory usage
75        match self.threshold {
76            Threshold::MinAvailableBytes(min_m) => match self.mem_checker.available_memory() {
77                Ok(v) => {
78                    if v < min_m as usize {
79                        // Reset sleep
80                        self.sleep
81                            .as_mut()
82                            .reset(Instant::now() + self.retry_interval);
83
84                        // Wake up after sleep
85                        match self.sleep.as_mut().poll(cx) {
86                            Poll::Ready(_r) => {
87                                // Unlikely as we just reset the Sleep but we handle it
88                                cx.waker().wake_by_ref();
89                            }
90                            Poll::Pending => (),
91                        }
92
93                        Poll::Pending
94                    } else {
95                        self.is_ready = true;
96                        self.inner.poll_ready(cx).map_err(Into::into)
97                    }
98                }
99                Err(e) => Poll::Ready(Err(MemCheckFailure::new(e).into())),
100            },
101        }
102    }
103
104    fn call(&mut self, request: Request) -> Self::Future {
105        if self.is_ready {
106            ResponseFuture::called(self.inner.call(request))
107        } else {
108            panic!("service not ready; poll_ready must be called first")
109        }
110    }
111}
112
113impl<S: Clone, M> Clone for MemoryLimit<S, M>
114where
115    M: AvailableMemory,
116{
117    fn clone(&self) -> Self {
118        Self {
119            inner: self.inner.clone(),
120            threshold: self.threshold.clone(),
121            mem_checker: self.mem_checker.clone(),
122            retry_interval: self.retry_interval,
123            is_ready: false,
124            sleep: Box::pin(tokio::time::sleep(self.retry_interval)),
125        }
126    }
127}