salesforce_client/
auth.rs1use crate::error::SfError;
6use chrono::{DateTime, Duration, Utc};
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, info, warn};
11
12#[derive(Debug, Clone)]
14pub struct OAuthCredentials {
15 pub client_id: String,
17
18 pub client_secret: String,
20
21 pub refresh_token: Option<String>,
23
24 pub username: Option<String>,
26
27 pub password: Option<String>,
29}
30
31#[derive(Debug, Deserialize, Serialize)]
33struct TokenResponse {
34 access_token: String,
35 refresh_token: Option<String>,
36 instance_url: String,
37
38 #[serde(default)]
39 expires_in: Option<i64>,
40
41 token_type: String,
42
43 #[serde(default)]
44 issued_at: Option<String>,
45}
46
47#[derive(Debug, Clone)]
49pub struct AccessToken {
50 token: String,
51 expires_at: Option<DateTime<Utc>>,
52 instance_url: String,
53}
54
55impl AccessToken {
56 pub fn new(token: String, instance_url: String, expires_in: Option<i64>) -> Self {
58 let expires_at = expires_in.map(|secs| Utc::now() + Duration::seconds(secs));
59
60 Self {
61 token,
62 expires_at,
63 instance_url,
64 }
65 }
66
67 pub fn is_expired(&self) -> bool {
69 if let Some(expires_at) = self.expires_at {
70 let buffer = Duration::minutes(5);
71 Utc::now() + buffer >= expires_at
72 } else {
73 false }
75 }
76
77 pub fn token(&self) -> &str {
79 &self.token
80 }
81
82 pub fn instance_url(&self) -> &str {
84 &self.instance_url
85 }
86}
87
88pub struct TokenManager {
90 credentials: OAuthCredentials,
91 current_token: Arc<RwLock<Option<AccessToken>>>,
92 http_client: reqwest::Client,
93 auth_url: String,
94}
95
96impl TokenManager {
97 pub fn new(credentials: OAuthCredentials) -> Self {
99 Self {
100 credentials,
101 current_token: Arc::new(RwLock::new(None)),
102 http_client: reqwest::Client::new(),
103 auth_url: "https://login.salesforce.com".to_string(),
104 }
105 }
106
107 pub fn sandbox(credentials: OAuthCredentials) -> Self {
109 let mut manager = Self::new(credentials);
110 manager.auth_url = "https://test.salesforce.com".to_string();
111 manager
112 }
113
114 pub async fn get_token(&self) -> Result<AccessToken, SfError> {
121 {
123 let token_guard = self.current_token.read().await;
124 if let Some(token) = token_guard.as_ref() {
125 if !token.is_expired() {
126 debug!("Using cached access token");
127 return Ok(token.clone());
128 }
129 }
130 }
131
132 info!("Access token expired or missing, refreshing...");
134 let mut token_guard = self.current_token.write().await;
135
136 if let Some(token) = token_guard.as_ref() {
138 if !token.is_expired() {
139 return Ok(token.clone());
140 }
141 }
142
143 let new_token = self.fetch_new_token().await?;
145 *token_guard = Some(new_token.clone());
146
147 info!("Successfully refreshed access token");
148 Ok(new_token)
149 }
150
151 async fn fetch_new_token(&self) -> Result<AccessToken, SfError> {
153 if let Some(refresh_token) = &self.credentials.refresh_token {
155 match self.refresh_token_flow(refresh_token).await {
156 Ok(token) => return Ok(token),
157 Err(e) => {
158 warn!(
159 "Refresh token flow failed: {}, falling back to password flow",
160 e
161 );
162 }
163 }
164 }
165
166 if self.credentials.username.is_some() && self.credentials.password.is_some() {
168 return self.password_flow().await;
169 }
170
171 Err(SfError::Auth(
172 "No valid authentication method available".to_string(),
173 ))
174 }
175
176 async fn refresh_token_flow(&self, refresh_token: &str) -> Result<AccessToken, SfError> {
178 let url = format!("{}/services/oauth2/token", self.auth_url);
179
180 let params = [
181 ("grant_type", "refresh_token"),
182 ("client_id", &self.credentials.client_id),
183 ("client_secret", &self.credentials.client_secret),
184 ("refresh_token", refresh_token),
185 ];
186
187 let response = self.http_client.post(&url).form(¶ms).send().await?;
188
189 if !response.status().is_success() {
190 let body = response.text().await?;
191 return Err(SfError::Auth(format!("Token refresh failed: {}", body)));
192 }
193
194 let token_response: TokenResponse = response.json().await?;
195
196 Ok(AccessToken::new(
197 token_response.access_token,
198 token_response.instance_url,
199 token_response.expires_in,
200 ))
201 }
202
203 async fn password_flow(&self) -> Result<AccessToken, SfError> {
205 let username = self
206 .credentials
207 .username
208 .as_ref()
209 .ok_or_else(|| SfError::Auth("Username not provided".to_string()))?;
210 let password = self
211 .credentials
212 .password
213 .as_ref()
214 .ok_or_else(|| SfError::Auth("Password not provided".to_string()))?;
215
216 let url = format!("{}/services/oauth2/token", self.auth_url);
217
218 let params = [
219 ("grant_type", "password"),
220 ("client_id", &self.credentials.client_id),
221 ("client_secret", &self.credentials.client_secret),
222 ("username", username),
223 ("password", password),
224 ];
225
226 let response = self.http_client.post(&url).form(¶ms).send().await?;
227
228 if !response.status().is_success() {
229 let body = response.text().await?;
230 return Err(SfError::Auth(format!("Authentication failed: {}", body)));
231 }
232
233 let token_response: TokenResponse = response.json().await?;
234
235 Ok(AccessToken::new(
236 token_response.access_token,
237 token_response.instance_url,
238 token_response.expires_in,
239 ))
240 }
241
242 pub async fn invalidate(&self) {
244 let mut token_guard = self.current_token.write().await;
245 *token_guard = None;
246 info!("Access token invalidated");
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_access_token_expiry() {
256 let token = AccessToken::new(
257 "test_token".to_string(),
258 "https://test.salesforce.com".to_string(),
259 Some(3600), );
261
262 assert!(!token.is_expired());
263 }
264
265 #[test]
266 fn test_access_token_no_expiry() {
267 let token = AccessToken::new(
268 "test_token".to_string(),
269 "https://test.salesforce.com".to_string(),
270 None,
271 );
272
273 assert!(!token.is_expired());
274 }
275}