spiris/
auth.rs

1//! OAuth2 authentication for the Spiris Bokföring och Fakturering API.
2
3use crate::error::{Error, Result};
4use chrono::{DateTime, Duration, Utc};
5use oauth2::{
6    basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
7    PkceCodeChallenge, RedirectUrl, Scope, TokenResponse, TokenUrl,
8};
9use serde::{Deserialize, Serialize};
10
11#[cfg(feature = "tracing")]
12use tracing::{debug, error, info};
13
14/// OAuth2 configuration for Spiris Bokföring och Fakturering.
15///
16/// You can obtain OAuth2 credentials by registering your application
17/// in the [Visma Developer Portal](https://developer.visma.com/).
18#[derive(Debug, Clone)]
19pub struct OAuth2Config {
20    /// Client ID from Visma developer portal.
21    pub client_id: String,
22
23    /// Client secret from Visma developer portal.
24    pub client_secret: String,
25
26    /// Redirect URI registered in Visma developer portal.
27    /// Must exactly match the URI registered in your application settings.
28    pub redirect_uri: String,
29
30    /// Authorization endpoint URL.
31    pub auth_url: String,
32
33    /// Token endpoint URL.
34    pub token_url: String,
35}
36
37impl Default for OAuth2Config {
38    fn default() -> Self {
39        Self {
40            client_id: String::new(),
41            client_secret: String::new(),
42            redirect_uri: String::new(),
43            auth_url: "https://identity.vismaonline.com/connect/authorize".to_string(),
44            token_url: "https://identity.vismaonline.com/connect/token".to_string(),
45        }
46    }
47}
48
49impl OAuth2Config {
50    /// Create a new OAuth2 configuration.
51    ///
52    /// # Arguments
53    ///
54    /// * `client_id` - OAuth2 client ID from developer portal
55    /// * `client_secret` - OAuth2 client secret from developer portal
56    /// * `redirect_uri` - Callback URI for OAuth2 flow
57    ///
58    /// # Example
59    ///
60    /// ```
61    /// use spiris::auth::OAuth2Config;
62    ///
63    /// let config = OAuth2Config::new(
64    ///     "your_client_id".to_string(),
65    ///     "your_client_secret".to_string(),
66    ///     "http://localhost:8080/callback".to_string(),
67    /// );
68    /// ```
69    pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
70        Self {
71            client_id,
72            client_secret,
73            redirect_uri,
74            ..Default::default()
75        }
76    }
77}
78
79/// Access token with expiration tracking.
80///
81/// Tokens typically expire after 1 hour. Use `is_expired()` to check
82/// if a token needs to be refreshed before making API requests.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct AccessToken {
85    /// The access token string.
86    pub token: String,
87
88    /// When the token expires (UTC).
89    pub expires_at: DateTime<Utc>,
90
91    /// Refresh token for obtaining new access tokens.
92    /// Required for token refresh flow.
93    pub refresh_token: Option<String>,
94
95    /// Token type (usually "Bearer").
96    pub token_type: String,
97}
98
99impl AccessToken {
100    /// Create a new access token.
101    ///
102    /// # Arguments
103    ///
104    /// * `token` - The access token string
105    /// * `expires_in` - Token lifetime in seconds
106    /// * `refresh_token` - Optional refresh token for token renewal
107    ///
108    /// # Example
109    ///
110    /// ```
111    /// use spiris::AccessToken;
112    ///
113    /// // Token expires in 1 hour (3600 seconds)
114    /// let token = AccessToken::new(
115    ///     "access_token_string".to_string(),
116    ///     3600,
117    ///     Some("refresh_token_string".to_string())
118    /// );
119    /// ```
120    pub fn new(token: String, expires_in: i64, refresh_token: Option<String>) -> Self {
121        let expires_at = Utc::now() + Duration::seconds(expires_in);
122        Self {
123            token,
124            expires_at,
125            refresh_token,
126            token_type: "Bearer".to_string(),
127        }
128    }
129
130    /// Check if the token is expired or will expire soon (within 5 minutes).
131    ///
132    /// Returns `true` if the token should be refreshed.
133    ///
134    /// # Example
135    ///
136    /// ```
137    /// # use spiris::AccessToken;
138    /// # let token = AccessToken::new("token".to_string(), 3600, None);
139    /// if token.is_expired() {
140    ///     println!("Token needs to be refreshed!");
141    /// }
142    /// ```
143    pub fn is_expired(&self) -> bool {
144        let buffer = Duration::minutes(5);
145        Utc::now() + buffer >= self.expires_at
146    }
147
148    /// Get the authorization header value.
149    ///
150    /// Returns a string in the format "Bearer {token}" suitable
151    /// for use in HTTP Authorization headers.
152    pub fn authorization_header(&self) -> String {
153        format!("{} {}", self.token_type, self.token)
154    }
155}
156
157/// OAuth2 authentication handler.
158pub struct OAuth2Handler {
159    #[allow(dead_code)]
160    config: OAuth2Config,
161    client: BasicClient,
162}
163
164impl OAuth2Handler {
165    /// Create a new OAuth2 handler.
166    pub fn new(config: OAuth2Config) -> Result<Self> {
167        let client = BasicClient::new(
168            ClientId::new(config.client_id.clone()),
169            Some(ClientSecret::new(config.client_secret.clone())),
170            AuthUrl::new(config.auth_url.clone())
171                .map_err(|e| Error::InvalidConfig(format!("Invalid auth URL: {}", e)))?,
172            Some(
173                TokenUrl::new(config.token_url.clone())
174                    .map_err(|e| Error::InvalidConfig(format!("Invalid token URL: {}", e)))?,
175            ),
176        )
177        .set_redirect_uri(
178            RedirectUrl::new(config.redirect_uri.clone())
179                .map_err(|e| Error::InvalidConfig(format!("Invalid redirect URI: {}", e)))?,
180        );
181
182        Ok(Self { config, client })
183    }
184
185    /// Generate an authorization URL for the OAuth2 flow.
186    ///
187    /// Returns a tuple of (authorization_url, csrf_token, pkce_verifier).
188    /// The user should be redirected to the authorization URL to approve access.
189    pub fn authorize_url(&self) -> (String, String, String) {
190        #[cfg(feature = "tracing")]
191        debug!("Generating OAuth2 authorization URL with PKCE");
192
193        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
194
195        let (auth_url, csrf_token) = self
196            .client
197            .authorize_url(CsrfToken::new_random)
198            .add_scope(Scope::new("ea:api".to_string()))
199            .add_scope(Scope::new("ea:sales".to_string()))
200            .add_scope(Scope::new("offline_access".to_string()))
201            .set_pkce_challenge(pkce_challenge)
202            .url();
203
204        #[cfg(feature = "tracing")]
205        info!("Authorization URL generated successfully");
206
207        (
208            auth_url.to_string(),
209            csrf_token.secret().to_string(),
210            pkce_verifier.secret().to_string(),
211        )
212    }
213
214    /// Exchange an authorization code for an access token.
215    ///
216    /// This should be called after the user approves access and is redirected
217    /// back to your application with an authorization code.
218    ///
219    /// # Arguments
220    ///
221    /// * `code` - The authorization code received from the OAuth2 callback
222    /// * `pkce_verifier` - The PKCE verifier generated during `authorize_url()`
223    ///
224    /// # Security
225    ///
226    /// The PKCE verifier MUST be the same one generated by `authorize_url()`.
227    /// This protects against authorization code interception attacks.
228    pub async fn exchange_code(&self, code: String, pkce_verifier: String) -> Result<AccessToken> {
229        use oauth2::PkceCodeVerifier;
230
231        #[cfg(feature = "tracing")]
232        info!("Exchanging authorization code for access token");
233
234        let token_result = self
235            .client
236            .exchange_code(AuthorizationCode::new(code))
237            .set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier))
238            .request_async(oauth2::reqwest::async_http_client)
239            .await
240            .map_err(|e| {
241                #[cfg(feature = "tracing")]
242                error!(error = %e, "Token exchange failed");
243                Error::OAuth2Error(format!("Token exchange failed: {}", e))
244            })?;
245
246        let expires_in = token_result
247            .expires_in()
248            .map(|d| d.as_secs() as i64)
249            .unwrap_or(3600); // Default to 1 hour
250
251        #[cfg(feature = "tracing")]
252        info!(expires_in_secs = expires_in, "Token exchange successful");
253
254        Ok(AccessToken::new(
255            token_result.access_token().secret().to_string(),
256            expires_in,
257            token_result.refresh_token().map(|t| t.secret().to_string()),
258        ))
259    }
260
261    /// Refresh an access token using a refresh token.
262    ///
263    /// # Arguments
264    ///
265    /// * `refresh_token` - The refresh token obtained during initial authorization
266    ///
267    /// # Example
268    ///
269    /// ```no_run
270    /// # use spiris::auth::{OAuth2Config, OAuth2Handler};
271    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
272    /// # let config = OAuth2Config::new("id".to_string(), "secret".to_string(), "uri".to_string());
273    /// let handler = OAuth2Handler::new(config)?;
274    /// let refresh_token = "existing_refresh_token".to_string();
275    /// let new_token = handler.refresh_token(refresh_token).await?;
276    /// # Ok(())
277    /// # }
278    /// ```
279    pub async fn refresh_token(&self, refresh_token: String) -> Result<AccessToken> {
280        use oauth2::RefreshToken;
281
282        #[cfg(feature = "tracing")]
283        debug!("Refreshing access token using refresh token");
284
285        let token_result = self
286            .client
287            .exchange_refresh_token(&RefreshToken::new(refresh_token))
288            .request_async(oauth2::reqwest::async_http_client)
289            .await
290            .map_err(|e| {
291                #[cfg(feature = "tracing")]
292                error!(error = %e, "Token refresh failed");
293                Error::OAuth2Error(format!("Token refresh failed: {}", e))
294            })?;
295
296        let expires_in = token_result
297            .expires_in()
298            .map(|d| d.as_secs() as i64)
299            .unwrap_or(3600);
300
301        #[cfg(feature = "tracing")]
302        info!(expires_in_secs = expires_in, "Token refresh successful");
303
304        Ok(AccessToken::new(
305            token_result.access_token().secret().to_string(),
306            expires_in,
307            token_result.refresh_token().map(|t| t.secret().to_string()),
308        ))
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_access_token_expiration() {
318        let token = AccessToken::new("test_token".to_string(), 3600, None);
319        assert!(!token.is_expired());
320
321        let expired_token = AccessToken::new("test_token".to_string(), 0, None);
322        // Wait a bit to ensure expiration
323        std::thread::sleep(std::time::Duration::from_millis(100));
324        assert!(expired_token.is_expired());
325    }
326
327    #[test]
328    fn test_authorization_header() {
329        let token = AccessToken::new("test_token_123".to_string(), 3600, None);
330        assert_eq!(token.authorization_header(), "Bearer test_token_123");
331    }
332}