1use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use std::time::{Duration, SystemTime, UNIX_EPOCH};
15
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18use tokio::sync::RwLock;
19
20const SCHEMA_VERSION: u32 = 1;
21
22pub type ChainId = String;
23
24#[derive(Debug, thiserror::Error)]
25pub enum RefreshError {
26 #[error("unknown refresh token")]
27 Unknown,
28 #[error("refresh token expired")]
29 Expired,
30 #[error("refresh token already consumed (replay attack signal); chain {0} revoked")]
31 Replayed(ChainId),
32}
33
34#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
36struct AccessRecord {
37 token_hash: String,
38 expires_at: u64,
39 chain_id: ChainId,
40}
41
42#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
44struct RefreshRecord {
45 token_hash: String,
46 expires_at: u64,
47 chain_id: ChainId,
48 consumed_at: Option<u64>,
52}
53
54#[derive(Clone, Debug, Default, Serialize, Deserialize)]
55struct Snapshot {
56 version: u32,
57 client_id_hash: String,
60 access: Vec<AccessRecord>,
61 refresh: Vec<RefreshRecord>,
62 revoked_chains: Vec<ChainId>,
63}
64
65#[derive(Default)]
68struct Index {
69 access_by_hash: HashMap<String, AccessRecord>,
70 refresh_by_hash: HashMap<String, RefreshRecord>,
71 revoked: std::collections::HashSet<ChainId>,
72}
73
74pub struct TokenStore {
75 inner: Arc<Inner>,
76}
77
78struct Inner {
79 state: RwLock<Index>,
80 path: PathBuf,
81 access_ttl: Duration,
82 refresh_ttl: Duration,
83 client_id_hash: String,
84}
85
86fn unix_now() -> u64 {
87 SystemTime::now()
88 .duration_since(UNIX_EPOCH)
89 .unwrap_or_default()
90 .as_secs()
91}
92
93fn sha256_hex(input: &str) -> String {
94 let digest = Sha256::digest(input.as_bytes());
95 let mut hex = String::with_capacity(64);
96 for byte in digest {
97 use std::fmt::Write;
98 let _ = write!(&mut hex, "{:02x}", byte);
99 }
100 hex
101}
102
103#[derive(Debug, Clone)]
105pub struct MintedPair {
106 pub access_token: String,
107 pub refresh_token: String,
108 pub access_ttl: Duration,
109 pub refresh_ttl: Duration,
110 pub chain_id: ChainId,
111}
112
113impl TokenStore {
114 pub fn load(
121 path: PathBuf,
122 client_id: &str,
123 access_ttl: Duration,
124 refresh_ttl: Duration,
125 ) -> anyhow::Result<Self> {
126 let client_id_hash = sha256_hex(client_id);
127 let mut snapshot = match std::fs::read(&path) {
128 Ok(bytes) => match serde_json::from_slice::<Snapshot>(&bytes) {
129 Ok(snap) if snap.version == SCHEMA_VERSION => snap,
130 Ok(_) | Err(_) => {
131 let backup = path.with_extension(format!("json.broken-{}", unix_now()));
132 if let Err(e) = std::fs::rename(&path, &backup) {
133 tracing::warn!(
134 path = %path.display(),
135 backup = %backup.display(),
136 error = %e,
137 "tokens.json corrupt or wrong schema version; could not rename aside (continuing with empty store)"
138 );
139 } else {
140 tracing::warn!(
141 path = %path.display(),
142 backup = %backup.display(),
143 "tokens.json corrupt or wrong schema version; renamed aside, starting fresh"
144 );
145 }
146 Snapshot::default()
147 }
148 },
149 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
150 tracing::info!(path = %path.display(), "no tokens.json found; starting fresh");
151 Snapshot::default()
152 }
153 Err(e) => {
154 tracing::warn!(
155 path = %path.display(),
156 error = %e,
157 "could not read tokens.json (transient I/O error?); starting fresh"
158 );
159 Snapshot::default()
160 }
161 };
162
163 if !snapshot.client_id_hash.is_empty() && snapshot.client_id_hash != client_id_hash {
164 tracing::warn!(
165 "tokens.json client_id_hash mismatch; wiping (oauth.toml was likely regenerated)"
166 );
167 snapshot = Snapshot::default();
168 }
169 snapshot.client_id_hash = client_id_hash.clone();
170 snapshot.version = SCHEMA_VERSION;
171
172 let now = unix_now();
173 snapshot.access.retain(|r| r.expires_at > now);
174 snapshot.refresh.retain(|r| r.expires_at > now);
175
176 let mut index = Index::default();
177 for r in &snapshot.access {
178 index.access_by_hash.insert(r.token_hash.clone(), r.clone());
179 }
180 for r in &snapshot.refresh {
181 index
182 .refresh_by_hash
183 .insert(r.token_hash.clone(), r.clone());
184 }
185 for c in &snapshot.revoked_chains {
186 index.revoked.insert(c.clone());
187 }
188
189 Ok(Self {
190 inner: Arc::new(Inner {
191 state: RwLock::new(index),
192 path,
193 access_ttl,
194 refresh_ttl,
195 client_id_hash,
196 }),
197 })
198 }
199
200 pub async fn mint_pair(&self, chain_id: Option<ChainId>) -> anyhow::Result<MintedPair> {
206 let chain_id = chain_id.unwrap_or_else(opaque_id);
207 let access_token = opaque_id();
208 let refresh_token = opaque_id();
209 let now = unix_now();
210 let access_record = AccessRecord {
211 token_hash: sha256_hex(&access_token),
212 expires_at: now + self.inner.access_ttl.as_secs(),
213 chain_id: chain_id.clone(),
214 };
215 let refresh_record = RefreshRecord {
216 token_hash: sha256_hex(&refresh_token),
217 expires_at: now + self.inner.refresh_ttl.as_secs(),
218 chain_id: chain_id.clone(),
219 consumed_at: None,
220 };
221
222 {
223 let mut idx = self.inner.state.write().await;
224 idx.access_by_hash
225 .insert(access_record.token_hash.clone(), access_record);
226 idx.refresh_by_hash
227 .insert(refresh_record.token_hash.clone(), refresh_record);
228 self.persist_locked(&idx);
229 }
230
231 Ok(MintedPair {
232 access_token,
233 refresh_token,
234 access_ttl: self.inner.access_ttl,
235 refresh_ttl: self.inner.refresh_ttl,
236 chain_id,
237 })
238 }
239
240 pub async fn validate_access(&self, raw: &str) -> bool {
243 let hash = sha256_hex(raw);
244 let idx = self.inner.state.read().await;
245 let Some(record) = idx.access_by_hash.get(&hash) else {
246 return false;
247 };
248 if record.expires_at <= unix_now() {
249 return false;
250 }
251 if idx.revoked.contains(&record.chain_id) {
252 return false;
253 }
254 true
255 }
256
257 pub async fn consume_refresh(&self, raw: &str) -> Result<ChainId, RefreshError> {
265 let hash = sha256_hex(raw);
266 let mut idx = self.inner.state.write().await;
267 let now = unix_now();
268
269 let (chain_id, expires_at, consumed_at) = match idx.refresh_by_hash.get(&hash) {
273 Some(r) => (r.chain_id.clone(), r.expires_at, r.consumed_at),
274 None => return Err(RefreshError::Unknown),
275 };
276 if idx.revoked.contains(&chain_id) {
277 return Err(RefreshError::Unknown);
278 }
279 if expires_at <= now {
280 return Err(RefreshError::Expired);
281 }
282 if consumed_at.is_some() {
283 return Err(RefreshError::Replayed(chain_id));
284 }
285
286 idx.refresh_by_hash.get_mut(&hash).unwrap().consumed_at = Some(now);
288 self.persist_locked(&idx);
289 Ok(chain_id)
290 }
291
292 pub async fn revoke_chain(&self, chain_id: ChainId) {
295 let mut idx = self.inner.state.write().await;
296 idx.revoked.insert(chain_id);
297 self.persist_locked(&idx);
298 }
299
300 fn persist_locked(&self, idx: &Index) {
303 let mut access: Vec<_> = idx.access_by_hash.values().cloned().collect();
304 let mut refresh: Vec<_> = idx.refresh_by_hash.values().cloned().collect();
305 access.sort_by(|a, b| a.token_hash.cmp(&b.token_hash));
306 refresh.sort_by(|a, b| a.token_hash.cmp(&b.token_hash));
307 let mut revoked: Vec<_> = idx.revoked.iter().cloned().collect();
308 revoked.sort();
309
310 let snap = Snapshot {
311 version: SCHEMA_VERSION,
312 client_id_hash: self.inner.client_id_hash.clone(),
313 access,
314 refresh,
315 revoked_chains: revoked,
316 };
317 let bytes = match serde_json::to_vec_pretty(&snap) {
318 Ok(b) => b,
319 Err(e) => {
320 tracing::error!(error = %e, "could not serialize token snapshot; in-memory state preserved");
321 return;
322 }
323 };
324 if let Err(e) = atomic_write_0600(&self.inner.path, &bytes) {
325 tracing::error!(
326 path = %self.inner.path.display(),
327 error = %e,
328 "could not persist tokens.json; in-memory state preserved"
329 );
330 }
331 }
332}
333
334fn opaque_id() -> String {
335 format!("{:032x}", rand::random::<u128>())
336}
337
338fn atomic_write_0600(path: &Path, bytes: &[u8]) -> std::io::Result<()> {
339 if let Some(parent) = path.parent() {
340 std::fs::create_dir_all(parent)?;
341 }
342 let tmp = path.with_extension(format!("json.tmp.{:08x}", rand::random::<u32>()));
343 std::fs::write(&tmp, bytes)?;
344 #[cfg(unix)]
345 {
346 use std::os::unix::fs::PermissionsExt;
347 std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600))?;
348 }
349 std::fs::rename(&tmp, path)?;
350 Ok(())
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use tempfile::TempDir;
357
358 fn fresh(dir: &TempDir) -> TokenStore {
359 TokenStore::load(
360 dir.path().join("tokens.json"),
361 "client-id-1",
362 Duration::from_secs(60),
363 Duration::from_secs(600),
364 )
365 .unwrap()
366 }
367
368 #[test]
369 fn load_treats_missing_file_as_empty() {
370 let dir = TempDir::new().unwrap();
371 let store = fresh(&dir);
372 let _ = store;
374 }
375
376 #[test]
377 fn load_renames_corrupt_file_aside_and_starts_fresh() {
378 let dir = TempDir::new().unwrap();
379 let path = dir.path().join("tokens.json");
380 std::fs::write(&path, b"this is not valid json").unwrap();
381 let _ = TokenStore::load(
382 path.clone(),
383 "client-id-1",
384 Duration::from_secs(60),
385 Duration::from_secs(600),
386 )
387 .unwrap();
388 assert!(
389 !path.exists(),
390 "original corrupt file should have been moved aside"
391 );
392 let entries: Vec<_> = std::fs::read_dir(dir.path())
393 .unwrap()
394 .map(|e| e.unwrap().file_name().to_string_lossy().to_string())
395 .collect();
396 assert!(
397 entries
398 .iter()
399 .any(|name| name.starts_with("tokens.json.broken-")),
400 "expected backup file, got {entries:?}"
401 );
402 }
403
404 #[test]
405 fn load_wipes_store_on_client_id_hash_mismatch() {
406 let dir = TempDir::new().unwrap();
407 let path = dir.path().join("tokens.json");
408 let snap = Snapshot {
409 version: SCHEMA_VERSION,
410 client_id_hash: sha256_hex("OLD-CLIENT-ID"),
411 access: vec![AccessRecord {
412 token_hash: "deadbeef".into(),
413 expires_at: unix_now() + 9999,
414 chain_id: "chain-x".into(),
415 }],
416 refresh: vec![],
417 revoked_chains: vec![],
418 };
419 std::fs::write(&path, serde_json::to_vec(&snap).unwrap()).unwrap();
420
421 let store = TokenStore::load(
422 path,
423 "NEW-CLIENT-ID",
424 Duration::from_secs(60),
425 Duration::from_secs(600),
426 )
427 .unwrap();
428 let idx = store.inner.state.try_read().unwrap();
429 assert!(
430 idx.access_by_hash.is_empty(),
431 "tokens issued under old client_id must be wiped"
432 );
433 }
434
435 #[test]
436 fn load_drops_expired_access_and_refresh_records() {
437 let dir = TempDir::new().unwrap();
438 let path = dir.path().join("tokens.json");
439 let now = unix_now();
440 let snap = Snapshot {
441 version: SCHEMA_VERSION,
442 client_id_hash: sha256_hex("client-id-1"),
443 access: vec![
444 AccessRecord {
445 token_hash: "fresh-access".into(),
446 expires_at: now + 600,
447 chain_id: "c1".into(),
448 },
449 AccessRecord {
450 token_hash: "stale-access".into(),
451 expires_at: now - 1,
452 chain_id: "c1".into(),
453 },
454 ],
455 refresh: vec![
456 RefreshRecord {
457 token_hash: "fresh-refresh".into(),
458 expires_at: now + 600,
459 chain_id: "c1".into(),
460 consumed_at: None,
461 },
462 RefreshRecord {
463 token_hash: "stale-refresh".into(),
464 expires_at: now - 1,
465 chain_id: "c1".into(),
466 consumed_at: None,
467 },
468 ],
469 revoked_chains: vec![],
470 };
471 std::fs::write(&path, serde_json::to_vec(&snap).unwrap()).unwrap();
472 let store = fresh_with_path(&path);
473 let idx = store.inner.state.try_read().unwrap();
474 assert!(idx.access_by_hash.contains_key("fresh-access"));
475 assert!(!idx.access_by_hash.contains_key("stale-access"));
476 assert!(idx.refresh_by_hash.contains_key("fresh-refresh"));
477 assert!(!idx.refresh_by_hash.contains_key("stale-refresh"));
478 }
479
480 fn fresh_with_path(path: &Path) -> TokenStore {
481 TokenStore::load(
482 path.to_path_buf(),
483 "client-id-1",
484 Duration::from_secs(60),
485 Duration::from_secs(600),
486 )
487 .unwrap()
488 }
489
490 #[tokio::test]
491 async fn mint_pair_returns_two_distinct_tokens() {
492 let dir = TempDir::new().unwrap();
493 let store = fresh(&dir);
494 let pair = store.mint_pair(None).await.unwrap();
495 assert_ne!(pair.access_token, pair.refresh_token);
496 assert!(pair.access_token.len() >= 32);
497 assert!(pair.refresh_token.len() >= 32);
498 assert!(!pair.chain_id.is_empty());
499 }
500
501 #[tokio::test]
502 async fn mint_pair_persists_to_disk_with_mode_0600() {
503 let dir = TempDir::new().unwrap();
504 let path = dir.path().join("tokens.json");
505 let store = fresh_with_path(&path);
506 let _ = store.mint_pair(None).await.unwrap();
507 assert!(path.exists(), "mint_pair must persist to disk");
508 #[cfg(unix)]
509 {
510 use std::os::unix::fs::PermissionsExt;
511 let mode = std::fs::metadata(&path).unwrap().permissions().mode() & 0o777;
512 assert_eq!(mode, 0o600);
513 }
514 }
515
516 #[tokio::test]
517 async fn tokens_at_rest_are_hashed_not_plaintext() {
518 let dir = TempDir::new().unwrap();
519 let path = dir.path().join("tokens.json");
520 let store = fresh_with_path(&path);
521 let pair = store.mint_pair(None).await.unwrap();
522 let bytes = std::fs::read(&path).unwrap();
523 let body = String::from_utf8(bytes).unwrap();
524 assert!(
525 !body.contains(&pair.access_token),
526 "raw access token must not appear on disk"
527 );
528 assert!(
529 !body.contains(&pair.refresh_token),
530 "raw refresh token must not appear on disk"
531 );
532 assert!(
533 body.contains(&sha256_hex(&pair.access_token)),
534 "expected access-token hash in file"
535 );
536 }
537
538 #[tokio::test]
539 async fn mint_pair_with_existing_chain_id_keeps_chain() {
540 let dir = TempDir::new().unwrap();
541 let store = fresh(&dir);
542 let first = store.mint_pair(None).await.unwrap();
543 let second = store.mint_pair(Some(first.chain_id.clone())).await.unwrap();
544 assert_eq!(first.chain_id, second.chain_id);
545 assert_ne!(first.access_token, second.access_token);
546 }
547
548 #[tokio::test]
549 async fn validate_access_returns_true_for_freshly_minted_token() {
550 let dir = TempDir::new().unwrap();
551 let store = fresh(&dir);
552 let pair = store.mint_pair(None).await.unwrap();
553 assert!(store.validate_access(&pair.access_token).await);
554 }
555
556 #[tokio::test]
557 async fn validate_access_returns_false_for_unknown_token() {
558 let dir = TempDir::new().unwrap();
559 let store = fresh(&dir);
560 assert!(!store.validate_access("not-a-real-token").await);
561 }
562
563 #[tokio::test]
564 async fn validate_access_returns_false_after_expiry() {
565 let dir = TempDir::new().unwrap();
566 let store = TokenStore::load(
567 dir.path().join("tokens.json"),
568 "client-id-1",
569 Duration::from_secs(0), Duration::from_secs(600),
571 )
572 .unwrap();
573 let pair = store.mint_pair(None).await.unwrap();
574 tokio::time::sleep(Duration::from_secs(1)).await;
576 assert!(!store.validate_access(&pair.access_token).await);
577 }
578
579 #[tokio::test]
580 async fn consume_refresh_returns_chain_id_on_first_use() {
581 let dir = TempDir::new().unwrap();
582 let store = fresh(&dir);
583 let pair = store.mint_pair(None).await.unwrap();
584 let chain = store.consume_refresh(&pair.refresh_token).await.unwrap();
585 assert_eq!(chain, pair.chain_id);
586 }
587
588 #[tokio::test]
589 async fn consume_refresh_replay_returns_replayed_with_chain_id() {
590 let dir = TempDir::new().unwrap();
591 let store = fresh(&dir);
592 let pair = store.mint_pair(None).await.unwrap();
593 let _first = store.consume_refresh(&pair.refresh_token).await.unwrap();
594 let err = store
595 .consume_refresh(&pair.refresh_token)
596 .await
597 .unwrap_err();
598 match err {
599 RefreshError::Replayed(chain) => assert_eq!(chain, pair.chain_id),
600 other => panic!("expected Replayed, got {other:?}"),
601 }
602 }
603
604 #[tokio::test]
605 async fn consume_refresh_unknown_returns_unknown() {
606 let dir = TempDir::new().unwrap();
607 let store = fresh(&dir);
608 let err = store.consume_refresh("never-issued").await.unwrap_err();
609 assert!(matches!(err, RefreshError::Unknown), "got {err:?}");
610 }
611
612 #[tokio::test]
613 async fn consume_refresh_expired_returns_expired() {
614 let dir = TempDir::new().unwrap();
615 let store = TokenStore::load(
616 dir.path().join("tokens.json"),
617 "client-id-1",
618 Duration::from_secs(60),
619 Duration::from_secs(0), )
621 .unwrap();
622 let pair = store.mint_pair(None).await.unwrap();
623 tokio::time::sleep(Duration::from_secs(1)).await;
624 let err = store
625 .consume_refresh(&pair.refresh_token)
626 .await
627 .unwrap_err();
628 assert!(matches!(err, RefreshError::Expired), "got {err:?}");
629 }
630
631 #[tokio::test]
632 async fn revoke_chain_invalidates_all_access_tokens_in_chain() {
633 let dir = TempDir::new().unwrap();
634 let store = fresh(&dir);
635 let pair = store.mint_pair(None).await.unwrap();
636 assert!(store.validate_access(&pair.access_token).await);
637 store.revoke_chain(pair.chain_id.clone()).await;
638 assert!(!store.validate_access(&pair.access_token).await);
639 }
640
641 #[tokio::test]
642 async fn revoke_chain_invalidates_subsequent_refresh_consumption() {
643 let dir = TempDir::new().unwrap();
644 let store = fresh(&dir);
645 let pair = store.mint_pair(None).await.unwrap();
646 store.revoke_chain(pair.chain_id.clone()).await;
647 let err = store
648 .consume_refresh(&pair.refresh_token)
649 .await
650 .unwrap_err();
651 assert!(
655 matches!(err, RefreshError::Unknown),
656 "revoked-chain refresh should look Unknown to callers; got {err:?}"
657 );
658 }
659
660 #[tokio::test]
661 async fn revoke_chain_persists_to_disk() {
662 let dir = TempDir::new().unwrap();
663 let path = dir.path().join("tokens.json");
664 let store = fresh_with_path(&path);
665 let pair = store.mint_pair(None).await.unwrap();
666 store.revoke_chain(pair.chain_id.clone()).await;
667 drop(store);
668 let store2 = fresh_with_path(&path);
669 assert!(!store2.validate_access(&pair.access_token).await);
670 }
671
672 #[tokio::test]
673 async fn tokens_survive_oauth_state_recreation() {
674 use crate::oauth::{OAuthConfig, OAuthState};
675 let dir = tempfile::TempDir::new().unwrap();
676 let tokens_path = dir.path().join("tokens.json");
677 let config = OAuthConfig {
678 client_id: "test-id".into(),
679 client_secret: "test-secret".into(),
680 issuer: "https://example.test".into(),
681 access_token_ttl_secs: None,
682 refresh_token_ttl_secs: None,
683 };
684
685 let access_token = {
686 let state_a =
687 OAuthState::with_tokens_path(config.clone(), tokens_path.clone()).unwrap();
688 let pair = state_a.token_store().mint_pair(None).await.unwrap();
689 assert!(state_a.validate_token(&pair.access_token).await);
690 pair.access_token
691 };
692
693 let state_b = OAuthState::with_tokens_path(config, tokens_path).unwrap();
694
695 assert!(
696 state_b.validate_token(&access_token).await,
697 "access token issued before restart must still validate after restart"
698 );
699 }
700}