1use crate::{
2 Result,
3 access_token::{AccessToken, get_access_token, get_stable_access_token},
4 constants,
5 credential::{Credential, CredentialBuilder},
6 error::Error::InternalServer,
7 response::Response,
8};
9use chrono::Utc;
10use std::{
11 collections::HashMap,
12 sync::{
13 Arc,
14 atomic::{AtomicBool, Ordering},
15 },
16};
17use tokio::sync::{Notify, RwLock};
18use tracing::{Level, event, instrument};
19
20#[derive(Debug, Clone)]
22pub struct Client {
23 inner: Arc<ClientInner>,
24 access_token: Arc<RwLock<AccessToken>>,
25 refreshing: Arc<AtomicBool>,
26 notify: Arc<Notify>,
27}
28
29impl Client {
30 pub fn new(app_id: &str, secret: &str) -> Self {
44 let client = reqwest::Client::new();
45
46 Self {
47 inner: Arc::new(ClientInner {
48 app_id: app_id.into(),
49 secret: secret.into(),
50 client,
51 }),
52 access_token: Arc::new(RwLock::new(AccessToken {
53 access_token: "".to_string(),
54 expired_at: Utc::now(),
55 force_refresh: None,
56 })),
57 refreshing: Arc::new(AtomicBool::new(false)),
58 notify: Arc::new(Notify::new()),
59 }
60 }
61
62 pub(crate) fn request(&self) -> &reqwest::Client {
63 &self.inner.client
64 }
65
66 #[instrument(skip(self, code))]
89 pub async fn login(&self, code: &str) -> Result<Credential> {
90 event!(Level::DEBUG, "code: {}", code);
91
92 let mut map: HashMap<&str, &str> = HashMap::new();
93
94 map.insert("appid", &self.inner.app_id);
95 map.insert("secret", &self.inner.secret);
96 map.insert("js_code", code);
97 map.insert("grant_type", "authorization_code");
98
99 let response = self
100 .inner
101 .client
102 .get(constants::AUTHENTICATION_END_POINT)
103 .query(&map)
104 .send()
105 .await?;
106
107 event!(Level::DEBUG, "authentication response: {:#?}", response);
108
109 if response.status().is_success() {
110 let response = response.json::<Response<CredentialBuilder>>().await?;
111
112 let credential = response.extract()?.build();
113
114 event!(Level::DEBUG, "credential: {:#?}", credential);
115
116 Ok(credential)
117 } else {
118 Err(InternalServer(response.text().await?))
119 }
120 }
121
122 pub async fn access_token(&self) -> Result<String> {
138 let guard = self.access_token.read().await;
139 event!(Level::DEBUG, "expired at: {}", guard.expired_at);
140
141 if self.refreshing.load(Ordering::Acquire) {
142 event!(Level::DEBUG, "refreshing");
143
144 self.notify.notified().await;
145 } else {
146 event!(Level::DEBUG, "prepare to fresh");
147
148 self.refreshing.store(true, Ordering::Release);
149
150 drop(guard);
151
152 event!(Level::DEBUG, "write access token guard");
153
154 let mut guard = self.access_token.write().await;
155
156 let builder = get_access_token(
157 self.inner.client.clone(),
158 &self.inner.app_id,
159 &self.inner.secret,
160 )
161 .await?;
162
163 guard.access_token = builder.access_token;
164 guard.expired_at = builder.expired_at;
165
166 self.refreshing.store(false, Ordering::Release);
167
168 self.notify.notify_waiters();
169
170 event!(Level::DEBUG, "fresh access token: {:#?}", guard);
171
172 return Ok(guard.access_token.clone());
173 }
174 Ok(guard.access_token.clone())
175 }
176
177
178 pub async fn stable_access_token(
195 &self,
196 force_refresh: impl Into<Option<bool>> + Clone + Send,
197 ) -> Result<String> {
198 let guard = self.access_token.read().await;
199 event!(Level::DEBUG, "expired at: {}", guard.expired_at);
200
201 if self.refreshing.load(Ordering::Acquire) {
202 event!(Level::DEBUG, "refreshing");
203
204 self.notify.notified().await;
205 } else {
206 event!(Level::DEBUG, "prepare to fresh");
207
208 self.refreshing.store(true, Ordering::Release);
209
210 drop(guard);
211
212 event!(Level::DEBUG, "write access token guard");
213
214 let mut guard = self.access_token.write().await;
215
216 let builder = get_stable_access_token(
217 self.inner.client.clone(),
218 &self.inner.app_id,
219 &self.inner.secret,
220 force_refresh,
221 )
222 .await?;
223
224 guard.access_token = builder.access_token;
225 guard.expired_at = builder.expired_at;
226
227 self.refreshing.store(false, Ordering::Release);
228
229 self.notify.notify_waiters();
230
231 event!(Level::DEBUG, "fresh access token: {:#?}", guard);
232
233 return Ok(guard.access_token.clone());
234 }
235 Ok(guard.access_token.clone())
236 }
237}
238
239#[derive(Debug)]
240struct ClientInner {
241 app_id: String,
242 secret: String,
243 client: reqwest::Client,
244}