wxkefu_rs/
token_cache.rs

1#![allow(dead_code)]
2//! Redis-backed access_token manager with refresh-ahead and distributed lock.
3//!
4//! Goals:
5//! - Cache `access_token` in Redis with proper TTL to avoid rate limits.
6//! - Refresh-ahead to minimize chances of using an about-to-expire token.
7//! - Distributed lock to prevent thundering herd across multiple instances.
8//! - Safe logging (no secrets); concurrency-aware behavior.
9//!
10//! Design notes:
11//! - The manager is instantiated per credential (`Auth`) identity.
12//! - Tokens are stored as JSON with an `expires_at` epoch timestamp and Redis TTL.
13//! - Refresh-ahead triggers a background refresh when the token is close to expiry.
14//! - When the token is missing/expired, one instance acquires a lock and refreshes;
15//!   others will wait briefly and re-check the cache to avoid calling upstream too frequently.
16//!
17//! Example usage (WeCom/Kf):
18//! ```ignore
19//! use wxkefu_rs::{Auth, KfClient};
20//! use wxkefu_rs::token_cache::TokenManager;
21//! use redis::aio::ConnectionManager;
22//!
23//! #[tokio::main]
24//! async fn main() -> anyhow::Result<()> {
25//!     // Prepare Redis connection manager
26//!     let client = redis::Client::open("redis://127.0.0.1/")?;
27//!     let mut redis = ConnectionManager::new(client).await?;
28//!
29//!     // Prepare KfClient & Auth
30//!     let kf_client = KfClient::default();
31//!     let auth = Auth::WeCom {
32//!         corp_id: std::env::var("WXKF_CORP_ID")?,
33//!         corp_secret: std::env::var("WXKF_APP_SECRET")?,
34//!     };
35//!
36//!     // Create the TokenManager
37//!     let mut tm = TokenManager::new(redis, kf_client, auth).with_namespace("wxkefu:token");
38//!
39//!     // Obtain a token (cached or freshly fetched)
40//!     let token = tm.get_access_token().await?;
41//!     println!("wecom access_token (redacted len): {}", token.len());
42//!
43//!     Ok(())
44//! }
45//! ```
46
47use std::time::{Duration, SystemTime, UNIX_EPOCH};
48
49use redis::{AsyncCommands, RedisResult, aio::ConnectionManager};
50use serde::{Deserialize, Serialize};
51use thiserror::Error;
52use tokio::time::sleep;
53use tracing::{debug, instrument, warn};
54
55use crate::token::{AccessToken, Auth, Error as TokenError, KfClient};
56
57/// JSON payload stored in Redis for a cached token
58#[derive(Clone, Debug, Serialize, Deserialize)]
59struct CachedToken {
60    access_token: String,
61    /// Epoch seconds when the token should be considered expired locally
62    expires_at: i64,
63}
64
65/// Unified error for token caching/refreshing
66#[derive(Debug, Error)]
67pub enum TokenCacheError {
68    #[error("redis error: {0}")]
69    Redis(#[from] redis::RedisError),
70
71    #[error("token fetch error: {0}")]
72    Fetch(#[from] TokenError),
73
74    #[error("serialization error: {0}")]
75    Serde(#[from] serde_json::Error),
76
77    #[error("time error")]
78    Time,
79}
80
81/// Manages access_token caching in Redis with refresh-ahead and distributed lock
82#[derive(Clone)]
83pub struct TokenManager {
84    /// Redis connection manager (async)
85    redis: ConnectionManager,
86    /// Namespacing for keys, e.g. "wxkefu:token"
87    namespace: String,
88    /// Optional explicit Redis key override (when set, namespace/identity are ignored)
89    key_override: Option<String>,
90    /// HTTP client for WeChat Kf APIs
91    client: KfClient,
92    /// Identity (WeCom or OA/MP). This determines the cache key and which endpoint to call.
93    auth: Auth,
94    /// Refresh-ahead threshold in seconds (default: 300s)
95    refresh_ahead_secs: u32,
96    /// Safety margin subtracted from upstream `expires_in` (default: 120s)
97    safety_margin_secs: u32,
98    /// Distributed lock TTL in seconds (default: 30s)
99    lock_ttl_secs: u32,
100    /// Maximum total wait when another worker holds the lock (default: 5s)
101    max_wait_secs: u32,
102}
103
104impl TokenManager {
105    /// Create a new `TokenManager` with sensible defaults
106    pub fn new(redis: ConnectionManager, client: KfClient, auth: Auth) -> Self {
107        Self {
108            redis,
109            namespace: "wxkefu:token".to_string(),
110            key_override: None,
111            client,
112            auth,
113            refresh_ahead_secs: 300, // 5 minutes
114            safety_margin_secs: 120, // 2 minutes
115            lock_ttl_secs: 30,       // 30 seconds
116            max_wait_secs: 5,        // 5 seconds
117        }
118    }
119
120    /// Override the Redis key namespace
121    pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
122        self.namespace = namespace.into();
123        self
124    }
125
126    /// Override the Redis token key explicitly (full key). When set,
127    /// `token_key()` will return this value and ignore `namespace` and identity.
128    pub fn with_key_override(mut self, key: impl Into<String>) -> Self {
129        self.key_override = Some(key.into());
130        self
131    }
132
133    /// Override refresh-ahead threshold
134    pub fn with_refresh_ahead(mut self, secs: u32) -> Self {
135        self.refresh_ahead_secs = secs;
136        self
137    }
138
139    /// Override safety margin (subtracted from `expires_in`)
140    pub fn with_safety_margin(mut self, secs: u32) -> Self {
141        self.safety_margin_secs = secs;
142        self
143    }
144
145    /// Override lock TTL seconds
146    pub fn with_lock_ttl(mut self, secs: u32) -> Self {
147        self.lock_ttl_secs = secs;
148        self
149    }
150
151    /// Override max wait seconds when lock is held by another worker
152    pub fn with_max_wait(mut self, secs: u32) -> Self {
153        self.max_wait_secs = secs;
154        self
155    }
156
157    /// Get current access_token, refreshing if missing or expired.
158    ///
159    /// Behavior:
160    /// - If a valid token exists, return it immediately.
161    /// - If valid but expiring soon, attempt a background refresh (non-blocking).
162    /// - If missing/expired, acquire a distributed lock and fetch; otherwise, wait briefly
163    ///   for another worker to populate and then return.
164    #[instrument(level = "debug", skip(self))]
165    pub async fn get_access_token(&mut self) -> Result<String, TokenCacheError> {
166        let key = self.token_key();
167        if let Some(ct) = self.read_cached_token(&key).await? {
168            let now = epoch()?;
169            if ct.expires_at > now {
170                // Valid token
171                let remaining = ct.expires_at - now;
172                debug!("token valid, remaining={}s", remaining);
173
174                // Refresh-ahead if close to expiry (best-effort, non-blocking)
175                if remaining <= self.refresh_ahead_secs as i64 {
176                    debug!("token is close to expiry; attempting refresh-ahead");
177                    self.try_refresh_ahead(&key).await?;
178                }
179
180                return Ok(ct.access_token);
181            } else {
182                debug!("token expired locally; will refresh synchronously");
183            }
184        } else {
185            debug!("no token found in cache; will refresh synchronously");
186        }
187
188        // Missing or expired: refresh synchronously with distributed lock
189        self.refresh_with_lock(&key).await
190    }
191
192    /// Force refresh now (ignoring cache content), using distributed lock
193    pub async fn force_refresh(&mut self) -> Result<String, TokenCacheError> {
194        let key = self.token_key();
195        self.refresh_with_lock(&key).await
196    }
197
198    /// Explicitly invalidate the cache for this identity
199    pub async fn invalidate(&mut self) -> Result<(), TokenCacheError> {
200        let key = self.token_key();
201        let _: () = self.redis.del(key).await?;
202        Ok(())
203    }
204
205    /// Internal: read and decode token JSON from Redis
206    async fn read_cached_token(
207        &mut self,
208        key: &str,
209    ) -> Result<Option<CachedToken>, TokenCacheError> {
210        let raw: Option<String> = self.redis.get(key).await?;
211        if let Some(s) = raw {
212            let ct: CachedToken = serde_json::from_str(&s)?;
213            Ok(Some(ct))
214        } else {
215            Ok(None)
216        }
217    }
218
219    /// Internal: compute the token cache key
220    /// - If `key_override` is set, use it directly (full key)
221    /// - Otherwise, compose from `namespace` and identity
222    fn token_key(&self) -> String {
223        if let Some(ref k) = self.key_override {
224            return k.clone();
225        }
226        let ident = match &self.auth {
227            Auth::OfficialAccount { appid, .. } => format!("oa:{}", appid),
228            Auth::WeCom { corp_id, .. } => format!("wecom:{}", corp_id),
229        };
230        format!("{}:{}", self.namespace, ident)
231    }
232
233    /// Internal: compute the lock key for the given identity
234    fn lock_key(&self) -> String {
235        format!("{}:lock", self.token_key())
236    }
237
238    /// Internal: refresh token with distributed lock.
239    ///
240    /// - Try to acquire the lock via SET NX EX.
241    /// - If acquired, fetch upstream token and populate Redis.
242    /// - If not acquired, wait with jitter and re-check the cache for a limited time.
243    async fn refresh_with_lock(&mut self, key: &str) -> Result<String, TokenCacheError> {
244        let lock_key = self.lock_key();
245
246        if self.try_acquire_lock(&lock_key).await? {
247            debug!("lock acquired; fetching upstream token");
248            match self.fetch_and_store(key).await {
249                Ok(ct) => Ok(ct.access_token),
250                Err(e) => {
251                    warn!("fetch_and_store failed: {e}");
252                    // Let the lock expire; others may retry after lock TTL.
253                    Err(e)
254                }
255            }
256        } else {
257            debug!("lock held by another worker; waiting and re-checking cache");
258            let start = epoch()?;
259            let max_wait = self.max_wait_secs as i64;
260            let mut attempt = 0;
261
262            loop {
263                if let Some(ct) = self.read_cached_token(key).await? {
264                    let now = epoch()?;
265                    if ct.expires_at > now {
266                        debug!("another worker populated token; returning cached value");
267                        return Ok(ct.access_token);
268                    }
269                }
270
271                let now = epoch()?;
272                if now - start >= max_wait {
273                    warn!(
274                        "waited {}s for token, still unavailable; attempting to acquire lock again",
275                        max_wait
276                    );
277                    // As a last attempt, try to acquire the lock again
278                    if self.try_acquire_lock(&lock_key).await? {
279                        debug!("lock acquired on second attempt; fetching upstream token");
280                        let ct = self.fetch_and_store(key).await?;
281                        return Ok(ct.access_token);
282                    } else {
283                        return Err(TokenCacheError::Fetch(TokenError::Wx {
284                            code: 40001,
285                            message:
286                                "timeout waiting for token; lock held by another worker; try again"
287                                    .to_string(),
288                        }));
289                    }
290                }
291
292                attempt += 1;
293                let sleep_ms = 100 + ((attempt * 37) % 200); // lightweight jitter without extra deps
294                sleep(Duration::from_millis(sleep_ms as u64)).await;
295            }
296        }
297    }
298
299    /// Internal: best-effort, non-blocking refresh-ahead.
300    ///
301    /// - Attempt to acquire the lock (SET NX EX).
302    /// - If acquired, spawn a background task to refresh and store.
303    /// - If not acquired, do nothing (another worker is likely refreshing).
304    async fn try_refresh_ahead(&mut self, key: &str) -> Result<(), TokenCacheError> {
305        let lock_key = self.lock_key();
306        if self.try_acquire_lock(&lock_key).await? {
307            let redis = self.redis.clone();
308            let client = self.client.clone();
309            let auth = self.auth.clone();
310            let key = key.to_string();
311            let safety_margin = self.safety_margin_secs;
312
313            tokio::spawn(async move {
314                if let Err(e) = refresh_task(redis, client, auth, &key, safety_margin).await {
315                    warn!("refresh-ahead task failed: {e}");
316                } else {
317                    debug!("refresh-ahead task completed");
318                }
319                // lock expires automatically (EX)
320            });
321        } else {
322            debug!("refresh-ahead skipped; lock is held by another worker");
323        }
324        Ok(())
325    }
326
327    /// Internal: try to acquire a distributed lock with TTL.
328    ///
329    /// Uses `SET key value NX EX ttl`. We do not implement ownership-based release here;
330    /// the lock expires automatically.
331    async fn try_acquire_lock(&mut self, lock_key: &str) -> Result<bool, TokenCacheError> {
332        let ttl = self.lock_ttl_secs;
333        let val = lock_value();
334
335        // Equivalent Redis command: SET lock_key val NX EX ttl
336        let acquired: RedisResult<Option<String>> = redis::cmd("SET")
337            .arg(lock_key)
338            .arg(val)
339            .arg("NX")
340            .arg("EX")
341            .arg(ttl)
342            .query_async(&mut self.redis)
343            .await;
344
345        match acquired {
346            Ok(Some(_)) => Ok(true),
347            Ok(None) => Ok(false), // Not set → lock exists
348            Err(e) => Err(TokenCacheError::Redis(e)),
349        }
350    }
351
352    /// Internal: fetch upstream token and store to Redis with safety margin and TTL
353    async fn fetch_and_store(&mut self, key: &str) -> Result<CachedToken, TokenCacheError> {
354        // Call upstream
355        let resp: AccessToken = self.client.get_access_token(&self.auth).await?;
356        // Compute TTL with safety margin, minimum clamp
357        let ttl = compute_ttl(resp.expires_in, self.safety_margin_secs);
358        let now = epoch()?;
359        let ct = CachedToken {
360            access_token: resp.access_token,
361            expires_at: now + ttl as i64,
362        };
363
364        // Store JSON with EX TTL
365        let json = serde_json::to_string(&ct)?;
366        let _: () = redis::pipe()
367            .cmd("SET")
368            .arg(key)
369            .arg(&json)
370            .arg("EX")
371            .arg(ttl)
372            .ignore()
373            .query_async(&mut self.redis)
374            .await?;
375
376        Ok(ct)
377    }
378}
379
380/// Background refresh task for refresh-ahead
381#[instrument(level = "debug", skip(redis, client, auth))]
382async fn refresh_task(
383    mut redis: ConnectionManager,
384    client: KfClient,
385    auth: Auth,
386    key: &str,
387    safety_margin_secs: u32,
388) -> Result<(), TokenCacheError> {
389    let resp: AccessToken = client.get_access_token(&auth).await?;
390    let ttl = compute_ttl(resp.expires_in, safety_margin_secs);
391    let now = epoch()?;
392    let ct = CachedToken {
393        access_token: resp.access_token,
394        expires_at: now + ttl as i64,
395    };
396    let json = serde_json::to_string(&ct)?;
397
398    let _: () = redis::pipe()
399        .cmd("SET")
400        .arg(key)
401        .arg(&json)
402        .arg("EX")
403        .arg(ttl)
404        .ignore()
405        .query_async(&mut redis)
406        .await?;
407
408    Ok(())
409}
410
411/// Compute TTL with a safety margin and minimum clamp
412fn compute_ttl(expires_in: u32, safety_margin: u32) -> u32 {
413    let min_ttl = 60; // never store very short TTL
414    let ttl = expires_in.saturating_sub(safety_margin);
415    ttl.max(min_ttl)
416}
417
418/// Get current epoch seconds
419fn epoch() -> Result<i64, TokenCacheError> {
420    let now = SystemTime::now()
421        .duration_since(UNIX_EPOCH)
422        .map_err(|_| TokenCacheError::Time)?;
423    Ok(now.as_secs() as i64)
424}
425
426/// Redact an ID for keys/logs: keep first 2 and last 2 chars where possible
427fn redact_id(id: &str) -> String {
428    if id.len() <= 4 {
429        format!("{}***", id)
430    } else {
431        format!("{}***{}", &id[..2], &id[id.len().saturating_sub(2)..])
432    }
433}
434
435/// Generate a simple lock value string (timestamp-based)
436fn lock_value() -> String {
437    let now = SystemTime::now()
438        .duration_since(UNIX_EPOCH)
439        .map(|d| d.as_nanos())
440        .unwrap_or(0);
441    format!("ts-{}", now)
442}
443
444// Re-export for users
445pub use TokenCacheError as Error;
446pub use TokenManager as RedisTokenManager;