torn_api/
executor.rs

1use std::future::Future;
2
3use futures::{Stream, StreamExt};
4use http::{header::AUTHORIZATION, HeaderMap, HeaderValue};
5use serde::Deserialize;
6
7use crate::request::{ApiRequest, ApiResponse, IntoRequest};
8#[cfg(feature = "scopes")]
9use crate::scopes::{
10    BulkFactionScope, BulkForumScope, BulkMarketScope, BulkRacingScope, BulkTornScope,
11    BulkUserScope, FactionScope, ForumScope, MarketScope, RacingScope, TornScope, UserScope,
12};
13
14pub trait Executor: Sized {
15    type Error: From<serde_json::Error> + From<crate::ApiError> + Send;
16
17    fn execute<R>(
18        self,
19        request: R,
20    ) -> impl Future<Output = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Send
21    where
22        R: IntoRequest;
23
24    fn fetch<R>(self, request: R) -> impl Future<Output = Result<R::Response, Self::Error>> + Send
25    where
26        R: IntoRequest,
27    {
28        // HACK: workaround for not using `async` in trait declaration.
29        // The future is `Send` but `&self` might not be.
30        let fut = self.execute(request);
31        async {
32            let resp = fut.await.1?;
33
34            let bytes = resp.body.unwrap();
35
36            if bytes.starts_with(br#"{"error":{"#) {
37                #[derive(Deserialize)]
38                struct ErrorBody<'a> {
39                    code: u16,
40                    error: &'a str,
41                }
42                #[derive(Deserialize)]
43                struct ErrorContainer<'a> {
44                    #[serde(borrow)]
45                    error: ErrorBody<'a>,
46                }
47
48                let error: ErrorContainer = serde_json::from_slice(&bytes)?;
49                return Err(crate::ApiError::new(error.error.code, error.error.error).into());
50            }
51
52            let resp = serde_json::from_slice(&bytes)?;
53
54            Ok(resp)
55        }
56    }
57}
58
59pub trait BulkExecutor<'e>: 'e + Sized {
60    type Error: From<serde_json::Error> + From<crate::ApiError> + Send;
61
62    fn execute<R>(
63        self,
64        requests: impl IntoIterator<Item = R>,
65    ) -> impl Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
66    where
67        R: IntoRequest;
68
69    fn fetch_many<R>(
70        self,
71        requests: impl IntoIterator<Item = R>,
72    ) -> impl Stream<Item = (R::Discriminant, Result<R::Response, Self::Error>)>
73    where
74        R: IntoRequest,
75    {
76        self.execute(requests).map(|(d, r)| {
77            let r = match r {
78                Ok(r) => r,
79                Err(why) => return (d, Err(why)),
80            };
81            let bytes = r.body.unwrap();
82
83            if bytes.starts_with(br#"{"error":{"#) {
84                #[derive(Deserialize)]
85                struct ErrorBody<'a> {
86                    code: u16,
87                    error: &'a str,
88                }
89                #[derive(Deserialize)]
90                struct ErrorContainer<'a> {
91                    #[serde(borrow)]
92                    error: ErrorBody<'a>,
93                }
94
95                let error: ErrorContainer = match serde_json::from_slice(&bytes) {
96                    Ok(error) => error,
97                    Err(why) => return (d, Err(why.into())),
98                };
99                return (
100                    d,
101                    Err(crate::ApiError::new(error.error.code, error.error.error).into()),
102                );
103            }
104
105            let resp = match serde_json::from_slice(&bytes) {
106                Ok(resp) => resp,
107                Err(why) => return (d, Err(why.into())),
108            };
109
110            (d, Ok(resp))
111        })
112    }
113}
114
115#[cfg(feature = "scopes")]
116pub trait ExecutorExt: Executor + Sized {
117    fn user(self) -> UserScope<Self>;
118
119    fn faction(self) -> FactionScope<Self>;
120
121    fn torn(self) -> TornScope<Self>;
122
123    fn market(self) -> MarketScope<Self>;
124
125    fn racing(self) -> RacingScope<Self>;
126
127    fn forum(self) -> ForumScope<Self>;
128}
129
130#[cfg(feature = "scopes")]
131impl<T> ExecutorExt for T
132where
133    T: Executor + Sized,
134{
135    fn user(self) -> UserScope<Self> {
136        UserScope::new(self)
137    }
138
139    fn faction(self) -> FactionScope<Self> {
140        FactionScope::new(self)
141    }
142
143    fn torn(self) -> TornScope<Self> {
144        TornScope::new(self)
145    }
146
147    fn market(self) -> MarketScope<Self> {
148        MarketScope::new(self)
149    }
150
151    fn racing(self) -> RacingScope<Self> {
152        RacingScope::new(self)
153    }
154
155    fn forum(self) -> ForumScope<Self> {
156        ForumScope::new(self)
157    }
158}
159
160#[cfg(feature = "scopes")]
161pub trait BulkExecutorExt<'e>: BulkExecutor<'e> + Sized {
162    fn user_bulk(self) -> BulkUserScope<'e, Self>;
163
164    fn faction_bulk(self) -> BulkFactionScope<'e, Self>;
165
166    fn torn_bulk(self) -> BulkTornScope<'e, Self>;
167
168    fn market_bulk(self) -> BulkMarketScope<'e, Self>;
169
170    fn racing_bulk(self) -> BulkRacingScope<'e, Self>;
171
172    fn forum_bulk(self) -> BulkForumScope<'e, Self>;
173}
174
175#[cfg(feature = "scopes")]
176impl<'e, T> BulkExecutorExt<'e> for T
177where
178    T: BulkExecutor<'e> + Sized,
179{
180    fn user_bulk(self) -> BulkUserScope<'e, Self> {
181        BulkUserScope::new(self)
182    }
183
184    fn faction_bulk(self) -> BulkFactionScope<'e, Self> {
185        BulkFactionScope::new(self)
186    }
187
188    fn torn_bulk(self) -> BulkTornScope<'e, Self> {
189        BulkTornScope::new(self)
190    }
191
192    fn market_bulk(self) -> BulkMarketScope<'e, Self> {
193        BulkMarketScope::new(self)
194    }
195
196    fn racing_bulk(self) -> BulkRacingScope<'e, Self> {
197        BulkRacingScope::new(self)
198    }
199
200    fn forum_bulk(self) -> BulkForumScope<'e, Self> {
201        BulkForumScope::new(self)
202    }
203}
204
205pub struct ReqwestClient(reqwest::Client);
206
207impl ReqwestClient {
208    pub fn new(api_key: &str) -> Self {
209        let mut headers = HeaderMap::with_capacity(1);
210        headers.insert(
211            AUTHORIZATION,
212            HeaderValue::from_str(&format!("ApiKey {api_key}")).unwrap(),
213        );
214
215        let client = reqwest::Client::builder()
216            .default_headers(headers)
217            .brotli(true)
218            .build()
219            .unwrap();
220
221        Self(client)
222    }
223}
224
225impl ReqwestClient {
226    async fn execute_api_request(&self, request: ApiRequest) -> Result<ApiResponse, crate::Error> {
227        let url = request.url();
228
229        let response = self.0.get(url).send().await?;
230        let status = response.status();
231        let body = response.bytes().await.ok();
232
233        Ok(ApiResponse { status, body })
234    }
235}
236
237impl Executor for &ReqwestClient {
238    type Error = crate::Error;
239
240    async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
241    where
242        R: IntoRequest,
243    {
244        let (d, request) = request.into_request();
245        (d, self.execute_api_request(request).await)
246    }
247}
248
249impl<'e> BulkExecutor<'e> for &'e ReqwestClient {
250    type Error = crate::Error;
251
252    fn execute<R>(
253        self,
254        requests: impl IntoIterator<Item = R>,
255    ) -> impl Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
256    where
257        R: IntoRequest,
258    {
259        futures::stream::iter(requests)
260            .map(move |r| <Self as Executor>::execute(self, r))
261            .buffer_unordered(25)
262    }
263}
264
265#[cfg(test)]
266mod test {
267    use crate::{scopes::test::test_client, ApiError, Error};
268
269    use super::*;
270
271    #[cfg(feature = "scopes")]
272    #[tokio::test]
273    async fn api_error() {
274        let client = test_client().await;
275
276        let resp = client.faction().basic_for_id((-1).into(), |b| b).await;
277
278        match resp {
279            Err(Error::Api(ApiError::IncorrectIdEntityRelation)) => (),
280            other => panic!("Expected incorrect id entity relation error, got {other:?}"),
281        }
282    }
283
284    #[cfg(feature = "scopes")]
285    #[tokio::test]
286    async fn bulk_request() {
287        let client = test_client().await;
288
289        let stream = client
290            .faction_bulk()
291            .basic_for_id(vec![19.into(), 89.into()], |b| b);
292
293        let mut responses: Vec<_> = stream.collect().await;
294
295        let (_id1, basic1) = responses.pop().unwrap();
296        basic1.unwrap();
297
298        let (_id2, basic2) = responses.pop().unwrap();
299        basic2.unwrap();
300    }
301}