throttlecrab_server/transport/
http.rs1use super::Transport;
43use crate::actor::RateLimiterHandle;
44use crate::metrics::{Metrics, Transport as MetricsTransport};
45use crate::types::{ThrottleRequest as InternalRequest, ThrottleResponse};
46use anyhow::Result;
47use async_trait::async_trait;
48use axum::{
49 Router,
50 extract::State,
51 http::StatusCode,
52 response::Json,
53 routing::{get, post},
54};
55use serde::{Deserialize, Serialize};
56use std::net::SocketAddr;
57use std::sync::Arc;
58use std::time::{Instant, SystemTime};
59
60#[derive(Debug, Serialize, Deserialize)]
62pub struct HttpThrottleRequest {
63 pub key: String,
65 pub max_burst: i64,
67 pub count_per_period: i64,
69 pub period: i64,
71 pub quantity: Option<i64>,
73}
74
75#[derive(Debug, Serialize, Deserialize)]
77pub struct HttpErrorResponse {
78 pub error: String,
80}
81
82pub struct HttpTransport {
86 addr: SocketAddr,
87 metrics: Arc<Metrics>,
88}
89
90impl HttpTransport {
91 pub fn new(host: &str, port: u16, metrics: Arc<Metrics>) -> Self {
92 let addr = format!("{host}:{port}").parse().expect("Invalid address");
93 Self { addr, metrics }
94 }
95}
96
97#[async_trait]
98impl Transport for HttpTransport {
99 async fn start(self, limiter: RateLimiterHandle) -> Result<()> {
100 let metrics = Arc::clone(&self.metrics);
101 let app_state = Arc::new(AppState { limiter, metrics });
102
103 let app = Router::new()
104 .route("/throttle", post(handle_throttle))
105 .route("/health", get(|| async { "OK" }))
106 .route("/metrics", get(handle_metrics))
107 .with_state(app_state);
108
109 tracing::info!("HTTP server listening on {}", self.addr);
110
111 let listener = tokio::net::TcpListener::bind(self.addr).await?;
112 axum::serve(listener, app).await?;
113
114 Ok(())
115 }
116}
117
118struct AppState {
119 limiter: RateLimiterHandle,
120 metrics: Arc<Metrics>,
121}
122
123async fn handle_throttle(
124 State(state): State<Arc<AppState>>,
125 Json(req): Json<HttpThrottleRequest>,
126) -> Result<Json<ThrottleResponse>, (StatusCode, Json<HttpErrorResponse>)> {
127 let start = Instant::now();
128
129 let timestamp = SystemTime::now();
131
132 let internal_req = InternalRequest {
133 key: req.key.clone(),
134 max_burst: req.max_burst,
135 count_per_period: req.count_per_period,
136 period: req.period,
137 quantity: req.quantity.unwrap_or(1),
138 timestamp,
139 };
140
141 match state.limiter.throttle(internal_req).await {
142 Ok(response) => {
143 let latency_us = start.elapsed().as_micros() as u64;
144 state.metrics.record_request_with_key(
145 MetricsTransport::Http,
146 latency_us,
147 response.allowed,
148 &req.key,
149 );
150 Ok(Json(response))
151 }
152 Err(e) => {
153 tracing::error!("Rate limiter error: {}", e);
154 let latency_us = start.elapsed().as_micros() as u64;
155 state
156 .metrics
157 .record_error(MetricsTransport::Http, latency_us);
158 Err((
159 StatusCode::INTERNAL_SERVER_ERROR,
160 Json(HttpErrorResponse {
161 error: format!("Internal server error: {e}"),
162 }),
163 ))
164 }
165 }
166}
167
168async fn handle_metrics(State(state): State<Arc<AppState>>) -> Result<String, StatusCode> {
169 Ok(state.metrics.export_prometheus())
170}