rusty_box/auth/
auth_ccg.rs

1//! Client Credentials Grant (CCG) authentication
2use super::access_token::AccessToken;
3use super::auth_client::{AuthClient, Form};
4use super::{Auth, AuthError};
5use crate::config::Config;
6
7use async_trait::async_trait;
8use chrono::{DateTime, Duration, Utc};
9use serde::Serialize;
10
11/// The type of subject that is being authenticated (user or enterprise)
12#[derive(Debug, Clone, Serialize, PartialEq)]
13pub enum SubjectType {
14    Enterprise,
15    User,
16}
17impl SubjectType {
18    fn value(&self) -> String {
19        match self {
20            Self::Enterprise => "enterprise".to_owned(),
21            Self::User => "user".to_owned(),
22        }
23    }
24}
25impl Default for SubjectType {
26    fn default() -> SubjectType {
27        Self::Enterprise
28    }
29}
30
31/// Client Credentials Grant (CCG) authentication
32#[derive(Debug, Clone, Serialize, Default)]
33pub struct CCGAuth {
34    pub config: Config,
35    client_id: String,
36    client_secret: String,
37    box_subject_type: SubjectType,
38    box_subject_id: String,
39    access_token: AccessToken,
40    expires_by: DateTime<Utc>,
41    #[serde(skip)]
42    client: AuthClient,
43}
44
45impl CCGAuth {
46    pub fn new(
47        config: Config,
48        client_id: String,
49        client_secret: String,
50        box_subject_type: SubjectType,
51        box_subject_id: String,
52    ) -> Self {
53        CCGAuth {
54            config,
55            client_id,
56            client_secret,
57            box_subject_type,
58            box_subject_id,
59            access_token: AccessToken::new(),
60            expires_by: Utc::now(),
61            client: AuthClient::default(),
62        }
63    }
64
65    pub fn is_expired(&self) -> bool {
66        Utc::now() > self.expires_by
67    }
68
69    async fn fetch_access_token(&mut self) -> Result<AccessToken, AuthError> {
70        let url = &(self.config.oauth2_api_url.clone() + "/token");
71
72        let headers = None; // TODO: Add headers to rquest
73
74        let box_subject_type = self.box_subject_type.value();
75
76        let mut payload = Form::new();
77        payload.insert("grant_type", "client_credentials");
78        payload.insert("client_id", &self.client_id);
79        payload.insert("client_secret", &self.client_secret);
80        payload.insert("box_subject_type", &box_subject_type);
81        payload.insert("box_subject_id", &self.box_subject_id);
82
83        let now = Utc::now();
84
85        let response = self.client.post_form(url, headers, &payload).await;
86
87        let data = match response {
88            Ok(data) => data,
89            Err(e) => return Err(e),
90        };
91
92        let access_token = match serde_json::from_str::<AccessToken>(&data) {
93            Ok(access_token) => access_token,
94            Err(e) => {
95                return Err(AuthError::Serde(e));
96            }
97        };
98        let expires_in = access_token.expires_in.unwrap_or_default();
99        self.expires_by = now + Duration::seconds(expires_in);
100        self.access_token = access_token.clone();
101        Ok(access_token)
102    }
103}
104
105#[async_trait]
106impl<'a> Auth<'a> for CCGAuth {
107    async fn access_token(&mut self) -> Result<String, AuthError> {
108        if self.is_expired() {
109            match self.fetch_access_token().await {
110                Ok(access_token) => Ok(access_token.access_token.unwrap_or_default()),
111                Err(e) => Err(e),
112            }
113        } else {
114            let access_token = match self.access_token.access_token.clone() {
115                Some(token) => token,
116                None => return Err(AuthError::Generic("CCG token is not set".to_owned())),
117            };
118            Ok(access_token)
119        }
120    }
121
122    async fn to_json(&mut self) -> Result<String, AuthError> {
123        self.access_token().await?;
124        match serde_json::to_string(&self) {
125            Ok(json) => Ok(json),
126            Err(e) => Err(AuthError::Serde(e)),
127        }
128    }
129
130    fn base_api_url(&self) -> String {
131        self.config.base_api_url()
132    }
133
134    fn user_agent(&self) -> String {
135        self.config.user_agent()
136    }
137}
138
139#[cfg(test)]
140use std::env;
141
142#[tokio::test]
143async fn test_ccg_new() {
144    let config = Config::new();
145    let client_id = "client_id".to_owned();
146    let client_secret = "client_secret".to_owned();
147    let box_subject_type = SubjectType::Enterprise;
148    let box_subject_id = "box_subject_id".to_owned();
149    let ccg_auth = CCGAuth::new(
150        config,
151        client_id,
152        client_secret,
153        box_subject_type,
154        box_subject_id,
155    );
156
157    assert_eq!(ccg_auth.client_id, "client_id".to_owned());
158    assert_eq!(ccg_auth.client_secret, "client_secret".to_owned());
159    assert_eq!(ccg_auth.box_subject_type, SubjectType::Enterprise);
160    assert_eq!(ccg_auth.box_subject_id, "box_subject_id".to_owned());
161}
162
163#[tokio::test]
164async fn test_ccg_request() {
165    dotenv::from_filename(".ccg.env").ok();
166    let config = Config::new();
167    let client_id = env::var("CLIENT_ID").expect("CLIENT_ID must be set");
168    let client_secret = env::var("CLIENT_SECRET").expect("CLIENT_SECRET must be set");
169    let env_subject_type = env::var("BOX_SUBJECT_TYPE").expect("BOX_SUBJECT_TYPE must be set");
170    let box_subject_type = match env_subject_type.as_str() {
171        "user" => SubjectType::User,
172        "enterprise" => SubjectType::Enterprise,
173        _ => panic!("BOX_SUBJECT_TYPE must be either 'user' or 'enterprise'"),
174    };
175
176    let box_subject_id = env::var("BOX_SUBJECT_ID").expect("BOX_SUBJECT_ID must be set");
177
178    let mut auth = CCGAuth::new(
179        config,
180        client_id,
181        client_secret,
182        box_subject_type,
183        box_subject_id,
184    );
185
186    let access_token = auth.access_token().await;
187    // println!("access_token: {:#?}", access_token);
188
189    assert!(access_token.is_ok());
190    assert!(!auth.is_expired());
191    assert!(auth.access_token.access_token.is_some());
192    assert_eq!(
193        access_token.unwrap(),
194        auth.access_token.access_token.unwrap()
195    );
196}