1use crate::{KeyResolver, config::Config, service::RateLimit};
2use fred::prelude::LuaInterface;
3use std::time::Duration;
4use tower::Layer;
5
6#[derive(Debug, Clone)]
10pub struct RateLimitLayer<L, K>
11where
12 L: LuaInterface,
13 K: KeyResolver,
14{
15 config: Config<L, K>,
16}
17
18impl<S, L, K> Layer<S> for RateLimitLayer<L, K>
19where
20 L: LuaInterface,
21 K: KeyResolver,
22{
23 type Service = RateLimit<S, L, K>;
24
25 fn layer(&self, inner: S) -> Self::Service {
26 RateLimit::new(inner, self.config.clone())
27 }
28}
29
30impl<L, K> RateLimitLayer<L, K>
31where
32 L: LuaInterface,
33 K: KeyResolver,
34{
35 pub fn builder() -> RateLimitLayerBuilder<L, K, false, false, false, false> {
37 RateLimitLayerBuilder {
38 conn: None,
39 key_resolver: None,
40 capacity: None,
41 emission_interval: None,
42 fail_open: false,
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
58pub struct RateLimitLayerBuilder<
59 L,
60 K,
61 const CONN: bool,
62 const KEY_RESOLVER: bool,
63 const CAPACITY: bool,
64 const EMISSION_INTERVAL: bool,
65> where
66 L: LuaInterface,
67 K: KeyResolver,
68{
69 conn: Option<L>,
70 key_resolver: Option<K>,
71 capacity: Option<u16>,
72 emission_interval: Option<u32>,
73 fail_open: bool,
74}
75
76impl<
77 L,
78 K,
79 const CONN: bool,
80 const KEY_RESOLVER: bool,
81 const CAPACITY: bool,
82 const EMISSION_INTERVAL: bool,
83> RateLimitLayerBuilder<L, K, CONN, KEY_RESOLVER, CAPACITY, EMISSION_INTERVAL>
84where
85 L: LuaInterface,
86 K: KeyResolver,
87{
88 pub fn conn(
90 self,
91 conn: L,
92 ) -> RateLimitLayerBuilder<L, K, true, KEY_RESOLVER, CAPACITY, EMISSION_INTERVAL> {
93 RateLimitLayerBuilder {
94 conn: Some(conn),
95 key_resolver: self.key_resolver,
96 capacity: self.capacity,
97 emission_interval: self.emission_interval,
98 fail_open: self.fail_open,
99 }
100 }
101
102 pub fn key_resolver(
104 self,
105 key_resolver: K,
106 ) -> RateLimitLayerBuilder<L, K, CONN, true, CAPACITY, EMISSION_INTERVAL> {
107 RateLimitLayerBuilder {
108 conn: self.conn,
109 key_resolver: Some(key_resolver),
110 capacity: self.capacity,
111 emission_interval: self.emission_interval,
112 fail_open: self.fail_open,
113 }
114 }
115
116 pub fn capacity(
120 self,
121 capacity: u16,
122 ) -> RateLimitLayerBuilder<L, K, CONN, KEY_RESOLVER, true, EMISSION_INTERVAL> {
123 assert!(capacity > 0, "capacity cannot be 0");
124 RateLimitLayerBuilder {
125 conn: self.conn,
126 key_resolver: self.key_resolver,
127 capacity: Some(capacity),
128 emission_interval: self.emission_interval,
129 fail_open: self.fail_open,
130 }
131 }
132
133 pub fn emission_interval(
137 self,
138 emission_interval: Duration,
139 ) -> RateLimitLayerBuilder<L, K, CONN, KEY_RESOLVER, CAPACITY, true> {
140 let ms = emission_interval.as_millis();
141 assert!(
142 ms >= 1,
143 "emission_interval cannot be shorter than 1 millisecond. Got {emission_interval:?}",
144 );
145 RateLimitLayerBuilder {
146 conn: self.conn,
147 key_resolver: self.key_resolver,
148 capacity: self.capacity,
149 emission_interval: Some(ms.try_into().unwrap_or_else(|_| {
150 panic!(
151 "per duration is too long. Max allowed is {:?}, got {emission_interval:?}",
152 Duration::from_millis(u32::MAX.into())
153 )
154 })),
155 fail_open: self.fail_open,
156 }
157 }
158
159 pub fn fail_open(
163 self,
164 fail_open: bool,
165 ) -> RateLimitLayerBuilder<L, K, CONN, KEY_RESOLVER, CAPACITY, EMISSION_INTERVAL> {
166 RateLimitLayerBuilder {
167 conn: self.conn,
168 key_resolver: self.key_resolver,
169 capacity: self.capacity,
170 emission_interval: self.emission_interval,
171 fail_open,
172 }
173 }
174}
175
176impl<L, K> RateLimitLayerBuilder<L, K, true, true, true, true>
177where
178 L: LuaInterface,
179 K: KeyResolver,
180{
181 pub fn build(self) -> RateLimitLayer<L, K> {
183 RateLimitLayer {
184 config: unsafe {
185 Config {
186 conn: self.conn.unwrap_unchecked(),
187 key_resolver: self.key_resolver.unwrap_unchecked(),
188 capacity: self.capacity.unwrap_unchecked(),
189 emission_interval: self.emission_interval.unwrap_unchecked(),
190 fail_open: self.fail_open,
191 }
192 },
193 }
194 }
195}