ruvector_security/
middleware.rs

1//! Axum middleware layers for security
2//!
3//! Provides ready-to-use Tower layers for authentication and rate limiting.
4
5use crate::{
6    auth::{AuthMiddleware, AuthMode},
7    error::SecurityError,
8    rate_limit::{OperationType, RateLimiter},
9};
10use axum::{
11    body::Body,
12    extract::{ConnectInfo, Request, State},
13    http::{header::AUTHORIZATION, StatusCode},
14    middleware::Next,
15    response::{IntoResponse, Response},
16};
17use std::net::SocketAddr;
18use std::sync::Arc;
19
20/// Security state for middleware
21#[derive(Clone)]
22pub struct SecurityState {
23    /// Authentication middleware
24    pub auth: AuthMiddleware,
25    /// Rate limiter
26    pub rate_limiter: RateLimiter,
27}
28
29impl SecurityState {
30    /// Create new security state
31    pub fn new(auth: AuthMiddleware, rate_limiter: RateLimiter) -> Self {
32        Self { auth, rate_limiter }
33    }
34
35    /// Create development security state (no auth, disabled rate limiting)
36    pub fn development() -> Self {
37        Self {
38            auth: AuthMiddleware::none(),
39            rate_limiter: RateLimiter::disabled(),
40        }
41    }
42
43    /// Create production security state
44    pub fn production(token: &str) -> Self {
45        Self {
46            auth: AuthMiddleware::bearer(token),
47            rate_limiter: RateLimiter::default(),
48        }
49    }
50}
51
52impl Default for SecurityState {
53    fn default() -> Self {
54        Self::development()
55    }
56}
57
58/// Authentication middleware layer for axum
59///
60/// Checks the Authorization header for a valid bearer token.
61///
62/// # Example
63///
64/// ```rust,ignore
65/// use axum::{Router, routing::get, middleware};
66/// use ruvector_security::middleware::{auth_layer, SecurityState};
67///
68/// let security = SecurityState::production("my_secret_token");
69/// let app = Router::new()
70///     .route("/api", get(|| async { "protected" }))
71///     .layer(middleware::from_fn_with_state(security, auth_layer));
72/// ```
73pub async fn auth_layer(
74    State(security): State<SecurityState>,
75    ConnectInfo(addr): ConnectInfo<SocketAddr>,
76    request: Request,
77    next: Next,
78) -> Response {
79    // Check if localhost bypass is allowed
80    let remote_addr = addr.to_string();
81    if security.auth.is_localhost_allowed(&remote_addr) {
82        return next.run(request).await;
83    }
84
85    // Skip auth if mode is None
86    if *security.auth.mode() == AuthMode::None {
87        return next.run(request).await;
88    }
89
90    // Get authorization header
91    let auth_header = request
92        .headers()
93        .get(AUTHORIZATION)
94        .and_then(|h| h.to_str().ok());
95
96    // Validate token
97    match security.auth.validate_header(auth_header) {
98        Ok(()) => next.run(request).await,
99        Err(e) => {
100            let (status, message) = match e {
101                SecurityError::AuthenticationRequired => {
102                    (StatusCode::UNAUTHORIZED, "Authentication required")
103                }
104                SecurityError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
105                _ => (StatusCode::INTERNAL_SERVER_ERROR, "Authentication error"),
106            };
107            (status, message).into_response()
108        }
109    }
110}
111
112/// Rate limiting middleware layer for axum
113///
114/// Applies rate limiting based on operation type and client IP.
115///
116/// # Example
117///
118/// ```rust,ignore
119/// use axum::{Router, routing::get, middleware};
120/// use ruvector_security::middleware::{rate_limit_layer, SecurityState};
121///
122/// let security = SecurityState::default();
123/// let app = Router::new()
124///     .route("/api", get(|| async { "limited" }))
125///     .layer(middleware::from_fn_with_state(security, rate_limit_layer));
126/// ```
127pub async fn rate_limit_layer(
128    State(security): State<SecurityState>,
129    ConnectInfo(addr): ConnectInfo<SocketAddr>,
130    request: Request,
131    next: Next,
132) -> Response {
133    let ip = addr.ip().to_string();
134
135    // Determine operation type from HTTP method
136    let op_type = match *request.method() {
137        axum::http::Method::GET | axum::http::Method::HEAD => OperationType::Read,
138        axum::http::Method::POST | axum::http::Method::PUT | axum::http::Method::DELETE => {
139            OperationType::Write
140        }
141        _ => OperationType::Read,
142    };
143
144    // Check rate limit
145    match security.rate_limiter.check(op_type, Some(&ip)).await {
146        Ok(()) => next.run(request).await,
147        Err(SecurityError::RateLimitExceeded { retry_after_secs }) => {
148            let mut response = (
149                StatusCode::TOO_MANY_REQUESTS,
150                format!("Rate limit exceeded. Retry after {} seconds.", retry_after_secs),
151            )
152                .into_response();
153
154            // Add Retry-After header
155            response.headers_mut().insert(
156                "Retry-After",
157                retry_after_secs.to_string().parse().unwrap(),
158            );
159
160            response
161        }
162        Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Rate limiting error").into_response(),
163    }
164}
165
166/// Combined security middleware layer (auth + rate limiting)
167///
168/// Applies both authentication and rate limiting in a single middleware.
169pub async fn security_layer(
170    State(security): State<SecurityState>,
171    ConnectInfo(addr): ConnectInfo<SocketAddr>,
172    request: Request,
173    next: Next,
174) -> Response {
175    let remote_addr = addr.to_string();
176    let ip = addr.ip().to_string();
177
178    // Skip all security for localhost in development mode
179    let is_localhost = security.auth.is_localhost_allowed(&remote_addr);
180    let is_no_auth = *security.auth.mode() == AuthMode::None;
181
182    if !is_localhost && !is_no_auth {
183        // Check authentication first
184        let auth_header = request
185            .headers()
186            .get(AUTHORIZATION)
187            .and_then(|h| h.to_str().ok());
188
189        if let Err(e) = security.auth.validate_header(auth_header) {
190            let (status, message) = match e {
191                SecurityError::AuthenticationRequired => {
192                    (StatusCode::UNAUTHORIZED, "Authentication required")
193                }
194                SecurityError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
195                _ => (StatusCode::INTERNAL_SERVER_ERROR, "Authentication error"),
196            };
197            return (status, message).into_response();
198        }
199    }
200
201    // Check rate limit (always, even for localhost)
202    let op_type = match *request.method() {
203        axum::http::Method::GET | axum::http::Method::HEAD => OperationType::Read,
204        axum::http::Method::POST | axum::http::Method::PUT | axum::http::Method::DELETE => {
205            OperationType::Write
206        }
207        _ => OperationType::Read,
208    };
209
210    if let Err(SecurityError::RateLimitExceeded { retry_after_secs }) =
211        security.rate_limiter.check(op_type, Some(&ip)).await
212    {
213        let mut response = (
214            StatusCode::TOO_MANY_REQUESTS,
215            format!("Rate limit exceeded. Retry after {} seconds.", retry_after_secs),
216        )
217            .into_response();
218
219        response.headers_mut().insert(
220            "Retry-After",
221            retry_after_secs.to_string().parse().unwrap(),
222        );
223
224        return response;
225    }
226
227    next.run(request).await
228}
229
230/// Rate limit headers extractor
231///
232/// Adds X-RateLimit-* headers to responses
233pub struct RateLimitHeaders {
234    pub limit: u32,
235    pub remaining: u32,
236    pub reset: u64,
237}
238
239impl RateLimitHeaders {
240    /// Apply rate limit headers to a response
241    pub fn apply_to_response(&self, mut response: Response) -> Response {
242        let headers = response.headers_mut();
243        headers.insert("X-RateLimit-Limit", self.limit.into());
244        headers.insert("X-RateLimit-Remaining", self.remaining.into());
245        headers.insert("X-RateLimit-Reset", self.reset.into());
246        response
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::rate_limit::RateLimitConfig;
254
255    #[test]
256    fn test_security_state_development() {
257        let state = SecurityState::development();
258        assert_eq!(*state.auth.mode(), AuthMode::None);
259    }
260
261    #[test]
262    fn test_security_state_production() {
263        let state = SecurityState::production("secret_token");
264        assert_eq!(*state.auth.mode(), AuthMode::Bearer);
265    }
266}