wechat_minapp/client/
token_storage.rs

1//! 接口调用凭据存储读取模块
2//! 默认使用内存Arc结构,可参考实现读取保存方式,比如 redis、postgresql、mysql 等。
3//!
4//!
5
6use super::access_token::{AccessToken, is_token_expired};
7use super::token_type::TokenType;
8use crate::Result;
9use async_trait::async_trait;
10use chrono::Utc;
11use std::sync::{
12    Arc,
13    atomic::{AtomicBool, Ordering},
14};
15use tokio::sync::{Notify, RwLock};
16use tracing::debug;
17
18/// 定义接口调用凭据读取存储的行为
19#[async_trait]
20pub trait TokenStorage: Send + Sync {
21    async fn token(&self) -> Result<String>;
22    async fn refresh_access_token(&self) -> Result<String>;
23    fn token_type(&self) -> Arc<dyn TokenType>;
24}
25
26/// 接口调用凭据内存存储结构
27pub struct MemoryTokenStorage {
28    access_token: Arc<RwLock<AccessToken>>,
29    refreshing: Arc<AtomicBool>,
30    notify: Arc<Notify>,
31    token_type: Arc<dyn TokenType>,
32}
33
34impl MemoryTokenStorage {
35    pub fn new(token_type: Arc<dyn TokenType>) -> Self {
36        MemoryTokenStorage {
37            access_token: Arc::new(RwLock::new(AccessToken {
38                access_token: String::new(),
39                expired_at: Utc::now(),
40            })),
41            refreshing: Arc::new(AtomicBool::new(false)),
42            notify: Arc::new(Notify::new()),
43            token_type,
44        }
45    }
46}
47
48/// 内存存储方式的接口调用凭据存储读取实现
49#[async_trait]
50impl TokenStorage for MemoryTokenStorage {
51    /// 获取接口调用凭据
52    async fn token(&self) -> Result<String> {
53        // 第一次检查:快速路径
54        {
55            let guard = self.access_token.read().await;
56            if !is_token_expired(&guard) {
57                return Ok(guard.access_token.clone());
58            }
59        }
60
61        // 使用CAS竞争刷新权
62        if self
63            .refreshing
64            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
65            .is_ok()
66        {
67            // 获得刷新权
68            match self.refresh_access_token().await {
69                Ok(token) => {
70                    self.refreshing.store(false, Ordering::Release);
71                    self.notify.notify_waiters();
72                    Ok(token)
73                }
74                Err(e) => {
75                    self.refreshing.store(false, Ordering::Release);
76                    self.notify.notify_waiters();
77                    Err(e)
78                }
79            }
80        } else {
81            // 等待其他线程刷新完成
82            self.notify.notified().await;
83            // 刷新完成后重新读取
84            let guard = self.access_token.read().await;
85            Ok(guard.access_token.clone())
86        }
87    }
88
89    /// 刷新接口调用凭据
90    async fn refresh_access_token(&self) -> Result<String> {
91        let mut guard = self.access_token.write().await;
92
93        if !is_token_expired(&guard) {
94            debug!("token already refreshed by another thread");
95            return Ok(guard.access_token.clone());
96        }
97
98        debug!("performing network request to refresh token");
99
100        let builder = self.token_type.token().await?;
101
102        guard.access_token = builder.access_token.clone();
103        guard.expired_at = builder.expired_at;
104
105        debug!("fresh access token: {:#?}", guard);
106
107        Ok(guard.access_token.clone())
108    }
109
110    fn token_type(&self) -> Arc<dyn TokenType> {
111        self.token_type.clone()
112    }
113}