1use std::{
2 borrow::Cow,
3 sync::{
4 LazyLock,
5 atomic::{AtomicUsize, Ordering},
6 },
7 time::{SystemTime, UNIX_EPOCH},
8};
9use tank::{
10 AsValue, Entity, Error, Executor, QueryBuilder, Result, Value, current_timestamp_ms, expr,
11 join, stream::StreamExt,
12};
13use tokio::sync::Mutex;
14
15static MUTEX: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum Method {
19 GET,
20 POST,
21 PUT,
22 DELETE,
23}
24impl AsValue for Method {
25 fn as_empty_value() -> Value {
26 Value::Varchar(None)
27 }
28 fn as_value(self) -> Value {
29 Value::Varchar(Some(
30 match self {
31 Method::GET => "get",
32 Method::POST => "post",
33 Method::PUT => "put",
34 Method::DELETE => "delete",
35 }
36 .into(),
37 ))
38 }
39 fn try_from_value(value: Value) -> Result<Self>
40 where
41 Self: Sized,
42 {
43 if let Value::Varchar(Some(v)) = value.try_as(&String::as_empty_value())? {
44 match &*v {
45 "get" => return Ok(Method::GET),
46 "post" => return Ok(Method::POST),
47 "put" => return Ok(Method::PUT),
48 "delete" => return Ok(Method::DELETE),
49 _ => {
50 return Err(Error::msg(format!(
51 "Unexpected value `{v}` for Method enum"
52 )));
53 }
54 }
55 }
56 Err(Error::msg("Unexpected value for Method enum"))
57 }
58}
59
60#[derive(Default, Entity, PartialEq, Eq)]
61#[tank(schema = "api")]
62struct RequestLimit {
63 #[tank(primary_key)]
64 pub id: i32,
65 pub target_pattern: Cow<'static, str>,
66 pub requests: i32,
67 pub method: Option<Method>,
69 pub interval_ms: Option<i32>,
71}
72impl RequestLimit {
73 pub fn new(
74 target_pattern: &'static str,
75 requests: i32,
76 method: Option<Method>,
77 interval_ms: Option<i32>,
78 ) -> Self {
79 let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
80 Self {
81 id,
82 target_pattern: target_pattern.into(),
83 requests,
84 method,
85 interval_ms,
86 }
87 }
88}
89
90#[derive(Entity, PartialEq, Eq)]
91#[tank(schema = "api")]
92pub struct Request {
93 #[tank(primary_key)]
94 pub id: i64,
95 pub target: String,
96 pub method: Option<Method>,
97 pub beign_timestamp_ms: i64,
98 pub end_timestamp_ms: Option<i64>,
99}
100
101static GLOBAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
102
103impl Request {
104 pub fn new(target: String, method: Option<Method>) -> Self {
105 let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
106 Self {
107 id,
108 target,
109 method,
110 beign_timestamp_ms: SystemTime::now()
111 .duration_since(UNIX_EPOCH)
112 .unwrap()
113 .as_millis() as _,
114 end_timestamp_ms: None,
115 }
116 }
117 pub fn end(&mut self) {
118 self.end_timestamp_ms = Some(
119 SystemTime::now()
120 .duration_since(UNIX_EPOCH)
121 .unwrap()
122 .as_millis() as _,
123 );
124 }
125}
126
127pub async fn requests<E: Executor>(executor: &mut E) {
128 let _lock = MUTEX.lock();
129
130 RequestLimit::drop_table(executor, true, false)
132 .await
133 .expect("Could not drop the RequestLimit table");
134 Request::drop_table(executor, true, false)
135 .await
136 .expect("Could not drop the Request table");
137
138 RequestLimit::create_table(executor, false, true)
139 .await
140 .expect("Could not create the RequestLimit table");
141 Request::create_table(executor, false, true)
142 .await
143 .expect("Could not create the Request table");
144
145 RequestLimit::insert_many(
147 executor,
148 &[
149 RequestLimit::new("v1/%", 3, None, None),
151 RequestLimit::new("v1/server/data/%", 5, None, None),
153 RequestLimit::new("v1/server/user/%", 2, Method::PUT.into(), None),
155 RequestLimit::new("v1/server/user/%", 1, Method::DELETE.into(), None),
157 RequestLimit::new("v2/%", 5, None, 60_000.into()),
159 ],
160 )
161 .await
162 .expect("Could not insert the limits");
163 let limits = RequestLimit::find_many(executor, true, None)
164 .map(|v| v.expect("Found error"))
165 .count()
166 .await;
167 assert_eq!(limits, 5);
168
169 #[cfg(not(feature = "disable-joins"))]
170 {
171 let mut violated_limits = executor
172 .prepare(
173 QueryBuilder::new()
174 .select([
175 RequestLimit::target_pattern,
176 RequestLimit::requests,
177 RequestLimit::method,
178 RequestLimit::interval_ms,
179 ])
180 .from(join!(RequestLimit CROSS JOIN Request))
181 .where_expr(expr!(
182 ? == RequestLimit::target_pattern as LIKE
183 && Request::target == RequestLimit::target_pattern as LIKE
184 && (RequestLimit::method == NULL
185 || RequestLimit::method == Request::method)
186 && (RequestLimit::interval_ms == NULL && Request::end_timestamp_ms == NULL
187 || RequestLimit::interval_ms != NULL
188 && Request::end_timestamp_ms
189 >= current_timestamp_ms!() - RequestLimit::interval_ms)
190 ))
191 .group_by([
192 RequestLimit::target_pattern,
193 RequestLimit::requests,
194 RequestLimit::method,
195 RequestLimit::interval_ms,
196 ])
197 .having(expr!(COUNT(Request::id) >= RequestLimit::requests))
198 .build(&executor.driver()),
199 )
200 .await
201 .expect("Failed to prepare the limit query");
202
203 let mut r1 = Request::new("v1/server/user/new/1".into(), Method::PUT.into());
204 let mut r2 = Request::new("v1/server/user/new/2".into(), Method::PUT.into());
205 let r3 = Request::new("v1/server/user/new/3".into(), Method::PUT.into());
206 let r4 = Request::new("v1/server/articles/new/4".into(), Method::PUT.into());
207 let r5 = Request::new("v1/server/user/new/5".into(), Method::PUT.into());
208
209 violated_limits.bind(r1.target.clone()).unwrap();
210 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
211 r1.save(executor).await.expect("Failed to save r1");
212
213 violated_limits.bind(r2.target.clone()).unwrap();
214 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
215 r2.save(executor).await.expect("Failed to save r2");
216
217 violated_limits.bind(r3.target.clone()).unwrap();
218 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); violated_limits.bind(r4.target.clone()).unwrap();
222 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
223 r4.save(executor).await.expect("Failed to save r4");
224
225 violated_limits.bind(r5.target.clone()).unwrap();
226 assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); r1.end();
229 r1.save(executor).await.expect("Could not terminate r1");
230
231 violated_limits.bind(r3.target.clone()).unwrap();
232 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
233 r3.save(executor).await.expect("Failed to save r3");
234
235 let mut data_reqs = vec![];
238 for i in 0..5 {
239 let req = Request::new(format!("v1/server/data/item/{}", i), None);
240 req.save(executor)
241 .await
242 .expect("Failed to save data request");
243 data_reqs.push(req);
244 }
245
246 violated_limits
249 .bind("v1/server/data/item/999".to_string())
250 .unwrap();
251 assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); for i in 0..4 {
254 data_reqs[i].end();
255 data_reqs[i]
256 .save(executor)
257 .await
258 .expect(&format!("Failed to save data_reqs[{i}]"));
259 }
260
261 violated_limits
264 .bind("v1/server/data/item/999".to_string())
265 .unwrap();
266 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); r2.end();
269 r2.save(executor).await.expect("Could not terminate r2");
270 data_reqs[4].end();
271 data_reqs[4]
272 .save(executor)
273 .await
274 .expect("Could not terminate data_reqs[4]");
275
276 violated_limits
277 .bind("v1/server/data/item/999".to_string())
278 .unwrap();
279 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
280 }
281}