1use chrono::{DateTime, Utc};
2use serde::Deserialize;
3
4pub const REFRESH_TOKEN_LIFETIME_SECS: i64 = 7 * 24 * 3600;
6
7#[derive(Debug, Clone, serde::Serialize, Deserialize)]
9pub struct Tokens {
10 pub access_token: String,
11 pub refresh_token: String,
12 pub token_type: String,
13 pub expires_at: DateTime<Utc>,
14 pub scope: Option<String>,
15 #[serde(default = "default_obtained_at")]
17 pub obtained_at: DateTime<Utc>,
18}
19
20fn default_obtained_at() -> DateTime<Utc> {
21 Utc::now()
22}
23
24impl Tokens {
25 pub fn is_expired(&self) -> bool {
26 Utc::now() >= self.expires_at
27 }
28
29 pub fn expires_in_seconds(&self) -> i64 {
30 (self.expires_at - Utc::now()).num_seconds().max(0)
31 }
32
33 pub fn refresh_age_seconds(&self) -> i64 {
34 (Utc::now() - self.obtained_at).num_seconds().max(0)
35 }
36
37 pub fn refresh_expires_in_seconds(&self) -> i64 {
38 (REFRESH_TOKEN_LIFETIME_SECS - self.refresh_age_seconds()).max(0)
39 }
40}
41
42#[derive(Debug, Deserialize)]
43struct TokenResponse {
44 access_token: String,
45 refresh_token: String,
46 token_type: String,
47 expires_in: i64,
48 scope: Option<String>,
49}
50
51#[derive(Debug, Deserialize)]
52struct OAuthErrorResponse {
53 error: Option<String>,
54 error_description: Option<String>,
55 message: Option<String>,
56}
57
58use std::path::PathBuf;
59
60use reqwest::Client;
61use tokio::fs;
62use tracing::{debug, info};
63
64use crate::config::ClientConfig;
65use crate::error::{ApiError, Result};
66
67#[derive(Debug, Clone)]
69pub struct TokenStore {
70 path: PathBuf,
71}
72
73impl TokenStore {
74 pub fn new(token_dir: PathBuf) -> Self {
75 Self {
76 path: token_dir.join("tokens.json"),
77 }
78 }
79
80 pub fn path(&self) -> &PathBuf {
81 &self.path
82 }
83
84 pub async fn load(&self) -> Result<Option<Tokens>> {
85 match fs::read_to_string(&self.path).await {
86 Ok(raw) => Ok(Some(serde_json::from_str(&raw)?)),
87 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
88 Err(err) => Err(ApiError::TokenStore(err.to_string())),
89 }
90 }
91
92 pub async fn save(&self, tokens: &Tokens) -> Result<()> {
93 if let Some(parent) = self.path.parent() {
94 fs::create_dir_all(parent)
95 .await
96 .map_err(|e| ApiError::TokenStore(e.to_string()))?;
97 }
98 let raw = serde_json::to_string_pretty(tokens)?;
99 fs::write(&self.path, raw)
100 .await
101 .map_err(|e| ApiError::TokenStore(e.to_string()))?;
102 Ok(())
103 }
104
105 pub async fn clear(&self) -> Result<()> {
106 match fs::remove_file(&self.path).await {
107 Ok(()) => Ok(()),
108 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
109 Err(err) => Err(ApiError::TokenStore(err.to_string())),
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct OAuthClient {
117 http: Client,
118 config: ClientConfig,
119 store: TokenStore,
120}
121
122impl OAuthClient {
123 pub fn new(config: ClientConfig) -> Self {
124 let store = TokenStore::new(config.token_dir.clone());
125 let http = Client::builder()
126 .gzip(true)
127 .build()
128 .expect("reqwest client");
129 Self {
130 http,
131 config,
132 store,
133 }
134 }
135
136 pub fn store(&self) -> &TokenStore {
137 &self.store
138 }
139
140 pub fn authorize_url(&self) -> String {
141 let mut url =
142 url::Url::parse(&self.config.oauth_authorize_url).expect("valid oauth authorize url");
143 {
144 let mut pairs = url.query_pairs_mut();
145 pairs.append_pair("client_id", &self.config.app_key);
146 pairs.append_pair("redirect_uri", &self.config.redirect_uri);
147 pairs.append_pair("response_type", "code");
148 }
149 url.to_string()
150 }
151
152 pub async fn exchange_code(&self, code: &str) -> Result<Tokens> {
153 let tokens = self
154 .token_request(&[
155 ("grant_type", "authorization_code"),
156 ("code", code),
157 ("redirect_uri", &self.config.redirect_uri),
158 ])
159 .await?;
160 self.store.save(&tokens).await?;
161 info!("OAuth tokens saved");
162 Ok(tokens)
163 }
164
165 pub async fn refresh(&self) -> Result<Tokens> {
166 let existing = self
167 .store
168 .load()
169 .await?
170 .ok_or_else(|| ApiError::NotAuthenticated("No refresh token on disk".into()))?;
171
172 let tokens = self
173 .token_request(&[
174 ("grant_type", "refresh_token"),
175 ("refresh_token", &existing.refresh_token),
176 ])
177 .await?;
178 self.store.save(&tokens).await?;
179 info!("OAuth tokens refreshed");
180 Ok(tokens)
181 }
182
183 pub async fn ensure_access_token(&self) -> Result<String> {
184 let tokens = match self.store.load().await? {
185 Some(tokens) if !tokens.is_expired() => tokens,
186 Some(_) => self.refresh().await?,
187 None => {
188 return Err(ApiError::NotAuthenticated(
189 "Run `schwab auth login` to authenticate".into(),
190 ))
191 }
192 };
193 Ok(tokens.access_token)
194 }
195
196 pub async fn status(&self) -> Result<Option<Tokens>> {
197 self.store.load().await
198 }
199
200 pub async fn logout(&self) -> Result<()> {
201 self.store.clear().await
202 }
203
204 async fn token_request(&self, params: &[(&str, &str)]) -> Result<Tokens> {
205 debug!("Requesting OAuth token");
206 let response = self
207 .http
208 .post(&self.config.oauth_token_url)
209 .basic_auth(&self.config.app_key, Some(&self.config.app_secret))
210 .header("Content-Type", "application/x-www-form-urlencoded")
211 .header("Accept", "application/json")
212 .form(params)
213 .send()
214 .await?;
215
216 let status = response.status();
217 let body = response.text().await?;
218 if !status.is_success() {
219 return Err(ApiError::OAuth(format_oauth_error(status.as_u16(), &body)));
220 }
221
222 let parsed: TokenResponse = serde_json::from_str(&body).map_err(|e| {
223 ApiError::OAuth(format!("Token response parse error: {e}; body={body}"))
224 })?;
225 Ok(Tokens {
226 access_token: parsed.access_token,
227 refresh_token: parsed.refresh_token,
228 token_type: parsed.token_type,
229 expires_at: Utc::now() + chrono::Duration::seconds(parsed.expires_in),
230 scope: parsed.scope,
231 obtained_at: Utc::now(),
232 })
233 }
234}
235
236fn format_oauth_error(status: u16, body: &str) -> String {
237 if let Ok(parsed) = serde_json::from_str::<OAuthErrorResponse>(body) {
238 let msg = parsed
239 .error_description
240 .or(parsed.message)
241 .or(parsed.error)
242 .unwrap_or_else(|| body.to_string());
243 return format!("HTTP {status}: {msg}");
244 }
245 if body.chars().all(|c| c.is_ascii() || c.is_whitespace()) {
246 format!("HTTP {status}: {body}")
247 } else {
248 format!(
249 "HTTP {status}: non-text error body ({} bytes). \
250 Common causes: expired authorization code (retry login immediately), \
251 redirect URI mismatch, or invalid app secret.",
252 body.len()
253 )
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn formats_json_oauth_error() {
263 let body = r#"{"error":"invalid_grant","error_description":"code expired"}"#;
264 let msg = format_oauth_error(400, body);
265 assert!(msg.contains("code expired"));
266 }
267
268 #[test]
269 fn refresh_expiry_counts_down_from_obtained_at() {
270 let tokens = Tokens {
271 access_token: "a".into(),
272 refresh_token: "r".into(),
273 token_type: "Bearer".into(),
274 expires_at: Utc::now() + chrono::Duration::hours(1),
275 scope: None,
276 obtained_at: Utc::now() - chrono::Duration::days(6),
277 };
278 assert!(tokens.refresh_expires_in_seconds() < 2 * 86400);
279 }
280}