throttlecrab_server/transport/
http.rs1use super::Transport;
2use crate::actor::RateLimiterHandle;
3use crate::types::{ThrottleRequest as InternalRequest, ThrottleResponse};
4use anyhow::Result;
5use async_trait::async_trait;
6use axum::{Router, extract::State, http::StatusCode, response::Json, routing::post};
7use serde::{Deserialize, Serialize};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::SystemTime;
11
12#[derive(Debug, Serialize, Deserialize)]
13pub struct HttpThrottleRequest {
14 pub key: String,
15 pub max_burst: i64,
16 pub count_per_period: i64,
17 pub period: i64,
18 pub quantity: Option<i64>,
19 pub timestamp: Option<i64>, }
21
22#[derive(Debug, Serialize, Deserialize)]
23pub struct HttpErrorResponse {
24 pub error: String,
25}
26
27pub struct HttpTransport {
28 addr: SocketAddr,
29}
30
31impl HttpTransport {
32 pub fn new(host: &str, port: u16) -> Self {
33 let addr = format!("{host}:{port}").parse().expect("Invalid address");
34 Self { addr }
35 }
36}
37
38#[async_trait]
39impl Transport for HttpTransport {
40 async fn start(self, limiter: RateLimiterHandle) -> Result<()> {
41 let app_state = Arc::new(AppState { limiter });
42
43 let app = Router::new()
44 .route("/throttle", post(handle_throttle))
45 .route("/health", axum::routing::get(|| async { "OK" }))
46 .with_state(app_state);
47
48 tracing::info!("HTTP server listening on {}", self.addr);
49
50 let listener = tokio::net::TcpListener::bind(self.addr).await?;
51 axum::serve(listener, app).await?;
52
53 Ok(())
54 }
55}
56
57struct AppState {
58 limiter: RateLimiterHandle,
59}
60
61async fn handle_throttle(
62 State(state): State<Arc<AppState>>,
63 Json(req): Json<HttpThrottleRequest>,
64) -> Result<Json<ThrottleResponse>, (StatusCode, Json<HttpErrorResponse>)> {
65 let timestamp = if let Some(nanos) = req.timestamp {
66 std::time::UNIX_EPOCH + std::time::Duration::from_nanos(nanos as u64)
67 } else {
68 SystemTime::now()
69 };
70
71 let internal_req = InternalRequest {
72 key: req.key,
73 max_burst: req.max_burst,
74 count_per_period: req.count_per_period,
75 period: req.period,
76 quantity: req.quantity.unwrap_or(1),
77 timestamp,
78 };
79
80 match state.limiter.throttle(internal_req).await {
81 Ok(response) => Ok(Json(response)),
82 Err(e) => {
83 tracing::error!("Rate limiter error: {}", e);
84 Err((
85 StatusCode::INTERNAL_SERVER_ERROR,
86 Json(HttpErrorResponse {
87 error: format!("Internal server error: {e}"),
88 }),
89 ))
90 }
91 }
92}