Skip to main content

sharesight_reqwest/
lib.rs

1use std::sync::Arc;
2
3use log::warn;
4use reqwest_middleware::reqwest;
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use sharesight_types::{
7    ApiEndpoint, CashAccountsList, CashAccountsListCashAccountsSuccess, CashAccountsListParameters,
8    CashAccountsListSuccess, PortfolioList, PortfolioListParameters,
9    PortfolioListPortfoliosSuccess, PortfolioListSuccess,
10};
11
12pub use aliri_tokens::TokenWithLifetime;
13
14pub struct Client {
15    client: reqwest_middleware::ClientWithMiddleware,
16    host: Arc<str>,
17}
18
19impl Client {
20    pub async fn new(
21        user_credentials_file: std::path::PathBuf,
22        client_credentials_file: std::path::PathBuf,
23    ) -> Result<Self, SharesightReqwestError> {
24        use predicates::prelude::PredicateBooleanExt;
25
26        let client = reqwest::Client::default();
27
28        let client_credentials = serde_json::from_reader::<_, ClientCredentials>(
29            std::fs::File::open(client_credentials_file)?,
30        )?;
31        let credentials =
32            std::sync::Arc::new(aliri_tokens::sources::oauth2::dto::ClientCredentials {
33                client_id: client_credentials.client_id,
34                client_secret: client_credentials.client_secret,
35            });
36
37        let credentials = aliri_tokens::sources::oauth2::dto::ClientCredentialsWithAudience {
38            credentials,
39            audience: "".into(),
40        };
41
42        let fallback = aliri_tokens::sources::oauth2::ClientCredentialsTokenSource::new(
43            client.clone(),
44            reqwest::Url::parse(&format!("https://{}/oauth2/token", client_credentials.host))
45                .unwrap(),
46            credentials,
47            aliri_tokens::TokenLifetimeConfig::default(),
48        );
49
50        let file_source = aliri_tokens::sources::file::FileTokenSource::new(user_credentials_file);
51
52        let token_source = aliri_tokens::sources::cache::CachedTokenSource::new(fallback)
53            .with_cache("file", file_source);
54
55        let token_watcher = aliri_tokens::TokenWatcher::spawn_from_token_source(
56            token_source,
57            aliri_tokens::jitter::RandomEarlyJitter::new(aliri_clock::DurationSecs(60)),
58            aliri_tokens::backoff::ErrorBackoffConfig::default(),
59        )
60        .await?;
61        let client = reqwest_middleware::ClientBuilder::new(client)
62            .with(
63                aliri_reqwest::AccessTokenMiddleware::new(token_watcher).with_predicate(
64                    aliri_reqwest::HttpsOnly
65                        .and(aliri_reqwest::ExactHostMatch::new(&client_credentials.host)),
66                ),
67            )
68            .build();
69
70        Ok(Client {
71            host: client_credentials.host.into(),
72            client,
73        })
74    }
75
76    pub async fn execute<'a, T: ApiEndpoint<'a>, U: DeserializeOwned>(
77        &'a self,
78        parameters: &'a T::Parameters,
79    ) -> Result<U, SharesightReqwestError> {
80        let method = match T::HTTP_METHOD {
81            sharesight_types::ApiHttpMethod::Get => reqwest::Method::GET,
82            sharesight_types::ApiHttpMethod::Post => reqwest::Method::POST,
83            sharesight_types::ApiHttpMethod::Patch => reqwest::Method::PATCH,
84            sharesight_types::ApiHttpMethod::Put => reqwest::Method::PUT,
85            sharesight_types::ApiHttpMethod::Delete => reqwest::Method::DELETE,
86        };
87        let resp = self
88            .client
89            .request(method, T::url(&self.host, parameters).to_string())
90            .json(parameters)
91            .send()
92            .await?;
93
94        if resp.status().is_success() {
95            let full = resp.bytes().await?;
96
97            let slice = if full.is_empty() {
98                b"null".as_slice()
99            } else {
100                &full
101            };
102
103            Ok(serde_json::from_slice(slice).map_err(|e| {
104                if let Ok(s) = std::str::from_utf8(&full) {
105                    warn!("Error deserializing json: {:?}\n{}", e, s);
106                } else {
107                    warn!("Error deserializing json - not valid utf-8: {:?}", e);
108                }
109                e
110            })?)
111        } else {
112            Err(SharesightReqwestError::Http(
113                resp.url().clone(),
114                resp.status(),
115                resp.text().await?,
116            ))
117        }
118    }
119
120    pub async fn build_portfolio_index(
121        &self,
122    ) -> Result<NameIndex<PortfolioListPortfoliosSuccess>, SharesightReqwestError> {
123        let mut index = NameIndex::default();
124        let parameters = PortfolioListParameters {
125            consolidated: Some(true),
126            instrument_id: None,
127        };
128        let PortfolioListSuccess { portfolios, .. } =
129            self.execute::<PortfolioList, _>(&parameters).await?;
130        index.extend(portfolios);
131
132        let parameters = PortfolioListParameters {
133            consolidated: Some(false),
134            instrument_id: None,
135        };
136        let PortfolioListSuccess { portfolios, .. } =
137            self.execute::<PortfolioList, _>(&parameters).await?;
138        index.extend(portfolios);
139
140        Ok(index)
141    }
142
143    pub async fn build_cash_account_index(
144        &self,
145        portfolio: &PortfolioListPortfoliosSuccess,
146    ) -> Result<NameIndex<CashAccountsListCashAccountsSuccess>, SharesightReqwestError> {
147        let mut index = NameIndex::default();
148
149        let account_params = CashAccountsListParameters { date: None };
150        let CashAccountsListSuccess { cash_accounts, .. } =
151            self.execute::<CashAccountsList, _>(&account_params).await?;
152        let cash_accounts = cash_accounts
153            .into_iter()
154            .filter(|a| a.portfolio_id == portfolio.id);
155
156        index.extend(cash_accounts);
157
158        Ok(index)
159    }
160}
161
162#[derive(Debug, thiserror::Error)]
163pub enum SharesightReqwestError {
164    #[error("Http request returned non-success status code\n{0} {1}\n{2}")]
165    Http(reqwest::Url, reqwest::StatusCode, String),
166    #[error("Http error occurred\n{0:?}")]
167    Reqwest(#[from] reqwest::Error),
168    #[error("Http error occurred\n{0:?}")]
169    ReqwestMiddleware(#[from] reqwest_middleware::Error),
170    #[error("Deserialize error occurred\n{0:?}")]
171    Deserialize(#[from] serde_json::Error),
172    #[error("Token request error occurred\n{0:?}")]
173    TokenRequestError(#[from] aliri_tokens::sources::oauth2::TokenRequestError),
174    #[error("IO error occurred\n{0:?}")]
175    IoError(#[from] std::io::Error),
176}
177
178#[derive(Debug)]
179pub struct NameIndex<T>(Vec<T>);
180
181impl<T> Default for NameIndex<T> {
182    fn default() -> Self {
183        Self(Vec::new())
184    }
185}
186
187impl<T> NameIndex<T> {
188    fn extend(&mut self, portfolios: impl IntoIterator<Item = T>) {
189        for portfolio in portfolios {
190            self.push(portfolio);
191        }
192    }
193    fn push(&mut self, portfolio: T) {
194        self.0.push(portfolio);
195    }
196}
197
198impl<T: NameIndexItem> NameIndex<T> {
199    pub fn find<'a>(&'a self, name: &str) -> Option<&'a T> {
200        self.0.iter().find(|p| p.name() == name)
201    }
202
203    pub fn names<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a str> + 'a {
204        self.0.iter().map(|p| p.name())
205    }
206
207    pub fn log_error_for(&self, name: &str) {
208        eprint!("Unknown {}: {}, ", T::TYPE_NAME_SINGULAR, name);
209
210        let mut names = self.names();
211
212        match (names.next(), names.next_back()) {
213            (Some(name_start), Some(name_end)) => {
214                eprint!("the {} are: {}", T::TYPE_NAME_PLURAL, name_start);
215                for name in names {
216                    eprint!(", {}", name);
217                }
218                eprintln!(" or {}", name_end);
219            }
220            (Some(name), None) => {
221                eprintln!("the only {} is: {}", T::TYPE_NAME_SINGULAR, name);
222            }
223            (None, None) => {
224                eprintln!("there are no {}", T::TYPE_NAME_PLURAL);
225            }
226            _ => unreachable!(),
227        }
228    }
229}
230
231pub trait NameIndexItem {
232    const TYPE_NAME_SINGULAR: &'static str;
233    const TYPE_NAME_PLURAL: &'static str;
234
235    fn name(&self) -> &str;
236}
237
238impl NameIndexItem for PortfolioListPortfoliosSuccess {
239    const TYPE_NAME_SINGULAR: &'static str = "portfolio";
240    const TYPE_NAME_PLURAL: &'static str = "portfolios";
241
242    fn name(&self) -> &str {
243        &self.name
244    }
245}
246
247impl NameIndexItem for CashAccountsListCashAccountsSuccess {
248    const TYPE_NAME_SINGULAR: &'static str = "cash account";
249    const TYPE_NAME_PLURAL: &'static str = "cash accounts";
250
251    fn name(&self) -> &str {
252        &self.name
253    }
254}
255
256#[derive(Debug, Deserialize, Serialize)]
257pub struct ClientCredentials {
258    pub host: String,
259    pub client_id: aliri_tokens::ClientId,
260    pub client_secret: aliri_tokens::ClientSecret,
261}