rusty_box/auth/
auth_oauth.rs

1//! Client Credentials Grant (CCG) authentication
2use super::access_token::AccessToken;
3use super::auth_client::{AuthClient, Form};
4use super::{Auth, AuthError};
5use crate::config::Config;
6
7use async_trait::async_trait;
8use chrono::{DateTime, Duration, Utc};
9use rand::distributions::Alphanumeric;
10use rand::Rng;
11use serde::Serialize;
12
13/// Client Credentials Grant (CCG) authentication
14#[derive(Debug, Clone, Serialize, Deserialize, Default)]
15pub struct OAuth {
16    pub config: Config,
17    client_id: String,
18    client_secret: String,
19    access_token: AccessToken,
20    expires_by: DateTime<Utc>,
21
22    #[serde(skip)]
23    client: AuthClient,
24
25    #[serde(skip)]
26    store: Option<fn(String)>,
27}
28
29impl OAuth {
30    pub fn new(
31        config: Config,
32        client_id: String,
33        client_secret: String,
34        store: Option<fn(String)>,
35    ) -> Self {
36        OAuth {
37            config,
38            client_id,
39            client_secret,
40            access_token: AccessToken::new(),
41            expires_by: Utc::now(),
42            client: AuthClient::default(),
43            store,
44        }
45    }
46
47    pub fn is_expired(&self) -> bool {
48        Utc::now() > self.expires_by
49    }
50
51    pub fn authorization_url(
52        &self,
53        redirect_url: Option<String>,
54        scope: Option<String>, // TODO: vector of strings?
55        state: Option<String>,
56    ) -> Result<(String, String), AuthError> {
57        let url = self.config.oauth2_authorize_url.clone();
58        let url = url + "?client_id=" + &self.client_id;
59        let url = url + "&response_type=code";
60
61        let url = match scope {
62            Some(scope) => url + "&scope=" + scope.as_str(),
63            None => url,
64        };
65
66        let local_state = match state {
67            Some(state) => state,
68            None => "box_csrf_token_".to_string() + &generate_state(16),
69        };
70        let url = url + "&state=" + local_state.as_str();
71
72        let url = match redirect_url {
73            Some(redirect_url) => url + "&redirect_uri=" + &urlencode(redirect_url.as_str()),
74            None => url,
75        };
76
77        Ok((url, local_state))
78    }
79
80    pub async fn request_access_token(
81        &mut self,
82        // client_id: String,
83        // client_secret: String,
84        code: String,
85    ) -> Result<AccessToken, AuthError> {
86        let url = self.config.oauth2_api_url.clone() + "/token";
87
88        let headers = None; // TODO: Add headers to rquest
89
90        let mut payload = Form::new();
91        payload.insert("client_id", &self.client_id);
92        payload.insert("client_secret", &self.client_secret);
93        payload.insert("grant_type", "authorization_code");
94        payload.insert("code", &code);
95
96        let now = Utc::now();
97
98        let response = self.client.post_form(&url, headers, &payload).await;
99
100        let data = match response {
101            Ok(data) => data,
102            Err(e) => return Err(e),
103        };
104
105        let access_token = match serde_json::from_str::<AccessToken>(&data) {
106            Ok(access_token) => access_token,
107            Err(e) => {
108                return Err(AuthError::Serde(e));
109            }
110        };
111        let expires_in = access_token.expires_in.unwrap_or_default();
112        self.expires_by = now + Duration::seconds(expires_in);
113        self.access_token = access_token.clone();
114        match self.store {
115            Some(store) => {
116                let json_access_token = self.to_json().await?;
117                store(json_access_token);
118            }
119            None => return Ok(access_token),
120        };
121        Ok(access_token)
122    }
123
124    pub fn set_access_token(&mut self, access_token: AccessToken) -> Result<(), AuthError> {
125        self.access_token = access_token;
126        match self.store {
127            Some(store) => {
128                let json_access_token = serde_json::to_string(&self)?;
129                store(json_access_token);
130            }
131            None => return Ok(()),
132        };
133        Ok(())
134    }
135
136    async fn refresh_access_token(&mut self) -> Result<AccessToken, AuthError> {
137        let url = self.config.oauth2_api_url.clone() + "/token";
138
139        let refresh_token = self.access_token.refresh_token.clone().unwrap_or_default();
140
141        let headers = None; // TODO: Add headers to rquest
142
143        let mut payload = Form::new();
144        payload.insert("grant_type", "client_credentials");
145        payload.insert("client_id", &self.client_id);
146        payload.insert("client_secret", &self.client_secret);
147        payload.insert("grant_type", "refresh_token");
148        payload.insert("refresh_token", &refresh_token);
149
150        let now = Utc::now();
151
152        let response = self.client.post_form(&url, headers, &payload).await;
153
154        let data = match response {
155            Ok(data) => data,
156            Err(e) => return Err(e),
157        };
158
159        let access_token = match serde_json::from_str::<AccessToken>(&data) {
160            Ok(access_token) => access_token,
161            Err(e) => {
162                return Err(AuthError::Serde(e));
163            }
164        };
165        let expires_in = access_token.expires_in.unwrap_or_default();
166        self.expires_by = now + Duration::seconds(expires_in);
167        self.access_token = access_token.clone();
168        match self.store {
169            Some(store) => {
170                let json_access_token = self.to_json().await?;
171                store(json_access_token);
172            }
173            None => return Ok(access_token),
174        };
175        Ok(access_token)
176    }
177}
178
179#[async_trait]
180impl<'a> Auth<'a> for OAuth {
181    async fn access_token(&mut self) -> Result<String, AuthError> {
182        if self.is_expired() {
183            match self.refresh_access_token().await {
184                Ok(access_token) => Ok(access_token.access_token.unwrap_or_default()),
185                Err(e) => Err(e),
186            }
187        } else {
188            let access_token = match self.access_token.access_token.clone() {
189                Some(token) => token,
190                None => return Err(AuthError::Generic("CCG token is not set".to_owned())),
191            };
192            Ok(access_token)
193        }
194    }
195
196    async fn to_json(&mut self) -> Result<String, AuthError> {
197        self.access_token().await?;
198        match serde_json::to_string(&self) {
199            Ok(json) => Ok(json),
200            Err(e) => Err(AuthError::Serde(e)),
201        }
202    }
203
204    fn base_api_url(&self) -> String {
205        self.config.base_api_url()
206    }
207
208    fn user_agent(&self) -> String {
209        self.config.user_agent()
210    }
211}
212
213fn urlencode<T: AsRef<str>>(s: T) -> String {
214    ::url::form_urlencoded::byte_serialize(s.as_ref().as_bytes()).collect()
215}
216
217fn generate_state(length: u8) -> String {
218    rand::thread_rng()
219        .sample_iter(&Alphanumeric)
220        .take(length.into())
221        .map(char::from)
222        .collect()
223}
224
225#[cfg(test)]
226fn store(json_access_token: String) {
227    // println!("{}", json_access_token);
228    assert!(json_access_token.len() > 0);
229    assert!(json_access_token.contains("ACCESS_TOKEN"));
230    assert!(json_access_token.contains("REFRESH_TOKEN"));
231}
232#[test]
233fn test_generate_state() {
234    let state = generate_state(16);
235    assert_eq!(state.len(), 16);
236}
237
238#[test]
239fn test_urlencode() {
240    let url = "https://example.com";
241    let encoded_url = urlencode(url);
242    assert_eq!(encoded_url, "https%3A%2F%2Fexample.com");
243}
244
245#[test]
246fn test_authorization_url_default() {
247    let config = Config::new();
248    let auth = OAuth::new(
249        config,
250        "client_id".to_owned(),
251        "client_secret".to_owned(),
252        None,
253    );
254
255    let (auth_url, state) = auth.authorization_url(None, None, None).unwrap_or_default();
256
257    // check if auth_url contains all required params
258    assert!(auth_url.contains("client_id=client_id"));
259    assert!(auth_url.contains("response_type=code"));
260    assert!(auth_url.contains("state=box_csrf_token_"));
261    assert!(state.contains("box_csrf_token_"));
262
263    // check if auth_url does not contain optional scope or redirect_url
264    assert!(!auth_url.contains("scope="));
265    assert!(!auth_url.contains("redirect_uri="));
266}
267
268#[test]
269fn test_authorization_url_state() {
270    let config = Config::new();
271    let auth = OAuth::new(
272        config,
273        "client_id".to_owned(),
274        "client_secret".to_owned(),
275        None,
276    );
277
278    let (auth_url, state) = auth
279        .authorization_url(None, None, Some("ABCDEF".to_string()))
280        .unwrap_or_default();
281
282    // check if auth_url contains all required params
283    assert!(auth_url.contains("client_id=client_id"));
284    assert!(auth_url.contains("response_type=code"));
285    assert!(auth_url.contains("state=ABCDEF"));
286    assert_eq!(state, "ABCDEF");
287
288    // check if auth_url does not contain optional scope or redirect_url
289    assert!(!auth_url.contains("scope="));
290    assert!(!auth_url.contains("redirect_uri="));
291}
292
293#[test]
294fn test_authorization_url_redirect() {
295    let config = Config::new();
296    let auth = OAuth::new(
297        config,
298        "client_id".to_owned(),
299        "client_secret".to_owned(),
300        None,
301    );
302
303    let (auth_url, state) = auth
304        .authorization_url(Some("https://example.com".to_string()), None, None)
305        .unwrap_or_default();
306
307    let encoded_redirect = "redirect_uri=".to_string() + &urlencode("https://example.com");
308
309    // check if auth_url contains all required params
310    assert!(auth_url.contains("client_id=client_id"));
311    assert!(auth_url.contains("response_type=code"));
312    assert!(auth_url.contains("state=box_csrf_token_"));
313    assert!(state.contains("box_csrf_token_"));
314    assert!(auth_url.contains(&encoded_redirect));
315
316    // check if auth_url does not contain optional scope or redirect_url
317    assert!(!auth_url.contains("scope="));
318}
319#[test]
320fn test_authorization_url_scope() {
321    let config = Config::new();
322    let auth = OAuth::new(
323        config,
324        "client_id".to_owned(),
325        "client_secret".to_owned(),
326        None,
327    );
328
329    let (auth_url, state) = auth
330        .authorization_url(None, Some("admin_readwrite".to_string()), None)
331        .unwrap_or_default();
332
333    // check if auth_url contains all required params
334    assert!(auth_url.contains("client_id=client_id"));
335    assert!(auth_url.contains("response_type=code"));
336    assert!(auth_url.contains("state=box_csrf_token_"));
337    assert!(state.contains("box_csrf_token_"));
338    assert!(auth_url.contains("scope=admin_readwrite"));
339
340    // check if auth_url does not contain optional scope or redirect_url
341}
342
343#[test]
344fn test_oauth_new() {
345    let config = Config::new();
346    let client_id = "client_id".to_owned();
347    let client_secret = "client_secret".to_owned();
348    let auth = OAuth::new(config, client_id, client_secret, None);
349
350    assert_eq!(auth.client_id, "client_id".to_owned());
351    assert_eq!(auth.client_secret, "client_secret".to_owned());
352}
353
354#[test]
355fn test_oauth_set_access_token() {
356    let config = Config::new();
357    let mut auth = OAuth::new(
358        config,
359        "client_id".to_owned(),
360        "client_secret".to_owned(),
361        None,
362    );
363    let access_token = AccessToken {
364        access_token: Some("access_token".to_owned()),
365        refresh_token: Some("refresh_token".to_owned()),
366        ..Default::default()
367    };
368    match auth.set_access_token(access_token.clone()) {
369        Ok(_) => assert_eq!(auth.access_token, access_token),
370        Err(_) => panic!("Error setting access token"),
371    };
372}
373
374#[test]
375fn test_store() {
376    let config = Config::new();
377    let mut auth = OAuth::new(
378        config,
379        "client_id".to_owned(),
380        "client_secret".to_owned(),
381        Some(store),
382    );
383    let fake_access_token = AccessToken {
384        access_token: Some("ACCESS_TOKEN".to_string()),
385        expires_in: Some(3333),
386        token_type: Some(super::access_token::TokenType::Bearer),
387        restricted_to: None,
388        refresh_token: Some("REFRESH_TOKEN".to_string()),
389        issued_token_type: None,
390    };
391    match auth.set_access_token(fake_access_token) {
392        Ok(_) => {}
393        Err(_) => panic!("Error setting access token"),
394    };
395}