trace_weft_server/
auth.rs1use std::collections::HashMap;
9
10use axum::http::HeaderMap;
11use sha2::{Digest, Sha256};
12use subtle::ConstantTimeEq;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum Auth {
17 Project(String),
20 DevBypass,
23}
24
25impl Auth {
26 pub fn project(&self) -> Option<&str> {
29 match self {
30 Auth::Project(p) => Some(p),
31 Auth::DevBypass => None,
32 }
33 }
34}
35
36#[derive(Clone, Default)]
39pub struct AuthConfig {
40 keys: HashMap<String, String>,
41 dev_mode: bool,
42}
43
44impl AuthConfig {
45 pub fn from_env() -> Self {
52 let raw_keys = raw_keys_from_env();
53 let dev_mode = dev_mode_from_env().unwrap_or(false);
54 let config = Self::new(raw_keys, dev_mode);
55 if config.keys.is_empty() && !dev_mode {
56 tracing::warn!(
57 "No API keys configured and dev mode is off; all requests will be rejected with 401. \
58 Set TRACE_WEFT_API_KEYS or TRACE_WEFT_DEV_MODE=1."
59 );
60 }
61 config
62 }
63
64 pub fn from_env_local_first() -> Self {
71 let raw_keys = raw_keys_from_env();
72 let dev_mode = dev_mode_from_env().unwrap_or(raw_keys.is_empty());
73 Self::new(raw_keys, dev_mode)
74 }
75
76 pub fn new(raw_keys: impl IntoIterator<Item = (String, String)>, dev_mode: bool) -> Self {
79 let keys = raw_keys
80 .into_iter()
81 .map(|(key, project)| (hash_key(&key), project))
82 .collect();
83 Self { keys, dev_mode }
84 }
85
86 pub fn authenticate(&self, headers: &HeaderMap) -> Option<Auth> {
90 if let Some(project) = bearer_token(headers).and_then(|token| self.lookup(&token)) {
91 return Some(Auth::Project(project));
92 }
93 self.dev_mode.then_some(Auth::DevBypass)
94 }
95
96 fn lookup(&self, presented: &str) -> Option<String> {
100 let presented_hash = hash_key(presented);
101 let mut matched: Option<String> = None;
102 for (stored_hash, project) in &self.keys {
103 if bool::from(stored_hash.as_bytes().ct_eq(presented_hash.as_bytes())) {
105 matched = Some(project.clone());
106 }
107 }
108 matched
109 }
110}
111
112fn raw_keys_from_env() -> Vec<(String, String)> {
115 std::env::var("TRACE_WEFT_API_KEYS")
116 .unwrap_or_default()
117 .split(',')
118 .filter_map(|pair| {
119 let (key, project) = pair.trim().split_once(':')?;
120 let (key, project) = (key.trim(), project.trim());
121 (!key.is_empty() && !project.is_empty()).then(|| (key.to_string(), project.to_string()))
122 })
123 .collect()
124}
125
126fn dev_mode_from_env() -> Option<bool> {
129 match std::env::var("TRACE_WEFT_DEV_MODE").as_deref() {
130 Ok("1") | Ok("true") => Some(true),
131 Ok("0") | Ok("false") => Some(false),
132 _ => None,
133 }
134}
135
136fn bearer_token(headers: &HeaderMap) -> Option<String> {
137 let value = headers.get("Authorization")?.to_str().ok()?;
138 value
139 .strip_prefix("Bearer ")
140 .map(|token| token.trim().to_string())
141 .filter(|token| !token.is_empty())
142}
143
144fn hash_key(key: &str) -> String {
145 let mut hasher = Sha256::new();
146 hasher.update(key.as_bytes());
147 hasher
148 .finalize()
149 .iter()
150 .map(|byte| format!("{byte:02x}"))
151 .collect()
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use axum::http::HeaderValue;
158
159 fn headers_with(auth: &str) -> HeaderMap {
160 let mut headers = HeaderMap::new();
161 headers.insert("Authorization", HeaderValue::from_str(auth).unwrap());
162 headers
163 }
164
165 fn config() -> AuthConfig {
166 AuthConfig::new(
167 [
168 ("tw-alpha-key".to_string(), "proj_alpha".to_string()),
169 ("tw-beta-key".to_string(), "proj_beta".to_string()),
170 ],
171 false,
172 )
173 }
174
175 #[test]
176 fn valid_key_resolves_to_its_project() {
177 let auth = config().authenticate(&headers_with("Bearer tw-alpha-key"));
178 assert_eq!(auth, Some(Auth::Project("proj_alpha".to_string())));
179
180 let auth = config().authenticate(&headers_with("Bearer tw-beta-key"));
181 assert_eq!(auth, Some(Auth::Project("proj_beta".to_string())));
182 }
183
184 #[test]
185 fn unknown_key_is_rejected() {
186 assert_eq!(
187 config().authenticate(&headers_with("Bearer tw-unknown")),
188 None
189 );
190 }
191
192 #[test]
193 fn missing_or_malformed_header_is_rejected() {
194 assert_eq!(config().authenticate(&HeaderMap::new()), None);
195 assert_eq!(config().authenticate(&headers_with("tw-alpha-key")), None);
196 assert_eq!(config().authenticate(&headers_with("Bearer ")), None);
197 }
198
199 #[test]
200 fn dev_bypass_only_works_when_enabled() {
201 let strict = AuthConfig::new([], false);
203 assert_eq!(strict.authenticate(&HeaderMap::new()), None);
204
205 let dev = AuthConfig::new([], true);
207 assert_eq!(dev.authenticate(&HeaderMap::new()), Some(Auth::DevBypass));
208 }
209
210 #[test]
211 fn valid_key_takes_precedence_over_dev_bypass() {
212 let dev = AuthConfig::new(
213 [("tw-alpha-key".to_string(), "proj_alpha".to_string())],
214 true,
215 );
216 assert_eq!(
218 dev.authenticate(&headers_with("Bearer tw-alpha-key")),
219 Some(Auth::Project("proj_alpha".to_string()))
220 );
221 assert_eq!(
223 dev.authenticate(&headers_with("Bearer tw-nope")),
224 Some(Auth::DevBypass)
225 );
226 }
227
228 #[test]
229 fn local_first_defaults_bypass_on_only_without_keys() {
230 unsafe {
232 std::env::remove_var("TRACE_WEFT_DEV_MODE");
233 std::env::remove_var("TRACE_WEFT_API_KEYS");
234 }
235
236 assert_eq!(
238 AuthConfig::from_env_local_first().authenticate(&HeaderMap::new()),
239 Some(Auth::DevBypass)
240 );
241 assert_eq!(AuthConfig::from_env().authenticate(&HeaderMap::new()), None);
243
244 unsafe { std::env::set_var("TRACE_WEFT_API_KEYS", "tw-x:proj_x") }
246 let local = AuthConfig::from_env_local_first();
247 assert_eq!(local.authenticate(&HeaderMap::new()), None);
248 assert_eq!(
249 local.authenticate(&headers_with("Bearer tw-x")),
250 Some(Auth::Project("proj_x".to_string()))
251 );
252
253 unsafe { std::env::set_var("TRACE_WEFT_DEV_MODE", "1") }
255 assert_eq!(
256 AuthConfig::from_env_local_first().authenticate(&HeaderMap::new()),
257 Some(Auth::DevBypass)
258 );
259
260 unsafe {
261 std::env::remove_var("TRACE_WEFT_DEV_MODE");
262 std::env::remove_var("TRACE_WEFT_API_KEYS");
263 }
264 }
265
266 #[test]
267 fn stored_config_holds_hashes_not_raw_keys() {
268 let config = config();
269 assert!(!config.keys.contains_key("tw-alpha-key"));
271 assert!(config.keys.keys().all(|k| k.len() == 64));
273 }
274}