1use std::{
2 borrow::Cow,
3 sync::{
4 LazyLock,
5 atomic::{AtomicUsize, Ordering},
6 },
7 time::{SystemTime, UNIX_EPOCH},
8};
9use tank::{
10 AsValue, Entity, Error, Executor, QueryBuilder, Result, Value, current_timestamp_ms, expr,
11 join,
12 stream::{StreamExt, TryStreamExt},
13};
14use tokio::sync::Mutex;
15
16static MUTEX: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
17
18#[derive(Clone, Debug, PartialEq, Eq)]
19pub enum Method {
20 GET,
21 POST,
22 PUT,
23 DELETE,
24}
25impl AsValue for Method {
26 fn as_empty_value() -> Value {
27 Value::Varchar(None)
28 }
29 fn as_value(self) -> Value {
30 Value::Varchar(Some(
31 match self {
32 Method::GET => "get",
33 Method::POST => "post",
34 Method::PUT => "put",
35 Method::DELETE => "delete",
36 }
37 .into(),
38 ))
39 }
40 fn try_from_value(value: Value) -> Result<Self>
41 where
42 Self: Sized,
43 {
44 if let Value::Varchar(Some(v)) = value.try_as(&String::as_empty_value())? {
45 match &*v {
46 "get" => return Ok(Method::GET),
47 "post" => return Ok(Method::POST),
48 "put" => return Ok(Method::PUT),
49 "delete" => return Ok(Method::DELETE),
50 _ => {
51 return Err(Error::msg(format!(
52 "Unexpected value `{v}` for Method enum"
53 )));
54 }
55 }
56 }
57 Err(Error::msg("Unexpected value for Method enum"))
58 }
59}
60
61#[derive(Default, Entity, PartialEq, Eq)]
62#[tank(schema = "api")]
63struct RequestLimit {
64 #[tank(primary_key)]
65 pub id: i32,
66 pub target_pattern: Cow<'static, str>,
67 pub requests: i32,
68 pub method: Option<Method>,
70 pub interval_ms: Option<i32>,
72}
73impl RequestLimit {
74 pub fn new(
75 target_pattern: &'static str,
76 requests: i32,
77 method: Option<Method>,
78 interval_ms: Option<i32>,
79 ) -> Self {
80 let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
81 Self {
82 id,
83 target_pattern: target_pattern.into(),
84 requests,
85 method,
86 interval_ms,
87 }
88 }
89}
90
91#[derive(Entity, PartialEq, Eq)]
92#[tank(schema = "api")]
93pub struct Request {
94 #[tank(primary_key)]
95 pub id: i64,
96 pub target: String,
97 pub method: Option<Method>,
98 pub beign_timestamp_ms: i64,
99 pub end_timestamp_ms: Option<i64>,
100}
101
102static GLOBAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
103
104impl Request {
105 pub fn new(target: String, method: Option<Method>) -> Self {
106 let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
107 Self {
108 id,
109 target,
110 method,
111 beign_timestamp_ms: SystemTime::now()
112 .duration_since(UNIX_EPOCH)
113 .unwrap()
114 .as_millis() as _,
115 end_timestamp_ms: None,
116 }
117 }
118 pub fn end(&mut self) {
119 self.end_timestamp_ms = Some(
120 SystemTime::now()
121 .duration_since(UNIX_EPOCH)
122 .unwrap()
123 .as_millis() as _,
124 );
125 }
126}
127
128pub async fn requests<E: Executor>(executor: &mut E) {
129 let _lock = MUTEX.lock();
130
131 RequestLimit::drop_table(executor, true, false)
133 .await
134 .expect("Could not drop the RequestLimit table");
135 Request::drop_table(executor, true, false)
136 .await
137 .expect("Could not drop the Request table");
138
139 RequestLimit::create_table(executor, false, true)
140 .await
141 .expect("Could not create the RequestLimit table");
142 Request::create_table(executor, false, true)
143 .await
144 .expect("Could not create the Request table");
145
146 RequestLimit::insert_many(
148 executor,
149 &[
150 RequestLimit::new("v1/%", 3, None, None),
152 RequestLimit::new("v1/server/data/%", 5, None, None),
154 RequestLimit::new("v1/server/user/%", 2, Method::PUT.into(), None),
156 RequestLimit::new("v1/server/user/%", 1, Method::DELETE.into(), None),
158 RequestLimit::new("v2/%", 5, None, 60_000.into()),
160 ],
161 )
162 .await
163 .expect("Could not insert the limits");
164 let limits = RequestLimit::find_many(executor, true, None)
165 .map_err(|e| panic!("{e:#}"))
166 .count()
167 .await;
168 assert_eq!(limits, 5);
169
170 #[cfg(not(feature = "disable-joins"))]
171 {
172 let mut violated_limits = executor
173 .prepare(
174 QueryBuilder::new()
175 .select([
176 RequestLimit::target_pattern,
177 RequestLimit::requests,
178 RequestLimit::method,
179 RequestLimit::interval_ms,
180 ])
181 .from(join!(RequestLimit CROSS JOIN Request))
182 .where_expr(expr!(
183 ? == RequestLimit::target_pattern as LIKE
184 && Request::target == RequestLimit::target_pattern as LIKE
185 && (RequestLimit::method == NULL
186 || RequestLimit::method == Request::method)
187 && (RequestLimit::interval_ms == NULL && Request::end_timestamp_ms == NULL
188 || RequestLimit::interval_ms != NULL
189 && Request::end_timestamp_ms
190 >= current_timestamp_ms!() - RequestLimit::interval_ms)
191 ))
192 .group_by([
193 RequestLimit::target_pattern,
194 RequestLimit::requests,
195 RequestLimit::method,
196 RequestLimit::interval_ms,
197 ])
198 .having(expr!(COUNT(Request::id) >= RequestLimit::requests))
199 .build(&executor.driver()),
200 )
201 .await
202 .expect("Failed to prepare the limit query");
203
204 let mut r1 = Request::new("v1/server/user/new/1".into(), Method::PUT.into());
205 let mut r2 = Request::new("v1/server/user/new/2".into(), Method::PUT.into());
206 let mut r3 = Request::new("v1/server/user/new/3".into(), Method::PUT.into());
207 let mut r4 = Request::new("v1/server/articles/new/4".into(), Method::PUT.into());
208 let r5 = Request::new("v1/server/user/new/5".into(), Method::PUT.into());
209
210 violated_limits.bind(r1.target.clone()).unwrap();
211 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
212 r1.save(executor).await.expect("Failed to save r1");
213
214 violated_limits.bind(r2.target.clone()).unwrap();
215 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
216 r2.save(executor).await.expect("Failed to save r2");
217
218 violated_limits.bind(r3.target.clone()).unwrap();
219 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); violated_limits.bind(r4.target.clone()).unwrap();
223 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
224 r4.save(executor).await.expect("Failed to save r4");
225
226 violated_limits.bind(r5.target.clone()).unwrap();
227 assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); r1.end();
230 r1.save(executor).await.expect("Could not terminate r1");
231
232 violated_limits.bind(r3.target.clone()).unwrap();
233 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
234 r3.save(executor).await.expect("Failed to save r3");
235
236 let mut data_reqs = vec![];
239 for i in 0..5 {
240 let req = Request::new(format!("v1/server/data/item/{}", i), None);
241 req.save(executor)
242 .await
243 .expect("Failed to save data request");
244 data_reqs.push(req);
245 }
246
247 violated_limits
250 .bind("v1/server/data/item/999".to_string())
251 .unwrap();
252 assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); for i in 0..4 {
255 data_reqs[i].end();
256 data_reqs[i]
257 .save(executor)
258 .await
259 .expect(&format!("Failed to save data_reqs[{i}]"));
260 }
261
262 violated_limits
265 .bind("v1/server/data/item/999".to_string())
266 .unwrap();
267 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); r2.end();
270 r2.save(executor).await.expect("Could not terminate r2");
271 data_reqs[4].end();
272 data_reqs[4]
273 .save(executor)
274 .await
275 .expect("Could not terminate data_reqs[4]");
276
277 violated_limits
278 .bind("v1/server/data/item/999".to_string())
279 .unwrap();
280 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
281
282 r3.end();
283 r3.save(executor).await.expect("Could not terminate r3");
284 r4.end();
285 r4.save(executor).await.expect("Could not terminate r4");
286
287 let mut d1 = Request::new("v1/server/user/del/1".into(), Method::DELETE.into());
290 let d2 = Request::new("v1/server/user/del/2".into(), Method::DELETE.into());
291
292 violated_limits.bind(d1.target.clone()).unwrap();
293 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
294 d1.save(executor).await.expect("Failed to save d1");
295
296 violated_limits.bind(d2.target.clone()).unwrap();
297 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); d1.end();
300 d1.save(executor).await.expect("Failed to end d1");
301
302 violated_limits.bind(d2.target.clone()).unwrap();
303 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
304
305 let mut v2_reqs = vec![];
306 for i in 0..5 {
307 let mut req = Request::new(format!("v2/resource/{}", i), Method::GET.into());
308 req.save(executor).await.expect("Failed to save v2 req");
309 req.end(); req.save(executor).await.expect("Failed to end v2 req");
311 v2_reqs.push(req);
312 }
313
314 violated_limits.bind("v2/resource/new".to_string()).unwrap();
315 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); }
317}