salesforce_client/
auth.rs

1//! OAuth 2.0 authentication module
2//!
3//! Handles OAuth flows, token refresh, and credential management.
4
5use crate::error::SfError;
6use chrono::{DateTime, Duration, Utc};
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, info, warn};
11
12/// OAuth 2.0 credentials for Salesforce
13#[derive(Debug, Clone)]
14pub struct OAuthCredentials {
15    /// OAuth client ID (Consumer Key)
16    pub client_id: String,
17
18    /// OAuth client secret (Consumer Secret)
19    pub client_secret: String,
20
21    /// Refresh token for obtaining new access tokens
22    pub refresh_token: Option<String>,
23
24    /// Username for password flow
25    pub username: Option<String>,
26
27    /// Password + security token for password flow
28    pub password: Option<String>,
29}
30
31/// Response from OAuth token endpoint
32#[derive(Debug, Deserialize, Serialize)]
33struct TokenResponse {
34    access_token: String,
35    refresh_token: Option<String>,
36    instance_url: String,
37
38    #[serde(default)]
39    expires_in: Option<i64>,
40
41    token_type: String,
42
43    #[serde(default)]
44    issued_at: Option<String>,
45}
46
47/// Managed access token with automatic refresh
48#[derive(Debug, Clone)]
49pub struct AccessToken {
50    token: String,
51    expires_at: Option<DateTime<Utc>>,
52    instance_url: String,
53}
54
55impl AccessToken {
56    /// Create a new access token
57    pub fn new(token: String, instance_url: String, expires_in: Option<i64>) -> Self {
58        let expires_at = expires_in.map(|secs| Utc::now() + Duration::seconds(secs));
59
60        Self {
61            token,
62            expires_at,
63            instance_url,
64        }
65    }
66
67    /// Check if token is expired or about to expire (within 5 minutes)
68    pub fn is_expired(&self) -> bool {
69        if let Some(expires_at) = self.expires_at {
70            let buffer = Duration::minutes(5);
71            Utc::now() + buffer >= expires_at
72        } else {
73            false // If no expiry, assume valid
74        }
75    }
76
77    /// Get the token value
78    pub fn token(&self) -> &str {
79        &self.token
80    }
81
82    /// Get the instance URL
83    pub fn instance_url(&self) -> &str {
84        &self.instance_url
85    }
86}
87
88/// Token manager that handles automatic refresh
89pub struct TokenManager {
90    credentials: OAuthCredentials,
91    current_token: Arc<RwLock<Option<AccessToken>>>,
92    http_client: reqwest::Client,
93    auth_url: String,
94}
95
96impl TokenManager {
97    /// Create a new token manager
98    pub fn new(credentials: OAuthCredentials) -> Self {
99        Self {
100            credentials,
101            current_token: Arc::new(RwLock::new(None)),
102            http_client: reqwest::Client::new(),
103            auth_url: "https://login.salesforce.com".to_string(),
104        }
105    }
106
107    /// Create a token manager for sandbox environment
108    pub fn sandbox(credentials: OAuthCredentials) -> Self {
109        let mut manager = Self::new(credentials);
110        manager.auth_url = "https://test.salesforce.com".to_string();
111        manager
112    }
113
114    /// Get a valid access token, refreshing if necessary
115    ///
116    /// This method ensures you always have a valid token by:
117    /// 1. Checking if current token exists and is valid
118    /// 2. If expired, automatically refreshing
119    /// 3. Thread-safe access via RwLock
120    pub async fn get_token(&self) -> Result<AccessToken, SfError> {
121        // Fast path: check if token is valid (read lock)
122        {
123            let token_guard = self.current_token.read().await;
124            if let Some(token) = token_guard.as_ref() {
125                if !token.is_expired() {
126                    debug!("Using cached access token");
127                    return Ok(token.clone());
128                }
129            }
130        }
131
132        // Slow path: token expired or doesn't exist (write lock)
133        info!("Access token expired or missing, refreshing...");
134        let mut token_guard = self.current_token.write().await;
135
136        // Double-check after acquiring write lock (another thread may have refreshed)
137        if let Some(token) = token_guard.as_ref() {
138            if !token.is_expired() {
139                return Ok(token.clone());
140            }
141        }
142
143        // Actually refresh the token
144        let new_token = self.fetch_new_token().await?;
145        *token_guard = Some(new_token.clone());
146
147        info!("Successfully refreshed access token");
148        Ok(new_token)
149    }
150
151    /// Fetch a new token from Salesforce
152    async fn fetch_new_token(&self) -> Result<AccessToken, SfError> {
153        // Try refresh token flow first
154        if let Some(refresh_token) = &self.credentials.refresh_token {
155            match self.refresh_token_flow(refresh_token).await {
156                Ok(token) => return Ok(token),
157                Err(e) => {
158                    warn!(
159                        "Refresh token flow failed: {}, falling back to password flow",
160                        e
161                    );
162                }
163            }
164        }
165
166        // Fall back to password flow
167        if self.credentials.username.is_some() && self.credentials.password.is_some() {
168            return self.password_flow().await;
169        }
170
171        Err(SfError::Auth(
172            "No valid authentication method available".to_string(),
173        ))
174    }
175
176    /// OAuth 2.0 Refresh Token Flow
177    async fn refresh_token_flow(&self, refresh_token: &str) -> Result<AccessToken, SfError> {
178        let url = format!("{}/services/oauth2/token", self.auth_url);
179
180        let params = [
181            ("grant_type", "refresh_token"),
182            ("client_id", &self.credentials.client_id),
183            ("client_secret", &self.credentials.client_secret),
184            ("refresh_token", refresh_token),
185        ];
186
187        let response = self.http_client.post(&url).form(&params).send().await?;
188
189        if !response.status().is_success() {
190            let body = response.text().await?;
191            return Err(SfError::Auth(format!("Token refresh failed: {}", body)));
192        }
193
194        let token_response: TokenResponse = response.json().await?;
195
196        Ok(AccessToken::new(
197            token_response.access_token,
198            token_response.instance_url,
199            token_response.expires_in,
200        ))
201    }
202
203    /// OAuth 2.0 Password Flow (less secure, use for development only)
204    async fn password_flow(&self) -> Result<AccessToken, SfError> {
205        let username = self
206            .credentials
207            .username
208            .as_ref()
209            .ok_or_else(|| SfError::Auth("Username not provided".to_string()))?;
210        let password = self
211            .credentials
212            .password
213            .as_ref()
214            .ok_or_else(|| SfError::Auth("Password not provided".to_string()))?;
215
216        let url = format!("{}/services/oauth2/token", self.auth_url);
217
218        let params = [
219            ("grant_type", "password"),
220            ("client_id", &self.credentials.client_id),
221            ("client_secret", &self.credentials.client_secret),
222            ("username", username),
223            ("password", password),
224        ];
225
226        let response = self.http_client.post(&url).form(&params).send().await?;
227
228        if !response.status().is_success() {
229            let body = response.text().await?;
230            return Err(SfError::Auth(format!("Authentication failed: {}", body)));
231        }
232
233        let token_response: TokenResponse = response.json().await?;
234
235        Ok(AccessToken::new(
236            token_response.access_token,
237            token_response.instance_url,
238            token_response.expires_in,
239        ))
240    }
241
242    /// Invalidate the current token (force refresh on next request)
243    pub async fn invalidate(&self) {
244        let mut token_guard = self.current_token.write().await;
245        *token_guard = None;
246        info!("Access token invalidated");
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_access_token_expiry() {
256        let token = AccessToken::new(
257            "test_token".to_string(),
258            "https://test.salesforce.com".to_string(),
259            Some(3600), // 1 hour
260        );
261
262        assert!(!token.is_expired());
263    }
264
265    #[test]
266    fn test_access_token_no_expiry() {
267        let token = AccessToken::new(
268            "test_token".to_string(),
269            "https://test.salesforce.com".to_string(),
270            None,
271        );
272
273        assert!(!token.is_expired());
274    }
275}