sf_rate_limiter/policy/
sliding_window.rs1use crate::error::{PolicyError, ReserveError};
2use crate::policy::Policy;
3use crate::storage::{State, Storage};
4use crate::LocalTime;
5use crate::{ChronoTimestampMillis, Duration, RateLimit, Reservation};
6use chrono::TimeZone;
7use std::cmp::{max, min};
8use std::ops::Add;
9
10pub struct SlidingWindowPolicy<'a, Store: Storage<SlidingWindowState, SlidingWindowState>> {
11 limit: usize,
12 key: String,
13 interval: chrono::Duration,
14 storage: &'a mut Store,
15}
16
17impl<Store: Storage<SlidingWindowState, SlidingWindowState>> Policy
18 for SlidingWindowPolicy<'_, Store>
19{
20 fn reserve(
21 &mut self,
22 tokens: usize,
23 max_time: Option<i64>,
24 ) -> Result<Reservation, ReserveError> {
25 if tokens > self.limit {
26 return Err(ReserveError::TooManyTokensError {
28 requested: tokens,
29 max: self.limit,
30 });
31 }
32
33 let mut state = self
34 .storage
35 .fetch(self.key.as_str())
36 .unwrap_or_else(|| SlidingWindowState::new(self.key.clone(), &self.interval));
37
38 if state.is_expired() {
39 state = SlidingWindowState::create_from_previous_window(&state, &self.interval);
40 }
41
42 let now = LocalTime::now();
43 let hit_count = state.get_hit_count();
44 let available_tokens = self.get_available_tokens(hit_count);
45
46 let reservation = if tokens == 0 {
47 let available_tokens = available_tokens.unwrap_or(0);
48 let reset_duration = state.calculate_time_for_tokens(self.limit, state.get_hit_count());
49 let reset_time = if available_tokens > 0 {
50 LocalTime::now()
51 } else {
52 LocalTime::timestamp_millis_opt(&LocalTime, now.timestamp_millis() + reset_duration)
53 .unwrap()
54 };
55
56 Reservation {
57 time_to_act: now.clone(),
58 rate_limit: RateLimit {
59 available_tokens,
60 retry_after: reset_time,
61 accepted: true,
62 limit: self.limit,
63 },
64 }
65 } else if available_tokens.is_some() && available_tokens.unwrap() >= tokens {
66 state.add(Some(tokens));
67 Reservation {
68 time_to_act: now.clone(),
69 rate_limit: RateLimit {
70 available_tokens: self
71 .get_available_tokens(state.get_hit_count())
72 .unwrap_or(0),
73 retry_after: now.clone(),
74 accepted: true,
75 limit: self.limit,
76 },
77 }
78 } else {
79 let wait_duration = state.calculate_time_for_tokens(self.limit, tokens);
80
81 if let Some(max_time) = max_time {
82 if wait_duration > max_time {
83 return Err(ReserveError::MaxWaitDurationExceededError);
84 }
85 }
86
87 state.add(Some(tokens));
88
89 let retry_after =
90 LocalTime::timestamp_millis_opt(&LocalTime, wait_duration + now.timestamp_millis())
91 .unwrap();
92
93 Reservation {
94 time_to_act: retry_after.clone(),
95 rate_limit: RateLimit {
96 available_tokens: self
97 .get_available_tokens(state.get_hit_count())
98 .unwrap_or(0),
99 retry_after,
100 accepted: false,
101 limit: self.limit,
102 },
103 }
104 };
105
106 if tokens > 0 {
107 self.storage.save(&self.key, state);
108 }
109
110 Ok(reservation)
111 }
112
113 fn consume(&mut self, tokens: usize) -> Result<Reservation, ReserveError> {
114 self.reserve(tokens, None)
115 }
116}
117
118impl<'a, Store: Storage<SlidingWindowState, SlidingWindowState>> SlidingWindowPolicy<'a, Store> {
119 pub fn new(
120 limit: usize,
121 key: String,
122 interval: Duration,
123 storage: &'a mut Store,
124 ) -> Result<Self, PolicyError> {
125 if limit == 0 {
126 return Err(PolicyError::ZeroLimitError);
127 }
128
129 if key.is_empty() {
130 return Err(PolicyError::EmptyKeyError);
131 }
132
133 Ok(Self {
134 limit,
135 key,
136 interval,
137 storage,
138 })
139 }
140
141 fn get_available_tokens(&self, hit_count: usize) -> Option<usize> {
142 if hit_count > self.limit {
143 return None; }
145
146 Some(self.limit - hit_count)
147 }
148}
149
150#[derive(Debug, Clone)]
151pub struct SlidingWindowState {
152 pub key: String,
153 hit_count: usize,
154 hit_count_for_last_window: usize,
155 pub interval: ChronoTimestampMillis,
156 pub window_end_at: ChronoTimestampMillis,
157}
158
159impl State<SlidingWindowState> for SlidingWindowState {
160 fn get_id(&self) -> String {
161 self.key.clone()
162 }
163
164 fn get_expiration_time(&self) -> usize {
165 self.interval as usize
166 }
167}
168
169impl SlidingWindowState {
170 pub fn new(key: String, interval: &chrono::Duration) -> Self {
171 Self {
172 key,
173 hit_count: 0,
174 hit_count_for_last_window: 0,
175 interval: interval.num_milliseconds(),
176 window_end_at: LocalTime::now().timestamp_millis() + interval.num_milliseconds(),
177 }
178 }
179
180 pub fn create_from_previous_window(window: &Self, interval: &chrono::Duration) -> Self {
181 let mut new = Self::new(window.key.clone(), interval);
182 let window_end_at = window.window_end_at + interval.num_milliseconds();
183
184 if LocalTime::now().timestamp_millis() < window_end_at {
185 new.hit_count_for_last_window = window.hit_count;
186 new.window_end_at = window_end_at;
187 }
188
189 new
190 }
191
192 pub fn get_expiration_time(&self) -> ChronoTimestampMillis {
193 self.window_end_at + self.interval - LocalTime::now().timestamp_millis()
195 }
196
197 pub fn is_expired(&self) -> bool {
198 LocalTime::now().timestamp_millis() > self.window_end_at
199 }
200
201 pub fn add(&mut self, hits: Option<usize>) {
202 let hits = hits.unwrap_or(1); self.hit_count += hits;
204 }
205
206 pub fn get_hit_count(&self) -> usize {
208 let start_of_window = self.window_end_at - self.interval;
209 let percent_of_current_time_frame =
210 min(LocalTime::now().timestamp_millis() - start_of_window, 1) as usize;
211
212 self.hit_count_for_last_window * (1 - percent_of_current_time_frame) + self.hit_count
214 }
215
216 pub fn calculate_time_for_tokens(&self, max_size: usize, tokens: usize) -> i64 {
217 let remaining = max_size - self.get_hit_count();
218
219 if remaining >= tokens {
220 return 0;
221 }
222
223 let time = LocalTime::now().timestamp_millis();
224 let start_of_window = self.window_end_at - self.interval;
225 let time_passed = time - start_of_window;
226
227 let window_passed: f64 = {
229 let value = time_passed as f64 / self.interval as f64;
232 if value > 1. {
233 1.
234 } else {
235 value
236 }
237 };
238
239 let releasable = max(
241 1,
242 max_size
243 - ((self.hit_count_for_last_window as f64 * (1. - window_passed)).floor() as usize),
244 );
245
246 let remaining_window = (self.interval - time_passed) as usize;
247 let needed = tokens - remaining;
248
249 if releasable >= needed {
250 return (needed as f64 * (remaining_window as f64 / max(1, releasable) as f64)) as i64;
251 }
252
253 (self.window_end_at - time)
256 + (needed as i64 - releasable as i64) * (self.interval / max_size as i64)
257 }
258}