1use chrono::{DateTime, Utc};
2use serde::Deserialize;
3
4#[derive(Debug, Clone, serde::Serialize, Deserialize)]
6pub struct Tokens {
7 pub access_token: String,
8 pub refresh_token: String,
9 pub token_type: String,
10 pub expires_at: DateTime<Utc>,
11 pub scope: Option<String>,
12}
13
14impl Tokens {
15 pub fn is_expired(&self) -> bool {
16 Utc::now() >= self.expires_at
17 }
18
19 pub fn expires_in_seconds(&self) -> i64 {
20 (self.expires_at - Utc::now()).num_seconds().max(0)
21 }
22}
23
24#[derive(Debug, Deserialize)]
25struct TokenResponse {
26 access_token: String,
27 refresh_token: String,
28 token_type: String,
29 expires_in: i64,
30 scope: Option<String>,
31}
32
33#[derive(Debug, Deserialize)]
34struct OAuthErrorResponse {
35 error: Option<String>,
36 error_description: Option<String>,
37 message: Option<String>,
38}
39
40use std::path::PathBuf;
41
42use reqwest::Client;
43use tokio::fs;
44use tracing::{debug, info};
45
46use crate::config::ClientConfig;
47use crate::error::{ApiError, Result};
48
49#[derive(Debug, Clone)]
51pub struct TokenStore {
52 path: PathBuf,
53}
54
55impl TokenStore {
56 pub fn new(token_dir: PathBuf) -> Self {
57 Self {
58 path: token_dir.join("tokens.json"),
59 }
60 }
61
62 pub fn path(&self) -> &PathBuf {
63 &self.path
64 }
65
66 pub async fn load(&self) -> Result<Option<Tokens>> {
67 match fs::read_to_string(&self.path).await {
68 Ok(raw) => Ok(Some(serde_json::from_str(&raw)?)),
69 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
70 Err(err) => Err(ApiError::TokenStore(err.to_string())),
71 }
72 }
73
74 pub async fn save(&self, tokens: &Tokens) -> Result<()> {
75 if let Some(parent) = self.path.parent() {
76 fs::create_dir_all(parent)
77 .await
78 .map_err(|e| ApiError::TokenStore(e.to_string()))?;
79 }
80 let raw = serde_json::to_string_pretty(tokens)?;
81 fs::write(&self.path, raw)
82 .await
83 .map_err(|e| ApiError::TokenStore(e.to_string()))?;
84 Ok(())
85 }
86
87 pub async fn clear(&self) -> Result<()> {
88 match fs::remove_file(&self.path).await {
89 Ok(()) => Ok(()),
90 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
91 Err(err) => Err(ApiError::TokenStore(err.to_string())),
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct OAuthClient {
99 http: Client,
100 config: ClientConfig,
101 store: TokenStore,
102}
103
104impl OAuthClient {
105 pub fn new(config: ClientConfig) -> Self {
106 let store = TokenStore::new(config.token_dir.clone());
107 let http = Client::builder()
108 .gzip(true)
109 .build()
110 .expect("reqwest client");
111 Self {
112 http,
113 config,
114 store,
115 }
116 }
117
118 pub fn store(&self) -> &TokenStore {
119 &self.store
120 }
121
122 pub fn authorize_url(&self) -> String {
123 let mut url = url::Url::parse(&self.config.oauth_authorize_url)
124 .expect("valid oauth authorize url");
125 {
126 let mut pairs = url.query_pairs_mut();
127 pairs.append_pair("client_id", &self.config.app_key);
128 pairs.append_pair("redirect_uri", &self.config.redirect_uri);
129 pairs.append_pair("response_type", "code");
130 }
131 url.to_string()
132 }
133
134 pub async fn exchange_code(&self, code: &str) -> Result<Tokens> {
135 let tokens = self
136 .token_request(&[
137 ("grant_type", "authorization_code"),
138 ("code", code),
139 ("redirect_uri", &self.config.redirect_uri),
140 ])
141 .await?;
142 self.store.save(&tokens).await?;
143 info!("OAuth tokens saved");
144 Ok(tokens)
145 }
146
147 pub async fn refresh(&self) -> Result<Tokens> {
148 let existing = self
149 .store
150 .load()
151 .await?
152 .ok_or_else(|| ApiError::NotAuthenticated("No refresh token on disk".into()))?;
153
154 let tokens = self
155 .token_request(&[
156 ("grant_type", "refresh_token"),
157 ("refresh_token", &existing.refresh_token),
158 ])
159 .await?;
160 self.store.save(&tokens).await?;
161 info!("OAuth tokens refreshed");
162 Ok(tokens)
163 }
164
165 pub async fn ensure_access_token(&self) -> Result<String> {
166 let tokens = match self.store.load().await? {
167 Some(tokens) if !tokens.is_expired() => tokens,
168 Some(_) => self.refresh().await?,
169 None => {
170 return Err(ApiError::NotAuthenticated(
171 "Run `schwab auth login` to authenticate".into(),
172 ))
173 }
174 };
175 Ok(tokens.access_token)
176 }
177
178 pub async fn status(&self) -> Result<Option<Tokens>> {
179 self.store.load().await
180 }
181
182 pub async fn logout(&self) -> Result<()> {
183 self.store.clear().await
184 }
185
186 async fn token_request(&self, params: &[(&str, &str)]) -> Result<Tokens> {
187 debug!("Requesting OAuth token");
188 let response = self
189 .http
190 .post(&self.config.oauth_token_url)
191 .basic_auth(&self.config.app_key, Some(&self.config.app_secret))
192 .header("Content-Type", "application/x-www-form-urlencoded")
193 .header("Accept", "application/json")
194 .form(params)
195 .send()
196 .await?;
197
198 let status = response.status();
199 let body = response.text().await?;
200 if !status.is_success() {
201 return Err(ApiError::OAuth(format_oauth_error(status.as_u16(), &body)));
202 }
203
204 let parsed: TokenResponse = serde_json::from_str(&body).map_err(|e| {
205 ApiError::OAuth(format!("Token response parse error: {e}; body={body}"))
206 })?;
207 Ok(Tokens {
208 access_token: parsed.access_token,
209 refresh_token: parsed.refresh_token,
210 token_type: parsed.token_type,
211 expires_at: Utc::now() + chrono::Duration::seconds(parsed.expires_in),
212 scope: parsed.scope,
213 })
214 }
215}
216
217fn format_oauth_error(status: u16, body: &str) -> String {
218 if let Ok(parsed) = serde_json::from_str::<OAuthErrorResponse>(body) {
219 let msg = parsed
220 .error_description
221 .or(parsed.message)
222 .or(parsed.error)
223 .unwrap_or_else(|| body.to_string());
224 return format!("HTTP {status}: {msg}");
225 }
226 if body.chars().all(|c| c.is_ascii() || c.is_whitespace()) {
227 format!("HTTP {status}: {body}")
228 } else {
229 format!(
230 "HTTP {status}: non-text error body ({} bytes). \
231 Common causes: expired authorization code (retry login immediately), \
232 redirect URI mismatch, or invalid app secret.",
233 body.len()
234 )
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn formats_json_oauth_error() {
244 let body = r#"{"error":"invalid_grant","error_description":"code expired"}"#;
245 let msg = format_oauth_error(400, body);
246 assert!(msg.contains("code expired"));
247 }
248}