throttlecrab_server/transport/
http.rs1use super::Transport;
43use crate::actor::RateLimiterHandle;
44use crate::types::{ThrottleRequest as InternalRequest, ThrottleResponse};
45use anyhow::Result;
46use async_trait::async_trait;
47use axum::{Router, extract::State, http::StatusCode, response::Json, routing::post};
48use serde::{Deserialize, Serialize};
49use std::net::SocketAddr;
50use std::sync::Arc;
51use std::time::SystemTime;
52
53#[derive(Debug, Serialize, Deserialize)]
55pub struct HttpThrottleRequest {
56 pub key: String,
58 pub max_burst: i64,
60 pub count_per_period: i64,
62 pub period: i64,
64 pub quantity: Option<i64>,
66}
67
68#[derive(Debug, Serialize, Deserialize)]
70pub struct HttpErrorResponse {
71 pub error: String,
73}
74
75pub struct HttpTransport {
79 addr: SocketAddr,
80}
81
82impl HttpTransport {
83 pub fn new(host: &str, port: u16) -> Self {
84 let addr = format!("{host}:{port}").parse().expect("Invalid address");
85 Self { addr }
86 }
87}
88
89#[async_trait]
90impl Transport for HttpTransport {
91 async fn start(self, limiter: RateLimiterHandle) -> Result<()> {
92 let app_state = Arc::new(AppState { limiter });
93
94 let app = Router::new()
95 .route("/throttle", post(handle_throttle))
96 .route("/health", axum::routing::get(|| async { "OK" }))
97 .with_state(app_state);
98
99 tracing::info!("HTTP server listening on {}", self.addr);
100
101 let listener = tokio::net::TcpListener::bind(self.addr).await?;
102 axum::serve(listener, app).await?;
103
104 Ok(())
105 }
106}
107
108struct AppState {
109 limiter: RateLimiterHandle,
110}
111
112async fn handle_throttle(
113 State(state): State<Arc<AppState>>,
114 Json(req): Json<HttpThrottleRequest>,
115) -> Result<Json<ThrottleResponse>, (StatusCode, Json<HttpErrorResponse>)> {
116 let timestamp = SystemTime::now();
118
119 let internal_req = InternalRequest {
120 key: req.key,
121 max_burst: req.max_burst,
122 count_per_period: req.count_per_period,
123 period: req.period,
124 quantity: req.quantity.unwrap_or(1),
125 timestamp,
126 };
127
128 match state.limiter.throttle(internal_req).await {
129 Ok(response) => Ok(Json(response)),
130 Err(e) => {
131 tracing::error!("Rate limiter error: {}", e);
132 Err((
133 StatusCode::INTERNAL_SERVER_ERROR,
134 Json(HttpErrorResponse {
135 error: format!("Internal server error: {e}"),
136 }),
137 ))
138 }
139 }
140}