throttlecrab_server/transport/
http.rs

1//! HTTP/JSON transport for easy integration
2//!
3//! This transport provides a REST API with JSON payloads, making it easy
4//! to integrate with any programming language or tool that supports HTTP.
5//!
6//! # API Endpoints
7//!
8//! ## POST /throttle
9//!
10//! Check rate limit for a key.
11//!
12//! ### Request Body
13//!
14//! ```json
15//! {
16//!   "key": "user:123",
17//!   "max_burst": 10,
18//!   "count_per_period": 100,
19//!   "period": 60,
20//!   "quantity": 1
21//! }
22//! ```
23//!
24//! - `quantity` is optional (defaults to 1)
25//!
26//! ### Response
27//!
28//! ```json
29//! {
30//!   "allowed": true,
31//!   "limit": 10,
32//!   "remaining": 9,
33//!   "reset_after": 60,
34//!   "retry_after": 0
35//! }
36//! ```
37//!
38//! ## GET /health
39//!
40//! Health check endpoint. Returns "OK" with 200 status.
41
42use super::Transport;
43use crate::actor::RateLimiterHandle;
44use crate::metrics::{Metrics, Transport as MetricsTransport};
45use crate::types::{ThrottleRequest as InternalRequest, ThrottleResponse};
46use anyhow::Result;
47use async_trait::async_trait;
48use axum::{
49    Router,
50    extract::State,
51    http::StatusCode,
52    response::Json,
53    routing::{get, post},
54};
55use serde::{Deserialize, Serialize};
56use std::net::SocketAddr;
57use std::sync::Arc;
58use std::time::SystemTime;
59
60/// HTTP request format for rate limiting
61#[derive(Debug, Serialize, Deserialize)]
62pub struct HttpThrottleRequest {
63    /// The key to rate limit
64    pub key: String,
65    /// Maximum burst capacity
66    pub max_burst: i64,
67    /// Total requests allowed per period
68    pub count_per_period: i64,
69    /// Time period in seconds
70    pub period: i64,
71    /// Number of tokens to consume (optional, defaults to 1)
72    pub quantity: Option<i64>,
73}
74
75/// Error response format
76#[derive(Debug, Serialize, Deserialize)]
77pub struct HttpErrorResponse {
78    /// Error message
79    pub error: String,
80}
81
82/// HTTP transport implementation
83///
84/// Provides a REST API with JSON payloads for easy integration.
85pub struct HttpTransport {
86    addr: SocketAddr,
87    metrics: Arc<Metrics>,
88}
89
90impl HttpTransport {
91    pub fn new(host: &str, port: u16, metrics: Arc<Metrics>) -> Self {
92        let addr = format!("{host}:{port}").parse().expect("Invalid address");
93        Self { addr, metrics }
94    }
95}
96
97#[async_trait]
98impl Transport for HttpTransport {
99    async fn start(self, limiter: RateLimiterHandle) -> Result<()> {
100        let metrics = Arc::clone(&self.metrics);
101        let app_state = Arc::new(AppState { limiter, metrics });
102
103        let app = Router::new()
104            .route("/throttle", post(handle_throttle))
105            .route("/health", get(|| async { "OK" }))
106            .route("/metrics", get(handle_metrics))
107            .with_state(app_state);
108
109        tracing::info!("HTTP server listening on {}", self.addr);
110
111        let listener = tokio::net::TcpListener::bind(self.addr).await?;
112        axum::serve(listener, app).await?;
113
114        Ok(())
115    }
116}
117
118struct AppState {
119    limiter: RateLimiterHandle,
120    metrics: Arc<Metrics>,
121}
122
123async fn handle_throttle(
124    State(state): State<Arc<AppState>>,
125    Json(req): Json<HttpThrottleRequest>,
126) -> Result<Json<ThrottleResponse>, (StatusCode, Json<HttpErrorResponse>)> {
127    // Always use server timestamp
128    let timestamp = SystemTime::now();
129
130    let internal_req = InternalRequest {
131        key: req.key.clone(),
132        max_burst: req.max_burst,
133        count_per_period: req.count_per_period,
134        period: req.period,
135        quantity: req.quantity.unwrap_or(1),
136        timestamp,
137    };
138
139    match state.limiter.throttle(internal_req).await {
140        Ok(response) => {
141            state.metrics.record_request_with_key(
142                MetricsTransport::Http,
143                response.allowed,
144                &req.key,
145            );
146            Ok(Json(response))
147        }
148        Err(e) => {
149            tracing::error!("Rate limiter error: {}", e);
150            state.metrics.record_error(MetricsTransport::Http);
151            Err((
152                StatusCode::INTERNAL_SERVER_ERROR,
153                Json(HttpErrorResponse {
154                    error: format!("Internal server error: {e}"),
155                }),
156            ))
157        }
158    }
159}
160
161async fn handle_metrics(State(state): State<Arc<AppState>>) -> Result<String, StatusCode> {
162    Ok(state.metrics.export_prometheus())
163}