throttlecrab_server/transport/
http.rs

1//! HTTP/JSON transport for easy integration
2//!
3//! This transport provides a REST API with JSON payloads, making it easy
4//! to integrate with any programming language or tool that supports HTTP.
5//!
6//! # API Endpoints
7//!
8//! ## POST /throttle
9//!
10//! Check rate limit for a key.
11//!
12//! ### Request Body
13//!
14//! ```json
15//! {
16//!   "key": "user:123",
17//!   "max_burst": 10,
18//!   "count_per_period": 100,
19//!   "period": 60,
20//!   "quantity": 1
21//! }
22//! ```
23//!
24//! - `quantity` is optional (defaults to 1)
25//!
26//! ### Response
27//!
28//! ```json
29//! {
30//!   "allowed": true,
31//!   "limit": 10,
32//!   "remaining": 9,
33//!   "reset_after": 60,
34//!   "retry_after": 0
35//! }
36//! ```
37//!
38//! ## GET /health
39//!
40//! Health check endpoint. Returns "OK" with 200 status.
41
42use super::Transport;
43use crate::actor::RateLimiterHandle;
44use crate::types::{ThrottleRequest as InternalRequest, ThrottleResponse};
45use anyhow::Result;
46use async_trait::async_trait;
47use axum::{Router, extract::State, http::StatusCode, response::Json, routing::post};
48use serde::{Deserialize, Serialize};
49use std::net::SocketAddr;
50use std::sync::Arc;
51use std::time::SystemTime;
52
53/// HTTP request format for rate limiting
54#[derive(Debug, Serialize, Deserialize)]
55pub struct HttpThrottleRequest {
56    /// The key to rate limit
57    pub key: String,
58    /// Maximum burst capacity
59    pub max_burst: i64,
60    /// Total requests allowed per period
61    pub count_per_period: i64,
62    /// Time period in seconds
63    pub period: i64,
64    /// Number of tokens to consume (optional, defaults to 1)
65    pub quantity: Option<i64>,
66}
67
68/// Error response format
69#[derive(Debug, Serialize, Deserialize)]
70pub struct HttpErrorResponse {
71    /// Error message
72    pub error: String,
73}
74
75/// HTTP transport implementation
76///
77/// Provides a REST API with JSON payloads for easy integration.
78pub struct HttpTransport {
79    addr: SocketAddr,
80}
81
82impl HttpTransport {
83    pub fn new(host: &str, port: u16) -> Self {
84        let addr = format!("{host}:{port}").parse().expect("Invalid address");
85        Self { addr }
86    }
87}
88
89#[async_trait]
90impl Transport for HttpTransport {
91    async fn start(self, limiter: RateLimiterHandle) -> Result<()> {
92        let app_state = Arc::new(AppState { limiter });
93
94        let app = Router::new()
95            .route("/throttle", post(handle_throttle))
96            .route("/health", axum::routing::get(|| async { "OK" }))
97            .with_state(app_state);
98
99        tracing::info!("HTTP server listening on {}", self.addr);
100
101        let listener = tokio::net::TcpListener::bind(self.addr).await?;
102        axum::serve(listener, app).await?;
103
104        Ok(())
105    }
106}
107
108struct AppState {
109    limiter: RateLimiterHandle,
110}
111
112async fn handle_throttle(
113    State(state): State<Arc<AppState>>,
114    Json(req): Json<HttpThrottleRequest>,
115) -> Result<Json<ThrottleResponse>, (StatusCode, Json<HttpErrorResponse>)> {
116    // Always use server timestamp
117    let timestamp = SystemTime::now();
118
119    let internal_req = InternalRequest {
120        key: req.key,
121        max_burst: req.max_burst,
122        count_per_period: req.count_per_period,
123        period: req.period,
124        quantity: req.quantity.unwrap_or(1),
125        timestamp,
126    };
127
128    match state.limiter.throttle(internal_req).await {
129        Ok(response) => Ok(Json(response)),
130        Err(e) => {
131            tracing::error!("Rate limiter error: {}", e);
132            Err((
133                StatusCode::INTERNAL_SERVER_ERROR,
134                Json(HttpErrorResponse {
135                    error: format!("Internal server error: {e}"),
136                }),
137            ))
138        }
139    }
140}