1use crate::error::{Error, Result};
4use chrono::{DateTime, Duration, Utc};
5use oauth2::{
6 basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
7 PkceCodeChallenge, RedirectUrl, Scope, TokenResponse, TokenUrl,
8};
9use serde::{Deserialize, Serialize};
10
11#[cfg(feature = "tracing")]
12use tracing::{debug, error, info};
13
14#[derive(Debug, Clone)]
19pub struct OAuth2Config {
20 pub client_id: String,
22
23 pub client_secret: String,
25
26 pub redirect_uri: String,
29
30 pub auth_url: String,
32
33 pub token_url: String,
35}
36
37impl Default for OAuth2Config {
38 fn default() -> Self {
39 Self {
40 client_id: String::new(),
41 client_secret: String::new(),
42 redirect_uri: String::new(),
43 auth_url: "https://identity.vismaonline.com/connect/authorize".to_string(),
44 token_url: "https://identity.vismaonline.com/connect/token".to_string(),
45 }
46 }
47}
48
49impl OAuth2Config {
50 pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
70 Self {
71 client_id,
72 client_secret,
73 redirect_uri,
74 ..Default::default()
75 }
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct AccessToken {
85 pub token: String,
87
88 pub expires_at: DateTime<Utc>,
90
91 pub refresh_token: Option<String>,
94
95 pub token_type: String,
97}
98
99impl AccessToken {
100 pub fn new(token: String, expires_in: i64, refresh_token: Option<String>) -> Self {
121 let expires_at = Utc::now() + Duration::seconds(expires_in);
122 Self {
123 token,
124 expires_at,
125 refresh_token,
126 token_type: "Bearer".to_string(),
127 }
128 }
129
130 pub fn is_expired(&self) -> bool {
144 let buffer = Duration::minutes(5);
145 Utc::now() + buffer >= self.expires_at
146 }
147
148 pub fn authorization_header(&self) -> String {
153 format!("{} {}", self.token_type, self.token)
154 }
155}
156
157pub struct OAuth2Handler {
159 #[allow(dead_code)]
160 config: OAuth2Config,
161 client: BasicClient,
162}
163
164impl OAuth2Handler {
165 pub fn new(config: OAuth2Config) -> Result<Self> {
167 let client = BasicClient::new(
168 ClientId::new(config.client_id.clone()),
169 Some(ClientSecret::new(config.client_secret.clone())),
170 AuthUrl::new(config.auth_url.clone())
171 .map_err(|e| Error::InvalidConfig(format!("Invalid auth URL: {}", e)))?,
172 Some(
173 TokenUrl::new(config.token_url.clone())
174 .map_err(|e| Error::InvalidConfig(format!("Invalid token URL: {}", e)))?,
175 ),
176 )
177 .set_redirect_uri(
178 RedirectUrl::new(config.redirect_uri.clone())
179 .map_err(|e| Error::InvalidConfig(format!("Invalid redirect URI: {}", e)))?,
180 );
181
182 Ok(Self { config, client })
183 }
184
185 pub fn authorize_url(&self) -> (String, String, String) {
190 #[cfg(feature = "tracing")]
191 debug!("Generating OAuth2 authorization URL with PKCE");
192
193 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
194
195 let (auth_url, csrf_token) = self
196 .client
197 .authorize_url(CsrfToken::new_random)
198 .add_scope(Scope::new("ea:api".to_string()))
199 .add_scope(Scope::new("ea:sales".to_string()))
200 .add_scope(Scope::new("offline_access".to_string()))
201 .set_pkce_challenge(pkce_challenge)
202 .url();
203
204 #[cfg(feature = "tracing")]
205 info!("Authorization URL generated successfully");
206
207 (
208 auth_url.to_string(),
209 csrf_token.secret().to_string(),
210 pkce_verifier.secret().to_string(),
211 )
212 }
213
214 pub async fn exchange_code(&self, code: String, pkce_verifier: String) -> Result<AccessToken> {
229 use oauth2::PkceCodeVerifier;
230
231 #[cfg(feature = "tracing")]
232 info!("Exchanging authorization code for access token");
233
234 let token_result = self
235 .client
236 .exchange_code(AuthorizationCode::new(code))
237 .set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier))
238 .request_async(oauth2::reqwest::async_http_client)
239 .await
240 .map_err(|e| {
241 #[cfg(feature = "tracing")]
242 error!(error = %e, "Token exchange failed");
243 Error::OAuth2Error(format!("Token exchange failed: {}", e))
244 })?;
245
246 let expires_in = token_result
247 .expires_in()
248 .map(|d| d.as_secs() as i64)
249 .unwrap_or(3600); #[cfg(feature = "tracing")]
252 info!(expires_in_secs = expires_in, "Token exchange successful");
253
254 Ok(AccessToken::new(
255 token_result.access_token().secret().to_string(),
256 expires_in,
257 token_result.refresh_token().map(|t| t.secret().to_string()),
258 ))
259 }
260
261 pub async fn refresh_token(&self, refresh_token: String) -> Result<AccessToken> {
280 use oauth2::RefreshToken;
281
282 #[cfg(feature = "tracing")]
283 debug!("Refreshing access token using refresh token");
284
285 let token_result = self
286 .client
287 .exchange_refresh_token(&RefreshToken::new(refresh_token))
288 .request_async(oauth2::reqwest::async_http_client)
289 .await
290 .map_err(|e| {
291 #[cfg(feature = "tracing")]
292 error!(error = %e, "Token refresh failed");
293 Error::OAuth2Error(format!("Token refresh failed: {}", e))
294 })?;
295
296 let expires_in = token_result
297 .expires_in()
298 .map(|d| d.as_secs() as i64)
299 .unwrap_or(3600);
300
301 #[cfg(feature = "tracing")]
302 info!(expires_in_secs = expires_in, "Token refresh successful");
303
304 Ok(AccessToken::new(
305 token_result.access_token().secret().to_string(),
306 expires_in,
307 token_result.refresh_token().map(|t| t.secret().to_string()),
308 ))
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_access_token_expiration() {
318 let token = AccessToken::new("test_token".to_string(), 3600, None);
319 assert!(!token.is_expired());
320
321 let expired_token = AccessToken::new("test_token".to_string(), 0, None);
322 std::thread::sleep(std::time::Duration::from_millis(100));
324 assert!(expired_token.is_expired());
325 }
326
327 #[test]
328 fn test_authorization_header() {
329 let token = AccessToken::new("test_token_123".to_string(), 3600, None);
330 assert_eq!(token.authorization_header(), "Bearer test_token_123");
331 }
332}