wechat_minapp/client/
stable_token.rs1use super::{
2 Client, ClientInner,
3 access_token::{AccessToken, AccessTokenBuilder, is_token_expired},
4};
5use crate::{Result, constants, error::Error::InternalServer, response::Response};
6use async_trait::async_trait;
7use chrono::Utc;
8use std::{
9 collections::HashMap,
10 sync::{
11 Arc,
12 atomic::{AtomicBool, Ordering},
13 },
14};
15use tokio::sync::{Notify, RwLock};
16use tracing::{debug, instrument};
17
18#[derive(Debug, Clone)]
19pub struct StableTokenClient {
20 inner: Arc<ClientInner>,
21 access_token: Arc<RwLock<AccessToken>>,
22 refreshing: Arc<AtomicBool>,
23 notify: Arc<Notify>,
24 force_refresh: bool,
25}
26
27#[async_trait]
28impl Client for StableTokenClient {
29 #[instrument(skip(self))]
30 async fn token(&self) -> Result<String> {
31 {
33 let guard = self.access_token.read().await;
34 if !is_token_expired(&guard) {
35 return Ok(guard.access_token.clone());
36 }
37 }
38
39 if self
41 .refreshing
42 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
43 .is_ok()
44 {
45 match self.refresh_access_token().await {
47 Ok(token) => {
48 self.refreshing.store(false, Ordering::Release);
49 self.notify.notify_waiters();
50 Ok(token)
51 }
52 Err(e) => {
53 self.refreshing.store(false, Ordering::Release);
54 self.notify.notify_waiters();
55 Err(e)
56 }
57 }
58 } else {
59 self.notify.notified().await;
61 let guard = self.access_token.read().await;
63 Ok(guard.access_token.clone())
64 }
65 }
66
67 fn inner_client(&self) -> &ClientInner {
68 &self.inner
69 }
70}
71
72impl StableTokenClient {
73 pub fn new(app_id: &str, secret: &str) -> Self {
92 StableTokenClient {
93 inner: Arc::new(ClientInner {
94 app_id: app_id.to_string(),
95 secret: secret.to_string(),
96 client: reqwest::Client::new(),
97 }),
98 access_token: Arc::new(RwLock::new(AccessToken {
99 access_token: String::new(),
100 expired_at: Utc::now(),
101 })),
102 refreshing: Arc::new(AtomicBool::new(false)),
103 notify: Arc::new(Notify::new()),
104 force_refresh: false,
105 }
106 }
107
108 pub fn with_force_refresh(mut self, force_refresh: bool) -> Self {
120 self.force_refresh = force_refresh;
121 self
122 }
123
124 async fn get_access_token(&self) -> Result<AccessTokenBuilder> {
127 let mut map: HashMap<&str, String> = HashMap::new();
128 let client = &self.inner.client;
129 let appid = &self.inner.app_id;
130 let secret = &self.inner.secret;
131 let force_refresh = self.force_refresh;
132 map.insert("grant_type", "client_credential".into());
133 map.insert("appid", appid.to_string());
134 map.insert("secret", secret.to_string());
135
136 if force_refresh {
137 debug!("force_refresh: {}", force_refresh);
138
139 map.insert("force_refresh", force_refresh.to_string());
140 }
141
142 let response = client
143 .post(constants::STABLE_ACCESS_TOKEN_END_POINT)
144 .json(&map)
145 .send()
146 .await?;
147
148 debug!("response: {:#?}", response);
149
150 if response.status().is_success() {
151 let response = response.json::<Response<AccessTokenBuilder>>().await?;
152
153 let builder = response.extract()?;
154
155 debug!("stable access token builder: {:#?}", builder);
156
157 Ok(builder)
158 } else {
159 Err(InternalServer(response.text().await?))
160 }
161 }
162
163 async fn refresh_access_token(&self) -> Result<String> {
164 let mut guard = self.access_token.write().await;
165
166 if !is_token_expired(&guard) {
167 debug!("token already refreshed by another thread");
168 return Ok(guard.access_token.clone());
169 }
170
171 debug!("performing network request to refresh token");
172
173 let builder = self.get_access_token().await?;
174
175 guard.access_token = builder.access_token.clone();
176 guard.expired_at = builder.expired_at;
177
178 debug!("fresh access token: {:#?}", guard);
179
180 Ok(guard.access_token.clone())
181 }
182}