1use crate::{
2 Result,
3 access_token::{AccessToken, get_access_token, get_stable_access_token},
4 constants,
5 credential::{Credential, CredentialBuilder},
6 error::Error::InternalServer,
7 response::Response,
8};
9use chrono::{Duration, Utc};
10use std::{
11 collections::HashMap,
12 sync::{
13 Arc,
14 atomic::{AtomicBool, Ordering},
15 },
16};
17use tokio::sync::{Notify, RwLock};
18use tracing::{debug, instrument};
19
20#[derive(Debug, Clone)]
22pub struct Client {
23 inner: Arc<ClientInner>,
24 access_token: Arc<RwLock<AccessToken>>,
25 refreshing: Arc<AtomicBool>,
26 notify: Arc<Notify>,
27}
28
29impl Client {
30 pub fn new(app_id: &str, secret: &str) -> Self {
44 let client = reqwest::Client::new();
45
46 Self {
47 inner: Arc::new(ClientInner {
48 app_id: app_id.into(),
49 secret: secret.into(),
50 client,
51 }),
52 access_token: Arc::new(RwLock::new(AccessToken {
53 access_token: "".to_string(),
54 expired_at: Utc::now(),
55 force_refresh: None,
56 })),
57 refreshing: Arc::new(AtomicBool::new(false)),
58 notify: Arc::new(Notify::new()),
59 }
60 }
61
62 pub(crate) fn request(&self) -> &reqwest::Client {
63 &self.inner.client
64 }
65
66 #[instrument(skip(self, code))]
87 pub async fn login(&self, code: &str) -> Result<Credential> {
88 debug!("code: {}", code);
89
90 let mut map: HashMap<&str, &str> = HashMap::new();
91
92 map.insert("appid", &self.inner.app_id);
93 map.insert("secret", &self.inner.secret);
94 map.insert("js_code", code);
95 map.insert("grant_type", "authorization_code");
96
97 let response = self
98 .inner
99 .client
100 .get(constants::AUTHENTICATION_END_POINT)
101 .query(&map)
102 .send()
103 .await?;
104
105 debug!("authentication response: {:#?}", response);
106
107 if response.status().is_success() {
108 let response = response.json::<Response<CredentialBuilder>>().await?;
109
110 let credential = response.extract()?.build();
111
112 debug!("credential: {:#?}", credential);
113
114 Ok(credential)
115 } else {
116 Err(InternalServer(response.text().await?))
117 }
118 }
119
120 pub async fn access_token(&self) -> Result<String> {
136 {
138 let guard = self.access_token.read().await;
139 if !is_token_expired(&guard) {
140 return Ok(guard.access_token.clone());
141 }
142 }
143
144 if self
146 .refreshing
147 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
148 .is_ok()
149 {
150 match self.refresh_access_token().await {
152 Ok(token) => {
153 self.refreshing.store(false, Ordering::Release);
154 self.notify.notify_waiters();
155 Ok(token)
156 }
157 Err(e) => {
158 self.refreshing.store(false, Ordering::Release);
159 self.notify.notify_waiters();
160 Err(e)
161 }
162 }
163 } else {
164 self.notify.notified().await;
166 let guard = self.access_token.read().await;
168 Ok(guard.access_token.clone())
169 }
170 }
171
172 async fn refresh_access_token(&self) -> Result<String> {
173 let mut guard = self.access_token.write().await;
174
175 if !is_token_expired(&guard) {
176 debug!("token already refreshed by another thread");
177 return Ok(guard.access_token.clone());
178 }
179
180 debug!("performing network request to refresh token");
181
182 let builder = get_access_token(
183 self.inner.client.clone(),
184 &self.inner.app_id,
185 &self.inner.secret,
186 )
187 .await?;
188
189 guard.access_token = builder.access_token.clone();
190 guard.expired_at = builder.expired_at;
191
192 debug!("fresh access token: {:#?}", guard);
193
194 Ok(guard.access_token.clone())
195 }
196
197 pub async fn stable_access_token(
214 &self,
215 force_refresh: impl Into<Option<bool>> + Clone + Send,
216 ) -> Result<String> {
217 {
219 let guard = self.access_token.read().await;
220 if !is_token_expired(&guard) {
221 return Ok(guard.access_token.clone());
222 }
223 }
224
225 if self
227 .refreshing
228 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
229 .is_ok()
230 {
231 match self.refresh_stable_access_token(force_refresh).await {
233 Ok(token) => {
234 self.refreshing.store(false, Ordering::Release);
235 self.notify.notify_waiters();
236 Ok(token)
237 }
238 Err(e) => {
239 self.refreshing.store(false, Ordering::Release);
240 self.notify.notify_waiters();
241 Err(e)
242 }
243 }
244 } else {
245 self.notify.notified().await;
247 let guard = self.access_token.read().await;
249 Ok(guard.access_token.clone())
250 }
251 }
252
253 async fn refresh_stable_access_token(
254 &self,
255 force_refresh: impl Into<Option<bool>> + Clone + Send,
256 ) -> Result<String> {
257 let mut guard = self.access_token.write().await;
259
260 if !is_token_expired(&guard) {
264 debug!("token already refreshed by another thread");
266 return Ok(guard.access_token.clone());
267 }
268
269 debug!("performing network request to refresh token");
271
272 let builder = get_stable_access_token(
273 self.inner.client.clone(),
274 &self.inner.app_id,
275 &self.inner.secret,
276 force_refresh,
277 )
278 .await?;
279
280 guard.access_token = builder.access_token.clone();
282 guard.expired_at = builder.expired_at;
283
284 debug!("fresh access token: {:#?}", guard);
285
286 Ok(guard.access_token.clone())
288 }
289}
290
291#[derive(Debug)]
292struct ClientInner {
293 app_id: String,
294 secret: String,
295 client: reqwest::Client,
296}
297
298fn is_token_expired(token: &AccessToken) -> bool {
299 let now = Utc::now();
301 token.expired_at.signed_duration_since(now) < Duration::minutes(5)
302}