tower_rate_limit_fred/
layer.rs

1use crate::{KeyResolver, config::Config, service::RateLimit};
2use fred::prelude::LuaInterface;
3use std::time::Duration;
4use tower::Layer;
5
6/// Enforces rate limit on the underlying service.
7///
8/// Construct using [`RateLimitLayerBuilder::build`], with the builder obtained from [`RateLimitLayer::builder`].
9#[derive(Debug, Clone)]
10pub struct RateLimitLayer<L, K>
11where
12    L: LuaInterface,
13    K: KeyResolver,
14{
15    config: Config<L, K>,
16}
17
18impl<S, L, K> Layer<S> for RateLimitLayer<L, K>
19where
20    L: LuaInterface,
21    K: KeyResolver,
22{
23    type Service = RateLimit<S, L, K>;
24
25    fn layer(&self, inner: S) -> Self::Service {
26        RateLimit::new(inner, self.config.clone())
27    }
28}
29
30impl<L, K> RateLimitLayer<L, K>
31where
32    L: LuaInterface,
33    K: KeyResolver,
34{
35    /// Returns a new builder for configuring the layer.
36    pub fn builder() -> RateLimitLayerBuilder<L, K, false, false, false, false> {
37        RateLimitLayerBuilder {
38            conn: None,
39            key_resolver: None,
40            capacity: None,
41            emission_interval: None,
42            fail_open: false,
43        }
44    }
45}
46
47/// Builder for [`RateLimitLayer`].
48///
49/// Uses const generics to enforce these required fields are set at compile time:
50/// - [`conn`](Self::conn)
51/// - [`key_resolver`](Self::key_resolver)
52/// - [`capacity`](Self::capacity)
53/// - [`emission_interval`](Self::emission_interval)
54///
55/// Each method consumes `self` and returns a new builder.
56/// Clone the builder first if you need to branch configurations.
57#[derive(Debug, Clone)]
58pub struct RateLimitLayerBuilder<
59    L,
60    K,
61    const CONN: bool,
62    const KEY_RESOLVER: bool,
63    const CAPACITY: bool,
64    const EMISSION_INTERVAL: bool,
65> where
66    L: LuaInterface,
67    K: KeyResolver,
68{
69    conn: Option<L>,
70    key_resolver: Option<K>,
71    capacity: Option<u16>,
72    emission_interval: Option<u32>,
73    fail_open: bool,
74}
75
76impl<
77    L,
78    K,
79    const CONN: bool,
80    const KEY_RESOLVER: bool,
81    const CAPACITY: bool,
82    const EMISSION_INTERVAL: bool,
83> RateLimitLayerBuilder<L, K, CONN, KEY_RESOLVER, CAPACITY, EMISSION_INTERVAL>
84where
85    L: LuaInterface,
86    K: KeyResolver,
87{
88    /// Sets the Redis connection.
89    pub fn conn(
90        self,
91        conn: L,
92    ) -> RateLimitLayerBuilder<L, K, true, KEY_RESOLVER, CAPACITY, EMISSION_INTERVAL> {
93        RateLimitLayerBuilder {
94            conn: Some(conn),
95            key_resolver: self.key_resolver,
96            capacity: self.capacity,
97            emission_interval: self.emission_interval,
98            fail_open: self.fail_open,
99        }
100    }
101
102    /// Sets the key resolver.
103    pub fn key_resolver(
104        self,
105        key_resolver: K,
106    ) -> RateLimitLayerBuilder<L, K, CONN, true, CAPACITY, EMISSION_INTERVAL> {
107        RateLimitLayerBuilder {
108            conn: self.conn,
109            key_resolver: Some(key_resolver),
110            capacity: self.capacity,
111            emission_interval: self.emission_interval,
112            fail_open: self.fail_open,
113        }
114    }
115
116    /// Sets the capacity.
117    ///
118    /// Panics if `capacity == 0`.
119    pub fn capacity(
120        self,
121        capacity: u16,
122    ) -> RateLimitLayerBuilder<L, K, CONN, KEY_RESOLVER, true, EMISSION_INTERVAL> {
123        assert!(capacity > 0, "capacity cannot be 0");
124        RateLimitLayerBuilder {
125            conn: self.conn,
126            key_resolver: self.key_resolver,
127            capacity: Some(capacity),
128            emission_interval: self.emission_interval,
129            fail_open: self.fail_open,
130        }
131    }
132
133    /// Sets the emission interval.
134    ///
135    /// Panics if the duration is less than 1 or greater than `u32::MAX` milliseconds.
136    pub fn emission_interval(
137        self,
138        emission_interval: Duration,
139    ) -> RateLimitLayerBuilder<L, K, CONN, KEY_RESOLVER, CAPACITY, true> {
140        let ms = emission_interval.as_millis();
141        assert!(
142            ms >= 1,
143            "emission_interval cannot be shorter than 1 millisecond. Got {emission_interval:?}",
144        );
145        RateLimitLayerBuilder {
146            conn: self.conn,
147            key_resolver: self.key_resolver,
148            capacity: self.capacity,
149            emission_interval: Some(ms.try_into().unwrap_or_else(|_| {
150                panic!(
151                    "per duration is too long. Max allowed is {:?}, got {emission_interval:?}",
152                    Duration::from_millis(u32::MAX.into())
153                )
154            })),
155            fail_open: self.fail_open,
156        }
157    }
158
159    /// Sets whether to allow requests if Redis is unavailable.
160    ///
161    /// Defaults to `false`.
162    pub fn fail_open(
163        self,
164        fail_open: bool,
165    ) -> RateLimitLayerBuilder<L, K, CONN, KEY_RESOLVER, CAPACITY, EMISSION_INTERVAL> {
166        RateLimitLayerBuilder {
167            conn: self.conn,
168            key_resolver: self.key_resolver,
169            capacity: self.capacity,
170            emission_interval: self.emission_interval,
171            fail_open,
172        }
173    }
174}
175
176impl<L, K> RateLimitLayerBuilder<L, K, true, true, true, true>
177where
178    L: LuaInterface,
179    K: KeyResolver,
180{
181    /// Builds the layer.
182    pub fn build(self) -> RateLimitLayer<L, K> {
183        RateLimitLayer {
184            config: unsafe {
185                Config {
186                    conn: self.conn.unwrap_unchecked(),
187                    key_resolver: self.key_resolver.unwrap_unchecked(),
188                    capacity: self.capacity.unwrap_unchecked(),
189                    emission_interval: self.emission_interval.unwrap_unchecked(),
190                    fail_open: self.fail_open,
191                }
192            },
193        }
194    }
195}