Skip to main content

schwab_api/
auth.rs

1use chrono::{DateTime, Utc};
2use serde::Deserialize;
3
4/// Schwab refresh tokens are valid for approximately seven days.
5pub const REFRESH_TOKEN_LIFETIME_SECS: i64 = 7 * 24 * 3600;
6
7/// OAuth token bundle persisted on disk.
8#[derive(Debug, Clone, serde::Serialize, Deserialize)]
9pub struct Tokens {
10    pub access_token: String,
11    pub refresh_token: String,
12    pub token_type: String,
13    pub expires_at: DateTime<Utc>,
14    pub scope: Option<String>,
15    /// When this token bundle was issued (login or last refresh).
16    #[serde(default = "default_obtained_at")]
17    pub obtained_at: DateTime<Utc>,
18}
19
20fn default_obtained_at() -> DateTime<Utc> {
21    Utc::now()
22}
23
24impl Tokens {
25    pub fn is_expired(&self) -> bool {
26        Utc::now() >= self.expires_at
27    }
28
29    pub fn expires_in_seconds(&self) -> i64 {
30        (self.expires_at - Utc::now()).num_seconds().max(0)
31    }
32
33    pub fn refresh_age_seconds(&self) -> i64 {
34        (Utc::now() - self.obtained_at).num_seconds().max(0)
35    }
36
37    pub fn refresh_expires_in_seconds(&self) -> i64 {
38        (REFRESH_TOKEN_LIFETIME_SECS - self.refresh_age_seconds()).max(0)
39    }
40}
41
42#[derive(Debug, Deserialize)]
43struct TokenResponse {
44    access_token: String,
45    refresh_token: String,
46    token_type: String,
47    expires_in: i64,
48    scope: Option<String>,
49}
50
51#[derive(Debug, Deserialize)]
52struct OAuthErrorResponse {
53    error: Option<String>,
54    error_description: Option<String>,
55    message: Option<String>,
56}
57
58use std::path::PathBuf;
59
60use reqwest::Client;
61use tokio::fs;
62use tracing::{debug, info};
63
64use crate::config::ClientConfig;
65use crate::error::{ApiError, Result};
66
67/// File-backed OAuth token storage.
68#[derive(Debug, Clone)]
69pub struct TokenStore {
70    path: PathBuf,
71}
72
73impl TokenStore {
74    pub fn new(token_dir: PathBuf) -> Self {
75        Self {
76            path: token_dir.join("tokens.json"),
77        }
78    }
79
80    pub fn path(&self) -> &PathBuf {
81        &self.path
82    }
83
84    pub async fn load(&self) -> Result<Option<Tokens>> {
85        match fs::read_to_string(&self.path).await {
86            Ok(raw) => Ok(Some(serde_json::from_str(&raw)?)),
87            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
88            Err(err) => Err(ApiError::TokenStore(err.to_string())),
89        }
90    }
91
92    pub async fn save(&self, tokens: &Tokens) -> Result<()> {
93        if let Some(parent) = self.path.parent() {
94            fs::create_dir_all(parent)
95                .await
96                .map_err(|e| ApiError::TokenStore(e.to_string()))?;
97        }
98        let raw = serde_json::to_string_pretty(tokens)?;
99        fs::write(&self.path, raw)
100            .await
101            .map_err(|e| ApiError::TokenStore(e.to_string()))?;
102        Ok(())
103    }
104
105    pub async fn clear(&self) -> Result<()> {
106        match fs::remove_file(&self.path).await {
107            Ok(()) => Ok(()),
108            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
109            Err(err) => Err(ApiError::TokenStore(err.to_string())),
110        }
111    }
112}
113
114/// Schwab OAuth 2.0 authorization-code client.
115#[derive(Debug, Clone)]
116pub struct OAuthClient {
117    http: Client,
118    config: ClientConfig,
119    store: TokenStore,
120}
121
122impl OAuthClient {
123    pub fn new(config: ClientConfig) -> Self {
124        let store = TokenStore::new(config.token_dir.clone());
125        let http = Client::builder()
126            .gzip(true)
127            .build()
128            .expect("reqwest client");
129        Self {
130            http,
131            config,
132            store,
133        }
134    }
135
136    pub fn store(&self) -> &TokenStore {
137        &self.store
138    }
139
140    pub fn authorize_url(&self) -> String {
141        let mut url =
142            url::Url::parse(&self.config.oauth_authorize_url).expect("valid oauth authorize url");
143        {
144            let mut pairs = url.query_pairs_mut();
145            pairs.append_pair("client_id", &self.config.app_key);
146            pairs.append_pair("redirect_uri", &self.config.redirect_uri);
147            pairs.append_pair("response_type", "code");
148        }
149        url.to_string()
150    }
151
152    pub async fn exchange_code(&self, code: &str) -> Result<Tokens> {
153        let tokens = self
154            .token_request(&[
155                ("grant_type", "authorization_code"),
156                ("code", code),
157                ("redirect_uri", &self.config.redirect_uri),
158            ])
159            .await?;
160        self.store.save(&tokens).await?;
161        info!("OAuth tokens saved");
162        Ok(tokens)
163    }
164
165    pub async fn refresh(&self) -> Result<Tokens> {
166        let existing = self
167            .store
168            .load()
169            .await?
170            .ok_or_else(|| ApiError::NotAuthenticated("No refresh token on disk".into()))?;
171
172        let tokens = self
173            .token_request(&[
174                ("grant_type", "refresh_token"),
175                ("refresh_token", &existing.refresh_token),
176            ])
177            .await?;
178        self.store.save(&tokens).await?;
179        info!("OAuth tokens refreshed");
180        Ok(tokens)
181    }
182
183    pub async fn ensure_access_token(&self) -> Result<String> {
184        let tokens = match self.store.load().await? {
185            Some(tokens) if !tokens.is_expired() => tokens,
186            Some(_) => self.refresh().await?,
187            None => {
188                return Err(ApiError::NotAuthenticated(
189                    "Run `schwab auth login` to authenticate".into(),
190                ))
191            }
192        };
193        Ok(tokens.access_token)
194    }
195
196    pub async fn status(&self) -> Result<Option<Tokens>> {
197        self.store.load().await
198    }
199
200    pub async fn logout(&self) -> Result<()> {
201        self.store.clear().await
202    }
203
204    async fn token_request(&self, params: &[(&str, &str)]) -> Result<Tokens> {
205        debug!("Requesting OAuth token");
206        let response = self
207            .http
208            .post(&self.config.oauth_token_url)
209            .basic_auth(&self.config.app_key, Some(&self.config.app_secret))
210            .header("Content-Type", "application/x-www-form-urlencoded")
211            .header("Accept", "application/json")
212            .form(params)
213            .send()
214            .await?;
215
216        let status = response.status();
217        let body = response.text().await?;
218        if !status.is_success() {
219            return Err(ApiError::OAuth(format_oauth_error(status.as_u16(), &body)));
220        }
221
222        let parsed: TokenResponse = serde_json::from_str(&body).map_err(|e| {
223            ApiError::OAuth(format!("Token response parse error: {e}; body={body}"))
224        })?;
225        Ok(Tokens {
226            access_token: parsed.access_token,
227            refresh_token: parsed.refresh_token,
228            token_type: parsed.token_type,
229            expires_at: Utc::now() + chrono::Duration::seconds(parsed.expires_in),
230            scope: parsed.scope,
231            obtained_at: Utc::now(),
232        })
233    }
234}
235
236fn format_oauth_error(status: u16, body: &str) -> String {
237    if let Ok(parsed) = serde_json::from_str::<OAuthErrorResponse>(body) {
238        let msg = parsed
239            .error_description
240            .or(parsed.message)
241            .or(parsed.error)
242            .unwrap_or_else(|| body.to_string());
243        return format!("HTTP {status}: {msg}");
244    }
245    if body.chars().all(|c| c.is_ascii() || c.is_whitespace()) {
246        format!("HTTP {status}: {body}")
247    } else {
248        format!(
249            "HTTP {status}: non-text error body ({} bytes). \
250             Common causes: expired authorization code (retry login immediately), \
251             redirect URI mismatch, or invalid app secret.",
252            body.len()
253        )
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn formats_json_oauth_error() {
263        let body = r#"{"error":"invalid_grant","error_description":"code expired"}"#;
264        let msg = format_oauth_error(400, body);
265        assert!(msg.contains("code expired"));
266    }
267
268    #[test]
269    fn refresh_expiry_counts_down_from_obtained_at() {
270        let tokens = Tokens {
271            access_token: "a".into(),
272            refresh_token: "r".into(),
273            token_type: "Bearer".into(),
274            expires_at: Utc::now() + chrono::Duration::hours(1),
275            scope: None,
276            obtained_at: Utc::now() - chrono::Duration::days(6),
277        };
278        assert!(tokens.refresh_expires_in_seconds() < 2 * 86400);
279    }
280}