1#[macro_use] extern crate serde;
93
94pub mod params;
95pub mod responses;
96
97use chrono::Utc;
98use params::{
99 GetAccountParams,
100 GetAccountsParams,
101 GetMoversParams,
102 GetPriceHistoryParams,
103};
104use thiserror::Error;
105
106use std::io;
107
108pub const TDA_API_BASE: &str = "https://api.tdameritrade.com/v1";
110
111#[derive(Debug)]
115pub struct Client {
116 pub access_token: Option<AccessToken>,
117 client_id: String,
118 refresh_token: String,
119}
120
121impl<'a> Client {
122 pub fn new(client_id: &'a str, refresh_token: &'a str, access_token: Option<AccessToken>) -> Self {
124 Self {
125 access_token,
126 client_id: client_id.to_string(),
127 refresh_token: refresh_token.to_string(),
128 }
129 }
130
131 pub fn set_access_token(&mut self, access_token: &Option<AccessToken>) -> &mut Self {
133 self.access_token = access_token.clone();
134
135 self
136 }
137
138 pub fn get_access_token(&self) -> Result<responses::AccessTokenResponse, ClientError> {
140 let url = format!("{}/oauth2/token", TDA_API_BASE);
141
142 let response = ureq::post(&url)
143 .send_form(&[
144 ("grant_type", "refresh_token"),
145 ("refresh_token", &self.refresh_token),
146 ("client_id", &self.client_id),
147 ]);
148 let status = response.status();
149 let body = response.into_string().map_err(ClientError::ReadResponse)?;
150
151 if status != 200 {
152 return Err(ClientError::NotHttpOk(status, body))
153 }
154
155 serde_json::from_str(&body).map_err(ClientError::ParseResponse)
156 }
157
158 pub fn get_account(&self, account_id: &'a str, params: GetAccountParams) -> Result<responses::Account, ClientError> {
162 if self.access_token.is_none() {
163 panic!("Client does not have a token set!");
164 }
165
166 let access_token = self.access_token.as_ref().unwrap();
167 let url = format!("{}/accounts/{}", TDA_API_BASE, account_id);
168
169 let mut request = ureq::get(&url);
170 request.set("Authorization", &format!("Bearer {}", access_token.token));
171
172 if let Some(fields) = params.fields {
173 request.query("fields", &fields);
174 }
175
176 let response = request.call();
177 let status = response.status();
178 let body = response.into_string().map_err(ClientError::ReadResponse)?;
179
180 if status != 200 {
181 return Err(ClientError::NotHttpOk(status, body));
182 }
183
184 serde_json::from_str(&body).map_err(ClientError::ParseResponse)
185 }
186
187 pub fn get_accounts(&self, params: GetAccountsParams) -> Result<Vec<responses::Account>, ClientError> {
191 if self.access_token.is_none() {
192 panic!("Client does not have a token set!");
193 }
194
195 let access_token = self.access_token.as_ref().unwrap();
196 let url = format!("{}/accounts", TDA_API_BASE);
197
198 let mut request = ureq::get(&url);
199 request.set("Authorization", &format!("Bearer {}", access_token.token));
200
201 if let Some(fields) = params.fields {
202 request.query("fields", &fields);
203 }
204
205 let response = request.call();
206 let status = response.status();
207 let body = response.into_string().map_err(ClientError::ReadResponse)?;
208
209 if status != 200 {
210 return Err(ClientError::NotHttpOk(status, body));
211 }
212
213 serde_json::from_str(&body).map_err(ClientError::ParseResponse)
214 }
215
216 pub fn get_movers(&self, index: &'a str, params: GetMoversParams) -> Result<Vec<responses::Mover>, ClientError> {
220 if self.access_token.is_none() {
221 panic!("Client does not have a token set!");
222 }
223
224 let access_token = self.access_token.as_ref().unwrap();
225 let url = format!("{}/marketdata/{}/movers", TDA_API_BASE, index);
226
227 let mut request = ureq::get(&url);
228 request.set("Authorization", &format!("Bearer {}", access_token.token));
229
230 if let Some(direction) = params.direction {
231 request.query("direction", &direction);
232 }
233
234 if let Some(change) = params.change {
235 request.query("change", &change);
236 }
237
238 let response = request.call();
239 let status = response.status();
240 let body = response.into_string().map_err(ClientError::ReadResponse)?;
241
242 if status != 200 {
243 return Err(ClientError::NotHttpOk(status, body));
244 }
245
246 serde_json::from_str(&body).map_err(ClientError::ParseResponse)
247 }
248
249 pub fn get_price_history(&self, symbol: &str, params: GetPriceHistoryParams) -> Result<responses::GetPriceHistoryResponse, ClientError> {
253 if self.access_token.is_none() {
254 panic!("Client does not have a token set!");
255 }
256
257 let access_token = self.access_token.as_ref().unwrap();
258 let url = format!("{}/marketdata/{}/pricehistory", TDA_API_BASE, symbol);
259
260 let mut request = ureq::get(&url);
261 request.set("Authorization", &format!("Bearer {}", access_token.token));
262
263 if let Some(period_type) = params.period_type {
264 request.query("periodType", &period_type);
265 }
266
267 if let Some(period) = params.period {
268 request.query("period", &period);
269 }
270
271 if let Some(frequency_type) = params.frequency_type {
272 request.query("frequencyType", &frequency_type);
273 }
274
275 if let Some(frequency) = params.frequency {
276 request.query("frequency", &frequency);
277 }
278
279 if let Some(end_date) = params.end_date {
280 request.query("endDate", &end_date);
281 }
282
283 if let Some(start_date) = params.start_date {
284 request.query("startDate", &start_date);
285 }
286
287 if let Some(need_extended_hours_data) = params.need_extended_hours_data {
288 request.query("needExtendedHoursData", &need_extended_hours_data.to_string());
289 }
290
291 let response = request.call();
292 let status = response.status();
293 let body = response.into_string().map_err(ClientError::ReadResponse)?;
294
295 if status != 200 {
296 return Err(ClientError::NotHttpOk(status, body));
297 }
298
299 serde_json::from_str(&body).map_err(ClientError::ParseResponse)
300 }
301}
302
303#[derive(Clone, Debug, Serialize)]
305pub struct AccessToken {
306 pub expires_at: i64,
308 pub scope: Vec<String>,
309 pub token: String,
310}
311
312impl From<responses::AccessTokenResponse> for AccessToken {
313 fn from(response: responses::AccessTokenResponse) -> Self {
314 let now = Utc::now().naive_utc().timestamp_millis();
315
316 Self {
317 token: response.access_token,
318 expires_at: now + response.expires_in,
319 scope: response.scope.split(' ').map(|v| v.to_string()).collect(),
320 }
321 }
322}
323
324impl AccessToken {
325 #[allow(dead_code)]
327 pub fn has_expired(&self) -> bool {
328 self.expires_at >= Utc::now().naive_utc().timestamp_millis()
329 }
330}
331
332#[derive(Debug, Error)]
334pub enum ClientError {
335 #[error("Received a {0} HTTP code: {1}")]
337 NotHttpOk(u16, String),
338
339 #[error("Failed to parse response: {0}")]
341 ParseResponse(#[from] serde_json::error::Error),
342
343 #[error("Failed to read response string: {0}")]
345 ReadResponse(#[from] io::Error),
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use std::fs::{self, OpenOptions};
352
353 const CONFIG_FILE: &'static str = "./.test.env";
355
356 const TOKEN_FILE_PATH: &'static str = "./.token.json";
358
359 #[derive(Debug)]
361 struct Config {
362 tda_client_id: String,
363 tda_refresh_token: String,
364 }
365
366 fn get_working_client() -> Client {
371 let config = load_config();
372 let mut client = Client::new(&config.tda_client_id, &config.tda_refresh_token, None);
373
374 let mut token: AccessToken = match OpenOptions::new().open(TOKEN_FILE_PATH) {
375 Ok(_) => load_token().into(),
376 Err(_) => {
377 let token: AccessToken = client.get_access_token().unwrap().into();
378 save_token(&token);
379
380 token
381 },
382 };
383
384 if token.has_expired() {
385 token = client.get_access_token().unwrap().into();
386 save_token(&token);
387 }
388
389 client.set_access_token(&Some(token));
390
391 client
392 }
393
394 fn load_config() -> Config {
396 dotenv::from_path(CONFIG_FILE).ok();
397
398 Config {
399 tda_client_id: dotenv::var("TDA_CLIENT_ID").unwrap(),
400 tda_refresh_token: dotenv::var("TDA_REFRESH_TOKEN").unwrap(),
401 }
402 }
403
404 fn load_token() -> responses::AccessTokenResponse {
408 let token = fs::read_to_string(TOKEN_FILE_PATH).unwrap();
409
410 serde_json::from_str(&token).unwrap()
411 }
412
413 fn save_token(token: &AccessToken) {
417 fs::write(TOKEN_FILE_PATH, serde_json::to_string(&token).unwrap()).unwrap();
418 }
419
420 #[test]
421 fn get_access_token() {
422 let config = load_config();
423 let client = Client::new(&config.tda_client_id, &config.tda_refresh_token, None);
424
425 let token = client.get_access_token().unwrap();
426
427 assert_ne!(token.access_token.len(), 0);
428 }
429
430 #[test]
431 fn set_access_token() {
432 let config = load_config();
433 let mut client = Client::new(&config.tda_client_id, &config.tda_refresh_token, None);
434
435 let response = client.get_access_token().unwrap();
436 let new_access_token = response.access_token.clone();
437
438 client.set_access_token(&Some(response.into()));
439
440 assert_eq!(new_access_token, client.access_token.unwrap().token);
441 }
442
443 #[test]
444 fn get_account() {
445 let client = get_working_client();
446
447 let accounts = client.get_accounts(GetAccountsParams::default()).unwrap();
448
449 match &accounts.get(0).unwrap().securities_account {
450 responses::SecuritiesAccount::MarginAccount { account_id, .. } => {
451 client.get_account(account_id, GetAccountParams::default()).unwrap();
452 }
453 }
454 }
455
456 #[test]
457 fn get_accounts() {
458 let client = get_working_client();
459
460 let accounts = client.get_accounts(GetAccountsParams::default()).unwrap();
461
462 assert_ne!(accounts.len(), 0);
463 }
464
465 #[test]
466 fn get_movers() {
467 let client = get_working_client();
468
469 let _movers = client.get_movers("$DJI", GetMoversParams::default()).unwrap();
470
471 }
473
474 #[test]
475 fn get_price_history() {
476 let client = get_working_client();
477
478 let response = client.get_price_history("AAPL", GetPriceHistoryParams::default()).unwrap();
479
480 assert_ne!(response.candles.len(), 0);
481 }
482}