sf_rate_limiter/policy/
fixed_window.rs

1use crate::error::{PolicyError, ReserveError};
2use crate::policy::Policy;
3use crate::storage::{State, Storage};
4use crate::{Duration, LocalDateTime, LocalTime, RateLimit, Reservation};
5use chrono::TimeZone;
6
7pub struct FixedWindowPolicy<'a, Store: Storage<FixedWindowState, FixedWindowState>> {
8    limit: usize,
9    key: String,
10    interval: chrono::Duration,
11    storage: &'a mut Store,
12}
13
14impl<Store: Storage<FixedWindowState, FixedWindowState>> Policy for FixedWindowPolicy<'_, Store> {
15    fn reserve(
16        &mut self,
17        tokens: usize,
18        max_time: Option<i64>,
19    ) -> Result<Reservation, ReserveError> {
20        if tokens > self.limit {
21            // Cannot reserve more tokens than the size of the rate limiter.
22            return Err(ReserveError::TooManyTokensError {
23                requested: tokens,
24                max: self.limit,
25            });
26        }
27
28        let mut state = self
29            .storage
30            .fetch(self.key.as_str())
31            .unwrap_or_else(|| FixedWindowState::new(self.key.clone(), &self.interval, self.limit));
32
33        let now = LocalTime::now();
34        let available_tokens = state.get_available_tokens(&now);
35
36        let reservation: Reservation = if tokens == 0 {
37            let wait_duration = state.calculate_time_for_tokens(tokens, &now);
38            let retry_after =
39                LocalTime::timestamp_millis_opt(&LocalTime, now.timestamp_millis() + wait_duration)
40                    .unwrap();
41
42            Reservation {
43                time_to_act: retry_after.clone(),
44                rate_limit: RateLimit {
45                    available_tokens: available_tokens.unwrap_or(0),
46                    retry_after,
47                    accepted: true,
48                    limit: self.limit,
49                },
50            }
51        } else if available_tokens.is_some() && available_tokens.unwrap() >= tokens {
52            state.add(Some(tokens), Some(&now));
53            Reservation {
54                time_to_act: now.clone(),
55                rate_limit: RateLimit {
56                    available_tokens: state.get_available_tokens(&now).unwrap_or(0),
57                    retry_after: now.clone(),
58                    accepted: true,
59                    limit: self.limit,
60                },
61            }
62        } else {
63            let wait_duration = state.calculate_time_for_tokens(tokens, &now);
64
65            if let Some(max_time) = max_time {
66                if wait_duration > max_time {
67                    return Err(ReserveError::MaxWaitDurationExceededError);
68                }
69            }
70
71            state.add(Some(tokens), Some(&now));
72
73            let retry_after =
74                LocalTime::timestamp_millis_opt(&LocalTime, now.timestamp_millis() + wait_duration)
75                    .unwrap();
76
77            Reservation {
78                time_to_act: retry_after.clone(),
79                rate_limit: RateLimit {
80                    available_tokens: state.get_available_tokens(&now).unwrap_or(0),
81                    retry_after,
82                    accepted: false,
83                    limit: self.limit,
84                },
85            }
86        };
87
88        if tokens > 0 {
89            self.storage.save(&self.key, state);
90        }
91
92        Ok(reservation)
93    }
94
95    fn consume(&mut self, tokens: usize) -> Result<Reservation, ReserveError> {
96        self.reserve(tokens, None)
97    }
98}
99
100impl<'a, Store: Storage<FixedWindowState, FixedWindowState>> FixedWindowPolicy<'a, Store> {
101    pub fn new(
102        limit: usize,
103        key: String,
104        interval: Duration,
105        storage: &'a mut Store,
106    ) -> Result<Self, PolicyError> {
107        if limit == 0 {
108            return Err(PolicyError::ZeroLimitError);
109        }
110
111        if key.is_empty() {
112            return Err(PolicyError::EmptyKeyError);
113        }
114
115        Ok(Self {
116            limit,
117            key,
118            interval,
119            storage,
120        })
121    }
122}
123
124#[derive(Debug, Clone)]
125pub struct FixedWindowState {
126    pub key: String,
127    pub hit_count: usize,
128    pub interval: i64, // chrono timestamp millis
129    pub max_size: usize,
130    pub timer: i64,
131}
132
133impl State<FixedWindowState> for FixedWindowState {
134    fn get_id(&self) -> String {
135        self.key.clone()
136    }
137
138    fn get_expiration_time(&self) -> usize {
139        self.interval as usize
140    }
141}
142
143impl FixedWindowState {
144    pub fn new(key: String, interval: &chrono::Duration, max_size: usize) -> Self {
145        Self {
146            key,
147            hit_count: 0,
148            interval: interval.num_milliseconds(),
149            max_size,
150            timer: 0,
151        }
152    }
153
154    pub fn add(&mut self, hits: Option<usize>, now: Option<&LocalDateTime>) {
155        let hits = hits.unwrap_or(1); // TODO : maybe error if hits == 0 ?
156        let now = now
157            .map(|date| date.clone())
158            .unwrap_or_else(|| LocalTime::now())
159            .timestamp_millis();
160
161        if (now - self.timer) > self.interval {
162            // reset window
163            self.timer = now;
164            self.hit_count = 0;
165        }
166
167        self.hit_count += hits;
168    }
169
170    pub fn get_available_tokens(&self, now: &LocalDateTime) -> Option<usize> {
171        let now = now.timestamp_millis();
172
173        if (now - self.timer) > self.interval {
174            return Some(self.max_size);
175        }
176
177        if self.hit_count > self.max_size {
178            return None; // Avoid to subtract with overflow
179        }
180
181        Some(self.max_size - self.hit_count)
182    }
183
184    pub fn calculate_time_for_tokens(&self, tokens: usize, now: &LocalDateTime) -> i64 {
185        if (self.max_size - self.hit_count) >= tokens {
186            return 0;
187        }
188
189        self.timer + self.interval - now.timestamp_millis()
190    }
191}