throttlecrab_server/transport/
http.rs1use super::Transport;
43use crate::actor::RateLimiterHandle;
44use crate::metrics::{Metrics, Transport as MetricsTransport};
45use crate::types::{ThrottleRequest as InternalRequest, ThrottleResponse};
46use anyhow::Result;
47use async_trait::async_trait;
48use axum::{
49 Router,
50 extract::State,
51 http::StatusCode,
52 response::Json,
53 routing::{get, post},
54};
55use serde::{Deserialize, Serialize};
56use std::net::SocketAddr;
57use std::sync::Arc;
58use std::time::SystemTime;
59
60#[derive(Debug, Serialize, Deserialize)]
62pub struct HttpThrottleRequest {
63 pub key: String,
65 pub max_burst: i64,
67 pub count_per_period: i64,
69 pub period: i64,
71 pub quantity: Option<i64>,
73}
74
75#[derive(Debug, Serialize, Deserialize)]
77pub struct HttpErrorResponse {
78 pub error: String,
80}
81
82pub struct HttpTransport {
86 addr: SocketAddr,
87 metrics: Arc<Metrics>,
88}
89
90impl HttpTransport {
91 pub fn new(host: &str, port: u16, metrics: Arc<Metrics>) -> Self {
92 let addr = format!("{host}:{port}").parse().expect("Invalid address");
93 Self { addr, metrics }
94 }
95}
96
97#[async_trait]
98impl Transport for HttpTransport {
99 async fn start(self, limiter: RateLimiterHandle) -> Result<()> {
100 let metrics = Arc::clone(&self.metrics);
101 let app_state = Arc::new(AppState { limiter, metrics });
102
103 let app = Router::new()
104 .route("/throttle", post(handle_throttle))
105 .route("/health", get(|| async { "OK" }))
106 .route("/metrics", get(handle_metrics))
107 .with_state(app_state);
108
109 tracing::info!("HTTP server listening on {}", self.addr);
110
111 let listener = tokio::net::TcpListener::bind(self.addr).await?;
112 axum::serve(listener, app).await?;
113
114 Ok(())
115 }
116}
117
118struct AppState {
119 limiter: RateLimiterHandle,
120 metrics: Arc<Metrics>,
121}
122
123async fn handle_throttle(
124 State(state): State<Arc<AppState>>,
125 Json(req): Json<HttpThrottleRequest>,
126) -> Result<Json<ThrottleResponse>, (StatusCode, Json<HttpErrorResponse>)> {
127 let timestamp = SystemTime::now();
129
130 let internal_req = InternalRequest {
131 key: req.key.clone(),
132 max_burst: req.max_burst,
133 count_per_period: req.count_per_period,
134 period: req.period,
135 quantity: req.quantity.unwrap_or(1),
136 timestamp,
137 };
138
139 match state.limiter.throttle(internal_req).await {
140 Ok(response) => {
141 state.metrics.record_request_with_key(
142 MetricsTransport::Http,
143 response.allowed,
144 &req.key,
145 );
146 Ok(Json(response))
147 }
148 Err(e) => {
149 tracing::error!("Rate limiter error: {}", e);
150 state.metrics.record_error(MetricsTransport::Http);
151 Err((
152 StatusCode::INTERNAL_SERVER_ERROR,
153 Json(HttpErrorResponse {
154 error: format!("Internal server error: {e}"),
155 }),
156 ))
157 }
158 }
159}
160
161async fn handle_metrics(State(state): State<Arc<AppState>>) -> Result<String, StatusCode> {
162 Ok(state.metrics.export_prometheus())
163}