ruvector_security/
middleware.rs1use crate::{
6 auth::{AuthMiddleware, AuthMode},
7 error::SecurityError,
8 rate_limit::{OperationType, RateLimiter},
9};
10use axum::{
11 body::Body,
12 extract::{ConnectInfo, Request, State},
13 http::{header::AUTHORIZATION, StatusCode},
14 middleware::Next,
15 response::{IntoResponse, Response},
16};
17use std::net::SocketAddr;
18use std::sync::Arc;
19
20#[derive(Clone)]
22pub struct SecurityState {
23 pub auth: AuthMiddleware,
25 pub rate_limiter: RateLimiter,
27}
28
29impl SecurityState {
30 pub fn new(auth: AuthMiddleware, rate_limiter: RateLimiter) -> Self {
32 Self { auth, rate_limiter }
33 }
34
35 pub fn development() -> Self {
37 Self {
38 auth: AuthMiddleware::none(),
39 rate_limiter: RateLimiter::disabled(),
40 }
41 }
42
43 pub fn production(token: &str) -> Self {
45 Self {
46 auth: AuthMiddleware::bearer(token),
47 rate_limiter: RateLimiter::default(),
48 }
49 }
50}
51
52impl Default for SecurityState {
53 fn default() -> Self {
54 Self::development()
55 }
56}
57
58pub async fn auth_layer(
74 State(security): State<SecurityState>,
75 ConnectInfo(addr): ConnectInfo<SocketAddr>,
76 request: Request,
77 next: Next,
78) -> Response {
79 let remote_addr = addr.to_string();
81 if security.auth.is_localhost_allowed(&remote_addr) {
82 return next.run(request).await;
83 }
84
85 if *security.auth.mode() == AuthMode::None {
87 return next.run(request).await;
88 }
89
90 let auth_header = request
92 .headers()
93 .get(AUTHORIZATION)
94 .and_then(|h| h.to_str().ok());
95
96 match security.auth.validate_header(auth_header) {
98 Ok(()) => next.run(request).await,
99 Err(e) => {
100 let (status, message) = match e {
101 SecurityError::AuthenticationRequired => {
102 (StatusCode::UNAUTHORIZED, "Authentication required")
103 }
104 SecurityError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
105 _ => (StatusCode::INTERNAL_SERVER_ERROR, "Authentication error"),
106 };
107 (status, message).into_response()
108 }
109 }
110}
111
112pub async fn rate_limit_layer(
128 State(security): State<SecurityState>,
129 ConnectInfo(addr): ConnectInfo<SocketAddr>,
130 request: Request,
131 next: Next,
132) -> Response {
133 let ip = addr.ip().to_string();
134
135 let op_type = match *request.method() {
137 axum::http::Method::GET | axum::http::Method::HEAD => OperationType::Read,
138 axum::http::Method::POST | axum::http::Method::PUT | axum::http::Method::DELETE => {
139 OperationType::Write
140 }
141 _ => OperationType::Read,
142 };
143
144 match security.rate_limiter.check(op_type, Some(&ip)).await {
146 Ok(()) => next.run(request).await,
147 Err(SecurityError::RateLimitExceeded { retry_after_secs }) => {
148 let mut response = (
149 StatusCode::TOO_MANY_REQUESTS,
150 format!("Rate limit exceeded. Retry after {} seconds.", retry_after_secs),
151 )
152 .into_response();
153
154 response.headers_mut().insert(
156 "Retry-After",
157 retry_after_secs.to_string().parse().unwrap(),
158 );
159
160 response
161 }
162 Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Rate limiting error").into_response(),
163 }
164}
165
166pub async fn security_layer(
170 State(security): State<SecurityState>,
171 ConnectInfo(addr): ConnectInfo<SocketAddr>,
172 request: Request,
173 next: Next,
174) -> Response {
175 let remote_addr = addr.to_string();
176 let ip = addr.ip().to_string();
177
178 let is_localhost = security.auth.is_localhost_allowed(&remote_addr);
180 let is_no_auth = *security.auth.mode() == AuthMode::None;
181
182 if !is_localhost && !is_no_auth {
183 let auth_header = request
185 .headers()
186 .get(AUTHORIZATION)
187 .and_then(|h| h.to_str().ok());
188
189 if let Err(e) = security.auth.validate_header(auth_header) {
190 let (status, message) = match e {
191 SecurityError::AuthenticationRequired => {
192 (StatusCode::UNAUTHORIZED, "Authentication required")
193 }
194 SecurityError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
195 _ => (StatusCode::INTERNAL_SERVER_ERROR, "Authentication error"),
196 };
197 return (status, message).into_response();
198 }
199 }
200
201 let op_type = match *request.method() {
203 axum::http::Method::GET | axum::http::Method::HEAD => OperationType::Read,
204 axum::http::Method::POST | axum::http::Method::PUT | axum::http::Method::DELETE => {
205 OperationType::Write
206 }
207 _ => OperationType::Read,
208 };
209
210 if let Err(SecurityError::RateLimitExceeded { retry_after_secs }) =
211 security.rate_limiter.check(op_type, Some(&ip)).await
212 {
213 let mut response = (
214 StatusCode::TOO_MANY_REQUESTS,
215 format!("Rate limit exceeded. Retry after {} seconds.", retry_after_secs),
216 )
217 .into_response();
218
219 response.headers_mut().insert(
220 "Retry-After",
221 retry_after_secs.to_string().parse().unwrap(),
222 );
223
224 return response;
225 }
226
227 next.run(request).await
228}
229
230pub struct RateLimitHeaders {
234 pub limit: u32,
235 pub remaining: u32,
236 pub reset: u64,
237}
238
239impl RateLimitHeaders {
240 pub fn apply_to_response(&self, mut response: Response) -> Response {
242 let headers = response.headers_mut();
243 headers.insert("X-RateLimit-Limit", self.limit.into());
244 headers.insert("X-RateLimit-Remaining", self.remaining.into());
245 headers.insert("X-RateLimit-Reset", self.reset.into());
246 response
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::rate_limit::RateLimitConfig;
254
255 #[test]
256 fn test_security_state_development() {
257 let state = SecurityState::development();
258 assert_eq!(*state.auth.mode(), AuthMode::None);
259 }
260
261 #[test]
262 fn test_security_state_production() {
263 let state = SecurityState::production("secret_token");
264 assert_eq!(*state.auth.mode(), AuthMode::Bearer);
265 }
266}