Skip to main content

tibba_middleware/
limit.rs

1// Copyright 2026 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::{ClientIp, Error, LOG_TARGET};
16use axum::extract::Request;
17use axum::extract::State;
18use axum::middleware::Next;
19use axum::response::IntoResponse;
20use axum::response::Response;
21use scopeguard::defer;
22use std::net::IpAddr;
23use std::time::Duration;
24use tibba_cache::RedisCache;
25use tibba_state::{AppState, CTX};
26use tracing::debug;
27
28// Custom Result type that uses the application's Error type
29type Result<T> = std::result::Result<T, tibba_error::Error>;
30
31/// Middleware that implements concurrent request processing limits
32///
33/// This middleware:
34/// 1. Tracks number of concurrent requests being processed
35/// 2. Enforces a maximum limit on concurrent requests
36/// 3. Returns 429 Too Many Requests when limit is exceeded
37/// 4. Properly decrements counter when request processing completes
38///
39/// # Arguments
40/// * `State(state)` - Application state containing limit configuration
41/// * `req` - The incoming request
42/// * `next` - The next middleware in the chain
43pub async fn processing_limit(
44    State(state): State<&AppState>,
45    req: Request,
46    next: Next,
47) -> Result<impl IntoResponse> {
48    // Log middleware entry
49    debug!(target: LOG_TARGET, "--> processing_limit");
50    // Ensure exit logging happens even if processing panics
51    defer!(debug!(target: LOG_TARGET, "<-- processing_limit"););
52
53    // Get configured processing limit from app state
54    let limit = state.get_processing_limit();
55
56    // If limit is negative, processing is unlimited
57    if limit < 0 {
58        let res = next.run(req).await;
59        if res.status().as_u16() >= 400 {
60            state.inc_error_requests();
61        }
62        return Ok(res);
63    }
64
65    let count = state.inc_processing();
66    defer!(state.dec_processing(););
67
68    // Check if processing limit has been exceeded
69    if count > limit {
70        state.inc_error_requests();
71        // Return 429 Too Many Requests error
72        return Err(Error::TooManyRequests {
73            limit: limit as i64,
74            current: count as i64,
75        }
76        .into());
77    }
78
79    let res = next.run(req).await;
80    if res.status().as_u16() >= 400 {
81        state.inc_error_requests();
82    }
83    Ok(res)
84}
85
86/// Type of rate limiting to apply
87#[derive(Debug, Clone, Default)]
88pub enum LimitType {
89    #[default]
90    Ip, // Rate limit based on IP address
91    Header(String), // Rate limit based on header value
92    Account,        // Rate limit based on authenticated account (falls back to IP if not logged in)
93}
94
95/// Configuration parameters for rate limiting middleware
96#[derive(Debug, Clone, Default)]
97pub struct LimitParams {
98    limit_type: LimitType, // Type of rate limiting to apply
99    category: String,      // Category identifier for the limit
100    max: i64,              // Maximum number of requests allowed
101    ttl: Duration,         // Time-to-live for the rate limit counter
102}
103
104impl LimitParams {
105    /// Creates a new LimitParams with the maximum number of requests allowed.
106    /// Defaults to IP-based limiting with no category and a 5-minute TTL.
107    pub fn new(max: i64) -> Self {
108        Self {
109            limit_type: LimitType::Ip,
110            max,
111            ttl: Duration::from_secs(5 * 60),
112            ..Default::default()
113        }
114    }
115
116    /// Sets the category identifier used as a prefix in the cache key.
117    #[must_use]
118    pub fn with_category(mut self, category: impl Into<String>) -> Self {
119        self.category = category.into();
120        self
121    }
122
123    /// Sets the TTL for the rate limit counter window.
124    #[must_use]
125    pub fn with_ttl(mut self, ttl: Duration) -> Self {
126        self.ttl = ttl;
127        self
128    }
129
130    /// Sets the limit type (IP-based or header-based).
131    #[must_use]
132    pub fn with_limit_type(mut self, limit_type: LimitType) -> Self {
133        self.limit_type = limit_type;
134        self
135    }
136}
137
138/// Generates the cache key and TTL for rate limiting
139///
140/// # Arguments
141/// * `ip` - Client IP address
142/// * `params` - Rate limiting parameters
143///
144/// # Returns
145/// Tuple of (cache_key, ttl_duration)
146fn get_limit_params(req: &Request, ip: IpAddr, params: &LimitParams) -> (String, Duration) {
147    let identifier = match &params.limit_type {
148        LimitType::Header(header_name) => req
149            .headers()
150            .get(header_name)
151            .and_then(|value| value.to_str().ok())
152            .map(|s| s.to_string())
153            .unwrap_or_else(|| ip.to_string()),
154        LimitType::Account => {
155            let account = CTX.get().get_account();
156            if account.is_empty() {
157                ip.to_string()
158            } else {
159                account.to_string()
160            }
161        }
162        LimitType::Ip => ip.to_string(),
163    };
164    // Append category to key if specified
165    let key = if params.category.is_empty() {
166        identifier
167    } else {
168        format!("{}:{}", params.category, identifier)
169    };
170    // Use default TTL of 5 minutes if none specified
171    let ttl = if params.ttl.is_zero() {
172        Duration::from_secs(5 * 60)
173    } else {
174        params.ttl
175    };
176    (key, ttl)
177}
178
179/// Middleware that limits requests only when errors occur
180/// Increments counter only for responses with status code >= 400
181pub async fn error_limiter(
182    ClientIp(ip): ClientIp,
183    State(params): State<LimitParams>,
184    State(cache): State<&'static RedisCache>,
185    req: Request,
186    next: Next,
187) -> Result<Response> {
188    let (key, ttl) = get_limit_params(&req, ip, &params);
189    // Check if current error count exceeds limit
190    let current_count = cache.get::<i64>(&key).await.unwrap_or(0);
191    if current_count > params.max {
192        return Err(Error::TooManyRequests {
193            limit: params.max,
194            current: current_count,
195        }
196        .into());
197    }
198    let res = next.run(req).await;
199    // Increment counter only on error responses
200    if res.status().as_u16() >= 400 {
201        // Ignore Redis errors when incrementing
202        let _ = cache.incr(&key, 1, Some(ttl)).await;
203    }
204    Ok(res)
205}
206
207/// Standard rate limiting middleware
208/// Increments counter for every request regardless of response status
209pub async fn limiter(
210    ClientIp(ip): ClientIp,
211    State(params): State<LimitParams>,
212    State(cache): State<&'static RedisCache>,
213    req: Request,
214    next: Next,
215) -> Result<Response> {
216    let (key, ttl) = get_limit_params(&req, ip, &params);
217
218    // Increment counter and check against limit
219    let count = cache.incr(&key, 1, Some(ttl)).await?;
220    if count > params.max {
221        return Err(Error::TooManyRequests {
222            limit: params.max,
223            current: count,
224        }
225        .into());
226    }
227
228    Ok(next.run(req).await)
229}