ruvector_scipix/api/
middleware.rs1use axum::{
2 extract::{Request, State},
3 http::HeaderMap,
4 middleware::Next,
5 response::Response,
6};
7use governor::{
8 clock::DefaultClock,
9 state::{InMemoryState, NotKeyed},
10 Quota, RateLimiter,
11};
12use nonzero_ext::nonzero;
13use sha2::{Sha256, Digest};
14use std::sync::Arc;
15use tracing::{debug, warn};
16
17use super::{responses::ErrorResponse, state::AppState};
18
19pub async fn auth_middleware(
22 State(state): State<AppState>,
23 headers: HeaderMap,
24 request: Request,
25 next: Next,
26) -> Result<Response, ErrorResponse> {
27 if !state.auth_enabled {
29 debug!("Authentication disabled, allowing request");
30 return Ok(next.run(request).await);
31 }
32
33 let app_id = headers
35 .get("app_id")
36 .and_then(|v| v.to_str().ok())
37 .or_else(|| {
38 request
40 .uri()
41 .query()
42 .and_then(|q| extract_query_param(q, "app_id"))
43 });
44
45 let app_key = headers
46 .get("app_key")
47 .and_then(|v| v.to_str().ok())
48 .or_else(|| {
49 request
50 .uri()
51 .query()
52 .and_then(|q| extract_query_param(q, "app_key"))
53 });
54
55 match (app_id, app_key) {
57 (Some(id), Some(key)) => {
58 if validate_credentials(&state, id, key).await {
59 debug!("Authentication successful for app_id: {}", id);
60 Ok(next.run(request).await)
61 } else {
62 warn!("Invalid credentials for app_id: {}", id);
63 Err(ErrorResponse::unauthorized("Invalid credentials"))
64 }
65 }
66 _ => {
67 warn!("Missing authentication credentials");
68 Err(ErrorResponse::unauthorized("Missing app_id or app_key"))
69 }
70 }
71}
72
73pub async fn rate_limit_middleware(
75 State(state): State<AppState>,
76 request: Request,
77 next: Next,
78) -> Result<Response, ErrorResponse> {
79 match state.rate_limiter.check() {
81 Ok(_) => {
82 debug!("Rate limit check passed");
83 Ok(next.run(request).await)
84 }
85 Err(_) => {
86 warn!("Rate limit exceeded");
87 Err(ErrorResponse::rate_limited(
88 "Rate limit exceeded. Please try again later.",
89 ))
90 }
91 }
92}
93
94async fn validate_credentials(state: &AppState, app_id: &str, app_key: &str) -> bool {
101 if app_id.is_empty() || app_key.is_empty() {
103 return false;
104 }
105
106 let Some(expected_key_hash) = state.api_keys.get(app_id) else {
108 warn!("Unknown app_id attempted authentication: {}", app_id);
109 return false;
110 };
111
112 let provided_key_hash = hash_api_key(app_key);
114
115 constant_time_compare(&provided_key_hash, expected_key_hash.as_str())
117}
118
119fn hash_api_key(key: &str) -> String {
121 let mut hasher = Sha256::new();
122 hasher.update(key.as_bytes());
123 format!("{:x}", hasher.finalize())
124}
125
126fn constant_time_compare(a: &str, b: &str) -> bool {
128 if a.len() != b.len() {
129 return false;
130 }
131
132 let mut result = 0u8;
133 for (x, y) in a.bytes().zip(b.bytes()) {
134 result |= x ^ y;
135 }
136 result == 0
137}
138
139fn extract_query_param<'a>(query: &'a str, param: &str) -> Option<&'a str> {
141 query
142 .split('&')
143 .find_map(|pair| {
144 let mut parts = pair.split('=');
145 match (parts.next(), parts.next()) {
146 (Some(k), Some(v)) if k == param => Some(v),
147 _ => None,
148 }
149 })
150}
151
152pub fn create_rate_limiter() -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
154 let quota = Quota::per_minute(nonzero!(100u32));
156 Arc::new(RateLimiter::direct(quota))
157}
158
159pub type AppRateLimiter = Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>;
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_extract_query_param() {
168 let query = "app_id=123&app_key=secret&foo=bar";
169 assert_eq!(extract_query_param(query, "app_id"), Some("123"));
170 assert_eq!(extract_query_param(query, "app_key"), Some("secret"));
171 assert_eq!(extract_query_param(query, "foo"), Some("bar"));
172 assert_eq!(extract_query_param(query, "missing"), None);
173 }
174
175 #[test]
176 fn test_hash_api_key() {
177 let key = "test_key_123";
178 let hash1 = hash_api_key(key);
179 let hash2 = hash_api_key(key);
180 assert_eq!(hash1, hash2);
181 assert_ne!(hash_api_key("different"), hash1);
182 }
183
184 #[test]
185 fn test_constant_time_compare() {
186 assert!(constant_time_compare("abc", "abc"));
187 assert!(!constant_time_compare("abc", "abd"));
188 assert!(!constant_time_compare("abc", "ab"));
189 assert!(!constant_time_compare("", "a"));
190 }
191
192 #[tokio::test]
193 async fn test_validate_credentials_rejects_empty() {
194 let state = AppState::new();
195 assert!(!validate_credentials(&state, "", "key").await);
196 assert!(!validate_credentials(&state, "test", "").await);
197 assert!(!validate_credentials(&state, "", "").await);
198 }
199}