tda_sdk/
lib.rs

1//! SDK for interacting with the TD Ameritrade API.
2//!
3//! [Developer Documentation](https://developer.tdameritrade.com/)
4//!
5//! **Important**: Before starting, you will need to make sure you have a
6//! developer application created (thus a client ID), and a valid refresh token.
7//! If you need help with either of these steps, you should refer to the
8//! following [API Guide Pages](https://developer.tdameritrade.com/guides):
9//!
10//! - [Getting Start](https://developer.tdameritrade.com/content/getting-started)
11//! - [Simple Auth for Local Apps](https://developer.tdameritrade.com/content/simple-auth-local-apps)
12//! - [Authentication FAQ](https://developer.tdameritrade.com/content/authentication-faq)
13//!
14//! ### Create a Client and Fetch a New Token
15//!
16//! After creating a [`Client`](struct.Client.html), you will need to give it an
17//! access token. You can either use an existing one from your database or
18//! filesystem, or fetch a new one from the API.
19//!
20//! ```no_run
21//! # use tda_sdk::Client;
22//! let mut client = Client::new("CLIENT_ID", "REFRESH_TOKEN", None);
23//!
24//! let access_token = client.get_access_token().unwrap();
25//!
26//! // We must convert the token response into a token usable by the client.
27//! client.set_access_token(&Some(access_token.into()));
28//! ```
29//!
30//! ### Create a Client and Use an Old Token
31//!
32//! ```no_run
33//! use tda_sdk::{AccessToken, Client};
34//!
35//! let access_token = AccessToken {
36//!     expires_at: 0,
37//!     token: "YOUR_TOKEN_STRING".to_string(),
38//!     scope: Vec::new(),
39//! };
40//!
41//! let client = Client::new("CLIENT_ID", "REFRESH_TOKEN", Some(access_token));
42//! ```
43//!
44//! ### Full Example for Fetching All Accounts
45//!
46//! After a token has been set, you may call any of the API methods. You can
47//! view all request parameters in the [params](params/index.html) module.
48//!
49//! ```no_run
50//! use tda_sdk::{
51//!     Client,
52//!     params::GetAccountsParams,
53//!     responses::SecuritiesAccount,
54//! };
55//!
56//! let mut client = Client::new("CLIENT_ID", "REFRESH_TOKEN", None);
57//!
58//! let access_token = client.get_access_token().unwrap();
59//! client.set_access_token(&Some(access_token.into()));
60//!
61//! let accounts = client.get_accounts(GetAccountsParams::default()).unwrap();
62//!
63//! for account in accounts {
64//!     match account.securities_account {
65//!         SecuritiesAccount::MarginAccount { r#type, account_id, .. } => {
66//!             println!("Account ID: {}", account_id);
67//!             println!("Account Type: {}", r#type);
68//!         }
69//!     }
70//! }
71//! ```
72//!
73//! ### Token Structure and Expiration
74//!
75//! This library does not handle token expirations, that is up to the user.
76//! However, the [`AccessToken`](struct.AccessToken.html) struct has a handy
77//! method for detecting its expiration status.
78//!
79//! **Note**: The `get_access_token()` response has a different structure than
80//! the token expected by the client. You will need to parse the response.
81//!
82//! ```no_run
83//! # use tda_sdk::{AccessToken, Client};
84//! # let client = Client::new("CLIENT_ID", "REFRESH_TOKEN", None);
85//! let access_token: AccessToken = client.get_access_token().unwrap().into();
86//!
87//! if access_token.has_expired() {
88//!     panic!("Token has expired!");
89//! }
90//! ```
91
92#[macro_use] extern crate serde;
93
94pub mod params;
95pub mod responses;
96
97use chrono::Utc;
98use params::{
99    GetAccountParams,
100    GetAccountsParams,
101    GetMoversParams,
102    GetPriceHistoryParams,
103};
104use thiserror::Error;
105
106use std::io;
107
108/// Base path for the TDA API.
109pub const TDA_API_BASE: &str = "https://api.tdameritrade.com/v1";
110
111/// Client for interacting with the TDA API.
112///
113/// Most API methods will panic if an access token is not set.
114#[derive(Debug)]
115pub struct Client {
116    pub access_token: Option<AccessToken>,
117    client_id: String,
118    refresh_token: String,
119}
120
121impl<'a> Client {
122    /// Create a new client with a client ID and refresh token.
123    pub fn new(client_id: &'a str, refresh_token: &'a str, access_token: Option<AccessToken>) -> Self {
124        Self {
125            access_token,
126            client_id: client_id.to_string(),
127            refresh_token: refresh_token.to_string(),
128        }
129    }
130
131    /// Set the internal access token of the client.
132    pub fn set_access_token(&mut self, access_token: &Option<AccessToken>) -> &mut Self {
133        self.access_token = access_token.clone();
134
135        self
136    }
137
138    /// Get a new access token from the API.
139    pub fn get_access_token(&self) -> Result<responses::AccessTokenResponse, ClientError> {
140        let url = format!("{}/oauth2/token", TDA_API_BASE);
141
142        let response = ureq::post(&url)
143            .send_form(&[
144                ("grant_type", "refresh_token"),
145                ("refresh_token", &self.refresh_token),
146                ("client_id", &self.client_id),
147           ]);
148        let status = response.status();
149        let body = response.into_string().map_err(ClientError::ReadResponse)?;
150
151        if status != 200 {
152            return Err(ClientError::NotHttpOk(status, body))
153        }
154
155        serde_json::from_str(&body).map_err(ClientError::ParseResponse)
156    }
157
158    /// Account balances, positions, and orders for a specific account.
159    ///
160    /// [API documentation](https://developer.tdameritrade.com/account-access/apis/get/accounts/%7BaccountId%7D-0)
161    pub fn get_account(&self, account_id: &'a str, params: GetAccountParams) -> Result<responses::Account, ClientError> {
162        if self.access_token.is_none() {
163            panic!("Client does not have a token set!");
164        }
165
166        let access_token = self.access_token.as_ref().unwrap();
167        let url = format!("{}/accounts/{}", TDA_API_BASE, account_id);
168
169        let mut request = ureq::get(&url);
170        request.set("Authorization", &format!("Bearer {}", access_token.token));
171
172        if let Some(fields) = params.fields {
173            request.query("fields", &fields);
174        }
175
176        let response = request.call();
177        let status = response.status();
178        let body = response.into_string().map_err(ClientError::ReadResponse)?;
179
180        if status != 200 {
181            return Err(ClientError::NotHttpOk(status, body));
182        }
183
184        serde_json::from_str(&body).map_err(ClientError::ParseResponse)
185    }
186
187    /// Account balances, positions, and orders for all linked accounts.
188    ///
189    /// [Api Documentation](https://developer.tdameritrade.com/account-access/apis/get/accounts-0)
190    pub fn get_accounts(&self, params: GetAccountsParams) -> Result<Vec<responses::Account>, ClientError> {
191        if self.access_token.is_none() {
192            panic!("Client does not have a token set!");
193        }
194
195        let access_token = self.access_token.as_ref().unwrap();
196        let url = format!("{}/accounts", TDA_API_BASE);
197
198        let mut request = ureq::get(&url);
199        request.set("Authorization", &format!("Bearer {}", access_token.token));
200
201        if let Some(fields) = params.fields {
202            request.query("fields", &fields);
203        }
204
205        let response = request.call();
206        let status = response.status();
207        let body = response.into_string().map_err(ClientError::ReadResponse)?;
208
209        if status != 200 {
210            return Err(ClientError::NotHttpOk(status, body));
211        }
212
213        serde_json::from_str(&body).map_err(ClientError::ParseResponse)
214    }
215
216    /// Top 10 (up or down) movers by value or percent for a particular market
217    ///
218    /// [API Documentation](https://developer.tdameritrade.com/movers/apis/get/marketdata/%7Bindex%7D/movers)
219    pub fn get_movers(&self, index: &'a str, params: GetMoversParams) -> Result<Vec<responses::Mover>, ClientError> {
220        if self.access_token.is_none() {
221            panic!("Client does not have a token set!");
222        }
223
224        let access_token = self.access_token.as_ref().unwrap();
225        let url = format!("{}/marketdata/{}/movers", TDA_API_BASE, index);
226
227        let mut request = ureq::get(&url);
228        request.set("Authorization", &format!("Bearer {}", access_token.token));
229
230        if let Some(direction) = params.direction {
231            request.query("direction", &direction);
232        }
233
234        if let Some(change) = params.change {
235            request.query("change", &change);
236        }
237
238        let response = request.call();
239        let status = response.status();
240        let body = response.into_string().map_err(ClientError::ReadResponse)?;
241
242        if status != 200 {
243            return Err(ClientError::NotHttpOk(status, body));
244        }
245
246        serde_json::from_str(&body).map_err(ClientError::ParseResponse)
247    }
248
249    /// Get price history for a symbol
250    ///
251    /// [API Documentation](https://developer.tdameritrade.com/price-history/apis/get/marketdata/%7Bsymbol%7D/pricehistory)
252    pub fn get_price_history(&self, symbol: &str, params: GetPriceHistoryParams) -> Result<responses::GetPriceHistoryResponse, ClientError> {
253        if self.access_token.is_none() {
254            panic!("Client does not have a token set!");
255        }
256
257        let access_token = self.access_token.as_ref().unwrap();
258        let url = format!("{}/marketdata/{}/pricehistory", TDA_API_BASE, symbol);
259
260        let mut request = ureq::get(&url);
261        request.set("Authorization", &format!("Bearer {}", access_token.token));
262
263        if let Some(period_type) = params.period_type {
264            request.query("periodType", &period_type);
265        }
266
267        if let Some(period) = params.period {
268            request.query("period", &period);
269        }
270
271        if let Some(frequency_type) = params.frequency_type {
272            request.query("frequencyType", &frequency_type);
273        }
274
275        if let Some(frequency) = params.frequency {
276            request.query("frequency", &frequency);
277        }
278
279        if let Some(end_date) = params.end_date {
280            request.query("endDate", &end_date);
281        }
282
283        if let Some(start_date) = params.start_date {
284            request.query("startDate", &start_date);
285        }
286
287        if let Some(need_extended_hours_data) = params.need_extended_hours_data {
288            request.query("needExtendedHoursData", &need_extended_hours_data.to_string());
289        }
290
291        let response = request.call();
292        let status = response.status();
293        let body = response.into_string().map_err(ClientError::ReadResponse)?;
294
295        if status != 200 {
296            return Err(ClientError::NotHttpOk(status, body));
297        }
298
299        serde_json::from_str(&body).map_err(ClientError::ParseResponse)
300    }
301}
302
303/// API access token.
304#[derive(Clone, Debug, Serialize)]
305pub struct AccessToken {
306    /// Timestamp in milliseconds when the token expires.
307    pub expires_at: i64,
308    pub scope: Vec<String>,
309    pub token: String,
310}
311
312impl From<responses::AccessTokenResponse> for AccessToken {
313    fn from(response: responses::AccessTokenResponse) -> Self {
314        let now = Utc::now().naive_utc().timestamp_millis();
315
316        Self {
317            token: response.access_token,
318            expires_at: now + response.expires_in,
319            scope: response.scope.split(' ').map(|v| v.to_string()).collect(),
320        }
321    }
322}
323
324impl AccessToken {
325    /// Return true if the access token has expired.
326    #[allow(dead_code)]
327    pub fn has_expired(&self) -> bool {
328        self.expires_at >= Utc::now().naive_utc().timestamp_millis()
329    }
330}
331
332/// Represents all possible errors the `Client` might encounter.
333#[derive(Debug, Error)]
334pub enum ClientError {
335    /// Received a non-200 HTTP status code from the server.
336    #[error("Received a {0} HTTP code: {1}")]
337    NotHttpOk(u16, String),
338
339    /// Was unable to parse the response into a usable struct.
340    #[error("Failed to parse response: {0}")]
341    ParseResponse(#[from] serde_json::error::Error),
342
343    /// Was unable to read the response string.
344    #[error("Failed to read response string: {0}")]
345    ReadResponse(#[from] io::Error),
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use std::fs::{self, OpenOptions};
352
353    /// Configuration file path.
354    const CONFIG_FILE: &'static str = "./.test.env";
355
356    /// Local token file path.
357    const TOKEN_FILE_PATH: &'static str = "./.token.json";
358
359    /// Configuration settings found in `.test.env`.
360    #[derive(Debug)]
361    struct Config {
362        tda_client_id: String,
363        tda_refresh_token: String,
364    }
365
366    /// Get a client with a working access token.
367    ///
368    /// Handles loading/saving the local file token, as well as fetching a new
369    /// one if necessary.
370    fn get_working_client() -> Client {
371        let config = load_config();
372        let mut client = Client::new(&config.tda_client_id, &config.tda_refresh_token, None);
373
374        let mut token: AccessToken = match OpenOptions::new().open(TOKEN_FILE_PATH) {
375            Ok(_) => load_token().into(),
376            Err(_) => {
377                let token: AccessToken = client.get_access_token().unwrap().into();
378                save_token(&token);
379
380                token
381            },
382        };
383
384        if token.has_expired() {
385            token = client.get_access_token().unwrap().into();
386            save_token(&token);
387        }
388
389        client.set_access_token(&Some(token));
390
391        client
392    }
393
394    /// Load config settings from `.test.env`.
395    fn load_config() -> Config {
396        dotenv::from_path(CONFIG_FILE).ok();
397
398        Config {
399            tda_client_id: dotenv::var("TDA_CLIENT_ID").unwrap(),
400            tda_refresh_token: dotenv::var("TDA_REFRESH_TOKEN").unwrap(),
401        }
402    }
403
404    /// Load the token from the local `.token.json` file.
405    ///
406    /// Panics if the file could not be found or accessed.
407    fn load_token() -> responses::AccessTokenResponse {
408        let token = fs::read_to_string(TOKEN_FILE_PATH).unwrap();
409
410        serde_json::from_str(&token).unwrap()
411    }
412
413    /// Save a token to the local `.token.json` file.
414    ///
415    /// Panics if the file could not be written to.
416    fn save_token(token: &AccessToken) {
417        fs::write(TOKEN_FILE_PATH, serde_json::to_string(&token).unwrap()).unwrap();
418    }
419
420    #[test]
421    fn get_access_token() {
422        let config = load_config();
423        let client = Client::new(&config.tda_client_id, &config.tda_refresh_token, None);
424
425        let token = client.get_access_token().unwrap();
426
427        assert_ne!(token.access_token.len(), 0);
428    }
429
430    #[test]
431    fn set_access_token() {
432        let config = load_config();
433        let mut client = Client::new(&config.tda_client_id, &config.tda_refresh_token, None);
434
435        let response = client.get_access_token().unwrap();
436        let new_access_token = response.access_token.clone();
437
438        client.set_access_token(&Some(response.into()));
439
440        assert_eq!(new_access_token, client.access_token.unwrap().token);
441    }
442
443    #[test]
444    fn get_account() {
445        let client = get_working_client();
446
447        let accounts = client.get_accounts(GetAccountsParams::default()).unwrap();
448
449        match &accounts.get(0).unwrap().securities_account {
450            responses::SecuritiesAccount::MarginAccount { account_id, .. } => {
451                client.get_account(account_id, GetAccountParams::default()).unwrap();
452            }
453        }
454    }
455
456    #[test]
457    fn get_accounts() {
458        let client = get_working_client();
459
460        let accounts = client.get_accounts(GetAccountsParams::default()).unwrap();
461
462        assert_ne!(accounts.len(), 0);
463    }
464
465    #[test]
466    fn get_movers() {
467        let client = get_working_client();
468
469        let _movers = client.get_movers("$DJI", GetMoversParams::default()).unwrap();
470
471        // TODO: Make sure test the response is parsing, when we get data again.
472    }
473
474    #[test]
475    fn get_price_history() {
476        let client = get_working_client();
477
478        let response = client.get_price_history("AAPL", GetPriceHistoryParams::default()).unwrap();
479
480        assert_ne!(response.candles.len(), 0);
481    }
482}