1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3pub mod in_memory;
24
25#[cfg(feature = "distributed-impl")]
26#[cfg_attr(docsrs, doc(cfg(feature = "distributed-impl")))]
27pub mod distributed;
28
29#[cfg(feature = "redis-impl")]
30#[cfg_attr(docsrs, doc(cfg(feature = "redis-impl")))]
31pub mod in_redis;
32
33pub use in_memory::*;
34
35#[cfg(feature = "distributed-impl")]
36#[cfg_attr(docsrs, doc(cfg(feature = "distributed-impl")))]
37pub use distributed::*;
38
39#[cfg(feature = "redis-impl")]
40#[cfg_attr(docsrs, doc(cfg(feature = "redis-impl")))]
41pub use in_redis::*;
42
43pub trait Storage {
48 type Error: From<RateLimitExceededError>;
49
50 fn try_acquire(&self, alg: TokenBucketAlgorithm, permits: u32) -> Result<(), Self::Error>;
51}
52
53#[derive(Debug, Clone, Eq, PartialEq)]
55pub struct State {
56 pub cap: u32,
57 pub available_tokens: u32,
58 pub last_refill: time::OffsetDateTime,
59 pub refill_tick: time::Duration,
60}
61
62pub struct TokenBucket<S> {
64 storage: S,
65}
66
67impl<S> TokenBucket<S>
68where
69 S: Storage,
70{
71 pub fn new(storage: S) -> Self {
73 Self { storage }
74 }
75
76 pub fn try_acquire(&self, permits: u32) -> Result<(), S::Error> {
82 self.storage
83 .try_acquire(TokenBucketAlgorithm { mode: Mode::N }, permits)
84 }
85
86 pub fn try_acquire_one(&self) -> Result<(), S::Error> {
92 self.try_acquire(1)
93 }
94
95 pub fn try_acquire_n_or_all(&self, permits: u32) -> Result<(), S::Error> {
101 self.storage
102 .try_acquire(TokenBucketAlgorithm { mode: Mode::All }, permits)
103 }
104}
105
106#[derive(Debug)]
108pub struct TokenBucketAlgorithm {
109 mode: Mode,
110}
111
112#[derive(Debug, Clone, Copy, Eq, PartialEq)]
113enum Mode {
114 N,
115 All,
116}
117
118impl TokenBucketAlgorithm {
119 pub fn try_acquire(
120 &self,
121 state: &mut State,
122 permits: u32,
123 ) -> Result<(), RateLimitExceededError> {
124 self.refill_state(state);
125
126 match self.mode {
127 Mode::N => {
128 if state.available_tokens >= permits {
129 state.available_tokens -= permits;
130 Ok(())
131 } else {
132 Err(RateLimitExceededError(()))
133 }
134 }
135 Mode::All => {
136 state.available_tokens -= u32::min(permits, state.available_tokens);
137 Ok(())
138 }
139 }
140 }
141
142 fn refill_state(&self, state: &mut State) {
143 let now = time::OffsetDateTime::now_utc();
144 let since_last_refill = now - state.last_refill;
145
146 if since_last_refill <= state.refill_tick {
147 return;
148 }
149
150 let tokens_since_last_refill = {
151 let mut tokens_count = 0u32;
152 let mut k = since_last_refill;
153 loop {
154 k -= state.refill_tick;
155 if k <= time::Duration::ZERO {
156 break;
157 }
158 tokens_count += 1;
159 }
160 tokens_count
161 };
162
163 state.available_tokens =
164 u32::min(state.available_tokens + tokens_since_last_refill, state.cap);
165 state.last_refill += state.refill_tick * tokens_since_last_refill;
166 }
167}
168
169#[derive(Debug, thiserror::Error)]
170#[error("rate limit exceeded")]
171pub struct RateLimitExceededError(());