rusty_box/auth/
auth_jwt.rs

1//! JSON Web Token (JWT) authentication
2use super::access_token::AccessToken;
3use super::auth_client::{AuthClient, Form};
4use super::{Auth, AuthError};
5use crate::config::Config;
6
7use async_trait::async_trait;
8use chrono::{DateTime, Duration, Utc};
9use josekit::{
10    jws::RS512,
11    jwt::{self, JwtPayload},
12};
13use openssl::pkey::PKey;
14use serde::Serialize;
15use serde_json::json;
16
17/// The type of subject that is being authenticated (user or enterprise)
18#[derive(Debug, Clone, Serialize, PartialEq)]
19pub enum SubjectType {
20    Enterprise,
21    User,
22}
23impl SubjectType {
24    fn value(&self) -> String {
25        match self {
26            Self::Enterprise => "enterprise".to_owned(),
27            Self::User => "user".to_owned(),
28        }
29    }
30}
31impl Default for SubjectType {
32    fn default() -> SubjectType {
33        Self::Enterprise
34    }
35}
36
37#[derive(Debug, Clone, Serialize, Default)]
38pub struct JWTAuth {
39    pub config: Config,
40    client_id: String,
41    client_secret: String,
42
43    box_subject_type: SubjectType,
44    box_subject_id: String,
45
46    public_key_id: String,
47    #[serde(skip)]
48    private_key: String,
49    #[serde(skip)]
50    passphrase: String,
51    access_token: AccessToken,
52    expires_by: DateTime<Utc>,
53    #[serde(skip)]
54    client: AuthClient,
55}
56
57impl JWTAuth {
58    #[allow(clippy::too_many_arguments)]
59    pub fn new(
60        config: Config,
61        client_id: String,
62        client_secret: String,
63
64        box_subject_type: SubjectType,
65        box_subject_id: String,
66
67        public_key_id: String,
68        private_key: String,
69        passphrase: String,
70    ) -> Self {
71        JWTAuth {
72            config,
73            client_id,
74            client_secret,
75
76            box_subject_type,
77            box_subject_id,
78
79            public_key_id,
80            private_key,
81            passphrase,
82
83            access_token: AccessToken::new(),
84            expires_by: Utc::now(),
85            client: AuthClient::default(),
86        }
87    }
88
89    pub fn is_expired(&self) -> bool {
90        Utc::now() > self.expires_by - Duration::seconds(60 * 5)
91    }
92
93    async fn fetch_access_token(&mut self) -> Result<AccessToken, AuthError> {
94        let url = &(self.config.oauth2_api_url.clone() + "/token");
95
96        let headers = None; // TODO: Add headers to rquest
97
98        let jwt_token = jwt_assertion(self.clone())?;
99
100        let mut payload = Form::new();
101        payload.insert("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer");
102        payload.insert("client_id", &self.client_id);
103        payload.insert("client_secret", &self.client_secret);
104
105        payload.insert("assertion", &jwt_token);
106
107        let now = Utc::now();
108
109        let response = self.client.post_form(url, headers, &payload).await;
110
111        let data = match response {
112            Ok(data) => data,
113            Err(e) => return Err(e),
114        };
115
116        let access_token = match serde_json::from_str::<AccessToken>(&data) {
117            Ok(access_token) => access_token,
118            Err(e) => {
119                return Err(AuthError::Serde(e));
120            }
121        };
122        let expires_in = access_token.expires_in.unwrap_or_default();
123        self.expires_by = now + Duration::seconds(expires_in);
124        self.access_token = access_token.clone();
125        Ok(access_token)
126    }
127}
128
129fn jwt_assertion(jwt_aut: JWTAuth) -> Result<String, AuthError> {
130    // JWT Header
131    let mut header = josekit::jws::JwsHeader::new();
132    header.set_token_type("JWT");
133
134    header.set_key_id(jwt_aut.public_key_id);
135
136    // JWT Payload
137    let mut payload = JwtPayload::new();
138
139    payload.set_issuer(jwt_aut.client_id);
140
141    payload.set_subject(jwt_aut.box_subject_id);
142
143    let box_subject_type = Some(json!(jwt_aut.box_subject_type.value()));
144    payload.set_claim("box_sub_type", box_subject_type)?;
145
146    let audience = vec![jwt_aut.config.oauth2_api_url + "/token"];
147    payload.set_audience(audience);
148
149    let jwt_id = uuid::Uuid::new_v4().to_string();
150    payload.set_jwt_id(jwt_id);
151
152    let expires_at = std::time::SystemTime::now() + std::time::Duration::from_secs(59);
153    payload.set_expires_at(&expires_at);
154
155    // decrupt private key
156    let private_key = PKey::private_key_from_pem_passphrase(
157        jwt_aut.private_key.as_bytes(),
158        jwt_aut.passphrase.as_bytes(),
159    )?;
160
161    let private_key_pem = private_key.private_key_to_pem_pkcs8()?;
162
163    let signer = RS512.signer_from_pem(private_key_pem)?;
164    let jwt = jwt::encode_with_signer(&payload, &header, &signer)?;
165
166    Ok(jwt)
167}
168
169#[async_trait]
170impl<'a> Auth<'a> for JWTAuth {
171    async fn access_token(&mut self) -> Result<String, AuthError> {
172        if self.is_expired() {
173            match self.fetch_access_token().await {
174                Ok(access_token) => Ok(access_token.access_token.unwrap_or_default()),
175                Err(e) => Err(e),
176            }
177        } else {
178            let access_token = match self.access_token.access_token.clone() {
179                Some(token) => token,
180                None => return Err(AuthError::Generic("CCG token is not set".to_owned())),
181            };
182            Ok(access_token)
183        }
184    }
185
186    async fn to_json(&mut self) -> Result<String, AuthError> {
187        self.access_token().await?;
188        match serde_json::to_string(&self) {
189            Ok(json) => Ok(json),
190            Err(e) => Err(AuthError::Serde(e)),
191        }
192    }
193
194    fn base_api_url(&self) -> String {
195        self.config.base_api_url()
196    }
197
198    fn user_agent(&self) -> String {
199        self.config.user_agent()
200    }
201}
202
203#[cfg(test)]
204#[path = "./unit_tests/auth_jwt_test.rs"]
205mod auth_jwt_test;