torn_key_pool/
send.rs

1use std::{collections::HashMap, sync::Arc};
2
3use async_trait::async_trait;
4
5use torn_api::{
6    send::{ApiClient, ApiProvider, RequestExecutor},
7    ApiRequest, ApiResponse, ApiSelection, ResponseError,
8};
9
10use crate::{
11    ApiKey, IntoSelector, KeyAction, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage,
12    KeySelector, PoolOptions,
13};
14
15#[async_trait]
16impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
17where
18    C: ApiClient,
19    S: KeyPoolStorage + Send + Sync + 'static,
20{
21    type Error = KeyPoolError<S::Error, C::Error>;
22
23    async fn execute<A>(
24        &self,
25        client: &C,
26        mut request: ApiRequest<A>,
27        id: Option<String>,
28    ) -> Result<A::Response, Self::Error>
29    where
30        A: ApiSelection,
31    {
32        if request.comment.is_none() {
33            request.comment = self.options.comment.clone();
34        }
35        if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
36            let concrete = hook
37                .downcast_ref::<BeforeHook<A, S::Key, S::Domain>>()
38                .unwrap();
39
40            (concrete.body)(&mut request, &self.selector);
41        }
42        loop {
43            let key = self
44                .storage
45                .acquire_key(self.selector.clone())
46                .await
47                .map_err(KeyPoolError::Storage)?;
48            let url = request.url(key.value(), id.as_deref());
49            let value = client.request(url).await?;
50
51            match ApiResponse::from_value(value) {
52                Err(ResponseError::Api { code, reason }) => {
53                    if !self
54                        .storage
55                        .flag_key(key, code)
56                        .await
57                        .map_err(KeyPoolError::Storage)?
58                    {
59                        return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
60                    }
61                }
62                Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
63                Ok(res) => {
64                    let res = res.into();
65                    if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
66                        let concrete = hook
67                            .downcast_ref::<AfterHook<A, S::Key, S::Domain>>()
68                            .unwrap();
69
70                        match (concrete.body)(&res, &self.selector) {
71                            Err(KeyAction::Delete) => {
72                                self.storage
73                                    .remove_key(key.selector())
74                                    .await
75                                    .map_err(KeyPoolError::Storage)?;
76                                continue;
77                            }
78                            Err(KeyAction::RemoveDomain(domain)) => {
79                                self.storage
80                                    .remove_domain_from_key(key.selector(), domain)
81                                    .await
82                                    .map_err(KeyPoolError::Storage)?;
83                                continue;
84                            }
85                            _ => (),
86                        };
87                    }
88                    return Ok(res);
89                }
90            };
91        }
92    }
93
94    async fn execute_many<A, I>(
95        &self,
96        client: &C,
97        mut request: ApiRequest<A>,
98        ids: Vec<I>,
99    ) -> HashMap<I, Result<A::Response, Self::Error>>
100    where
101        A: ApiSelection,
102        I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
103    {
104        let keys = match self
105            .storage
106            .acquire_many_keys(self.selector.clone(), ids.len() as i64)
107            .await
108        {
109            Ok(keys) => keys,
110            Err(why) => {
111                return ids
112                    .into_iter()
113                    .map(|i| (i, Err(Self::Error::Storage(why.clone()))))
114                    .collect();
115            }
116        };
117
118        if request.comment.is_none() {
119            request.comment = self.options.comment.clone();
120        }
121        let request_ref = &request;
122
123        let tuples =
124            futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
125                let id_string = id.to_string();
126                loop {
127                    let url = request_ref.url(key.value(), Some(&id_string));
128                    let value = match client.request(url).await {
129                        Ok(v) => v,
130                        Err(why) => return (id, Err(Self::Error::Client(why))),
131                    };
132
133                    match ApiResponse::from_value(value) {
134                        Err(ResponseError::Api { code, reason }) => {
135                            match self.storage.flag_key(key, code).await {
136                                Ok(false) => {
137                                    return (
138                                        id,
139                                        Err(KeyPoolError::Response(ResponseError::Api {
140                                            code,
141                                            reason,
142                                        })),
143                                    )
144                                }
145                                Ok(true) => (),
146                                Err(why) => return (id, Err(KeyPoolError::Storage(why))),
147                            }
148                        }
149                        Err(parsing_error) => {
150                            return (id, Err(KeyPoolError::Response(parsing_error)))
151                        }
152                        Ok(res) => return (id, Ok(res.into())),
153                    };
154
155                    key = match self.storage.acquire_key(self.selector.clone()).await {
156                        Ok(k) => k,
157                        Err(why) => return (id, Err(Self::Error::Storage(why))),
158                    };
159                }
160            }))
161            .await;
162
163        HashMap::from_iter(tuples)
164    }
165}
166
167#[allow(clippy::type_complexity)]
168pub struct BeforeHook<A, K, D>
169where
170    A: ApiSelection,
171    K: ApiKey,
172    D: KeyDomain,
173{
174    body: Box<dyn Fn(&mut ApiRequest<A>, &KeySelector<K, D>) + Send + Sync + 'static>,
175}
176
177#[allow(clippy::type_complexity)]
178pub struct AfterHook<A, K, D>
179where
180    A: ApiSelection,
181    K: ApiKey,
182    D: KeyDomain,
183{
184    body: Box<
185        dyn Fn(&A::Response, &KeySelector<K, D>) -> Result<(), crate::KeyAction<D>>
186            + Send
187            + Sync
188            + 'static,
189    >,
190}
191
192pub struct PoolBuilder<C, S>
193where
194    C: ApiClient,
195    S: KeyPoolStorage,
196{
197    client: C,
198    storage: S,
199    options: crate::PoolOptions,
200}
201
202impl<C, S> PoolBuilder<C, S>
203where
204    C: ApiClient,
205    S: KeyPoolStorage,
206{
207    pub fn new(client: C, storage: S) -> Self {
208        Self {
209            client,
210            storage,
211            options: Default::default(),
212        }
213    }
214
215    pub fn comment(mut self, c: impl ToString) -> Self {
216        self.options.comment = Some(c.to_string());
217        self
218    }
219
220    pub fn hook_before<A>(
221        mut self,
222        hook: impl Fn(&mut ApiRequest<A>, &KeySelector<S::Key, S::Domain>) + Send + Sync + 'static,
223    ) -> Self
224    where
225        A: ApiSelection + 'static,
226    {
227        self.options.hooks_before.insert(
228            std::any::TypeId::of::<A>(),
229            Box::new(BeforeHook {
230                body: Box::new(hook),
231            }),
232        );
233        self
234    }
235
236    pub fn hook_after<A>(
237        mut self,
238        hook: impl Fn(&A::Response, &KeySelector<S::Key, S::Domain>) -> Result<(), KeyAction<S::Domain>>
239            + Send
240            + Sync
241            + 'static,
242    ) -> Self
243    where
244        A: ApiSelection + 'static,
245    {
246        self.options.hooks_after.insert(
247            std::any::TypeId::of::<A>(),
248            Box::new(AfterHook::<A, S::Key, S::Domain> {
249                body: Box::new(hook),
250            }),
251        );
252        self
253    }
254
255    pub fn build(self) -> KeyPool<C, S> {
256        KeyPool {
257            client: self.client,
258            storage: self.storage,
259            options: Arc::new(self.options),
260        }
261    }
262}
263
264#[derive(Clone, Debug)]
265pub struct KeyPool<C, S>
266where
267    C: ApiClient,
268    S: KeyPoolStorage,
269{
270    pub client: C,
271    pub storage: S,
272    pub options: Arc<PoolOptions>,
273}
274
275impl<C, S> KeyPool<C, S>
276where
277    C: ApiClient,
278    S: KeyPoolStorage + Send + Sync + 'static,
279{
280    pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
281    where
282        I: IntoSelector<S::Key, S::Domain>,
283    {
284        ApiProvider::new(
285            &self.client,
286            KeyPoolExecutor::new(
287                &self.storage,
288                selector.into_selector(),
289                self.options.clone(),
290            ),
291        )
292    }
293}
294
295pub trait WithStorage {
296    fn with_storage<'a, S, I>(
297        &'a self,
298        storage: &'a S,
299        selector: I,
300    ) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
301    where
302        Self: ApiClient + Sized,
303        S: KeyPoolStorage + Send + Sync + 'static,
304        I: IntoSelector<S::Key, S::Domain>,
305    {
306        ApiProvider::new(
307            self,
308            KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()),
309        )
310    }
311}
312
313#[cfg(feature = "reqwest")]
314impl WithStorage for reqwest::Client {}
315
316#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
317mod test {
318    use sqlx::PgPool;
319
320    use super::*;
321    use crate::{
322        postgres::test::{setup, Domain},
323        KeySelector,
324    };
325
326    #[sqlx::test]
327    async fn test_pool_request(pool: PgPool) {
328        let (storage, _) = setup(pool).await;
329        let pool = PoolBuilder::new(reqwest::Client::default(), storage)
330            .comment("api.rs")
331            .build();
332
333        let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
334        _ = response.profile().unwrap();
335    }
336
337    #[sqlx::test]
338    async fn test_with_storage_request(pool: PgPool) {
339        let (storage, _) = setup(pool).await;
340
341        let response = reqwest::Client::new()
342            .with_storage(&storage, Domain::All)
343            .user(|b| b)
344            .await
345            .unwrap();
346        _ = response.profile().unwrap();
347    }
348
349    #[sqlx::test]
350    async fn before_hook(pool: PgPool) {
351        let (storage, _) = setup(pool).await;
352
353        let pool = PoolBuilder::new(reqwest::Client::default(), storage)
354            .hook_before::<torn_api::user::UserSelection>(|req, _s| {
355                req.selections.push("crimes");
356            })
357            .build();
358
359        let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
360        _ = response.crimes().unwrap();
361    }
362
363    #[sqlx::test]
364    async fn after_hook(pool: PgPool) {
365        let (storage, _) = setup(pool).await;
366
367        let pool = PoolBuilder::new(reqwest::Client::default(), storage)
368            .hook_after::<torn_api::user::UserSelection>(|_res, _s| Err(KeyAction::Delete))
369            .build();
370
371        let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
372        assert!(key.is_some());
373
374        let response = pool.torn_api(Domain::All).user(|b| b).await;
375        assert!(matches!(response, Err(KeyPoolError::Storage(_))));
376
377        let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
378        assert!(key.is_none());
379    }
380}