ruvector_security/
auth.rs1use crate::error::{SecurityError, SecurityResult};
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use sha2::Sha256;
9use std::sync::Arc;
10use subtle::ConstantTimeEq;
11
12type HmacSha256 = Hmac<Sha256>;
13
14#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
16#[serde(rename_all = "lowercase")]
17pub enum AuthMode {
18 #[default]
20 None,
21 Bearer,
23 Mtls,
25}
26
27#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
29pub struct AuthConfig {
30 pub mode: AuthMode,
32 #[serde(skip_serializing)]
34 pub token: Option<String>,
35 #[serde(skip_serializing)]
37 pub secret_key: Option<String>,
38 pub token_expiry_secs: u64,
40 pub allow_localhost: bool,
42}
43
44impl Default for AuthConfig {
45 fn default() -> Self {
46 Self {
47 mode: AuthMode::None,
48 token: None,
49 secret_key: None,
50 token_expiry_secs: 0,
51 allow_localhost: true,
52 }
53 }
54}
55
56pub trait TokenValidator: Send + Sync {
58 fn validate(&self, token: &str) -> SecurityResult<()>;
60}
61
62#[derive(Clone)]
64pub struct BearerTokenValidator {
65 token_hash: Vec<u8>,
67 hmac_key: Vec<u8>,
69}
70
71impl BearerTokenValidator {
72 pub fn new(expected_token: &str) -> Self {
74 let mut hmac_key = vec![0u8; 32];
76 rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut hmac_key);
77
78 let mut mac = HmacSha256::new_from_slice(&hmac_key).expect("HMAC can take key of any size");
80 mac.update(expected_token.as_bytes());
81 let token_hash = mac.finalize().into_bytes().to_vec();
82
83 Self {
84 token_hash,
85 hmac_key,
86 }
87 }
88
89 pub fn from_env(env_var: &str) -> Option<Self> {
91 std::env::var(env_var).ok().map(|token| Self::new(&token))
92 }
93}
94
95impl TokenValidator for BearerTokenValidator {
96 fn validate(&self, token: &str) -> SecurityResult<()> {
97 let mut mac =
99 HmacSha256::new_from_slice(&self.hmac_key).expect("HMAC can take key of any size");
100 mac.update(token.as_bytes());
101 let token_hash = mac.finalize().into_bytes();
102
103 if token_hash.ct_eq(&self.token_hash).into() {
105 Ok(())
106 } else {
107 Err(SecurityError::InvalidToken)
108 }
109 }
110}
111
112#[derive(Clone)]
114pub struct AuthMiddleware {
115 validator: Option<Arc<dyn TokenValidator>>,
116 config: AuthConfig,
117}
118
119impl AuthMiddleware {
120 pub fn new(config: AuthConfig) -> Self {
122 let validator: Option<Arc<dyn TokenValidator>> = match &config.mode {
123 AuthMode::None => None,
124 AuthMode::Bearer => config
125 .token
126 .as_ref()
127 .map(|t| Arc::new(BearerTokenValidator::new(t)) as Arc<dyn TokenValidator>),
128 AuthMode::Mtls => {
129 tracing::warn!("mTLS authentication not yet implemented, falling back to None");
130 None
131 }
132 };
133
134 Self { validator, config }
135 }
136
137 pub fn none() -> Self {
139 Self::new(AuthConfig::default())
140 }
141
142 pub fn bearer(token: &str) -> Self {
144 Self::new(AuthConfig {
145 mode: AuthMode::Bearer,
146 token: Some(token.to_string()),
147 ..Default::default()
148 })
149 }
150
151 pub fn from_env(env_var: &str) -> Self {
153 match std::env::var(env_var) {
154 Ok(token) if !token.is_empty() => Self::bearer(&token),
155 _ => Self::none(),
156 }
157 }
158
159 pub fn validate_header(&self, auth_header: Option<&str>) -> SecurityResult<()> {
161 if self.config.mode == AuthMode::None {
163 return Ok(());
164 }
165
166 let validator = self
168 .validator
169 .as_ref()
170 .ok_or(SecurityError::AuthenticationRequired)?;
171
172 let header = auth_header.ok_or(SecurityError::AuthenticationRequired)?;
174
175 let token = header
177 .strip_prefix("Bearer ")
178 .or_else(|| header.strip_prefix("bearer "))
179 .ok_or(SecurityError::InvalidToken)?;
180
181 validator.validate(token)
182 }
183
184 pub fn is_localhost_allowed(&self, remote_addr: &str) -> bool {
186 if !self.config.allow_localhost {
187 return false;
188 }
189
190 remote_addr.starts_with("127.0.0.1")
191 || remote_addr.starts_with("::1")
192 || remote_addr.starts_with("localhost")
193 }
194
195 pub fn mode(&self) -> &AuthMode {
197 &self.config.mode
198 }
199}
200
201impl Default for AuthMiddleware {
202 fn default() -> Self {
203 Self::none()
204 }
205}
206
207pub fn generate_token() -> String {
209 let mut bytes = [0u8; 32];
210 rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut bytes);
211 BASE64.encode(bytes)
212}
213
214pub fn generate_prefixed_token(prefix: &str) -> String {
216 format!("{}_{}", prefix, generate_token())
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_bearer_token_validation() {
225 let token = "my_secret_token";
226 let validator = BearerTokenValidator::new(token);
227
228 assert!(validator.validate(token).is_ok());
229 assert!(validator.validate("wrong_token").is_err());
230 }
231
232 #[test]
233 fn test_auth_middleware_bearer() {
234 let token = "test_token_12345";
235 let middleware = AuthMiddleware::bearer(token);
236
237 assert!(middleware
239 .validate_header(Some(&format!("Bearer {}", token)))
240 .is_ok());
241
242 assert!(middleware
244 .validate_header(Some("Bearer wrong_token"))
245 .is_err());
246
247 assert!(middleware.validate_header(None).is_err());
249 }
250
251 #[test]
252 fn test_auth_middleware_none() {
253 let middleware = AuthMiddleware::none();
254
255 assert!(middleware.validate_header(None).is_ok());
257 assert!(middleware.validate_header(Some("anything")).is_ok());
258 }
259
260 #[test]
261 fn test_generate_token() {
262 let token1 = generate_token();
263 let token2 = generate_token();
264
265 assert_ne!(token1, token2);
267
268 assert!(token1.len() >= 40);
270 }
271
272 #[test]
273 fn test_localhost_bypass() {
274 let config = AuthConfig {
275 mode: AuthMode::Bearer,
276 token: Some("secret".to_string()),
277 allow_localhost: true,
278 ..Default::default()
279 };
280 let middleware = AuthMiddleware::new(config);
281
282 assert!(middleware.is_localhost_allowed("127.0.0.1:8080"));
283 assert!(middleware.is_localhost_allowed("::1:8080"));
284 assert!(!middleware.is_localhost_allowed("192.168.1.1:8080"));
285 }
286}