tibba_middleware/
limit.rs1use super::{ClientIp, Error, LOG_TARGET};
16use axum::extract::Request;
17use axum::extract::State;
18use axum::middleware::Next;
19use axum::response::IntoResponse;
20use axum::response::Response;
21use scopeguard::defer;
22use std::net::IpAddr;
23use std::time::Duration;
24use tibba_cache::RedisCache;
25use tibba_state::{AppState, CTX};
26use tracing::debug;
27
28type Result<T> = std::result::Result<T, tibba_error::Error>;
30
31pub async fn processing_limit(
44 State(state): State<&AppState>,
45 req: Request,
46 next: Next,
47) -> Result<impl IntoResponse> {
48 debug!(target: LOG_TARGET, "--> processing_limit");
50 defer!(debug!(target: LOG_TARGET, "<-- processing_limit"););
52
53 let limit = state.get_processing_limit();
55
56 if limit < 0 {
58 let res = next.run(req).await;
59 if res.status().as_u16() >= 400 {
60 state.inc_error_requests();
61 }
62 return Ok(res);
63 }
64
65 let count = state.inc_processing();
66 defer!(state.dec_processing(););
67
68 if count > limit {
70 state.inc_error_requests();
71 return Err(Error::TooManyRequests {
73 limit: limit as i64,
74 current: count as i64,
75 }
76 .into());
77 }
78
79 let res = next.run(req).await;
80 if res.status().as_u16() >= 400 {
81 state.inc_error_requests();
82 }
83 Ok(res)
84}
85
86#[derive(Debug, Clone, Default)]
88pub enum LimitType {
89 #[default]
90 Ip, Header(String), Account, }
94
95#[derive(Debug, Clone, Default)]
97pub struct LimitParams {
98 limit_type: LimitType, category: String, max: i64, ttl: Duration, }
103
104impl LimitParams {
105 pub fn new(max: i64) -> Self {
108 Self {
109 limit_type: LimitType::Ip,
110 max,
111 ttl: Duration::from_secs(5 * 60),
112 ..Default::default()
113 }
114 }
115
116 #[must_use]
118 pub fn with_category(mut self, category: impl Into<String>) -> Self {
119 self.category = category.into();
120 self
121 }
122
123 #[must_use]
125 pub fn with_ttl(mut self, ttl: Duration) -> Self {
126 self.ttl = ttl;
127 self
128 }
129
130 #[must_use]
132 pub fn with_limit_type(mut self, limit_type: LimitType) -> Self {
133 self.limit_type = limit_type;
134 self
135 }
136}
137
138fn get_limit_params(req: &Request, ip: IpAddr, params: &LimitParams) -> (String, Duration) {
147 let identifier = match ¶ms.limit_type {
148 LimitType::Header(header_name) => req
149 .headers()
150 .get(header_name)
151 .and_then(|value| value.to_str().ok())
152 .map(|s| s.to_string())
153 .unwrap_or_else(|| ip.to_string()),
154 LimitType::Account => {
155 let account = CTX.get().get_account();
156 if account.is_empty() {
157 ip.to_string()
158 } else {
159 account.to_string()
160 }
161 }
162 LimitType::Ip => ip.to_string(),
163 };
164 let key = if params.category.is_empty() {
166 identifier
167 } else {
168 format!("{}:{}", params.category, identifier)
169 };
170 let ttl = if params.ttl.is_zero() {
172 Duration::from_secs(5 * 60)
173 } else {
174 params.ttl
175 };
176 (key, ttl)
177}
178
179pub async fn error_limiter(
182 ClientIp(ip): ClientIp,
183 State(params): State<LimitParams>,
184 State(cache): State<&'static RedisCache>,
185 req: Request,
186 next: Next,
187) -> Result<Response> {
188 let (key, ttl) = get_limit_params(&req, ip, ¶ms);
189 let current_count = cache.get::<i64>(&key).await.unwrap_or(0);
191 if current_count > params.max {
192 return Err(Error::TooManyRequests {
193 limit: params.max,
194 current: current_count,
195 }
196 .into());
197 }
198 let res = next.run(req).await;
199 if res.status().as_u16() >= 400 {
201 let _ = cache.incr(&key, 1, Some(ttl)).await;
203 }
204 Ok(res)
205}
206
207pub async fn limiter(
210 ClientIp(ip): ClientIp,
211 State(params): State<LimitParams>,
212 State(cache): State<&'static RedisCache>,
213 req: Request,
214 next: Next,
215) -> Result<Response> {
216 let (key, ttl) = get_limit_params(&req, ip, ¶ms);
217
218 let count = cache.incr(&key, 1, Some(ttl)).await?;
220 if count > params.max {
221 return Err(Error::TooManyRequests {
222 limit: params.max,
223 current: count,
224 }
225 .into());
226 }
227
228 Ok(next.run(req).await)
229}