Skip to main content

schwab_api/
auth.rs

1use chrono::{DateTime, Utc};
2use serde::Deserialize;
3
4/// OAuth token bundle persisted on disk.
5#[derive(Debug, Clone, serde::Serialize, Deserialize)]
6pub struct Tokens {
7    pub access_token: String,
8    pub refresh_token: String,
9    pub token_type: String,
10    pub expires_at: DateTime<Utc>,
11    pub scope: Option<String>,
12}
13
14impl Tokens {
15    pub fn is_expired(&self) -> bool {
16        Utc::now() >= self.expires_at
17    }
18
19    pub fn expires_in_seconds(&self) -> i64 {
20        (self.expires_at - Utc::now()).num_seconds().max(0)
21    }
22}
23
24#[derive(Debug, Deserialize)]
25struct TokenResponse {
26    access_token: String,
27    refresh_token: String,
28    token_type: String,
29    expires_in: i64,
30    scope: Option<String>,
31}
32
33#[derive(Debug, Deserialize)]
34struct OAuthErrorResponse {
35    error: Option<String>,
36    error_description: Option<String>,
37    message: Option<String>,
38}
39
40use std::path::PathBuf;
41
42use reqwest::Client;
43use tokio::fs;
44use tracing::{debug, info};
45
46use crate::config::ClientConfig;
47use crate::error::{ApiError, Result};
48
49/// File-backed OAuth token storage.
50#[derive(Debug, Clone)]
51pub struct TokenStore {
52    path: PathBuf,
53}
54
55impl TokenStore {
56    pub fn new(token_dir: PathBuf) -> Self {
57        Self {
58            path: token_dir.join("tokens.json"),
59        }
60    }
61
62    pub fn path(&self) -> &PathBuf {
63        &self.path
64    }
65
66    pub async fn load(&self) -> Result<Option<Tokens>> {
67        match fs::read_to_string(&self.path).await {
68            Ok(raw) => Ok(Some(serde_json::from_str(&raw)?)),
69            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
70            Err(err) => Err(ApiError::TokenStore(err.to_string())),
71        }
72    }
73
74    pub async fn save(&self, tokens: &Tokens) -> Result<()> {
75        if let Some(parent) = self.path.parent() {
76            fs::create_dir_all(parent)
77                .await
78                .map_err(|e| ApiError::TokenStore(e.to_string()))?;
79        }
80        let raw = serde_json::to_string_pretty(tokens)?;
81        fs::write(&self.path, raw)
82            .await
83            .map_err(|e| ApiError::TokenStore(e.to_string()))?;
84        Ok(())
85    }
86
87    pub async fn clear(&self) -> Result<()> {
88        match fs::remove_file(&self.path).await {
89            Ok(()) => Ok(()),
90            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
91            Err(err) => Err(ApiError::TokenStore(err.to_string())),
92        }
93    }
94}
95
96/// Schwab OAuth 2.0 authorization-code client.
97#[derive(Debug, Clone)]
98pub struct OAuthClient {
99    http: Client,
100    config: ClientConfig,
101    store: TokenStore,
102}
103
104impl OAuthClient {
105    pub fn new(config: ClientConfig) -> Self {
106        let store = TokenStore::new(config.token_dir.clone());
107        let http = Client::builder()
108            .gzip(true)
109            .build()
110            .expect("reqwest client");
111        Self {
112            http,
113            config,
114            store,
115        }
116    }
117
118    pub fn store(&self) -> &TokenStore {
119        &self.store
120    }
121
122    pub fn authorize_url(&self) -> String {
123        let mut url = url::Url::parse(&self.config.oauth_authorize_url)
124            .expect("valid oauth authorize url");
125        {
126            let mut pairs = url.query_pairs_mut();
127            pairs.append_pair("client_id", &self.config.app_key);
128            pairs.append_pair("redirect_uri", &self.config.redirect_uri);
129            pairs.append_pair("response_type", "code");
130        }
131        url.to_string()
132    }
133
134    pub async fn exchange_code(&self, code: &str) -> Result<Tokens> {
135        let tokens = self
136            .token_request(&[
137                ("grant_type", "authorization_code"),
138                ("code", code),
139                ("redirect_uri", &self.config.redirect_uri),
140            ])
141            .await?;
142        self.store.save(&tokens).await?;
143        info!("OAuth tokens saved");
144        Ok(tokens)
145    }
146
147    pub async fn refresh(&self) -> Result<Tokens> {
148        let existing = self
149            .store
150            .load()
151            .await?
152            .ok_or_else(|| ApiError::NotAuthenticated("No refresh token on disk".into()))?;
153
154        let tokens = self
155            .token_request(&[
156                ("grant_type", "refresh_token"),
157                ("refresh_token", &existing.refresh_token),
158            ])
159            .await?;
160        self.store.save(&tokens).await?;
161        info!("OAuth tokens refreshed");
162        Ok(tokens)
163    }
164
165    pub async fn ensure_access_token(&self) -> Result<String> {
166        let tokens = match self.store.load().await? {
167            Some(tokens) if !tokens.is_expired() => tokens,
168            Some(_) => self.refresh().await?,
169            None => {
170                return Err(ApiError::NotAuthenticated(
171                    "Run `schwab auth login` to authenticate".into(),
172                ))
173            }
174        };
175        Ok(tokens.access_token)
176    }
177
178    pub async fn status(&self) -> Result<Option<Tokens>> {
179        self.store.load().await
180    }
181
182    pub async fn logout(&self) -> Result<()> {
183        self.store.clear().await
184    }
185
186    async fn token_request(&self, params: &[(&str, &str)]) -> Result<Tokens> {
187        debug!("Requesting OAuth token");
188        let response = self
189            .http
190            .post(&self.config.oauth_token_url)
191            .basic_auth(&self.config.app_key, Some(&self.config.app_secret))
192            .header("Content-Type", "application/x-www-form-urlencoded")
193            .header("Accept", "application/json")
194            .form(params)
195            .send()
196            .await?;
197
198        let status = response.status();
199        let body = response.text().await?;
200        if !status.is_success() {
201            return Err(ApiError::OAuth(format_oauth_error(status.as_u16(), &body)));
202        }
203
204        let parsed: TokenResponse = serde_json::from_str(&body).map_err(|e| {
205            ApiError::OAuth(format!("Token response parse error: {e}; body={body}"))
206        })?;
207        Ok(Tokens {
208            access_token: parsed.access_token,
209            refresh_token: parsed.refresh_token,
210            token_type: parsed.token_type,
211            expires_at: Utc::now() + chrono::Duration::seconds(parsed.expires_in),
212            scope: parsed.scope,
213        })
214    }
215}
216
217fn format_oauth_error(status: u16, body: &str) -> String {
218    if let Ok(parsed) = serde_json::from_str::<OAuthErrorResponse>(body) {
219        let msg = parsed
220            .error_description
221            .or(parsed.message)
222            .or(parsed.error)
223            .unwrap_or_else(|| body.to_string());
224        return format!("HTTP {status}: {msg}");
225    }
226    if body.chars().all(|c| c.is_ascii() || c.is_whitespace()) {
227        format!("HTTP {status}: {body}")
228    } else {
229        format!(
230            "HTTP {status}: non-text error body ({} bytes). \
231             Common causes: expired authorization code (retry login immediately), \
232             redirect URI mismatch, or invalid app secret.",
233            body.len()
234        )
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn formats_json_oauth_error() {
244        let body = r#"{"error":"invalid_grant","error_description":"code expired"}"#;
245        let msg = format_oauth_error(400, body);
246        assert!(msg.contains("code expired"));
247    }
248}