Skip to main content

stormchaser_api/
rate_limit.rs

1use async_nats::jetstream;
2use async_nats::jetstream::kv::Config;
3use async_nats::Client;
4use axum::{
5    extract::{ConnectInfo, Request, State},
6    http::StatusCode,
7    middleware::Next,
8    response::{IntoResponse, Response},
9};
10use bytes::Bytes;
11use chrono::Utc;
12use std::env;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::OnceCell;
17
18/// State for the NATS-backed rate limiter
19#[derive(Clone)]
20pub struct RateLimitState {
21    /// NATS client connection
22    pub nats: Client,
23    /// Lazy initialized Key-Value store for rate limiting
24    pub store: Arc<OnceCell<jetstream::kv::Store>>,
25    /// Allowed requests per second
26    pub per_second: u64,
27    /// Maximum burst size for requests
28    pub burst_size: u64,
29}
30
31/// Middleware that limits request rates using NATS KV
32pub async fn nats_rate_limiter(
33    State(state): State<Arc<RateLimitState>>,
34    ConnectInfo(addr): ConnectInfo<SocketAddr>,
35    req: Request,
36    next: Next,
37) -> Response {
38    if env::var("TEST_BYPASS_RATE_LIMIT").is_ok() {
39        return next.run(req).await;
40    }
41
42    let ip = addr.ip().to_string();
43    let current_second = Utc::now().timestamp();
44    let ip_safe = ip.replace(['.', ':'], "_");
45    let key = format!("{}_{}", ip_safe, current_second);
46
47    let store = match state
48        .store
49        .get_or_try_init(|| async {
50            let js = jetstream::new(state.nats.clone());
51            js.create_key_value(Config {
52                bucket: "api_rate_limits".to_string(),
53                max_age: Duration::from_secs(60),
54                ..Default::default()
55            })
56            .await
57        })
58        .await
59    {
60        Ok(s) => s,
61        Err(e) => {
62            tracing::error!("Failed to init NATS KV for rate limiting: {:?}", e);
63            // Fail open
64            return next.run(req).await;
65        }
66    };
67
68    let mut retries = 0;
69    let allowed = loop {
70        if retries > 5 {
71            break true; // fail open under high contention
72        }
73
74        match store.entry(&key).await {
75            Ok(Some(entry)) => {
76                let count: u64 = std::str::from_utf8(&entry.value)
77                    .unwrap_or("0")
78                    .parse()
79                    .unwrap_or(0);
80                if count >= state.burst_size {
81                    break false;
82                }
83                let next_count = count + 1;
84                let next_val: Bytes = next_count.to_string().into();
85                match store.update(&key, next_val, entry.revision).await {
86                    Ok(_) => break true,
87                    Err(_) => {
88                        retries += 1;
89                        continue;
90                    }
91                }
92            }
93            Ok(None) => {
94                let next_val: Bytes = "1".into();
95                match store.update(&key, next_val, 0).await {
96                    Ok(_) => break true,
97                    Err(_) => {
98                        retries += 1;
99                        continue;
100                    }
101                }
102            }
103            Err(e) => {
104                tracing::error!("NATS KV rate limit error: {:?}", e);
105                break true; // fail open
106            }
107        }
108    };
109
110    if !allowed {
111        return (StatusCode::TOO_MANY_REQUESTS, "Too Many Requests").into_response();
112    }
113
114    next.run(req).await
115}