throttlecrab_server/transport/
http.rs1use super::Transport;
45use crate::actor::RateLimiterHandle;
46use crate::types::{ThrottleRequest as InternalRequest, ThrottleResponse};
47use anyhow::Result;
48use async_trait::async_trait;
49use axum::{Router, extract::State, http::StatusCode, response::Json, routing::post};
50use serde::{Deserialize, Serialize};
51use std::net::SocketAddr;
52use std::sync::Arc;
53use std::time::SystemTime;
54
55#[derive(Debug, Serialize, Deserialize)]
57pub struct HttpThrottleRequest {
58 pub key: String,
60 pub max_burst: i64,
62 pub count_per_period: i64,
64 pub period: i64,
66 pub quantity: Option<i64>,
68 pub timestamp: Option<i64>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
74pub struct HttpErrorResponse {
75 pub error: String,
77}
78
79pub struct HttpTransport {
83 addr: SocketAddr,
84}
85
86impl HttpTransport {
87 pub fn new(host: &str, port: u16) -> Self {
88 let addr = format!("{host}:{port}").parse().expect("Invalid address");
89 Self { addr }
90 }
91}
92
93#[async_trait]
94impl Transport for HttpTransport {
95 async fn start(self, limiter: RateLimiterHandle) -> Result<()> {
96 let app_state = Arc::new(AppState { limiter });
97
98 let app = Router::new()
99 .route("/throttle", post(handle_throttle))
100 .route("/health", axum::routing::get(|| async { "OK" }))
101 .with_state(app_state);
102
103 tracing::info!("HTTP server listening on {}", self.addr);
104
105 let listener = tokio::net::TcpListener::bind(self.addr).await?;
106 axum::serve(listener, app).await?;
107
108 Ok(())
109 }
110}
111
112struct AppState {
113 limiter: RateLimiterHandle,
114}
115
116async fn handle_throttle(
117 State(state): State<Arc<AppState>>,
118 Json(req): Json<HttpThrottleRequest>,
119) -> Result<Json<ThrottleResponse>, (StatusCode, Json<HttpErrorResponse>)> {
120 let timestamp = if let Some(nanos) = req.timestamp {
121 std::time::UNIX_EPOCH + std::time::Duration::from_nanos(nanos as u64)
122 } else {
123 SystemTime::now()
124 };
125
126 let internal_req = InternalRequest {
127 key: req.key,
128 max_burst: req.max_burst,
129 count_per_period: req.count_per_period,
130 period: req.period,
131 quantity: req.quantity.unwrap_or(1),
132 timestamp,
133 };
134
135 match state.limiter.throttle(internal_req).await {
136 Ok(response) => Ok(Json(response)),
137 Err(e) => {
138 tracing::error!("Rate limiter error: {}", e);
139 Err((
140 StatusCode::INTERNAL_SERVER_ERROR,
141 Json(HttpErrorResponse {
142 error: format!("Internal server error: {e}"),
143 }),
144 ))
145 }
146 }
147}