Skip to main content

stack_auth/
token_store.rs

1//! Pluggable persistence for service tokens.
2//!
3//! [`AutoRefresh`](crate::auto_refresh::AutoRefresh) consults a [`TokenStore`]
4//! on cold start (no in-memory token) and writes back after every successful
5//! refresh or initial auth. This lets strategies share a service-token cache
6//! across short-lived processes — HTTP-only cookies in Edge Functions, KV
7//! stores in Cloudflare Workers, Redis in multi-instance Node services, or a
8//! shared cache across the CipherStash Proxy's worker pool.
9//!
10//! Wire a store onto a strategy via the builder:
11//!
12//! ```no_run
13//! use std::sync::Arc;
14//! use stack_auth::{AccessKey, AccessKeyStrategy, InMemoryTokenStore};
15//! use cts_common::Region;
16//!
17//! let region = Region::aws("ap-southeast-2").unwrap();
18//! let key: AccessKey = "CSAKmyKeyId.myKeySecret".parse().unwrap();
19//! let store = Arc::new(InMemoryTokenStore::new());
20//! let strategy = AccessKeyStrategy::builder(region, key)
21//!     .with_token_store(store)
22//!     .build()
23//!     .unwrap();
24//! ```
25//!
26//! For cookie-style storage where the load/save logic lives in the calling
27//! request handler, use [`TokenStoreFn::new`] with two async closures
28//! that deal in JSON strings:
29//!
30//! ```no_run
31//! use std::sync::Arc;
32//! use stack_auth::TokenStoreFn;
33//!
34//! let store = Arc::new(TokenStoreFn::new(
35//!     || async { /* read cookie */ None::<String> },
36//!     |_json: String| async move { /* write Set-Cookie header */ },
37//! ));
38//! ```
39//!
40//! See also: [`AuthStrategyFn`](crate::AuthStrategyFn) — the closure-shaped
41//! impl of the *acquisition* layer ([`AuthStrategy`](crate::AuthStrategy)).
42//! `TokenStoreFn` plugs into an existing strategy as a persistence backend;
43//! `AuthStrategyFn` replaces the whole acquisition pipeline (used by FFI
44//! consumers like `protect-ffi` that source tokens from JS).
45
46use std::future::Future;
47use std::sync::Arc;
48
49use tokio::sync::Mutex;
50use zeroize::Zeroizing;
51
52use crate::Token;
53
54/// Pluggable persistent cache for service tokens.
55///
56/// Implementations are consulted by `AutoRefresh` whenever it has no
57/// in-memory token (cold start), and written to after every successful
58/// refresh or initial authentication. Implementations should treat both
59/// methods as best-effort — `load` returns [`None`] for "no token, or load
60/// failed", `save` is fire-and-forget. The `AutoRefresh` state machine
61/// always validates freshness via [`Token::is_usable`] / [`Token::is_expired`]
62/// before returning a loaded token, so implementations don't need to.
63///
64/// On native targets the trait carries `Send + Sync` bounds so the store can
65/// be shared across `tokio::spawn` background work. On wasm32 the bounds are
66/// dropped — edge runtimes are single-threaded.
67#[cfg(not(target_arch = "wasm32"))]
68pub trait TokenStore: Send + Sync {
69    /// Load the most recently saved token, or `None` if none has been stored
70    /// (or the load failed). Errors are swallowed — the calling state machine
71    /// falls back to fresh authentication when this returns `None`.
72    fn load(&self) -> impl Future<Output = Option<Token>> + Send;
73
74    /// Persist a token after a successful refresh or initial authentication.
75    /// Best-effort — implementations should log on failure rather than
76    /// returning an error.
77    fn save(&self, token: &Token) -> impl Future<Output = ()> + Send;
78}
79
80#[cfg(target_arch = "wasm32")]
81pub trait TokenStore {
82    fn load(&self) -> impl Future<Output = Option<Token>>;
83    fn save(&self, token: &Token) -> impl Future<Output = ()>;
84}
85
86/// Forward [`TokenStore`] through `Arc` so one store can back many strategy
87/// instances (Edge Function pool, CipherStash Proxy worker pool, etc).
88#[cfg(not(target_arch = "wasm32"))]
89impl<T: TokenStore + ?Sized> TokenStore for Arc<T> {
90    fn load(&self) -> impl Future<Output = Option<Token>> + Send {
91        (**self).load()
92    }
93
94    fn save(&self, token: &Token) -> impl Future<Output = ()> + Send {
95        (**self).save(token)
96    }
97}
98
99#[cfg(target_arch = "wasm32")]
100impl<T: TokenStore + ?Sized> TokenStore for Arc<T> {
101    fn load(&self) -> impl Future<Output = Option<Token>> {
102        (**self).load()
103    }
104
105    fn save(&self, token: &Token) -> impl Future<Output = ()> {
106        (**self).save(token)
107    }
108}
109
110/// Zero-sized default for `AutoRefresh<R, S = NoStore>` — `load` returns
111/// `None`, `save` is a no-op. Carries no per-instance cost.
112#[derive(Debug, Default, Clone, Copy)]
113pub struct NoStore;
114
115impl TokenStore for NoStore {
116    async fn load(&self) -> Option<Token> {
117        None
118    }
119
120    async fn save(&self, _token: &Token) {}
121}
122
123/// In-process token store. Useful for tests and as a shared cache across
124/// multiple strategy instances in the same process (e.g. a worker pool).
125///
126/// Internally stores the JSON-serialised form of the token wrapped in
127/// [`Zeroizing`] so the buffer is wiped on overwrite and on store drop. The
128/// [`SecretToken`](crate::SecretToken) wrapped inside [`Token`] is
129/// [`ZeroizeOnDrop`](zeroize::ZeroizeOnDrop), so we deliberately don't clone
130/// the in-memory `Token` value — round-tripping through serde gives us a
131/// fresh `SecretToken` on each `load` without violating that invariant.
132pub struct InMemoryTokenStore {
133    state: Mutex<Option<Zeroizing<String>>>,
134}
135
136impl InMemoryTokenStore {
137    /// Create a new, empty in-memory token store.
138    pub fn new() -> Self {
139        Self {
140            state: Mutex::new(None),
141        }
142    }
143}
144
145impl Default for InMemoryTokenStore {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151impl TokenStore for InMemoryTokenStore {
152    async fn load(&self) -> Option<Token> {
153        let guard = self.state.lock().await;
154        let json = guard.as_ref()?;
155        serde_json::from_str(json).ok()
156    }
157
158    async fn save(&self, token: &Token) {
159        let Ok(json) = serde_json::to_string(token) else {
160            tracing::warn!("InMemoryTokenStore: failed to serialise token");
161            return;
162        };
163        let mut guard = self.state.lock().await;
164        *guard = Some(Zeroizing::new(json));
165    }
166}
167
168/// [`TokenStore`] backed by user-supplied `load` and `save` async closures.
169///
170/// This is the *persistence layer* primitive — it plugs into an existing
171/// strategy (e.g. [`AccessKeyStrategy`](crate::AccessKeyStrategy)) so that
172/// strategy can share its service-token cache across processes. For wiring
173/// in a complete *acquisition pipeline* (e.g. a JS-defined strategy across
174/// an FFI boundary), use [`AuthStrategyFn`](crate::AuthStrategyFn) instead.
175///
176/// Closures deal in JSON strings — the on-the-wire form of [`Token`] — not
177/// the `Token` type itself. This keeps the caller's signatures free of
178/// `stack-auth` internals and matches the natural shape of common storage
179/// substrates: a cookie value, a KV blob, a Redis string.
180///
181/// The closure return types are generic so async blocks / `async ||`
182/// closures / `async fn` adapters all compose without boxing.
183///
184/// **Secret-material handling.** The JSON string passed to the `save`
185/// closure contains the bearer token verbatim (via
186/// [`SecretToken`](crate::SecretToken)'s `#[serde(transparent)]` impl). Once
187/// the value crosses into the user's closure, `stack-auth` has no control
188/// over zeroize semantics — implementations should treat the input as
189/// secret material and clear any local copies promptly. `load` wraps the
190/// returned string in [`Zeroizing`] internally, so the buffer is wiped
191/// after deserialisation. End-to-end protection at rest (e.g. encrypting
192/// the value before it ever leaves the worker) is tracked as a future
193/// `EncryptedTokenStore` decorator.
194pub struct TokenStoreFn<L, S> {
195    load: L,
196    save: S,
197}
198
199impl<L, S> TokenStoreFn<L, S> {
200    /// Build a token store from a `load` closure (returns the stored JSON, or
201    /// `None` if nothing is cached) and a `save` closure (persists the JSON).
202    ///
203    /// See the module-level documentation for an example.
204    pub fn new(load: L, save: S) -> Self {
205        Self { load, save }
206    }
207}
208
209#[cfg(not(target_arch = "wasm32"))]
210impl<L, LF, S, SF> TokenStore for TokenStoreFn<L, S>
211where
212    L: Fn() -> LF + Send + Sync,
213    LF: Future<Output = Option<String>> + Send,
214    S: Fn(String) -> SF + Send + Sync,
215    SF: Future<Output = ()> + Send,
216{
217    async fn load(&self) -> Option<Token> {
218        let json = Zeroizing::new((self.load)().await?);
219        // Don't log the underlying serde_json error — its `Display` impl can
220        // include byte positions of unexpected tokens, leaking partial token
221        // content if the input was mid-parse when it failed.
222        serde_json::from_str(&json).ok()
223    }
224
225    async fn save(&self, token: &Token) {
226        let Ok(json) = serde_json::to_string(token) else {
227            tracing::warn!("TokenStoreFn: failed to serialise token");
228            return;
229        };
230        (self.save)(json).await;
231    }
232}
233
234#[cfg(target_arch = "wasm32")]
235impl<L, LF, S, SF> TokenStore for TokenStoreFn<L, S>
236where
237    L: Fn() -> LF,
238    LF: Future<Output = Option<String>>,
239    S: Fn(String) -> SF,
240    SF: Future<Output = ()>,
241{
242    async fn load(&self) -> Option<Token> {
243        let json = Zeroizing::new((self.load)().await?);
244        // Don't log the underlying serde_json error — its `Display` impl can
245        // include byte positions of unexpected tokens, leaking partial token
246        // content if the input was mid-parse when it failed.
247        serde_json::from_str(&json).ok()
248    }
249
250    async fn save(&self, token: &Token) {
251        let Ok(json) = serde_json::to_string(token) else {
252            tracing::warn!("TokenStoreFn: failed to serialise token");
253            return;
254        };
255        (self.save)(json).await;
256    }
257}
258
259#[cfg(test)]
260#[allow(clippy::unwrap_used)]
261mod tests {
262    use std::sync::atomic::{AtomicUsize, Ordering};
263    use std::sync::Arc;
264
265    use crate::SecretToken;
266
267    use super::*;
268
269    fn dummy_token(expires_at: u64) -> Token {
270        Token {
271            access_token: SecretToken::new("dummy-access".to_string()),
272            refresh_token: None,
273            token_type: "Bearer".to_string(),
274            expires_at,
275            region: None,
276            client_id: None,
277            device_instance_id: None,
278        }
279    }
280
281    #[tokio::test]
282    async fn in_memory_load_returns_none_when_empty() {
283        let store = InMemoryTokenStore::new();
284        assert!(
285            store.load().await.is_none(),
286            "freshly constructed store should hold no token"
287        );
288    }
289
290    #[tokio::test]
291    async fn in_memory_round_trip_preserves_expires_at() {
292        let store = InMemoryTokenStore::new();
293        store.save(&dummy_token(4_000_000_000)).await;
294        let loaded = store
295            .load()
296            .await
297            .expect("load should return the saved token");
298        assert_eq!(
299            loaded.expires_at(),
300            4_000_000_000,
301            "round-trip should preserve expires_at"
302        );
303        assert_eq!(
304            loaded.token_type(),
305            "Bearer",
306            "round-trip should preserve token_type"
307        );
308    }
309
310    #[tokio::test]
311    async fn in_memory_save_overwrites_previous() {
312        let store = InMemoryTokenStore::new();
313        store.save(&dummy_token(1_000_000_000)).await;
314        store.save(&dummy_token(2_000_000_000)).await;
315        let loaded = store.load().await.expect("store should hold a token");
316        assert_eq!(
317            loaded.expires_at(),
318            2_000_000_000,
319            "second save should replace the first"
320        );
321    }
322
323    #[tokio::test]
324    async fn callback_store_invokes_load_closure_each_call() {
325        let calls = Arc::new(AtomicUsize::new(0));
326        let calls_clone = Arc::clone(&calls);
327        let store = TokenStoreFn::new(
328            move || {
329                let calls = Arc::clone(&calls_clone);
330                async move {
331                    let n = calls.fetch_add(1, Ordering::SeqCst);
332                    if n == 0 {
333                        None
334                    } else {
335                        Some(serde_json::to_string(&dummy_token(4_000_000_000)).unwrap())
336                    }
337                }
338            },
339            |_json: String| async move {},
340        );
341
342        assert!(
343            store.load().await.is_none(),
344            "first load returns None because the closure does"
345        );
346        assert_eq!(
347            calls.load(Ordering::SeqCst),
348            1,
349            "first call should have invoked the load closure exactly once"
350        );
351
352        let loaded = store
353            .load()
354            .await
355            .expect("second load should yield a token");
356        assert_eq!(
357            loaded.expires_at(),
358            4_000_000_000,
359            "deserialised token should preserve the JSON payload's expires_at"
360        );
361        assert_eq!(
362            calls.load(Ordering::SeqCst),
363            2,
364            "second call should have invoked the load closure a second time"
365        );
366    }
367
368    #[tokio::test]
369    async fn callback_store_forwards_serialised_token_to_save_closure() {
370        let captured = Arc::new(Mutex::new(None::<String>));
371        let captured_clone = Arc::clone(&captured);
372        let store = TokenStoreFn::new(
373            || async { None },
374            move |json: String| {
375                let captured = Arc::clone(&captured_clone);
376                async move {
377                    *captured.lock().await = Some(json);
378                }
379            },
380        );
381
382        store.save(&dummy_token(4_000_000_000)).await;
383        let json = captured
384            .lock()
385            .await
386            .clone()
387            .expect("save closure should have captured the JSON");
388        assert!(
389            json.contains("\"expires_at\":4000000000"),
390            "captured JSON should encode expires_at; got: {json}"
391        );
392        assert!(
393            json.contains("\"token_type\":\"Bearer\""),
394            "captured JSON should encode token_type; got: {json}"
395        );
396    }
397
398    #[tokio::test]
399    async fn callback_store_ignores_invalid_json_on_load() {
400        let store = TokenStoreFn::new(
401            || async { Some("not valid json".to_string()) },
402            |_json: String| async move {},
403        );
404        assert!(
405            store.load().await.is_none(),
406            "invalid JSON from the load closure should be treated as cache miss"
407        );
408    }
409}