throttlecrab_server/transport/
http.rs

1use super::Transport;
2use crate::actor::RateLimiterHandle;
3use crate::types::{ThrottleRequest as InternalRequest, ThrottleResponse};
4use anyhow::Result;
5use async_trait::async_trait;
6use axum::{Router, extract::State, http::StatusCode, response::Json, routing::post};
7use serde::{Deserialize, Serialize};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::SystemTime;
11
12#[derive(Debug, Serialize, Deserialize)]
13pub struct HttpThrottleRequest {
14    pub key: String,
15    pub max_burst: i64,
16    pub count_per_period: i64,
17    pub period: i64,
18    pub quantity: Option<i64>,
19    pub timestamp: Option<i64>, // Optional timestamp in nanoseconds
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23pub struct HttpErrorResponse {
24    pub error: String,
25}
26
27pub struct HttpTransport {
28    addr: SocketAddr,
29}
30
31impl HttpTransport {
32    pub fn new(host: &str, port: u16) -> Self {
33        let addr = format!("{host}:{port}").parse().expect("Invalid address");
34        Self { addr }
35    }
36}
37
38#[async_trait]
39impl Transport for HttpTransport {
40    async fn start(self, limiter: RateLimiterHandle) -> Result<()> {
41        let app_state = Arc::new(AppState { limiter });
42
43        let app = Router::new()
44            .route("/throttle", post(handle_throttle))
45            .route("/health", axum::routing::get(|| async { "OK" }))
46            .with_state(app_state);
47
48        tracing::info!("HTTP server listening on {}", self.addr);
49
50        let listener = tokio::net::TcpListener::bind(self.addr).await?;
51        axum::serve(listener, app).await?;
52
53        Ok(())
54    }
55}
56
57struct AppState {
58    limiter: RateLimiterHandle,
59}
60
61async fn handle_throttle(
62    State(state): State<Arc<AppState>>,
63    Json(req): Json<HttpThrottleRequest>,
64) -> Result<Json<ThrottleResponse>, (StatusCode, Json<HttpErrorResponse>)> {
65    let timestamp = if let Some(nanos) = req.timestamp {
66        std::time::UNIX_EPOCH + std::time::Duration::from_nanos(nanos as u64)
67    } else {
68        SystemTime::now()
69    };
70
71    let internal_req = InternalRequest {
72        key: req.key,
73        max_burst: req.max_burst,
74        count_per_period: req.count_per_period,
75        period: req.period,
76        quantity: req.quantity.unwrap_or(1),
77        timestamp,
78    };
79
80    match state.limiter.throttle(internal_req).await {
81        Ok(response) => Ok(Json(response)),
82        Err(e) => {
83            tracing::error!("Rate limiter error: {}", e);
84            Err((
85                StatusCode::INTERNAL_SERVER_ERROR,
86                Json(HttpErrorResponse {
87                    error: format!("Internal server error: {e}"),
88                }),
89            ))
90        }
91    }
92}