wechat_minapp/client/
non_stable_token.rs

1use super::{
2    Client, ClientInner,
3    access_token::{AccessToken, AccessTokenBuilder, is_token_expired},
4};
5use crate::{Result, constants, error::Error::InternalServer, response::Response};
6use async_trait::async_trait;
7use chrono::Utc;
8use std::{
9    collections::HashMap,
10    sync::{
11        Arc,
12        atomic::{AtomicBool, Ordering},
13    },
14};
15use tokio::sync::{Notify, RwLock};
16use tracing::{debug, instrument};
17
18#[derive(Debug, Clone)]
19pub struct NonStableTokenClient {
20    inner: Arc<ClientInner>,
21    access_token: Arc<RwLock<AccessToken>>,
22    refreshing: Arc<AtomicBool>,
23    notify: Arc<Notify>,
24}
25
26#[async_trait]
27impl Client for NonStableTokenClient {
28    #[instrument(skip(self))]
29    async fn token(&self) -> Result<String> {
30        // 第一次检查:快速路径
31        {
32            let guard = self.access_token.read().await;
33            if !is_token_expired(&guard) {
34                return Ok(guard.access_token.clone());
35            }
36        }
37
38        // 使用CAS竞争刷新权
39        if self
40            .refreshing
41            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
42            .is_ok()
43        {
44            // 获得刷新权
45            match self.refresh_access_token().await {
46                Ok(token) => {
47                    self.refreshing.store(false, Ordering::Release);
48                    self.notify.notify_waiters();
49                    Ok(token)
50                }
51                Err(e) => {
52                    self.refreshing.store(false, Ordering::Release);
53                    self.notify.notify_waiters();
54                    Err(e)
55                }
56            }
57        } else {
58            // 等待其他线程刷新完成
59            self.notify.notified().await;
60            // 刷新完成后重新读取
61            let guard = self.access_token.read().await;
62            Ok(guard.access_token.clone())
63        }
64    }
65
66    fn inner_client(&self) -> &ClientInner {
67        &self.inner
68    }
69}
70
71impl NonStableTokenClient {
72    /// 创建新的微信小程序客户端
73    ///
74    /// # 参数
75    ///
76    /// - `app_id`: 小程序 AppID
77    /// - `secret`: 小程序 AppSecret
78    ///
79    /// # 返回
80    ///
81    /// 新的 `StableTokenClient` 实例
82    ///
83    /// # 示例
84    ///
85    /// ```no_run
86    /// use wechat_minapp::client::NonStableTokenClient;
87    ///
88    /// let client = NonStableTokenClient::new("your_appid", "your_app_secret_here");
89    /// ```
90    pub fn new(app_id: &str, secret: &str) -> Self {
91        NonStableTokenClient {
92            inner: Arc::new(ClientInner {
93                app_id: app_id.to_string(),
94                secret: secret.to_string(),
95                client: reqwest::Client::new(),
96            }),
97            access_token: Arc::new(RwLock::new(AccessToken {
98                access_token: String::new(),
99                expired_at: Utc::now(),
100            })),
101            refreshing: Arc::new(AtomicBool::new(false)),
102            notify: Arc::new(Notify::new()),
103        }
104    }
105
106    /// 获取小程序全局唯一后台接口调用凭据(access_token)
107    /// https://developers.weixin.qq.com/miniprogram/dev/api-backend/open-api/access-token/auth.getAccessToken.html
108    async fn get_access_token(&self) -> Result<AccessTokenBuilder> {
109        let mut map: HashMap<&str, String> = HashMap::new();
110        let client = &self.inner.client;
111        let appid = &self.inner.app_id;
112        let secret = &self.inner.secret;
113        map.insert("grant_type", "client_credential".into());
114        map.insert("appid", appid.to_string());
115        map.insert("secret", secret.to_string());
116
117        let response = client
118            .post(constants::ACCESS_TOKEN_END_POINT)
119            .json(&map)
120            .send()
121            .await?;
122
123        debug!("response: {:#?}", response);
124
125        if response.status().is_success() {
126            let response = response.json::<Response<AccessTokenBuilder>>().await?;
127
128            let builder = response.extract()?;
129
130            debug!("stable access token builder: {:#?}", builder);
131
132            Ok(builder)
133        } else {
134            Err(InternalServer(response.text().await?))
135        }
136    }
137
138    async fn refresh_access_token(&self) -> Result<String> {
139        let mut guard = self.access_token.write().await;
140
141        if !is_token_expired(&guard) {
142            debug!("token already refreshed by another thread");
143            return Ok(guard.access_token.clone());
144        }
145
146        debug!("performing network request to refresh token");
147
148        let builder = self.get_access_token().await?;
149
150        guard.access_token = builder.access_token.clone();
151        guard.expired_at = builder.expired_at;
152
153        debug!("fresh access token: {:#?}", guard);
154
155        Ok(guard.access_token.clone())
156    }
157}