ruvector_scipix/api/
middleware.rs

1use axum::{
2    extract::{Request, State},
3    http::HeaderMap,
4    middleware::Next,
5    response::Response,
6};
7use governor::{
8    clock::DefaultClock,
9    state::{InMemoryState, NotKeyed},
10    Quota, RateLimiter,
11};
12use nonzero_ext::nonzero;
13use sha2::{Sha256, Digest};
14use std::sync::Arc;
15use tracing::{debug, warn};
16
17use super::{responses::ErrorResponse, state::AppState};
18
19/// Authentication middleware
20/// Validates app_id and app_key from headers or query parameters
21pub async fn auth_middleware(
22    State(state): State<AppState>,
23    headers: HeaderMap,
24    request: Request,
25    next: Next,
26) -> Result<Response, ErrorResponse> {
27    // Check if authentication is enabled
28    if !state.auth_enabled {
29        debug!("Authentication disabled, allowing request");
30        return Ok(next.run(request).await);
31    }
32
33    // Extract credentials from headers
34    let app_id = headers
35        .get("app_id")
36        .and_then(|v| v.to_str().ok())
37        .or_else(|| {
38            // Fallback to query parameters
39            request
40                .uri()
41                .query()
42                .and_then(|q| extract_query_param(q, "app_id"))
43        });
44
45    let app_key = headers
46        .get("app_key")
47        .and_then(|v| v.to_str().ok())
48        .or_else(|| {
49            request
50                .uri()
51                .query()
52                .and_then(|q| extract_query_param(q, "app_key"))
53        });
54
55    // Validate credentials
56    match (app_id, app_key) {
57        (Some(id), Some(key)) => {
58            if validate_credentials(&state, id, key).await {
59                debug!("Authentication successful for app_id: {}", id);
60                Ok(next.run(request).await)
61            } else {
62                warn!("Invalid credentials for app_id: {}", id);
63                Err(ErrorResponse::unauthorized("Invalid credentials"))
64            }
65        }
66        _ => {
67            warn!("Missing authentication credentials");
68            Err(ErrorResponse::unauthorized("Missing app_id or app_key"))
69        }
70    }
71}
72
73/// Rate limiting middleware using token bucket algorithm
74pub async fn rate_limit_middleware(
75    State(state): State<AppState>,
76    request: Request,
77    next: Next,
78) -> Result<Response, ErrorResponse> {
79    // Check rate limit
80    match state.rate_limiter.check() {
81        Ok(_) => {
82            debug!("Rate limit check passed");
83            Ok(next.run(request).await)
84        }
85        Err(_) => {
86            warn!("Rate limit exceeded");
87            Err(ErrorResponse::rate_limited(
88                "Rate limit exceeded. Please try again later.",
89            ))
90        }
91    }
92}
93
94/// Validate app credentials using secure comparison
95///
96/// SECURITY: This implementation:
97/// 1. Requires credentials to be pre-configured in AppState
98/// 2. Uses constant-time comparison to prevent timing attacks
99/// 3. Hashes the key before comparison
100async fn validate_credentials(state: &AppState, app_id: &str, app_key: &str) -> bool {
101    // Reject empty credentials
102    if app_id.is_empty() || app_key.is_empty() {
103        return false;
104    }
105
106    // Get configured credentials from state
107    let Some(expected_key_hash) = state.api_keys.get(app_id) else {
108        warn!("Unknown app_id attempted authentication: {}", app_id);
109        return false;
110    };
111
112    // Hash the provided key
113    let provided_key_hash = hash_api_key(app_key);
114
115    // Constant-time comparison to prevent timing attacks
116    constant_time_compare(&provided_key_hash, expected_key_hash.as_str())
117}
118
119/// Hash an API key using SHA-256
120fn hash_api_key(key: &str) -> String {
121    let mut hasher = Sha256::new();
122    hasher.update(key.as_bytes());
123    format!("{:x}", hasher.finalize())
124}
125
126/// Constant-time string comparison to prevent timing attacks
127fn constant_time_compare(a: &str, b: &str) -> bool {
128    if a.len() != b.len() {
129        return false;
130    }
131
132    let mut result = 0u8;
133    for (x, y) in a.bytes().zip(b.bytes()) {
134        result |= x ^ y;
135    }
136    result == 0
137}
138
139/// Extract query parameter from query string
140fn extract_query_param<'a>(query: &'a str, param: &str) -> Option<&'a str> {
141    query
142        .split('&')
143        .find_map(|pair| {
144            let mut parts = pair.split('=');
145            match (parts.next(), parts.next()) {
146                (Some(k), Some(v)) if k == param => Some(v),
147                _ => None,
148            }
149        })
150}
151
152/// Create a rate limiter with token bucket algorithm
153pub fn create_rate_limiter() -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
154    // Allow 100 requests per minute
155    let quota = Quota::per_minute(nonzero!(100u32));
156    Arc::new(RateLimiter::direct(quota))
157}
158
159/// Type alias for rate limiter
160pub type AppRateLimiter = Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>;
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_extract_query_param() {
168        let query = "app_id=123&app_key=secret&foo=bar";
169        assert_eq!(extract_query_param(query, "app_id"), Some("123"));
170        assert_eq!(extract_query_param(query, "app_key"), Some("secret"));
171        assert_eq!(extract_query_param(query, "foo"), Some("bar"));
172        assert_eq!(extract_query_param(query, "missing"), None);
173    }
174
175    #[test]
176    fn test_hash_api_key() {
177        let key = "test_key_123";
178        let hash1 = hash_api_key(key);
179        let hash2 = hash_api_key(key);
180        assert_eq!(hash1, hash2);
181        assert_ne!(hash_api_key("different"), hash1);
182    }
183
184    #[test]
185    fn test_constant_time_compare() {
186        assert!(constant_time_compare("abc", "abc"));
187        assert!(!constant_time_compare("abc", "abd"));
188        assert!(!constant_time_compare("abc", "ab"));
189        assert!(!constant_time_compare("", "a"));
190    }
191
192    #[tokio::test]
193    async fn test_validate_credentials_rejects_empty() {
194        let state = AppState::new();
195        assert!(!validate_credentials(&state, "", "key").await);
196        assert!(!validate_credentials(&state, "test", "").await);
197        assert!(!validate_credentials(&state, "", "").await);
198    }
199}