sf_rate_limiter/policy/
sliding_window.rs

1use crate::error::{PolicyError, ReserveError};
2use crate::policy::Policy;
3use crate::storage::{State, Storage};
4use crate::LocalTime;
5use crate::{ChronoTimestampMillis, Duration, RateLimit, Reservation};
6use chrono::TimeZone;
7use std::cmp::{max, min};
8use std::ops::Add;
9
10pub struct SlidingWindowPolicy<'a, Store: Storage<SlidingWindowState, SlidingWindowState>> {
11    limit: usize,
12    key: String,
13    interval: chrono::Duration,
14    storage: &'a mut Store,
15}
16
17impl<Store: Storage<SlidingWindowState, SlidingWindowState>> Policy
18    for SlidingWindowPolicy<'_, Store>
19{
20    fn reserve(
21        &mut self,
22        tokens: usize,
23        max_time: Option<i64>,
24    ) -> Result<Reservation, ReserveError> {
25        if tokens > self.limit {
26            // Cannot reserve more tokens than the size of the rate limiter.
27            return Err(ReserveError::TooManyTokensError {
28                requested: tokens,
29                max: self.limit,
30            });
31        }
32
33        let mut state = self
34            .storage
35            .fetch(self.key.as_str())
36            .unwrap_or_else(|| SlidingWindowState::new(self.key.clone(), &self.interval));
37
38        if state.is_expired() {
39            state = SlidingWindowState::create_from_previous_window(&state, &self.interval);
40        }
41
42        let now = LocalTime::now();
43        let hit_count = state.get_hit_count();
44        let available_tokens = self.get_available_tokens(hit_count);
45
46        let reservation = if tokens == 0 {
47            let available_tokens = available_tokens.unwrap_or(0);
48            let reset_duration = state.calculate_time_for_tokens(self.limit, state.get_hit_count());
49            let reset_time = if available_tokens > 0 {
50                LocalTime::now()
51            } else {
52                LocalTime::timestamp_millis_opt(&LocalTime, now.timestamp_millis() + reset_duration)
53                    .unwrap()
54            };
55
56            Reservation {
57                time_to_act: now.clone(),
58                rate_limit: RateLimit {
59                    available_tokens,
60                    retry_after: reset_time,
61                    accepted: true,
62                    limit: self.limit,
63                },
64            }
65        } else if available_tokens.is_some() && available_tokens.unwrap() >= tokens {
66            state.add(Some(tokens));
67            Reservation {
68                time_to_act: now.clone(),
69                rate_limit: RateLimit {
70                    available_tokens: self
71                        .get_available_tokens(state.get_hit_count())
72                        .unwrap_or(0),
73                    retry_after: now.clone(),
74                    accepted: true,
75                    limit: self.limit,
76                },
77            }
78        } else {
79            let wait_duration = state.calculate_time_for_tokens(self.limit, tokens);
80
81            if let Some(max_time) = max_time {
82                if wait_duration > max_time {
83                    return Err(ReserveError::MaxWaitDurationExceededError);
84                }
85            }
86
87            state.add(Some(tokens));
88
89            let retry_after =
90                LocalTime::timestamp_millis_opt(&LocalTime, wait_duration + now.timestamp_millis())
91                    .unwrap();
92
93            Reservation {
94                time_to_act: retry_after.clone(),
95                rate_limit: RateLimit {
96                    available_tokens: self
97                        .get_available_tokens(state.get_hit_count())
98                        .unwrap_or(0),
99                    retry_after,
100                    accepted: false,
101                    limit: self.limit,
102                },
103            }
104        };
105
106        if tokens > 0 {
107            self.storage.save(&self.key, state);
108        }
109
110        Ok(reservation)
111    }
112
113    fn consume(&mut self, tokens: usize) -> Result<Reservation, ReserveError> {
114        self.reserve(tokens, None)
115    }
116}
117
118impl<'a, Store: Storage<SlidingWindowState, SlidingWindowState>> SlidingWindowPolicy<'a, Store> {
119    pub fn new(
120        limit: usize,
121        key: String,
122        interval: Duration,
123        storage: &'a mut Store,
124    ) -> Result<Self, PolicyError> {
125        if limit == 0 {
126            return Err(PolicyError::ZeroLimitError);
127        }
128
129        if key.is_empty() {
130            return Err(PolicyError::EmptyKeyError);
131        }
132
133        Ok(Self {
134            limit,
135            key,
136            interval,
137            storage,
138        })
139    }
140
141    fn get_available_tokens(&self, hit_count: usize) -> Option<usize> {
142        if hit_count > self.limit {
143            return None; // Avoid to subtract with overflow
144        }
145
146        Some(self.limit - hit_count)
147    }
148}
149
150#[derive(Debug, Clone)]
151pub struct SlidingWindowState {
152    pub key: String,
153    hit_count: usize,
154    hit_count_for_last_window: usize,
155    pub interval: ChronoTimestampMillis,
156    pub window_end_at: ChronoTimestampMillis,
157}
158
159impl State<SlidingWindowState> for SlidingWindowState {
160    fn get_id(&self) -> String {
161        self.key.clone()
162    }
163
164    fn get_expiration_time(&self) -> usize {
165        self.interval as usize
166    }
167}
168
169impl SlidingWindowState {
170    pub fn new(key: String, interval: &chrono::Duration) -> Self {
171        Self {
172            key,
173            hit_count: 0,
174            hit_count_for_last_window: 0,
175            interval: interval.num_milliseconds(),
176            window_end_at: LocalTime::now().timestamp_millis() + interval.num_milliseconds(),
177        }
178    }
179
180    pub fn create_from_previous_window(window: &Self, interval: &chrono::Duration) -> Self {
181        let mut new = Self::new(window.key.clone(), interval);
182        let window_end_at = window.window_end_at + interval.num_milliseconds();
183
184        if LocalTime::now().timestamp_millis() < window_end_at {
185            new.hit_count_for_last_window = window.hit_count;
186            new.window_end_at = window_end_at;
187        }
188
189        new
190    }
191
192    pub fn get_expiration_time(&self) -> ChronoTimestampMillis {
193        // TODO : Maybe subtract with overflow?
194        self.window_end_at + self.interval - LocalTime::now().timestamp_millis()
195    }
196
197    pub fn is_expired(&self) -> bool {
198        LocalTime::now().timestamp_millis() > self.window_end_at
199    }
200
201    pub fn add(&mut self, hits: Option<usize>) {
202        let hits = hits.unwrap_or(1); // TODO : maybe error if hits == 0?
203        self.hit_count += hits;
204    }
205
206    /// Calculates the sliding window number of request.
207    pub fn get_hit_count(&self) -> usize {
208        let start_of_window = self.window_end_at - self.interval;
209        let percent_of_current_time_frame =
210            min(LocalTime::now().timestamp_millis() - start_of_window, 1) as usize;
211
212        // TODO : Maybe subtract with overflow?
213        self.hit_count_for_last_window * (1 - percent_of_current_time_frame) + self.hit_count
214    }
215
216    pub fn calculate_time_for_tokens(&self, max_size: usize, tokens: usize) -> i64 {
217        let remaining = max_size - self.get_hit_count();
218
219        if remaining >= tokens {
220            return 0;
221        }
222
223        let time = LocalTime::now().timestamp_millis();
224        let start_of_window = self.window_end_at - self.interval;
225        let time_passed = time - start_of_window;
226
227        // https://github.com/symfony/rate-limiter/blob/f1fbc60e7fed63f1c77bbf8601170cc80fddd95a/Policy/SlidingWindow.php#L97
228        let window_passed: f64 = {
229            // I would do it via std::cmp::min, but Ord<f64> is not implemented,
230            // so you can't do without that shit
231            let value = time_passed as f64 / self.interval as f64;
232            if value > 1. {
233                1.
234            } else {
235                value
236            }
237        };
238
239        // https://github.com/symfony/rate-limiter/blob/f1fbc60e7fed63f1c77bbf8601170cc80fddd95a/Policy/SlidingWindow.php#L98
240        let releasable = max(
241            1,
242            max_size
243                - ((self.hit_count_for_last_window as f64 * (1. - window_passed)).floor() as usize),
244        );
245
246        let remaining_window = (self.interval - time_passed) as usize;
247        let needed = tokens - remaining;
248
249        if releasable >= needed {
250            return (needed as f64 * (remaining_window as f64 / max(1, releasable) as f64)) as i64;
251        }
252
253        // TODO : Refactor
254
255        (self.window_end_at - time)
256            + (needed as i64 - releasable as i64) * (self.interval / max_size as i64)
257    }
258}