1use crate::config;
2use crate::error::Error;
3use crate::rule;
4use redis::{FromRedisValue, aio::ConnectionLike};
5pub use redis_cell_rs as redis_cell;
6use std::{pin::Pin, sync::Arc};
7
8pub struct RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C> {
9 inner: S,
10 config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
11 connection: C,
12}
13
14impl<S, PR, ReqTy, RespTy, IntoRespTy, C> Clone for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C>
15where
16 S: Clone,
17 C: Clone,
18{
19 fn clone(&self) -> Self {
20 Self {
21 inner: self.inner.clone(),
22 config: Arc::clone(&self.config),
23 connection: self.connection.clone(),
24 }
25 }
26}
27
28impl<S, PR, ReqTy, RespTy, IntoRespTy, C> RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C> {
29 pub fn new<RLC>(inner: S, config: RLC, connection: C) -> Self
30 where
31 RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
32 {
33 RateLimit {
34 inner,
35 config: config.into(),
36 connection,
37 }
38 }
39}
40
41impl<S, PR, ReqTy, RespTy, IntoRespTy, C> tower::Service<ReqTy>
42 for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C>
43where
44 S: tower::Service<ReqTy, Response = RespTy> + Clone + Send + 'static,
45 S::Future: Send + 'static,
46 S::Error: Send,
47 S::Response: Send,
48 PR: rule::ProvideRule<ReqTy> + Clone + Send + Sync + 'static,
49 ReqTy: Send + 'static,
50 IntoRespTy: Into<RespTy> + 'static,
51 RespTy: 'static,
52 C: ConnectionLike + Clone + Send + 'static,
53{
54 type Response = S::Response;
55 type Error = S::Error;
56 type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;
57
58 fn poll_ready(
59 &mut self,
60 cx: &mut std::task::Context<'_>,
61 ) -> std::task::Poll<Result<(), Self::Error>> {
62 self.inner.poll_ready(cx)
63 }
64
65 fn call(&mut self, req: ReqTy) -> Self::Future {
66 let mut connection = self.connection.clone();
67 let mut inner = self.inner.clone();
68 let config = self.config.clone();
69
70 Box::pin(async move {
71 let maybe_rule = match config.rule_provider.provide(&req) {
72 Ok(rule) => rule,
73 Err(e) => {
74 let config::OnError::Sync(ref h) = config.on_error;
75 let resp = h(Error::ProvideRule(e), &req);
76 return Ok(resp.into());
77 }
78 };
79 let rule = match maybe_rule {
80 Some(rule) => rule,
81 None => {
82 return inner
83 .call(req)
84 .await
85 .map(|mut resp| match &config.on_unruled {
86 config::OnUnruled::Noop => resp,
87 config::OnUnruled::Sync(h) => {
88 h(&mut resp);
89 resp
90 }
91 });
92 }
93 };
94 let policy = rule.policy;
95 let cmd = redis_cell::Cmd::new(&rule.key, &policy);
96
97 let redis_response = match connection.req_packed_command(&cmd.into()).await {
98 Ok(res) => res,
99 Err(redis_err) => {
100 let config::OnError::Sync(ref h) = config.on_error;
101 let handled = h(redis_err.into(), &req);
102 return Ok(handled.into());
103 }
104 };
105 let redis_cell_verdict = match redis_cell::Verdict::from_redis_value(&redis_response) {
106 Ok(verdict) => verdict,
107 Err(redis_err) => {
108 let config::OnError::Sync(ref h) = config.on_error;
109 let handled = h(Error::Redis(redis_err), &req);
110 return Ok(handled.into());
111 }
112 };
113 match redis_cell_verdict {
114 redis_cell::Verdict::Blocked(details) => {
115 let config::OnError::Sync(ref h) = config.on_error;
116 let handled = h(
117 Error::RateLimit(rule::RequestBlockedDetails { rule, details }),
118 &req,
119 );
120 Ok(handled.into())
121 }
122 redis_cell::Verdict::Allowed(details) => {
123 let policy = rule.policy;
124 let resource = rule.resource;
125 inner
126 .call(req)
127 .await
128 .map(|mut resp| match &config.on_success {
129 config::OnSuccess::Noop => resp,
130 config::OnSuccess::Sync(h) => {
131 let details = rule::RequestAllowedDetails {
132 details,
133 policy,
134 resource,
135 };
136 h(details, &mut resp);
137 resp
138 }
139 })
140 }
141 }
142 })
143 }
144}
145
146pub struct RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C> {
147 config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
148 connection: C,
149}
150
151impl<PR, ReqTy, RespTy, IntoRespTy, C> Clone for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C>
152where
153 C: Clone,
154{
155 fn clone(&self) -> Self {
156 Self {
157 config: Arc::clone(&self.config),
158 connection: self.connection.clone(),
159 }
160 }
161}
162
163impl<S, PR, ReqTy, RespTy, IntoRespTy, C> tower::Layer<S>
164 for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C>
165where
166 C: Clone,
167{
168 type Service = RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C>;
169 fn layer(&self, inner: S) -> Self::Service {
170 RateLimit::new(inner, Arc::clone(&self.config), self.connection.clone())
171 }
172}
173
174impl<PR, ReqTy, RespTy, IntoRespTy, C> RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C> {
175 pub fn new<RLC>(config: RLC, connection: C) -> Self
176 where
177 RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
178 {
179 RateLimitLayer {
180 config: config.into(),
181 connection,
182 }
183 }
184}
185
186#[cfg(feature = "deadpool")]
187#[cfg_attr(docsrs, doc(cfg(feature = "deadpool")))]
188pub mod deadpool {
189 use crate::config;
190 use crate::error::Error;
191 use crate::rule;
192 use redis::{FromRedisValue, aio::ConnectionLike};
193 pub use redis_cell_rs as redis_cell;
194 use redis_cell_rs::Verdict;
195 use std::{pin::Pin, sync::Arc};
196
197 pub struct RateLimit<S, PR, ReqTy, RespTy, IntoRespTy> {
198 inner: S,
199 config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
200 pool: deadpool_redis::Pool,
201 }
202
203 impl<S, PR, ReqTy, RespTy, IntoRespTy> Clone for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy>
204 where
205 S: Clone,
206 {
207 fn clone(&self) -> Self {
208 Self {
209 inner: self.inner.clone(),
210 config: Arc::clone(&self.config),
211 pool: self.pool.clone(),
212 }
213 }
214 }
215
216 impl<S, PR, ReqTy, RespTy, IntoRespTy> RateLimit<S, PR, ReqTy, RespTy, IntoRespTy> {
217 pub fn new<RLC>(inner: S, config: RLC, pool: deadpool_redis::Pool) -> Self
218 where
219 RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
220 {
221 RateLimit {
222 inner,
223 config: config.into(),
224 pool,
225 }
226 }
227 }
228
229 impl<S, PR, ReqTy, RespTy, IntoRespTy> tower::Service<ReqTy>
230 for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy>
231 where
232 S: tower::Service<ReqTy, Response = RespTy> + Clone + Send + 'static,
233 S::Future: Send + 'static,
234 S::Error: Send,
235 S::Response: Send,
236 PR: rule::ProvideRule<ReqTy> + Clone + Send + Sync + 'static,
237 ReqTy: Send + 'static,
238 IntoRespTy: Into<RespTy> + 'static,
239 RespTy: 'static,
240 {
241 type Response = S::Response;
242 type Error = S::Error;
243 type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;
244
245 fn poll_ready(
246 &mut self,
247 cx: &mut std::task::Context<'_>,
248 ) -> std::task::Poll<Result<(), Self::Error>> {
249 self.inner.poll_ready(cx)
250 }
251
252 fn call(&mut self, req: ReqTy) -> Self::Future {
253 let pool = self.pool.clone();
254 let mut inner = self.inner.clone();
255 let config = self.config.clone();
256
257 Box::pin(async move {
258 let maybe_rule = match config.rule_provider.provide(&req) {
259 Ok(rule) => rule,
260 Err(e) => {
261 let config::OnError::Sync(ref h) = config.on_error;
262 let resp = h(Error::ProvideRule(e), &req);
263 return Ok(resp.into());
264 }
265 };
266 let rule = match maybe_rule {
267 Some(rule) => rule,
268 None => {
269 return inner
270 .call(req)
271 .await
272 .map(|mut resp| match &config.on_unruled {
273 config::OnUnruled::Noop => resp,
274 config::OnUnruled::Sync(h) => {
275 h(&mut resp);
276 resp
277 }
278 });
279 }
280 };
281 let policy = rule.policy;
282 let cmd = redis_cell::Cmd::new(&rule.key, &policy);
283
284 let mut connection = match pool.get().await {
285 Ok(conn) => conn,
286 Err(deadpool_err) => {
287 let config::OnError::Sync(ref h) = config.on_error;
288 let handled = h(deadpool_err.into(), &req);
289 return Ok(handled.into());
290 }
291 };
292 let redis_response = match connection.req_packed_command(&cmd.into()).await {
293 Ok(res) => res,
294 Err(redis_err) => {
295 let config::OnError::Sync(ref h) = config.on_error;
296 let handled = h(redis_err.into(), &req);
297 return Ok(handled.into());
298 }
299 };
300 let redis_cell_verdict = match Verdict::from_redis_value(&redis_response) {
301 Ok(verdict) => verdict,
302 Err(redis_err) => {
303 let config::OnError::Sync(ref h) = config.on_error;
304 let handled = h(Error::Redis(redis_err), &req);
305 return Ok(handled.into());
306 }
307 };
308 match redis_cell_verdict {
309 redis_cell::Verdict::Blocked(details) => {
310 let config::OnError::Sync(ref h) = config.on_error;
311 let handled = h(
312 Error::RateLimit(rule::RequestBlockedDetails { rule, details }),
313 &req,
314 );
315 Ok(handled.into())
316 }
317 redis_cell::Verdict::Allowed(details) => {
318 let policy = rule.policy;
319 let resource = rule.resource;
320 inner
321 .call(req)
322 .await
323 .map(|mut resp| match &config.on_success {
324 config::OnSuccess::Noop => resp,
325 config::OnSuccess::Sync(h) => {
326 let details = rule::RequestAllowedDetails {
327 details,
328 policy,
329 resource,
330 };
331 h(details, &mut resp);
332 resp
333 }
334 })
335 }
336 }
337 })
338 }
339 }
340
341 pub struct RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy> {
342 config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
343 pool: deadpool_redis::Pool,
344 }
345
346 impl<PR, ReqTy, RespTy, IntoRespTy> Clone for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy> {
347 fn clone(&self) -> Self {
348 Self {
349 config: Arc::clone(&self.config),
350 pool: self.pool.clone(),
351 }
352 }
353 }
354
355 impl<S, PR, ReqTy, RespTy, IntoRespTy> tower::Layer<S>
356 for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy>
357 {
358 type Service = RateLimit<S, PR, ReqTy, RespTy, IntoRespTy>;
359 fn layer(&self, inner: S) -> Self::Service {
360 RateLimit::new(inner, Arc::clone(&self.config), self.pool.clone())
361 }
362 }
363
364 impl<PR, ReqTy, RespTy, IntoRespTy> RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy> {
365 pub fn new<RLC>(config: RLC, pool: deadpool_redis::Pool) -> Self
366 where
367 RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
368 {
369 RateLimitLayer {
370 config: config.into(),
371 pool,
372 }
373 }
374 }
375}