1use super::access_token::AccessToken;
3use super::auth_client::{AuthClient, Form};
4use super::{Auth, AuthError};
5use crate::config::Config;
6
7use async_trait::async_trait;
8use chrono::{DateTime, Duration, Utc};
9use rand::distributions::Alphanumeric;
10use rand::Rng;
11use serde::Serialize;
12
13#[derive(Debug, Clone, Serialize, Deserialize, Default)]
15pub struct OAuth {
16 pub config: Config,
17 client_id: String,
18 client_secret: String,
19 access_token: AccessToken,
20 expires_by: DateTime<Utc>,
21
22 #[serde(skip)]
23 client: AuthClient,
24
25 #[serde(skip)]
26 store: Option<fn(String)>,
27}
28
29impl OAuth {
30 pub fn new(
31 config: Config,
32 client_id: String,
33 client_secret: String,
34 store: Option<fn(String)>,
35 ) -> Self {
36 OAuth {
37 config,
38 client_id,
39 client_secret,
40 access_token: AccessToken::new(),
41 expires_by: Utc::now(),
42 client: AuthClient::default(),
43 store,
44 }
45 }
46
47 pub fn is_expired(&self) -> bool {
48 Utc::now() > self.expires_by
49 }
50
51 pub fn authorization_url(
52 &self,
53 redirect_url: Option<String>,
54 scope: Option<String>, state: Option<String>,
56 ) -> Result<(String, String), AuthError> {
57 let url = self.config.oauth2_authorize_url.clone();
58 let url = url + "?client_id=" + &self.client_id;
59 let url = url + "&response_type=code";
60
61 let url = match scope {
62 Some(scope) => url + "&scope=" + scope.as_str(),
63 None => url,
64 };
65
66 let local_state = match state {
67 Some(state) => state,
68 None => "box_csrf_token_".to_string() + &generate_state(16),
69 };
70 let url = url + "&state=" + local_state.as_str();
71
72 let url = match redirect_url {
73 Some(redirect_url) => url + "&redirect_uri=" + &urlencode(redirect_url.as_str()),
74 None => url,
75 };
76
77 Ok((url, local_state))
78 }
79
80 pub async fn request_access_token(
81 &mut self,
82 code: String,
85 ) -> Result<AccessToken, AuthError> {
86 let url = self.config.oauth2_api_url.clone() + "/token";
87
88 let headers = None; let mut payload = Form::new();
91 payload.insert("client_id", &self.client_id);
92 payload.insert("client_secret", &self.client_secret);
93 payload.insert("grant_type", "authorization_code");
94 payload.insert("code", &code);
95
96 let now = Utc::now();
97
98 let response = self.client.post_form(&url, headers, &payload).await;
99
100 let data = match response {
101 Ok(data) => data,
102 Err(e) => return Err(e),
103 };
104
105 let access_token = match serde_json::from_str::<AccessToken>(&data) {
106 Ok(access_token) => access_token,
107 Err(e) => {
108 return Err(AuthError::Serde(e));
109 }
110 };
111 let expires_in = access_token.expires_in.unwrap_or_default();
112 self.expires_by = now + Duration::seconds(expires_in);
113 self.access_token = access_token.clone();
114 match self.store {
115 Some(store) => {
116 let json_access_token = self.to_json().await?;
117 store(json_access_token);
118 }
119 None => return Ok(access_token),
120 };
121 Ok(access_token)
122 }
123
124 pub fn set_access_token(&mut self, access_token: AccessToken) -> Result<(), AuthError> {
125 self.access_token = access_token;
126 match self.store {
127 Some(store) => {
128 let json_access_token = serde_json::to_string(&self)?;
129 store(json_access_token);
130 }
131 None => return Ok(()),
132 };
133 Ok(())
134 }
135
136 async fn refresh_access_token(&mut self) -> Result<AccessToken, AuthError> {
137 let url = self.config.oauth2_api_url.clone() + "/token";
138
139 let refresh_token = self.access_token.refresh_token.clone().unwrap_or_default();
140
141 let headers = None; let mut payload = Form::new();
144 payload.insert("grant_type", "client_credentials");
145 payload.insert("client_id", &self.client_id);
146 payload.insert("client_secret", &self.client_secret);
147 payload.insert("grant_type", "refresh_token");
148 payload.insert("refresh_token", &refresh_token);
149
150 let now = Utc::now();
151
152 let response = self.client.post_form(&url, headers, &payload).await;
153
154 let data = match response {
155 Ok(data) => data,
156 Err(e) => return Err(e),
157 };
158
159 let access_token = match serde_json::from_str::<AccessToken>(&data) {
160 Ok(access_token) => access_token,
161 Err(e) => {
162 return Err(AuthError::Serde(e));
163 }
164 };
165 let expires_in = access_token.expires_in.unwrap_or_default();
166 self.expires_by = now + Duration::seconds(expires_in);
167 self.access_token = access_token.clone();
168 match self.store {
169 Some(store) => {
170 let json_access_token = self.to_json().await?;
171 store(json_access_token);
172 }
173 None => return Ok(access_token),
174 };
175 Ok(access_token)
176 }
177}
178
179#[async_trait]
180impl<'a> Auth<'a> for OAuth {
181 async fn access_token(&mut self) -> Result<String, AuthError> {
182 if self.is_expired() {
183 match self.refresh_access_token().await {
184 Ok(access_token) => Ok(access_token.access_token.unwrap_or_default()),
185 Err(e) => Err(e),
186 }
187 } else {
188 let access_token = match self.access_token.access_token.clone() {
189 Some(token) => token,
190 None => return Err(AuthError::Generic("CCG token is not set".to_owned())),
191 };
192 Ok(access_token)
193 }
194 }
195
196 async fn to_json(&mut self) -> Result<String, AuthError> {
197 self.access_token().await?;
198 match serde_json::to_string(&self) {
199 Ok(json) => Ok(json),
200 Err(e) => Err(AuthError::Serde(e)),
201 }
202 }
203
204 fn base_api_url(&self) -> String {
205 self.config.base_api_url()
206 }
207
208 fn user_agent(&self) -> String {
209 self.config.user_agent()
210 }
211}
212
213fn urlencode<T: AsRef<str>>(s: T) -> String {
214 ::url::form_urlencoded::byte_serialize(s.as_ref().as_bytes()).collect()
215}
216
217fn generate_state(length: u8) -> String {
218 rand::thread_rng()
219 .sample_iter(&Alphanumeric)
220 .take(length.into())
221 .map(char::from)
222 .collect()
223}
224
225#[cfg(test)]
226fn store(json_access_token: String) {
227 assert!(json_access_token.len() > 0);
229 assert!(json_access_token.contains("ACCESS_TOKEN"));
230 assert!(json_access_token.contains("REFRESH_TOKEN"));
231}
232#[test]
233fn test_generate_state() {
234 let state = generate_state(16);
235 assert_eq!(state.len(), 16);
236}
237
238#[test]
239fn test_urlencode() {
240 let url = "https://example.com";
241 let encoded_url = urlencode(url);
242 assert_eq!(encoded_url, "https%3A%2F%2Fexample.com");
243}
244
245#[test]
246fn test_authorization_url_default() {
247 let config = Config::new();
248 let auth = OAuth::new(
249 config,
250 "client_id".to_owned(),
251 "client_secret".to_owned(),
252 None,
253 );
254
255 let (auth_url, state) = auth.authorization_url(None, None, None).unwrap_or_default();
256
257 assert!(auth_url.contains("client_id=client_id"));
259 assert!(auth_url.contains("response_type=code"));
260 assert!(auth_url.contains("state=box_csrf_token_"));
261 assert!(state.contains("box_csrf_token_"));
262
263 assert!(!auth_url.contains("scope="));
265 assert!(!auth_url.contains("redirect_uri="));
266}
267
268#[test]
269fn test_authorization_url_state() {
270 let config = Config::new();
271 let auth = OAuth::new(
272 config,
273 "client_id".to_owned(),
274 "client_secret".to_owned(),
275 None,
276 );
277
278 let (auth_url, state) = auth
279 .authorization_url(None, None, Some("ABCDEF".to_string()))
280 .unwrap_or_default();
281
282 assert!(auth_url.contains("client_id=client_id"));
284 assert!(auth_url.contains("response_type=code"));
285 assert!(auth_url.contains("state=ABCDEF"));
286 assert_eq!(state, "ABCDEF");
287
288 assert!(!auth_url.contains("scope="));
290 assert!(!auth_url.contains("redirect_uri="));
291}
292
293#[test]
294fn test_authorization_url_redirect() {
295 let config = Config::new();
296 let auth = OAuth::new(
297 config,
298 "client_id".to_owned(),
299 "client_secret".to_owned(),
300 None,
301 );
302
303 let (auth_url, state) = auth
304 .authorization_url(Some("https://example.com".to_string()), None, None)
305 .unwrap_or_default();
306
307 let encoded_redirect = "redirect_uri=".to_string() + &urlencode("https://example.com");
308
309 assert!(auth_url.contains("client_id=client_id"));
311 assert!(auth_url.contains("response_type=code"));
312 assert!(auth_url.contains("state=box_csrf_token_"));
313 assert!(state.contains("box_csrf_token_"));
314 assert!(auth_url.contains(&encoded_redirect));
315
316 assert!(!auth_url.contains("scope="));
318}
319#[test]
320fn test_authorization_url_scope() {
321 let config = Config::new();
322 let auth = OAuth::new(
323 config,
324 "client_id".to_owned(),
325 "client_secret".to_owned(),
326 None,
327 );
328
329 let (auth_url, state) = auth
330 .authorization_url(None, Some("admin_readwrite".to_string()), None)
331 .unwrap_or_default();
332
333 assert!(auth_url.contains("client_id=client_id"));
335 assert!(auth_url.contains("response_type=code"));
336 assert!(auth_url.contains("state=box_csrf_token_"));
337 assert!(state.contains("box_csrf_token_"));
338 assert!(auth_url.contains("scope=admin_readwrite"));
339
340 }
342
343#[test]
344fn test_oauth_new() {
345 let config = Config::new();
346 let client_id = "client_id".to_owned();
347 let client_secret = "client_secret".to_owned();
348 let auth = OAuth::new(config, client_id, client_secret, None);
349
350 assert_eq!(auth.client_id, "client_id".to_owned());
351 assert_eq!(auth.client_secret, "client_secret".to_owned());
352}
353
354#[test]
355fn test_oauth_set_access_token() {
356 let config = Config::new();
357 let mut auth = OAuth::new(
358 config,
359 "client_id".to_owned(),
360 "client_secret".to_owned(),
361 None,
362 );
363 let access_token = AccessToken {
364 access_token: Some("access_token".to_owned()),
365 refresh_token: Some("refresh_token".to_owned()),
366 ..Default::default()
367 };
368 match auth.set_access_token(access_token.clone()) {
369 Ok(_) => assert_eq!(auth.access_token, access_token),
370 Err(_) => panic!("Error setting access token"),
371 };
372}
373
374#[test]
375fn test_store() {
376 let config = Config::new();
377 let mut auth = OAuth::new(
378 config,
379 "client_id".to_owned(),
380 "client_secret".to_owned(),
381 Some(store),
382 );
383 let fake_access_token = AccessToken {
384 access_token: Some("ACCESS_TOKEN".to_string()),
385 expires_in: Some(3333),
386 token_type: Some(super::access_token::TokenType::Bearer),
387 restricted_to: None,
388 refresh_token: Some("REFRESH_TOKEN".to_string()),
389 issued_token_type: None,
390 };
391 match auth.set_access_token(fake_access_token) {
392 Ok(_) => {}
393 Err(_) => panic!("Error setting access token"),
394 };
395}