tower_redis_cell/
service.rs

1use crate::config;
2use crate::error::Error;
3use crate::rule;
4use redis::{FromRedisValue, aio::ConnectionLike};
5pub use redis_cell_rs as redis_cell;
6use std::{pin::Pin, sync::Arc};
7
8pub struct RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C> {
9    inner: S,
10    config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
11    connection: C,
12}
13
14impl<S, PR, ReqTy, RespTy, IntoRespTy, C> Clone for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C>
15where
16    S: Clone,
17    C: Clone,
18{
19    fn clone(&self) -> Self {
20        Self {
21            inner: self.inner.clone(),
22            config: Arc::clone(&self.config),
23            connection: self.connection.clone(),
24        }
25    }
26}
27
28impl<S, PR, ReqTy, RespTy, IntoRespTy, C> RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C> {
29    pub fn new<RLC>(inner: S, config: RLC, connection: C) -> Self
30    where
31        RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
32    {
33        RateLimit {
34            inner,
35            config: config.into(),
36            connection,
37        }
38    }
39}
40
41impl<S, PR, ReqTy, RespTy, IntoRespTy, C> tower::Service<ReqTy>
42    for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C>
43where
44    S: tower::Service<ReqTy, Response = RespTy> + Clone + Send + 'static,
45    S::Future: Send + 'static,
46    S::Error: Send,
47    S::Response: Send,
48    PR: rule::ProvideRule<ReqTy> + Clone + Send + Sync + 'static,
49    ReqTy: Send + 'static,
50    IntoRespTy: Into<RespTy> + 'static,
51    RespTy: 'static,
52    C: ConnectionLike + Clone + Send + 'static,
53{
54    type Response = S::Response;
55    type Error = S::Error;
56    type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;
57
58    fn poll_ready(
59        &mut self,
60        cx: &mut std::task::Context<'_>,
61    ) -> std::task::Poll<Result<(), Self::Error>> {
62        self.inner.poll_ready(cx)
63    }
64
65    fn call(&mut self, req: ReqTy) -> Self::Future {
66        let mut connection = self.connection.clone();
67        let mut inner = self.inner.clone();
68        let config = self.config.clone();
69
70        Box::pin(async move {
71            let maybe_rule = match config.rule_provider.provide(&req) {
72                Ok(rule) => rule,
73                Err(e) => {
74                    let config::OnError::Sync(ref h) = config.on_error;
75                    let resp = h(Error::ProvideRule(e), &req);
76                    return Ok(resp.into());
77                }
78            };
79            let rule = match maybe_rule {
80                Some(rule) => rule,
81                None => {
82                    return inner
83                        .call(req)
84                        .await
85                        .map(|mut resp| match &config.on_unruled {
86                            config::OnUnruled::Noop => resp,
87                            config::OnUnruled::Sync(h) => {
88                                h(&mut resp);
89                                resp
90                            }
91                        });
92                }
93            };
94            let policy = rule.policy;
95            let cmd = redis_cell::Cmd::new(&rule.key, &policy);
96
97            let redis_response = match connection.req_packed_command(&cmd.into()).await {
98                Ok(res) => res,
99                Err(redis_err) => {
100                    let config::OnError::Sync(ref h) = config.on_error;
101                    let handled = h(redis_err.into(), &req);
102                    return Ok(handled.into());
103                }
104            };
105            let redis_cell_verdict = match redis_cell::Verdict::from_redis_value(&redis_response) {
106                Ok(verdict) => verdict,
107                Err(redis_err) => {
108                    let config::OnError::Sync(ref h) = config.on_error;
109                    let handled = h(Error::Redis(redis_err), &req);
110                    return Ok(handled.into());
111                }
112            };
113            match redis_cell_verdict {
114                redis_cell::Verdict::Blocked(details) => {
115                    let config::OnError::Sync(ref h) = config.on_error;
116                    let handled = h(
117                        Error::RateLimit(rule::RequestBlockedDetails { rule, details }),
118                        &req,
119                    );
120                    Ok(handled.into())
121                }
122                redis_cell::Verdict::Allowed(details) => {
123                    let policy = rule.policy;
124                    let resource = rule.resource;
125                    inner
126                        .call(req)
127                        .await
128                        .map(|mut resp| match &config.on_success {
129                            config::OnSuccess::Noop => resp,
130                            config::OnSuccess::Sync(h) => {
131                                let details = rule::RequestAllowedDetails {
132                                    details,
133                                    policy,
134                                    resource,
135                                };
136                                h(details, &mut resp);
137                                resp
138                            }
139                        })
140                }
141            }
142        })
143    }
144}
145
146pub struct RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C> {
147    config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
148    connection: C,
149}
150
151impl<PR, ReqTy, RespTy, IntoRespTy, C> Clone for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C>
152where
153    C: Clone,
154{
155    fn clone(&self) -> Self {
156        Self {
157            config: Arc::clone(&self.config),
158            connection: self.connection.clone(),
159        }
160    }
161}
162
163impl<S, PR, ReqTy, RespTy, IntoRespTy, C> tower::Layer<S>
164    for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C>
165where
166    C: Clone,
167{
168    type Service = RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C>;
169    fn layer(&self, inner: S) -> Self::Service {
170        RateLimit::new(inner, Arc::clone(&self.config), self.connection.clone())
171    }
172}
173
174impl<PR, ReqTy, RespTy, IntoRespTy, C> RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C> {
175    pub fn new<RLC>(config: RLC, connection: C) -> Self
176    where
177        RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
178    {
179        RateLimitLayer {
180            config: config.into(),
181            connection,
182        }
183    }
184}
185
186#[cfg(feature = "deadpool")]
187#[cfg_attr(docsrs, doc(cfg(feature = "deadpool")))]
188pub mod deadpool {
189    use crate::config;
190    use crate::error::Error;
191    use crate::rule;
192    use redis::{FromRedisValue, aio::ConnectionLike};
193    pub use redis_cell_rs as redis_cell;
194    use redis_cell_rs::Verdict;
195    use std::{pin::Pin, sync::Arc};
196
197    pub struct RateLimit<S, PR, ReqTy, RespTy, IntoRespTy> {
198        inner: S,
199        config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
200        pool: deadpool_redis::Pool,
201    }
202
203    impl<S, PR, ReqTy, RespTy, IntoRespTy> Clone for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy>
204    where
205        S: Clone,
206    {
207        fn clone(&self) -> Self {
208            Self {
209                inner: self.inner.clone(),
210                config: Arc::clone(&self.config),
211                pool: self.pool.clone(),
212            }
213        }
214    }
215
216    impl<S, PR, ReqTy, RespTy, IntoRespTy> RateLimit<S, PR, ReqTy, RespTy, IntoRespTy> {
217        pub fn new<RLC>(inner: S, config: RLC, pool: deadpool_redis::Pool) -> Self
218        where
219            RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
220        {
221            RateLimit {
222                inner,
223                config: config.into(),
224                pool,
225            }
226        }
227    }
228
229    impl<S, PR, ReqTy, RespTy, IntoRespTy> tower::Service<ReqTy>
230        for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy>
231    where
232        S: tower::Service<ReqTy, Response = RespTy> + Clone + Send + 'static,
233        S::Future: Send + 'static,
234        S::Error: Send,
235        S::Response: Send,
236        PR: rule::ProvideRule<ReqTy> + Clone + Send + Sync + 'static,
237        ReqTy: Send + 'static,
238        IntoRespTy: Into<RespTy> + 'static,
239        RespTy: 'static,
240    {
241        type Response = S::Response;
242        type Error = S::Error;
243        type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;
244
245        fn poll_ready(
246            &mut self,
247            cx: &mut std::task::Context<'_>,
248        ) -> std::task::Poll<Result<(), Self::Error>> {
249            self.inner.poll_ready(cx)
250        }
251
252        fn call(&mut self, req: ReqTy) -> Self::Future {
253            let pool = self.pool.clone();
254            let mut inner = self.inner.clone();
255            let config = self.config.clone();
256
257            Box::pin(async move {
258                let maybe_rule = match config.rule_provider.provide(&req) {
259                    Ok(rule) => rule,
260                    Err(e) => {
261                        let config::OnError::Sync(ref h) = config.on_error;
262                        let resp = h(Error::ProvideRule(e), &req);
263                        return Ok(resp.into());
264                    }
265                };
266                let rule = match maybe_rule {
267                    Some(rule) => rule,
268                    None => {
269                        return inner
270                            .call(req)
271                            .await
272                            .map(|mut resp| match &config.on_unruled {
273                                config::OnUnruled::Noop => resp,
274                                config::OnUnruled::Sync(h) => {
275                                    h(&mut resp);
276                                    resp
277                                }
278                            });
279                    }
280                };
281                let policy = rule.policy;
282                let cmd = redis_cell::Cmd::new(&rule.key, &policy);
283
284                let mut connection = match pool.get().await {
285                    Ok(conn) => conn,
286                    Err(deadpool_err) => {
287                        let config::OnError::Sync(ref h) = config.on_error;
288                        let handled = h(deadpool_err.into(), &req);
289                        return Ok(handled.into());
290                    }
291                };
292                let redis_response = match connection.req_packed_command(&cmd.into()).await {
293                    Ok(res) => res,
294                    Err(redis_err) => {
295                        let config::OnError::Sync(ref h) = config.on_error;
296                        let handled = h(redis_err.into(), &req);
297                        return Ok(handled.into());
298                    }
299                };
300                let redis_cell_verdict = match Verdict::from_redis_value(&redis_response) {
301                    Ok(verdict) => verdict,
302                    Err(redis_err) => {
303                        let config::OnError::Sync(ref h) = config.on_error;
304                        let handled = h(Error::Redis(redis_err), &req);
305                        return Ok(handled.into());
306                    }
307                };
308                match redis_cell_verdict {
309                    redis_cell::Verdict::Blocked(details) => {
310                        let config::OnError::Sync(ref h) = config.on_error;
311                        let handled = h(
312                            Error::RateLimit(rule::RequestBlockedDetails { rule, details }),
313                            &req,
314                        );
315                        Ok(handled.into())
316                    }
317                    redis_cell::Verdict::Allowed(details) => {
318                        let policy = rule.policy;
319                        let resource = rule.resource;
320                        inner
321                            .call(req)
322                            .await
323                            .map(|mut resp| match &config.on_success {
324                                config::OnSuccess::Noop => resp,
325                                config::OnSuccess::Sync(h) => {
326                                    let details = rule::RequestAllowedDetails {
327                                        details,
328                                        policy,
329                                        resource,
330                                    };
331                                    h(details, &mut resp);
332                                    resp
333                                }
334                            })
335                    }
336                }
337            })
338        }
339    }
340
341    pub struct RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy> {
342        config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
343        pool: deadpool_redis::Pool,
344    }
345
346    impl<PR, ReqTy, RespTy, IntoRespTy> Clone for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy> {
347        fn clone(&self) -> Self {
348            Self {
349                config: Arc::clone(&self.config),
350                pool: self.pool.clone(),
351            }
352        }
353    }
354
355    impl<S, PR, ReqTy, RespTy, IntoRespTy> tower::Layer<S>
356        for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy>
357    {
358        type Service = RateLimit<S, PR, ReqTy, RespTy, IntoRespTy>;
359        fn layer(&self, inner: S) -> Self::Service {
360            RateLimit::new(inner, Arc::clone(&self.config), self.pool.clone())
361        }
362    }
363
364    impl<PR, ReqTy, RespTy, IntoRespTy> RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy> {
365        pub fn new<RLC>(config: RLC, pool: deadpool_redis::Pool) -> Self
366        where
367            RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
368        {
369            RateLimitLayer {
370                config: config.into(),
371                pool,
372            }
373        }
374    }
375}