wechat_minapp/client/
token_storage.rs

1//! 接口调用凭据存储读取模块
2//! 默认使用内存Arc结构,可参考实现读取保存方式,比如 redis、postgresql、mysql 等。
3//! 
4//! 
5
6use super::access_token::{AccessToken, is_token_expired};
7use super::token_type::TokenType;
8use crate::Result;
9use async_trait::async_trait;
10use chrono::Utc;
11use std::sync::{
12    Arc,
13    atomic::{AtomicBool, Ordering},
14};
15use tokio::sync::{Notify, RwLock};
16use tracing::debug;
17
18
19/// 定义接口调用凭据读取存储的行为
20#[async_trait]
21pub trait TokenStorage: Send + Sync {
22    async fn token(&self) -> Result<String>;
23    async fn refresh_access_token(&self) -> Result<String>;
24    fn token_type(&self) -> Arc<dyn TokenType>;
25}
26
27/// 接口调用凭据内存存储结构
28pub struct MemoryTokenStorage {
29    access_token: Arc<RwLock<AccessToken>>,
30    refreshing: Arc<AtomicBool>,
31    notify: Arc<Notify>,
32    token_type: Arc<dyn TokenType>,
33}
34
35impl MemoryTokenStorage {
36    pub fn new(token_type: Arc<dyn TokenType>) -> Self {
37        MemoryTokenStorage {
38            access_token: Arc::new(RwLock::new(AccessToken {
39                access_token: String::new(),
40                expired_at: Utc::now(),
41            })),
42            refreshing: Arc::new(AtomicBool::new(false)),
43            notify: Arc::new(Notify::new()),
44            token_type,
45        }
46    }
47}
48
49
50/// 内存存储方式的接口调用凭据存储读取实现
51#[async_trait]
52impl TokenStorage for MemoryTokenStorage {
53    /// 获取接口调用凭据
54    async fn token(&self) -> Result<String> {
55        // 第一次检查:快速路径
56        {
57            let guard = self.access_token.read().await;
58            if !is_token_expired(&guard) {
59                return Ok(guard.access_token.clone());
60            }
61        }
62
63        // 使用CAS竞争刷新权
64        if self
65            .refreshing
66            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
67            .is_ok()
68        {
69            // 获得刷新权
70            match self.refresh_access_token().await {
71                Ok(token) => {
72                    self.refreshing.store(false, Ordering::Release);
73                    self.notify.notify_waiters();
74                    Ok(token)
75                }
76                Err(e) => {
77                    self.refreshing.store(false, Ordering::Release);
78                    self.notify.notify_waiters();
79                    Err(e)
80                }
81            }
82        } else {
83            // 等待其他线程刷新完成
84            self.notify.notified().await;
85            // 刷新完成后重新读取
86            let guard = self.access_token.read().await;
87            Ok(guard.access_token.clone())
88        }
89    }
90
91    /// 刷新接口调用凭据
92    async fn refresh_access_token(&self) -> Result<String> {
93        let mut guard = self.access_token.write().await;
94
95        if !is_token_expired(&guard) {
96            debug!("token already refreshed by another thread");
97            return Ok(guard.access_token.clone());
98        }
99
100        debug!("performing network request to refresh token");
101
102        let builder = self.token_type.token().await?;
103
104        guard.access_token = builder.access_token.clone();
105        guard.expired_at = builder.expired_at;
106
107        debug!("fresh access token: {:#?}", guard);
108
109        Ok(guard.access_token.clone())
110    }
111
112    fn token_type(&self) -> Arc<dyn TokenType> {
113        self.token_type.clone()
114    }
115}