tradestation_rs/
client.rs

1use crate::token::RefreshedToken;
2use crate::{Error, Token};
3use reqwest::{header, Response};
4use serde::Serialize;
5use serde_json::Value;
6use std::collections::HashMap;
7
8#[derive(Clone, Debug)]
9/// TradeStation API Client
10pub struct Client {
11    http_client: reqwest::Client,
12    client_id: String,
13    client_secret: String,
14    /// Bearer Token for TradeStation's API
15    pub token: Token,
16    /// The base url used for all endpoints.
17    ///
18    /// NOTE: You should leave this default unless you
19    /// specifically want to use a custom address for
20    /// testing or mocking purposes.
21    pub base_url: String,
22}
23impl Client {
24    /// Send an HTTP request to TradeStation's API, with automatic
25    /// token refreshing near, at, or after auth token expiration.
26    ///
27    /// NOTE: You should use `Client::post()` or `Client::get()` in favor of this method.
28    pub async fn send_request<F, T>(&mut self, request_fn: F) -> Result<Response, Error>
29    where
30        F: Fn(&Token) -> T,
31        T: std::future::Future<Output = Result<Response, reqwest::Error>>,
32    {
33        match request_fn(&self.token).await {
34            Ok(resp) => {
35                // Check if the client gets a 401 unauthorized to try and re auth the client
36                // this happens when auth token expires.
37                if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
38                    // Refresh the clients token
39                    self.refresh_token().await?;
40
41                    // Retry sending the request to TradeStation's API
42                    let retry_response = request_fn(&self.token).await?;
43                    Ok(retry_response)
44                } else {
45                    Ok(resp)
46                }
47            }
48            Err(e) => Err(Error::Request(e)),
49        }
50    }
51
52    /// Send a POST request from your `Client` to TradeStation's API
53    pub async fn post<T: Serialize>(
54        &mut self,
55        endpoint: &str,
56        payload: &T,
57    ) -> Result<Response, Error> {
58        let url = format!("https://{}/{}", self.base_url, endpoint);
59        let resp = self
60            .clone()
61            .send_request(|token| {
62                self.http_client
63                    .post(&url)
64                    .header("Content-Type", "application/json")
65                    .header(
66                        header::AUTHORIZATION,
67                        format!("Bearer {}", token.access_token),
68                    )
69                    .json(payload)
70                    .send()
71            })
72            .await?;
73
74        Ok(resp)
75    }
76
77    /// Send a GET request from your `Client` to TradeStation's API
78    pub async fn get(&mut self, endpoint: &str) -> Result<Response, Error> {
79        let url = format!("https://{}/{}", self.base_url, endpoint);
80        let resp = self
81            .clone()
82            .send_request(|token| {
83                self.http_client
84                    .get(&url)
85                    .header(
86                        header::AUTHORIZATION,
87                        format!("Bearer {}", token.access_token),
88                    )
89                    .send()
90            })
91            .await?;
92
93        Ok(resp)
94    }
95
96    /// Send a PUT request from your `Client` to TradeStation's API
97    pub async fn put<T: Serialize>(
98        &mut self,
99        endpoint: &str,
100        payload: &T,
101    ) -> Result<Response, Error> {
102        let url = format!("https://{}/{}", self.base_url, endpoint);
103        let resp = self
104            .clone()
105            .send_request(|token| {
106                self.http_client
107                    .put(&url)
108                    .header("Content-Type", "application/json")
109                    .header(
110                        header::AUTHORIZATION,
111                        format!("Bearer {}", token.access_token),
112                    )
113                    .json(&payload)
114                    .send()
115            })
116            .await?;
117
118        Ok(resp)
119    }
120
121    /// Send a DELETE request from your `Client` to TradeStation's API
122    pub async fn delete(&mut self, endpoint: &str) -> Result<Response, Error> {
123        let url = format!("https://{}/{}", self.base_url, endpoint);
124        let resp = self
125            .clone()
126            .send_request(|token| {
127                self.http_client
128                    .delete(&url)
129                    .header("Content-Type", "application/json")
130                    .header(
131                        header::AUTHORIZATION,
132                        format!("Bearer {}", token.access_token),
133                    )
134                    .send()
135            })
136            .await?;
137
138        Ok(resp)
139    }
140
141    /// Start a stream from the TradeStation API to the `Client`
142    ///
143    /// NOTE: You need to provide a processing function for handeling the stream chunks
144    pub async fn stream<F>(&mut self, endpoint: &str, mut process_chunk: F) -> Result<(), Error>
145    where
146        F: FnMut(Value) -> Result<(), Error>,
147    {
148        let url = format!("https://{}/{}", self.base_url, endpoint);
149
150        let mut resp = self
151            .clone()
152            .send_request(|token| {
153                self.http_client
154                    .get(&url)
155                    .header(
156                        reqwest::header::AUTHORIZATION,
157                        format!("Bearer {}", token.access_token),
158                    )
159                    .send()
160            })
161            .await?;
162
163        if !resp.status().is_success() {
164            return Err(Error::StreamIssue(format!(
165                "Request failed with status: {}",
166                resp.status()
167            )));
168        }
169
170        let mut buffer = String::new();
171        while let Some(chunk) = resp.chunk().await? {
172            let chunk_str = std::str::from_utf8(&chunk).unwrap_or("");
173            buffer.push_str(chunk_str);
174
175            while let Some(pos) = buffer.find("\n") {
176                let json_str = buffer[..pos].trim().to_string();
177                buffer = buffer[pos + 1..].to_string();
178                if json_str.is_empty() {
179                    continue;
180                }
181
182                match serde_json::from_str::<Value>(&json_str) {
183                    Ok(json_value) => {
184                        if let Err(e) = process_chunk(json_value) {
185                            if matches!(e, Error::StopStream) {
186                                return Ok(());
187                            } else {
188                                return Err(e);
189                            }
190                        }
191                    }
192                    Err(e) => {
193                        return Err(Error::Json(e));
194                    }
195                }
196            }
197        }
198
199        // Handle any leftover data in the buffer
200        if !buffer.trim().is_empty() {
201            match serde_json::from_str::<Value>(&buffer) {
202                Ok(json_value) => {
203                    if let Err(e) = process_chunk(json_value) {
204                        if matches!(e, Error::StopStream) {
205                            return Ok(());
206                        } else {
207                            return Err(e);
208                        }
209                    }
210                }
211                Err(e) => {
212                    return Err(Error::Json(e));
213                }
214            }
215        }
216
217        Ok(())
218    }
219
220    /// Refresh your clients bearer token used for authentication
221    /// with TradeStation's API.
222    pub async fn refresh_token(&mut self) -> Result<(), Error> {
223        let form_data: HashMap<String, String> = HashMap::from([
224            ("grant_type".into(), "refresh_token".into()),
225            ("client_id".into(), self.client_id.clone()),
226            ("client_secret".into(), self.client_secret.clone()),
227            ("refresh_token".into(), self.token.refresh_token.clone()),
228            ("redirect_uri".into(), "http://localhost:8080/".into()),
229        ]);
230
231        let new_token = self
232            .http_client
233            .post("https://signin.tradestation.com/oauth/token")
234            .header("Content-Type", "application/x-www-form-urlencoded")
235            .form(&form_data)
236            .send()
237            .await?
238            .json::<RefreshedToken>()
239            .await?;
240
241        // Update the clients token
242        self.token = Token {
243            refresh_token: self.token.refresh_token.clone(),
244            access_token: new_token.access_token,
245            id_token: new_token.id_token,
246            scope: new_token.scope,
247            token_type: new_token.token_type,
248            expires_in: new_token.expires_in,
249        };
250
251        Ok(())
252    }
253}
254
255#[derive(Debug, Default)]
256/// Builder for `Client`
257pub struct ClientBuilder;
258
259#[derive(Debug, Default)]
260/// First step to building a `Client`.
261pub struct Step1;
262#[derive(Debug, Default)]
263/// Second step to building a `Client`.
264pub struct Step2;
265#[derive(Debug, Default)]
266/// Third step to building a `Client`.
267pub struct Step3;
268
269#[derive(Debug, Default)]
270/// Phantom Type for compile time enforcement
271/// on the order of builder steps used.
272pub struct ClientBuilderStep<CurrentStep> {
273    _current_step: CurrentStep,
274    http_client: Option<reqwest::Client>,
275    client_id: Option<String>,
276    client_secret: Option<String>,
277    token: Option<Token>,
278    testing_url: Option<String>,
279}
280
281impl ClientBuilder {
282    #[allow(clippy::new_ret_no_self)]
283    /// Instantiate a new instance of `ClientBuilder`
284    pub fn new() -> Result<ClientBuilderStep<Step1>, Error> {
285        Ok(ClientBuilderStep {
286            _current_step: Step1,
287            http_client: Some(reqwest::Client::new()),
288            ..Default::default()
289        })
290    }
291}
292impl ClientBuilderStep<Step1> {
293    /// Set your client id/key and secret
294    pub fn credentials(
295        self,
296        client_id: &str,
297        client_secret: &str,
298    ) -> Result<ClientBuilderStep<Step2>, Error> {
299        Ok(ClientBuilderStep {
300            _current_step: Step2,
301            http_client: Some(self.http_client.unwrap()),
302            client_id: Some(client_id.into()),
303            client_secret: Some(client_secret.into()),
304            ..Default::default()
305        })
306    }
307
308    /// Set the testing url for the client to use for sending
309    /// ALL the requests to your test/mock server instead of
310    /// the default TradeStation API url.
311    ///
312    /// NOTE: This should ONLY be set for testing and
313    /// mocking purposes. This should NOT be set used
314    /// with a production `Client`.
315    pub fn testing_url(self, url: &str) -> ClientBuilderStep<Step3> {
316        ClientBuilderStep {
317            _current_step: Step3,
318            http_client: self.http_client,
319            client_id: self.client_id,
320            client_secret: self.client_secret,
321            token: self.token,
322            testing_url: Some(url.into()),
323        }
324    }
325}
326impl ClientBuilderStep<Step2> {
327    /// Use your authorization code to get and set auth token
328    pub async fn authorize(
329        self,
330        authorization_code: &str,
331    ) -> Result<ClientBuilderStep<Step3>, Error> {
332        // NOTE: These unwraps are panic safe due to type invariant
333        // with compile time enforced order of steps for `ClientBuilderStep`
334        let http_client = self.http_client.unwrap();
335        let client_id = self.client_id.as_ref().unwrap();
336        let client_secret = self.client_secret.as_ref().unwrap();
337
338        // Send HTTP request to TradeStation API to get auth token
339        let form_data = HashMap::from([
340            ("grant_type", "authorization_code"),
341            ("client_id", client_id),
342            ("client_secret", client_secret),
343            ("code", authorization_code),
344            ("redirect_uri", "http://localhost:8080/"),
345        ]);
346        let token = http_client
347            .post("https://signin.tradestation.com/oauth/token")
348            .header("Content-Type", "application/x-www-form-urlencoded")
349            .form(&form_data)
350            .send()
351            .await?
352            .json::<Token>()
353            .await?;
354
355        Ok(ClientBuilderStep {
356            _current_step: Step3,
357            http_client: Some(http_client),
358            client_id: self.client_id,
359            client_secret: self.client_secret,
360            token: Some(token),
361            testing_url: self.testing_url,
362        })
363    }
364
365    /// Set the current `Token` for the `Client` to use
366    pub fn token(self, token: Token) -> Result<ClientBuilderStep<Step3>, Error> {
367        Ok(ClientBuilderStep {
368            _current_step: Step3,
369            http_client: self.http_client,
370            client_id: self.client_id,
371            client_secret: self.client_secret,
372            token: Some(token),
373            testing_url: self.testing_url,
374        })
375    }
376}
377impl ClientBuilderStep<Step3> {
378    /// Finish building into a `Client`.
379    pub async fn build(self) -> Result<Client, Error> {
380        let http_client = self.http_client.unwrap();
381
382        if self.testing_url.is_none() {
383            let client_id = self.client_id.unwrap();
384            let client_secret = self.client_secret.unwrap();
385            let token = self.token.unwrap();
386            let base_url = "api.tradestation.com/v3".to_string();
387
388            Ok(Client {
389                http_client,
390                client_id,
391                client_secret,
392                token,
393                base_url,
394            })
395        } else {
396            let client_id = "NO_CLIENT_ID_IN_TEST_MODE".to_string();
397            let client_secret = "NO_CLIENT_SECRET_IN_TEST_MODE".to_string();
398            let token = Token {
399                access_token: String::from("NO_ACCESS_TOKEN_IN_TEST_MODE"),
400                refresh_token: String::from("NO_REFRESH_TOKEN_IN_TEST_MODE"),
401                id_token: String::from("NO_ID_TOKEN_IN_TEST_MODE"),
402                token_type: String::from("TESTING"),
403                scope: String::from("NO SCOPES IN TEST MODE"),
404                expires_in: 9999,
405            };
406            let base_url = self
407                .testing_url
408                .expect("Some `Client::testing_url` to be set due to invariant check.");
409
410            Ok(Client {
411                http_client,
412                client_id,
413                client_secret,
414                token,
415                base_url,
416            })
417        }
418    }
419}