wechat_minapp/client/
non_stable_token.rs1use super::{
2 Client, ClientInner,
3 access_token::{AccessToken, AccessTokenBuilder, is_token_expired},
4};
5use crate::{Result, constants, error::Error::InternalServer, response::Response};
6use async_trait::async_trait;
7use chrono::Utc;
8use std::{
9 collections::HashMap,
10 sync::{
11 Arc,
12 atomic::{AtomicBool, Ordering},
13 },
14};
15use tokio::sync::{Notify, RwLock};
16use tracing::{debug, instrument};
17
18#[derive(Debug, Clone)]
19pub struct NonStableTokenClient {
20 inner: Arc<ClientInner>,
21 access_token: Arc<RwLock<AccessToken>>,
22 refreshing: Arc<AtomicBool>,
23 notify: Arc<Notify>,
24}
25
26#[async_trait]
27impl Client for NonStableTokenClient {
28 #[instrument(skip(self))]
29 async fn token(&self) -> Result<String> {
30 {
32 let guard = self.access_token.read().await;
33 if !is_token_expired(&guard) {
34 return Ok(guard.access_token.clone());
35 }
36 }
37
38 if self
40 .refreshing
41 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
42 .is_ok()
43 {
44 match self.refresh_access_token().await {
46 Ok(token) => {
47 self.refreshing.store(false, Ordering::Release);
48 self.notify.notify_waiters();
49 Ok(token)
50 }
51 Err(e) => {
52 self.refreshing.store(false, Ordering::Release);
53 self.notify.notify_waiters();
54 Err(e)
55 }
56 }
57 } else {
58 self.notify.notified().await;
60 let guard = self.access_token.read().await;
62 Ok(guard.access_token.clone())
63 }
64 }
65
66 fn inner_client(&self) -> &ClientInner {
67 &self.inner
68 }
69}
70
71impl NonStableTokenClient {
72 pub fn new(app_id: &str, secret: &str) -> Self {
91 NonStableTokenClient {
92 inner: Arc::new(ClientInner {
93 app_id: app_id.to_string(),
94 secret: secret.to_string(),
95 client: reqwest::Client::new(),
96 }),
97 access_token: Arc::new(RwLock::new(AccessToken {
98 access_token: String::new(),
99 expired_at: Utc::now(),
100 })),
101 refreshing: Arc::new(AtomicBool::new(false)),
102 notify: Arc::new(Notify::new()),
103 }
104 }
105
106 async fn get_access_token(&self) -> Result<AccessTokenBuilder> {
109 let mut map: HashMap<&str, String> = HashMap::new();
110 let client = &self.inner.client;
111 let appid = &self.inner.app_id;
112 let secret = &self.inner.secret;
113 map.insert("grant_type", "client_credential".into());
114 map.insert("appid", appid.to_string());
115 map.insert("secret", secret.to_string());
116
117 let response = client
118 .post(constants::ACCESS_TOKEN_END_POINT)
119 .json(&map)
120 .send()
121 .await?;
122
123 debug!("response: {:#?}", response);
124
125 if response.status().is_success() {
126 let response = response.json::<Response<AccessTokenBuilder>>().await?;
127
128 let builder = response.extract()?;
129
130 debug!("stable access token builder: {:#?}", builder);
131
132 Ok(builder)
133 } else {
134 Err(InternalServer(response.text().await?))
135 }
136 }
137
138 async fn refresh_access_token(&self) -> Result<String> {
139 let mut guard = self.access_token.write().await;
140
141 if !is_token_expired(&guard) {
142 debug!("token already refreshed by another thread");
143 return Ok(guard.access_token.clone());
144 }
145
146 debug!("performing network request to refresh token");
147
148 let builder = self.get_access_token().await?;
149
150 guard.access_token = builder.access_token.clone();
151 guard.expired_at = builder.expired_at;
152
153 debug!("fresh access token: {:#?}", guard);
154
155 Ok(guard.access_token.clone())
156 }
157}