1use std::{collections::HashMap, sync::Arc};
2
3use async_trait::async_trait;
4
5use torn_api::{
6 send::{ApiClient, ApiProvider, RequestExecutor},
7 ApiRequest, ApiResponse, ApiSelection, ResponseError,
8};
9
10use crate::{
11 ApiKey, IntoSelector, KeyAction, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage,
12 KeySelector, PoolOptions,
13};
14
15#[async_trait]
16impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
17where
18 C: ApiClient,
19 S: KeyPoolStorage + Send + Sync + 'static,
20{
21 type Error = KeyPoolError<S::Error, C::Error>;
22
23 async fn execute<A>(
24 &self,
25 client: &C,
26 mut request: ApiRequest<A>,
27 id: Option<String>,
28 ) -> Result<A::Response, Self::Error>
29 where
30 A: ApiSelection,
31 {
32 if request.comment.is_none() {
33 request.comment = self.options.comment.clone();
34 }
35 if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
36 let concrete = hook
37 .downcast_ref::<BeforeHook<A, S::Key, S::Domain>>()
38 .unwrap();
39
40 (concrete.body)(&mut request, &self.selector);
41 }
42 loop {
43 let key = self
44 .storage
45 .acquire_key(self.selector.clone())
46 .await
47 .map_err(KeyPoolError::Storage)?;
48 let url = request.url(key.value(), id.as_deref());
49 let value = client.request(url).await?;
50
51 match ApiResponse::from_value(value) {
52 Err(ResponseError::Api { code, reason }) => {
53 if !self
54 .storage
55 .flag_key(key, code)
56 .await
57 .map_err(KeyPoolError::Storage)?
58 {
59 return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
60 }
61 }
62 Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
63 Ok(res) => {
64 let res = res.into();
65 if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
66 let concrete = hook
67 .downcast_ref::<AfterHook<A, S::Key, S::Domain>>()
68 .unwrap();
69
70 match (concrete.body)(&res, &self.selector) {
71 Err(KeyAction::Delete) => {
72 self.storage
73 .remove_key(key.selector())
74 .await
75 .map_err(KeyPoolError::Storage)?;
76 continue;
77 }
78 Err(KeyAction::RemoveDomain(domain)) => {
79 self.storage
80 .remove_domain_from_key(key.selector(), domain)
81 .await
82 .map_err(KeyPoolError::Storage)?;
83 continue;
84 }
85 _ => (),
86 };
87 }
88 return Ok(res);
89 }
90 };
91 }
92 }
93
94 async fn execute_many<A, I>(
95 &self,
96 client: &C,
97 mut request: ApiRequest<A>,
98 ids: Vec<I>,
99 ) -> HashMap<I, Result<A::Response, Self::Error>>
100 where
101 A: ApiSelection,
102 I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
103 {
104 let keys = match self
105 .storage
106 .acquire_many_keys(self.selector.clone(), ids.len() as i64)
107 .await
108 {
109 Ok(keys) => keys,
110 Err(why) => {
111 return ids
112 .into_iter()
113 .map(|i| (i, Err(Self::Error::Storage(why.clone()))))
114 .collect();
115 }
116 };
117
118 if request.comment.is_none() {
119 request.comment = self.options.comment.clone();
120 }
121 let request_ref = &request;
122
123 let tuples =
124 futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
125 let id_string = id.to_string();
126 loop {
127 let url = request_ref.url(key.value(), Some(&id_string));
128 let value = match client.request(url).await {
129 Ok(v) => v,
130 Err(why) => return (id, Err(Self::Error::Client(why))),
131 };
132
133 match ApiResponse::from_value(value) {
134 Err(ResponseError::Api { code, reason }) => {
135 match self.storage.flag_key(key, code).await {
136 Ok(false) => {
137 return (
138 id,
139 Err(KeyPoolError::Response(ResponseError::Api {
140 code,
141 reason,
142 })),
143 )
144 }
145 Ok(true) => (),
146 Err(why) => return (id, Err(KeyPoolError::Storage(why))),
147 }
148 }
149 Err(parsing_error) => {
150 return (id, Err(KeyPoolError::Response(parsing_error)))
151 }
152 Ok(res) => return (id, Ok(res.into())),
153 };
154
155 key = match self.storage.acquire_key(self.selector.clone()).await {
156 Ok(k) => k,
157 Err(why) => return (id, Err(Self::Error::Storage(why))),
158 };
159 }
160 }))
161 .await;
162
163 HashMap::from_iter(tuples)
164 }
165}
166
167#[allow(clippy::type_complexity)]
168pub struct BeforeHook<A, K, D>
169where
170 A: ApiSelection,
171 K: ApiKey,
172 D: KeyDomain,
173{
174 body: Box<dyn Fn(&mut ApiRequest<A>, &KeySelector<K, D>) + Send + Sync + 'static>,
175}
176
177#[allow(clippy::type_complexity)]
178pub struct AfterHook<A, K, D>
179where
180 A: ApiSelection,
181 K: ApiKey,
182 D: KeyDomain,
183{
184 body: Box<
185 dyn Fn(&A::Response, &KeySelector<K, D>) -> Result<(), crate::KeyAction<D>>
186 + Send
187 + Sync
188 + 'static,
189 >,
190}
191
192pub struct PoolBuilder<C, S>
193where
194 C: ApiClient,
195 S: KeyPoolStorage,
196{
197 client: C,
198 storage: S,
199 options: crate::PoolOptions,
200}
201
202impl<C, S> PoolBuilder<C, S>
203where
204 C: ApiClient,
205 S: KeyPoolStorage,
206{
207 pub fn new(client: C, storage: S) -> Self {
208 Self {
209 client,
210 storage,
211 options: Default::default(),
212 }
213 }
214
215 pub fn comment(mut self, c: impl ToString) -> Self {
216 self.options.comment = Some(c.to_string());
217 self
218 }
219
220 pub fn hook_before<A>(
221 mut self,
222 hook: impl Fn(&mut ApiRequest<A>, &KeySelector<S::Key, S::Domain>) + Send + Sync + 'static,
223 ) -> Self
224 where
225 A: ApiSelection + 'static,
226 {
227 self.options.hooks_before.insert(
228 std::any::TypeId::of::<A>(),
229 Box::new(BeforeHook {
230 body: Box::new(hook),
231 }),
232 );
233 self
234 }
235
236 pub fn hook_after<A>(
237 mut self,
238 hook: impl Fn(&A::Response, &KeySelector<S::Key, S::Domain>) -> Result<(), KeyAction<S::Domain>>
239 + Send
240 + Sync
241 + 'static,
242 ) -> Self
243 where
244 A: ApiSelection + 'static,
245 {
246 self.options.hooks_after.insert(
247 std::any::TypeId::of::<A>(),
248 Box::new(AfterHook::<A, S::Key, S::Domain> {
249 body: Box::new(hook),
250 }),
251 );
252 self
253 }
254
255 pub fn build(self) -> KeyPool<C, S> {
256 KeyPool {
257 client: self.client,
258 storage: self.storage,
259 options: Arc::new(self.options),
260 }
261 }
262}
263
264#[derive(Clone, Debug)]
265pub struct KeyPool<C, S>
266where
267 C: ApiClient,
268 S: KeyPoolStorage,
269{
270 pub client: C,
271 pub storage: S,
272 pub options: Arc<PoolOptions>,
273}
274
275impl<C, S> KeyPool<C, S>
276where
277 C: ApiClient,
278 S: KeyPoolStorage + Send + Sync + 'static,
279{
280 pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
281 where
282 I: IntoSelector<S::Key, S::Domain>,
283 {
284 ApiProvider::new(
285 &self.client,
286 KeyPoolExecutor::new(
287 &self.storage,
288 selector.into_selector(),
289 self.options.clone(),
290 ),
291 )
292 }
293}
294
295pub trait WithStorage {
296 fn with_storage<'a, S, I>(
297 &'a self,
298 storage: &'a S,
299 selector: I,
300 ) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
301 where
302 Self: ApiClient + Sized,
303 S: KeyPoolStorage + Send + Sync + 'static,
304 I: IntoSelector<S::Key, S::Domain>,
305 {
306 ApiProvider::new(
307 self,
308 KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()),
309 )
310 }
311}
312
313#[cfg(feature = "reqwest")]
314impl WithStorage for reqwest::Client {}
315
316#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
317mod test {
318 use sqlx::PgPool;
319
320 use super::*;
321 use crate::{
322 postgres::test::{setup, Domain},
323 KeySelector,
324 };
325
326 #[sqlx::test]
327 async fn test_pool_request(pool: PgPool) {
328 let (storage, _) = setup(pool).await;
329 let pool = PoolBuilder::new(reqwest::Client::default(), storage)
330 .comment("api.rs")
331 .build();
332
333 let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
334 _ = response.profile().unwrap();
335 }
336
337 #[sqlx::test]
338 async fn test_with_storage_request(pool: PgPool) {
339 let (storage, _) = setup(pool).await;
340
341 let response = reqwest::Client::new()
342 .with_storage(&storage, Domain::All)
343 .user(|b| b)
344 .await
345 .unwrap();
346 _ = response.profile().unwrap();
347 }
348
349 #[sqlx::test]
350 async fn before_hook(pool: PgPool) {
351 let (storage, _) = setup(pool).await;
352
353 let pool = PoolBuilder::new(reqwest::Client::default(), storage)
354 .hook_before::<torn_api::user::UserSelection>(|req, _s| {
355 req.selections.push("crimes");
356 })
357 .build();
358
359 let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
360 _ = response.crimes().unwrap();
361 }
362
363 #[sqlx::test]
364 async fn after_hook(pool: PgPool) {
365 let (storage, _) = setup(pool).await;
366
367 let pool = PoolBuilder::new(reqwest::Client::default(), storage)
368 .hook_after::<torn_api::user::UserSelection>(|_res, _s| Err(KeyAction::Delete))
369 .build();
370
371 let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
372 assert!(key.is_some());
373
374 let response = pool.torn_api(Domain::All).user(|b| b).await;
375 assert!(matches!(response, Err(KeyPoolError::Storage(_))));
376
377 let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
378 assert!(key.is_none());
379 }
380}