skp_ratelimit/algorithm/
leaky_bucket.rs1use std::time::Duration;
7
8use crate::algorithm::{current_timestamp_ms, timestamp_to_instant, Algorithm};
9use crate::decision::{Decision, DecisionMetadata, RateLimitInfo};
10use crate::error::Result;
11use crate::quota::Quota;
12use crate::storage::{Storage, StorageEntry};
13
14#[derive(Debug, Clone, Default)]
19pub struct LeakyBucket;
20
21impl LeakyBucket {
22 pub fn new() -> Self {
24 Self
25 }
26
27 fn calculate_leak(&self, elapsed_ms: u64, leak_rate: f64) -> f64 {
29 let elapsed_secs = elapsed_ms as f64 / 1000.0;
30 elapsed_secs * leak_rate
31 }
32}
33
34impl Algorithm for LeakyBucket {
35 fn name(&self) -> &'static str {
36 "leaky_bucket"
37 }
38
39 async fn check_and_record<S: Storage>(
40 &self,
41 storage: &S,
42 key: &str,
43 quota: &Quota,
44 ) -> Result<Decision> {
45 let now = current_timestamp_ms();
46 let max_level = quota.effective_burst() as f64;
47 let leak_rate = quota.effective_refill_rate(); let ttl_ms = ((max_level / leak_rate) * 1000.0 * 2.0) as u64;
50 let ttl = Duration::from_millis(ttl_ms.max(1000));
51
52 let decision = storage
53 .execute_atomic(key, ttl, |entry| {
54 let (mut level, last_update) = match entry {
55 Some(e) => (e.tokens.unwrap_or(0.0), e.last_update),
56 None => (0.0, now),
57 };
58
59 if now > last_update {
61 let elapsed = now - last_update;
62 let leaked = self.calculate_leak(elapsed, leak_rate);
63 level = (level - leaked).max(0.0);
64 }
65
66 if level + 1.0 <= max_level {
68 level += 1.0;
69 let new_entry = StorageEntry::with_tokens(level, now);
70
71 let remaining = (max_level - level).floor() as u64;
72 let drain_time = (level / leak_rate * 1000.0) as u64;
73 let reset_at = timestamp_to_instant(now + drain_time);
74
75 let info = RateLimitInfo::new(max_level as u64, remaining, reset_at, timestamp_to_instant(now))
76 .with_algorithm("leaky_bucket")
77 .with_metadata(DecisionMetadata::new().with_tokens_available(max_level - level));
78
79 (new_entry, Decision::allowed(info))
80 } else {
81 let new_entry = StorageEntry::with_tokens(level, now);
82
83 let wait_ms = ((level + 1.0 - max_level) / leak_rate * 1000.0) as u64;
85 let reset_at = timestamp_to_instant(now + wait_ms);
86
87 let info = RateLimitInfo::new(max_level as u64, 0, reset_at, timestamp_to_instant(now))
88 .with_algorithm("leaky_bucket")
89 .with_retry_after(Duration::from_millis(wait_ms));
90
91 (new_entry, Decision::denied(info))
92 }
93 })
94 .await?;
95
96 Ok(decision)
97 }
98
99 async fn check<S: Storage>(
100 &self,
101 storage: &S,
102 key: &str,
103 quota: &Quota,
104 ) -> Result<Decision> {
105 let now = current_timestamp_ms();
106 let max_level = quota.effective_burst() as f64;
107 let leak_rate = quota.effective_refill_rate();
108
109 let entry = storage.get(key).await?;
110
111 let (mut level, last_update) = match entry {
112 Some(e) => (e.tokens.unwrap_or(0.0), e.last_update),
113 None => (0.0, now),
114 };
115
116 if now > last_update {
117 let elapsed = now - last_update;
118 let leaked = self.calculate_leak(elapsed, leak_rate);
119 level = (level - leaked).max(0.0);
120 }
121
122 let remaining = (max_level - level).floor() as u64;
123 let drain_time = (level / leak_rate * 1000.0) as u64;
124 let reset_at = timestamp_to_instant(now + drain_time);
125
126 let info = RateLimitInfo::new(max_level as u64, remaining, reset_at, timestamp_to_instant(now))
127 .with_algorithm("leaky_bucket");
128
129 Ok(if level + 1.0 <= max_level {
130 Decision::allowed(info)
131 } else {
132 let wait_ms = ((level + 1.0 - max_level) / leak_rate * 1000.0) as u64;
133 Decision::denied(info.with_retry_after(Duration::from_millis(wait_ms)))
134 })
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::storage::MemoryStorage;
142
143 #[tokio::test]
144 async fn test_leaky_bucket_basic() {
145 let algorithm = LeakyBucket::new();
146 let storage = MemoryStorage::new();
147 let quota = Quota::per_second(10).with_burst(5);
148
149 for i in 1..=5 {
150 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
151 assert!(decision.is_allowed(), "Request {} should be allowed", i);
152 }
153
154 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
155 assert!(decision.is_denied());
156 }
157
158 #[tokio::test]
159 async fn test_leaky_bucket_drain() {
160 let algorithm = LeakyBucket::new();
161 let storage = MemoryStorage::new();
162 let quota = Quota::per_second(10).with_burst(2);
163
164 algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
166 algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
167
168 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
169 assert!(decision.is_denied());
170
171 tokio::time::sleep(Duration::from_millis(150)).await;
173
174 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
175 assert!(decision.is_allowed());
176 }
177}