wechat_minapp/client/
stable_token.rs

1use super::{
2    Client, ClientInner,
3    access_token::{AccessToken, AccessTokenBuilder, is_token_expired},
4};
5use crate::{Result, constants, error::Error::InternalServer, response::Response};
6use async_trait::async_trait;
7use chrono::Utc;
8use std::{
9    collections::HashMap,
10    sync::{
11        Arc,
12        atomic::{AtomicBool, Ordering},
13    },
14};
15use tokio::sync::{Notify, RwLock};
16use tracing::{debug, instrument};
17
18#[derive(Debug, Clone)]
19pub struct StableTokenClient {
20    inner: Arc<ClientInner>,
21    access_token: Arc<RwLock<AccessToken>>,
22    refreshing: Arc<AtomicBool>,
23    notify: Arc<Notify>,
24    force_refresh: bool,
25}
26
27#[async_trait]
28impl Client for StableTokenClient {
29    #[instrument(skip(self))]
30    async fn token(&self) -> Result<String> {
31        // 第一次检查:快速路径
32        {
33            let guard = self.access_token.read().await;
34            if !is_token_expired(&guard) {
35                return Ok(guard.access_token.clone());
36            }
37        }
38
39        // 使用CAS竞争刷新权
40        if self
41            .refreshing
42            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
43            .is_ok()
44        {
45            // 获得刷新权
46            match self.refresh_access_token().await {
47                Ok(token) => {
48                    self.refreshing.store(false, Ordering::Release);
49                    self.notify.notify_waiters();
50                    Ok(token)
51                }
52                Err(e) => {
53                    self.refreshing.store(false, Ordering::Release);
54                    self.notify.notify_waiters();
55                    Err(e)
56                }
57            }
58        } else {
59            // 等待其他线程刷新完成
60            self.notify.notified().await;
61            // 刷新完成后重新读取
62            let guard = self.access_token.read().await;
63            Ok(guard.access_token.clone())
64        }
65    }
66
67    fn inner_client(&self) -> &ClientInner {
68        &self.inner
69    }
70}
71
72impl StableTokenClient {
73    /// 创建新的使用稳定版接口调用凭据的微信小程序客户端
74    ///
75    /// # 参数
76    ///
77    /// - `app_id`: 小程序 AppID
78    /// - `secret`: 小程序 AppSecret
79    ///
80    /// # 返回
81    ///
82    /// 新的 `StableTokenClient` 实例
83    ///
84    /// # 示例
85    ///
86    /// ```no_run
87    /// use wechat_minapp::client::StableTokenClient;
88    ///
89    /// let client = StableTokenClient::new("your_appid", "your_app_secret_here");
90    /// ```
91    pub fn new(app_id: &str, secret: &str) -> Self {
92        StableTokenClient {
93            inner: Arc::new(ClientInner {
94                app_id: app_id.to_string(),
95                secret: secret.to_string(),
96                client: reqwest::Client::new(),
97            }),
98            access_token: Arc::new(RwLock::new(AccessToken {
99                access_token: String::new(),
100                expired_at: Utc::now(),
101            })),
102            refreshing: Arc::new(AtomicBool::new(false)),
103            notify: Arc::new(Notify::new()),
104            force_refresh: false,
105        }
106    }
107
108    /// 稳定版接口调用凭据有两种调用模式:
109    /// 1. 普通模式,access_token 有效期内重复调用该接口不会更新 access_token,绝大部分场景下使用该模式;
110    /// 2. 强制刷新模式,会导致上次获取的 access_token 失效,并返回新的 access_token;
111    ///
112    /// 默认使用普通模式,如果需要强制刷新,可调用此方法
113    /// ```no_run
114    /// use wechat_minapp::client::StableTokenClient;
115    ///
116    /// let mut client = StableTokenClient::new("your_appid", "your_app_secret_here");
117    /// client.with_fore_refresh(true);
118    /// ```
119    pub fn with_force_refresh(mut self, force_refresh: bool) -> Self {
120        self.force_refresh = force_refresh;
121        self
122    }
123
124    /// 获取小程序全局唯一后台接口调用凭据(access_token)
125    /// https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/mp-access-token/getStableAccessToken.html
126    async fn get_access_token(&self) -> Result<AccessTokenBuilder> {
127        let mut map: HashMap<&str, String> = HashMap::new();
128        let client = &self.inner.client;
129        let appid = &self.inner.app_id;
130        let secret = &self.inner.secret;
131        let force_refresh = self.force_refresh;
132        map.insert("grant_type", "client_credential".into());
133        map.insert("appid", appid.to_string());
134        map.insert("secret", secret.to_string());
135
136        if force_refresh {
137            debug!("force_refresh: {}", force_refresh);
138
139            map.insert("force_refresh", force_refresh.to_string());
140        }
141
142        let response = client
143            .post(constants::STABLE_ACCESS_TOKEN_END_POINT)
144            .json(&map)
145            .send()
146            .await?;
147
148        debug!("response: {:#?}", response);
149
150        if response.status().is_success() {
151            let response = response.json::<Response<AccessTokenBuilder>>().await?;
152
153            let builder = response.extract()?;
154
155            debug!("stable access token builder: {:#?}", builder);
156
157            Ok(builder)
158        } else {
159            Err(InternalServer(response.text().await?))
160        }
161    }
162
163    async fn refresh_access_token(&self) -> Result<String> {
164        let mut guard = self.access_token.write().await;
165
166        if !is_token_expired(&guard) {
167            debug!("token already refreshed by another thread");
168            return Ok(guard.access_token.clone());
169        }
170
171        debug!("performing network request to refresh token");
172
173        let builder = self.get_access_token().await?;
174
175        guard.access_token = builder.access_token.clone();
176        guard.expired_at = builder.expired_at;
177
178        debug!("fresh access token: {:#?}", guard);
179
180        Ok(guard.access_token.clone())
181    }
182}