stormchaser_api/
rate_limit.rs1use async_nats::jetstream;
2use async_nats::jetstream::kv::Config;
3use async_nats::Client;
4use axum::{
5 extract::{ConnectInfo, Request, State},
6 http::StatusCode,
7 middleware::Next,
8 response::{IntoResponse, Response},
9};
10use bytes::Bytes;
11use chrono::Utc;
12use std::env;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::OnceCell;
17
18#[derive(Clone)]
20pub struct RateLimitState {
21 pub nats: Client,
23 pub store: Arc<OnceCell<jetstream::kv::Store>>,
25 pub per_second: u64,
27 pub burst_size: u64,
29}
30
31pub async fn nats_rate_limiter(
33 State(state): State<Arc<RateLimitState>>,
34 ConnectInfo(addr): ConnectInfo<SocketAddr>,
35 req: Request,
36 next: Next,
37) -> Response {
38 if env::var("TEST_BYPASS_RATE_LIMIT").is_ok() {
39 return next.run(req).await;
40 }
41
42 let ip = addr.ip().to_string();
43 let current_second = Utc::now().timestamp();
44 let ip_safe = ip.replace(['.', ':'], "_");
45 let key = format!("{}_{}", ip_safe, current_second);
46
47 let store = match state
48 .store
49 .get_or_try_init(|| async {
50 let js = jetstream::new(state.nats.clone());
51 js.create_key_value(Config {
52 bucket: "api_rate_limits".to_string(),
53 max_age: Duration::from_secs(60),
54 ..Default::default()
55 })
56 .await
57 })
58 .await
59 {
60 Ok(s) => s,
61 Err(e) => {
62 tracing::error!("Failed to init NATS KV for rate limiting: {:?}", e);
63 return next.run(req).await;
65 }
66 };
67
68 let mut retries = 0;
69 let allowed = loop {
70 if retries > 5 {
71 break true; }
73
74 match store.entry(&key).await {
75 Ok(Some(entry)) => {
76 let count: u64 = std::str::from_utf8(&entry.value)
77 .unwrap_or("0")
78 .parse()
79 .unwrap_or(0);
80 if count >= state.burst_size {
81 break false;
82 }
83 let next_count = count + 1;
84 let next_val: Bytes = next_count.to_string().into();
85 match store.update(&key, next_val, entry.revision).await {
86 Ok(_) => break true,
87 Err(_) => {
88 retries += 1;
89 continue;
90 }
91 }
92 }
93 Ok(None) => {
94 let next_val: Bytes = "1".into();
95 match store.update(&key, next_val, 0).await {
96 Ok(_) => break true,
97 Err(_) => {
98 retries += 1;
99 continue;
100 }
101 }
102 }
103 Err(e) => {
104 tracing::error!("NATS KV rate limit error: {:?}", e);
105 break true; }
107 }
108 };
109
110 if !allowed {
111 return (StatusCode::TOO_MANY_REQUESTS, "Too Many Requests").into_response();
112 }
113
114 next.run(req).await
115}