1#![allow(dead_code)]
2use std::time::{Duration, SystemTime, UNIX_EPOCH};
48
49use redis::{AsyncCommands, RedisResult, aio::ConnectionManager};
50use serde::{Deserialize, Serialize};
51use thiserror::Error;
52use tokio::time::sleep;
53use tracing::{debug, instrument, warn};
54
55use crate::token::{AccessToken, Auth, Error as TokenError, KfClient};
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
59struct CachedToken {
60 access_token: String,
61 expires_at: i64,
63}
64
65#[derive(Debug, Error)]
67pub enum TokenCacheError {
68 #[error("redis error: {0}")]
69 Redis(#[from] redis::RedisError),
70
71 #[error("token fetch error: {0}")]
72 Fetch(#[from] TokenError),
73
74 #[error("serialization error: {0}")]
75 Serde(#[from] serde_json::Error),
76
77 #[error("time error")]
78 Time,
79}
80
81#[derive(Clone)]
83pub struct TokenManager {
84 redis: ConnectionManager,
86 namespace: String,
88 key_override: Option<String>,
90 client: KfClient,
92 auth: Auth,
94 refresh_ahead_secs: u32,
96 safety_margin_secs: u32,
98 lock_ttl_secs: u32,
100 max_wait_secs: u32,
102}
103
104impl TokenManager {
105 pub fn new(redis: ConnectionManager, client: KfClient, auth: Auth) -> Self {
107 Self {
108 redis,
109 namespace: "wxkefu:token".to_string(),
110 key_override: None,
111 client,
112 auth,
113 refresh_ahead_secs: 300, safety_margin_secs: 120, lock_ttl_secs: 30, max_wait_secs: 5, }
118 }
119
120 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
122 self.namespace = namespace.into();
123 self
124 }
125
126 pub fn with_key_override(mut self, key: impl Into<String>) -> Self {
129 self.key_override = Some(key.into());
130 self
131 }
132
133 pub fn with_refresh_ahead(mut self, secs: u32) -> Self {
135 self.refresh_ahead_secs = secs;
136 self
137 }
138
139 pub fn with_safety_margin(mut self, secs: u32) -> Self {
141 self.safety_margin_secs = secs;
142 self
143 }
144
145 pub fn with_lock_ttl(mut self, secs: u32) -> Self {
147 self.lock_ttl_secs = secs;
148 self
149 }
150
151 pub fn with_max_wait(mut self, secs: u32) -> Self {
153 self.max_wait_secs = secs;
154 self
155 }
156
157 #[instrument(level = "debug", skip(self))]
165 pub async fn get_access_token(&mut self) -> Result<String, TokenCacheError> {
166 let key = self.token_key();
167 if let Some(ct) = self.read_cached_token(&key).await? {
168 let now = epoch()?;
169 if ct.expires_at > now {
170 let remaining = ct.expires_at - now;
172 debug!("token valid, remaining={}s", remaining);
173
174 if remaining <= self.refresh_ahead_secs as i64 {
176 debug!("token is close to expiry; attempting refresh-ahead");
177 self.try_refresh_ahead(&key).await?;
178 }
179
180 return Ok(ct.access_token);
181 } else {
182 debug!("token expired locally; will refresh synchronously");
183 }
184 } else {
185 debug!("no token found in cache; will refresh synchronously");
186 }
187
188 self.refresh_with_lock(&key).await
190 }
191
192 pub async fn force_refresh(&mut self) -> Result<String, TokenCacheError> {
194 let key = self.token_key();
195 self.refresh_with_lock(&key).await
196 }
197
198 pub async fn invalidate(&mut self) -> Result<(), TokenCacheError> {
200 let key = self.token_key();
201 let _: () = self.redis.del(key).await?;
202 Ok(())
203 }
204
205 async fn read_cached_token(
207 &mut self,
208 key: &str,
209 ) -> Result<Option<CachedToken>, TokenCacheError> {
210 let raw: Option<String> = self.redis.get(key).await?;
211 if let Some(s) = raw {
212 let ct: CachedToken = serde_json::from_str(&s)?;
213 Ok(Some(ct))
214 } else {
215 Ok(None)
216 }
217 }
218
219 fn token_key(&self) -> String {
223 if let Some(ref k) = self.key_override {
224 return k.clone();
225 }
226 let ident = match &self.auth {
227 Auth::OfficialAccount { appid, .. } => format!("oa:{}", appid),
228 Auth::WeCom { corp_id, .. } => format!("wecom:{}", corp_id),
229 };
230 format!("{}:{}", self.namespace, ident)
231 }
232
233 fn lock_key(&self) -> String {
235 format!("{}:lock", self.token_key())
236 }
237
238 async fn refresh_with_lock(&mut self, key: &str) -> Result<String, TokenCacheError> {
244 let lock_key = self.lock_key();
245
246 if self.try_acquire_lock(&lock_key).await? {
247 debug!("lock acquired; fetching upstream token");
248 match self.fetch_and_store(key).await {
249 Ok(ct) => Ok(ct.access_token),
250 Err(e) => {
251 warn!("fetch_and_store failed: {e}");
252 Err(e)
254 }
255 }
256 } else {
257 debug!("lock held by another worker; waiting and re-checking cache");
258 let start = epoch()?;
259 let max_wait = self.max_wait_secs as i64;
260 let mut attempt = 0;
261
262 loop {
263 if let Some(ct) = self.read_cached_token(key).await? {
264 let now = epoch()?;
265 if ct.expires_at > now {
266 debug!("another worker populated token; returning cached value");
267 return Ok(ct.access_token);
268 }
269 }
270
271 let now = epoch()?;
272 if now - start >= max_wait {
273 warn!(
274 "waited {}s for token, still unavailable; attempting to acquire lock again",
275 max_wait
276 );
277 if self.try_acquire_lock(&lock_key).await? {
279 debug!("lock acquired on second attempt; fetching upstream token");
280 let ct = self.fetch_and_store(key).await?;
281 return Ok(ct.access_token);
282 } else {
283 return Err(TokenCacheError::Fetch(TokenError::Wx {
284 code: 40001,
285 message:
286 "timeout waiting for token; lock held by another worker; try again"
287 .to_string(),
288 }));
289 }
290 }
291
292 attempt += 1;
293 let sleep_ms = 100 + ((attempt * 37) % 200); sleep(Duration::from_millis(sleep_ms as u64)).await;
295 }
296 }
297 }
298
299 async fn try_refresh_ahead(&mut self, key: &str) -> Result<(), TokenCacheError> {
305 let lock_key = self.lock_key();
306 if self.try_acquire_lock(&lock_key).await? {
307 let redis = self.redis.clone();
308 let client = self.client.clone();
309 let auth = self.auth.clone();
310 let key = key.to_string();
311 let safety_margin = self.safety_margin_secs;
312
313 tokio::spawn(async move {
314 if let Err(e) = refresh_task(redis, client, auth, &key, safety_margin).await {
315 warn!("refresh-ahead task failed: {e}");
316 } else {
317 debug!("refresh-ahead task completed");
318 }
319 });
321 } else {
322 debug!("refresh-ahead skipped; lock is held by another worker");
323 }
324 Ok(())
325 }
326
327 async fn try_acquire_lock(&mut self, lock_key: &str) -> Result<bool, TokenCacheError> {
332 let ttl = self.lock_ttl_secs;
333 let val = lock_value();
334
335 let acquired: RedisResult<Option<String>> = redis::cmd("SET")
337 .arg(lock_key)
338 .arg(val)
339 .arg("NX")
340 .arg("EX")
341 .arg(ttl)
342 .query_async(&mut self.redis)
343 .await;
344
345 match acquired {
346 Ok(Some(_)) => Ok(true),
347 Ok(None) => Ok(false), Err(e) => Err(TokenCacheError::Redis(e)),
349 }
350 }
351
352 async fn fetch_and_store(&mut self, key: &str) -> Result<CachedToken, TokenCacheError> {
354 let resp: AccessToken = self.client.get_access_token(&self.auth).await?;
356 let ttl = compute_ttl(resp.expires_in, self.safety_margin_secs);
358 let now = epoch()?;
359 let ct = CachedToken {
360 access_token: resp.access_token,
361 expires_at: now + ttl as i64,
362 };
363
364 let json = serde_json::to_string(&ct)?;
366 let _: () = redis::pipe()
367 .cmd("SET")
368 .arg(key)
369 .arg(&json)
370 .arg("EX")
371 .arg(ttl)
372 .ignore()
373 .query_async(&mut self.redis)
374 .await?;
375
376 Ok(ct)
377 }
378}
379
380#[instrument(level = "debug", skip(redis, client, auth))]
382async fn refresh_task(
383 mut redis: ConnectionManager,
384 client: KfClient,
385 auth: Auth,
386 key: &str,
387 safety_margin_secs: u32,
388) -> Result<(), TokenCacheError> {
389 let resp: AccessToken = client.get_access_token(&auth).await?;
390 let ttl = compute_ttl(resp.expires_in, safety_margin_secs);
391 let now = epoch()?;
392 let ct = CachedToken {
393 access_token: resp.access_token,
394 expires_at: now + ttl as i64,
395 };
396 let json = serde_json::to_string(&ct)?;
397
398 let _: () = redis::pipe()
399 .cmd("SET")
400 .arg(key)
401 .arg(&json)
402 .arg("EX")
403 .arg(ttl)
404 .ignore()
405 .query_async(&mut redis)
406 .await?;
407
408 Ok(())
409}
410
411fn compute_ttl(expires_in: u32, safety_margin: u32) -> u32 {
413 let min_ttl = 60; let ttl = expires_in.saturating_sub(safety_margin);
415 ttl.max(min_ttl)
416}
417
418fn epoch() -> Result<i64, TokenCacheError> {
420 let now = SystemTime::now()
421 .duration_since(UNIX_EPOCH)
422 .map_err(|_| TokenCacheError::Time)?;
423 Ok(now.as_secs() as i64)
424}
425
426fn redact_id(id: &str) -> String {
428 if id.len() <= 4 {
429 format!("{}***", id)
430 } else {
431 format!("{}***{}", &id[..2], &id[id.len().saturating_sub(2)..])
432 }
433}
434
435fn lock_value() -> String {
437 let now = SystemTime::now()
438 .duration_since(UNIX_EPOCH)
439 .map(|d| d.as_nanos())
440 .unwrap_or(0);
441 format!("ts-{}", now)
442}
443
444pub use TokenCacheError as Error;
446pub use TokenManager as RedisTokenManager;