1#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)]
2
3#[cfg(feature = "postgres")]
4pub mod postgres;
5
6use std::{collections::HashMap, future::Future, ops::Deref, sync::Arc, time::Duration};
7
8use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use serde::Deserialize;
11use tokio_stream::StreamExt as TokioStreamExt;
12use torn_api::{
13 executor::{BulkExecutor, Executor},
14 request::{ApiRequest, ApiResponse},
15 ApiError,
16};
17
18pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
19
20impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
21
22pub trait ApiKey: Send + Sync + Clone + 'static {
23 type IdType: ApiKeyId;
24
25 fn value(&self) -> &str;
26
27 fn id(&self) -> Self::IdType;
28
29 fn selector<D>(&self) -> KeySelector<Self, D>
30 where
31 D: KeyDomain,
32 {
33 KeySelector::Id(self.id())
34 }
35}
36
37pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
38 fn fallback(&self) -> Option<Self> {
39 None
40 }
41}
42
43#[derive(Debug, Clone)]
44pub enum KeySelector<K, D>
45where
46 K: ApiKey,
47 D: KeyDomain,
48{
49 Key(String),
50 Id(K::IdType),
51 UserId(i32),
52 Has(Vec<D>),
53 OneOf(Vec<D>),
54}
55
56impl<K, D> KeySelector<K, D>
57where
58 K: ApiKey,
59 D: KeyDomain,
60{
61 pub(crate) fn fallback(&self) -> Option<Self> {
62 match self {
63 Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
64 Self::Has(domains) => {
65 let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
66 if fallbacks.is_empty() {
67 None
68 } else {
69 Some(Self::Has(fallbacks))
70 }
71 }
72 Self::OneOf(domains) => {
73 let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
74 if fallbacks.is_empty() {
75 None
76 } else {
77 Some(Self::OneOf(fallbacks))
78 }
79 }
80 }
81 }
82}
83
84impl<K, D> From<&str> for KeySelector<K, D>
85where
86 K: ApiKey,
87 D: KeyDomain,
88{
89 fn from(value: &str) -> Self {
90 Self::Key(value.to_owned())
91 }
92}
93
94impl<K, D> From<D> for KeySelector<K, D>
95where
96 K: ApiKey,
97 D: KeyDomain,
98{
99 fn from(value: D) -> Self {
100 Self::Has(vec![value])
101 }
102}
103
104impl<K, D> From<&[D]> for KeySelector<K, D>
105where
106 K: ApiKey,
107 D: KeyDomain,
108{
109 fn from(value: &[D]) -> Self {
110 Self::Has(value.to_vec())
111 }
112}
113
114impl<K, D> From<Vec<D>> for KeySelector<K, D>
115where
116 K: ApiKey,
117 D: KeyDomain,
118{
119 fn from(value: Vec<D>) -> Self {
120 Self::Has(value)
121 }
122}
123
124pub trait IntoSelector<K, D>: Send
125where
126 K: ApiKey,
127 D: KeyDomain,
128{
129 fn into_selector(self) -> KeySelector<K, D>;
130}
131
132impl<K, D, T> IntoSelector<K, D> for T
133where
134 K: ApiKey,
135 D: KeyDomain,
136 T: Into<KeySelector<K, D>> + Send,
137{
138 fn into_selector(self) -> KeySelector<K, D> {
139 self.into()
140 }
141}
142
143pub trait KeyPoolError:
144 From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + From<Arc<Self>> + Send
145{
146}
147
148impl<T> KeyPoolError for T where
149 T: From<reqwest::Error>
150 + From<serde_json::Error>
151 + From<torn_api::ApiError>
152 + From<Arc<Self>>
153 + Send
154{
155}
156
157pub trait KeyPoolStorage: Send + Sync {
158 type Key: ApiKey;
159 type Domain: KeyDomain;
160 type Error: KeyPoolError;
161
162 fn acquire_key<S>(
163 &self,
164 selector: S,
165 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
166 where
167 S: IntoSelector<Self::Key, Self::Domain>;
168
169 fn acquire_many_keys<S>(
170 &self,
171 selector: S,
172 number: i64,
173 ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
174 where
175 S: IntoSelector<Self::Key, Self::Domain>;
176
177 fn store_key(
178 &self,
179 user_id: i32,
180 key: String,
181 domains: Vec<Self::Domain>,
182 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
183
184 fn read_key<S>(
185 &self,
186 selector: S,
187 ) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
188 where
189 S: IntoSelector<Self::Key, Self::Domain>;
190
191 fn read_keys<S>(
192 &self,
193 selector: S,
194 ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
195 where
196 S: IntoSelector<Self::Key, Self::Domain>;
197
198 fn remove_key<S>(
199 &self,
200 selector: S,
201 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
202 where
203 S: IntoSelector<Self::Key, Self::Domain>;
204
205 fn add_domain_to_key<S>(
206 &self,
207 selector: S,
208 domain: Self::Domain,
209 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
210 where
211 S: IntoSelector<Self::Key, Self::Domain>;
212
213 fn remove_domain_from_key<S>(
214 &self,
215 selector: S,
216 domain: Self::Domain,
217 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
218 where
219 S: IntoSelector<Self::Key, Self::Domain>;
220
221 fn set_domains_for_key<S>(
222 &self,
223 selector: S,
224 domains: Vec<Self::Domain>,
225 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
226 where
227 S: IntoSelector<Self::Key, Self::Domain>;
228
229 fn timeout_key<S>(
230 &self,
231 selector: S,
232 duration: Duration,
233 ) -> impl Future<Output = Result<(), Self::Error>> + Send
234 where
235 S: IntoSelector<Self::Key, Self::Domain>;
236}
237
238#[derive(Default)]
239pub struct PoolOptions<S>
240where
241 S: KeyPoolStorage,
242{
243 comment: Option<String>,
244 #[allow(clippy::type_complexity)]
245 error_hooks: HashMap<
246 u16,
247 Box<
248 dyn for<'a> Fn(
249 &'a S,
250 &'a S::Key,
251 &'a ApiRequest,
252 ) -> BoxFuture<'a, Result<bool, S::Error>>
253 + Send
254 + Sync,
255 >,
256 >,
257}
258
259pub struct PoolBuilder<S>
260where
261 S: KeyPoolStorage,
262{
263 client: reqwest::Client,
264 storage: S,
265 options: crate::PoolOptions<S>,
266}
267
268impl<S> PoolBuilder<S>
269where
270 S: KeyPoolStorage,
271{
272 pub fn new(storage: S) -> Self {
273 Self {
274 client: reqwest::Client::builder()
275 .brotli(true)
276 .http2_keep_alive_timeout(Duration::from_secs(60))
277 .http2_keep_alive_interval(Duration::from_secs(5))
278 .https_only(true)
279 .build()
280 .unwrap(),
281 storage,
282 options: PoolOptions {
283 comment: None,
284 error_hooks: Default::default(),
285 },
286 }
287 }
288
289 pub fn comment(mut self, c: impl ToString) -> Self {
290 self.options.comment = Some(c.to_string());
291 self
292 }
293
294 pub fn error_hook<F>(mut self, error: ApiError, handler: F) -> Self
295 where
296 F: for<'a> Fn(&'a S, &'a S::Key, &'a ApiRequest) -> BoxFuture<'a, Result<bool, S::Error>>
297 + Send
298 + Sync
299 + 'static,
300 {
301 self.options
302 .error_hooks
303 .insert(error.code(), Box::new(handler));
304
305 self
306 }
307
308 pub fn use_default_hooks(self) -> Self {
309 self.error_hook(ApiError::IncorrectKey, |storage, key, _| {
310 async move {
311 storage.remove_key(KeySelector::Id(key.id())).await?;
312 Ok(true)
313 }
314 .boxed()
315 })
316 .error_hook(ApiError::TooManyRequest, |storage, key, _| {
317 async move {
318 storage
319 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
320 .await?;
321 Ok(true)
322 }
323 .boxed()
324 })
325 .error_hook(ApiError::KeyOwnerInFederalJail, |storage, key, _| {
326 async move {
327 storage.remove_key(KeySelector::Id(key.id())).await?;
328 Ok(true)
329 }
330 .boxed()
331 })
332 .error_hook(ApiError::TemporaryInactivity, |storage, key, _| {
333 async move {
334 storage
335 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
336 .await?;
337 Ok(true)
338 }
339 .boxed()
340 })
341 .error_hook(ApiError::Paused, |storage, key, _| {
342 async move {
343 storage
344 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
345 .await?;
346 Ok(true)
347 }
348 .boxed()
349 })
350 }
351
352 pub fn build(self) -> KeyPool<S> {
353 KeyPool {
354 inner: Arc::new(KeyPoolInner {
355 client: self.client,
356 storage: self.storage,
357 options: self.options,
358 }),
359 }
360 }
361}
362
363pub struct KeyPoolInner<S>
364where
365 S: KeyPoolStorage,
366{
367 pub client: reqwest::Client,
368 pub storage: S,
369 pub options: PoolOptions<S>,
370}
371
372impl<S> KeyPoolInner<S>
373where
374 S: KeyPoolStorage,
375{
376 async fn execute_with_key(
377 &self,
378 key: &S::Key,
379 request: &ApiRequest,
380 ) -> Result<RequestResult, S::Error> {
381 let mut headers = HeaderMap::with_capacity(1);
382 headers.insert(
383 AUTHORIZATION,
384 HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
385 );
386
387 let resp = self
388 .client
389 .get(request.url())
390 .headers(headers)
391 .send()
392 .await?;
393
394 let status = resp.status();
395
396 let bytes = resp.bytes().await?;
397
398 if let Some(err) = decode_error(&bytes)? {
399 if let Some(handler) = self.options.error_hooks.get(&err.code()) {
400 let retry = (*handler)(&self.storage, key, request).await?;
401
402 if retry {
403 return Ok(RequestResult::Retry);
404 }
405 }
406 Err(err.into())
407 } else {
408 Ok(RequestResult::Response(ApiResponse {
409 body: Some(bytes),
410 status,
411 }))
412 }
413 }
414
415 async fn execute_request(
416 &self,
417 selector: KeySelector<S::Key, S::Domain>,
418 request: ApiRequest,
419 ) -> Result<ApiResponse, S::Error> {
420 loop {
421 let key = self.storage.acquire_key(selector.clone()).await?;
422 match self.execute_with_key(&key, &request).await {
423 Ok(RequestResult::Response(resp)) => return Ok(resp),
424 Ok(RequestResult::Retry) => (),
425 Err(why) => return Err(why),
426 }
427 }
428 }
429
430 async fn execute_bulk_requests<D, T: IntoIterator<Item = (D, ApiRequest)>>(
431 &self,
432 selector: KeySelector<S::Key, S::Domain>,
433 requests: T,
434 ) -> impl Stream<Item = (D, Result<ApiResponse, S::Error>)> + use<'_, D, S, T> {
435 let requests: Vec<_> = requests.into_iter().collect();
436
437 let keys: Vec<_> = match self
438 .storage
439 .acquire_many_keys(selector.clone(), requests.len() as i64)
440 .await
441 {
442 Ok(keys) => keys.into_iter().map(Ok).collect(),
443 Err(why) => {
444 let why = Arc::new(why);
445 std::iter::repeat_n(why, requests.len())
446 .map(|e| Err(S::Error::from(e)))
447 .collect()
448 }
449 };
450
451 StreamExt::map(
452 futures::stream::iter(std::iter::zip(requests, keys)),
453 move |((discriminant, request), mut maybe_key)| {
454 let selector = selector.clone();
455 async move {
456 loop {
457 let key = match maybe_key {
458 Ok(key) => key,
459 Err(why) => return (discriminant, Err(why)),
460 };
461 match self.execute_with_key(&key, &request).await {
462 Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)),
463 Ok(RequestResult::Retry) => (),
464 Err(why) => return (discriminant, Err(why)),
465 }
466 maybe_key = self.storage.acquire_key(selector.clone()).await;
467 }
468 }
469 },
470 )
471 .buffer_unordered(25)
472 }
473}
474
475pub struct KeyPool<S>
476where
477 S: KeyPoolStorage,
478{
479 inner: Arc<KeyPoolInner<S>>,
480}
481
482impl<S> Deref for KeyPool<S>
483where
484 S: KeyPoolStorage,
485{
486 type Target = KeyPoolInner<S>;
487 fn deref(&self) -> &Self::Target {
488 &self.inner
489 }
490}
491
492enum RequestResult {
493 Response(ApiResponse),
494 Retry,
495}
496
497impl<S> KeyPool<S>
498where
499 S: KeyPoolStorage + Send + Sync + 'static,
500{
501 pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<'_, S>
502 where
503 I: IntoSelector<S::Key, S::Domain>,
504 {
505 KeyPoolExecutor::new(self, selector.into_selector())
506 }
507
508 pub fn throttled_torn_api<I>(
509 &self,
510 selector: I,
511 distance: Duration,
512 ) -> ThrottledKeyPoolExecutor<'_, S>
513 where
514 I: IntoSelector<S::Key, S::Domain>,
515 {
516 ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance)
517 }
518}
519
520fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
521 if buf.starts_with(br#"{"error":{"#) {
522 #[derive(Deserialize)]
523 struct ErrorBody<'a> {
524 code: u16,
525 error: &'a str,
526 }
527 #[derive(Deserialize)]
528 struct ErrorContainer<'a> {
529 #[serde(borrow)]
530 error: ErrorBody<'a>,
531 }
532
533 let error: ErrorContainer = serde_json::from_slice(buf)?;
534 Ok(Some(crate::ApiError::new(
535 error.error.code,
536 error.error.error,
537 )))
538 } else {
539 Ok(None)
540 }
541}
542
543pub struct KeyPoolExecutor<'p, S>
544where
545 S: KeyPoolStorage,
546{
547 pool: &'p KeyPoolInner<S>,
548 selector: KeySelector<S::Key, S::Domain>,
549}
550
551impl<'p, S> KeyPoolExecutor<'p, S>
552where
553 S: KeyPoolStorage,
554{
555 pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
556 Self {
557 pool: &pool.inner,
558 selector,
559 }
560 }
561}
562
563impl<S> Executor for KeyPoolExecutor<'_, S>
564where
565 S: KeyPoolStorage + 'static,
566{
567 type Error = S::Error;
568
569 async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
570 where
571 R: torn_api::request::IntoRequest,
572 {
573 let (d, request) = request.into_request();
574
575 (d, self.pool.execute_request(self.selector, request).await)
576 }
577}
578
579impl<S> BulkExecutor for KeyPoolExecutor<'_, S>
580where
581 S: KeyPoolStorage + 'static,
582{
583 type Error = S::Error;
584
585 fn execute<R>(
586 self,
587 requests: impl IntoIterator<Item = R>,
588 ) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
589 where
590 R: torn_api::request::IntoRequest,
591 {
592 let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
593 self.pool
594 .execute_bulk_requests(self.selector.clone(), requests)
595 .into_stream()
596 .flatten()
597 .boxed()
598 }
599}
600
601pub struct ThrottledKeyPoolExecutor<'p, S>
602where
603 S: KeyPoolStorage,
604{
605 pool: &'p KeyPoolInner<S>,
606 selector: KeySelector<S::Key, S::Domain>,
607 distance: Duration,
608}
609
610impl<S> Clone for ThrottledKeyPoolExecutor<'_, S>
611where
612 S: KeyPoolStorage,
613{
614 fn clone(&self) -> Self {
615 Self {
616 pool: self.pool,
617 selector: self.selector.clone(),
618 distance: self.distance,
619 }
620 }
621}
622
623impl<S> ThrottledKeyPoolExecutor<'_, S>
624where
625 S: KeyPoolStorage,
626{
627 async fn execute_request(self, request: ApiRequest) -> Result<ApiResponse, S::Error> {
628 self.pool.execute_request(self.selector, request).await
629 }
630}
631
632impl<'p, S> ThrottledKeyPoolExecutor<'p, S>
633where
634 S: KeyPoolStorage,
635{
636 pub fn new(
637 pool: &'p KeyPool<S>,
638 selector: KeySelector<S::Key, S::Domain>,
639 distance: Duration,
640 ) -> Self {
641 Self {
642 pool: &pool.inner,
643 selector,
644 distance,
645 }
646 }
647}
648
649impl<S> BulkExecutor for ThrottledKeyPoolExecutor<'_, S>
650where
651 S: KeyPoolStorage + 'static,
652{
653 type Error = S::Error;
654
655 fn execute<R>(
656 self,
657 requests: impl IntoIterator<Item = R>,
658 ) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
659 where
660 R: torn_api::request::IntoRequest,
661 {
662 let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
663 StreamExt::map(
664 futures::stream::iter(requests).throttle(self.distance),
665 move |(d, request)| {
666 let this = self.clone();
667 async move {
668 let result = this.execute_request(request).await;
669 (d, result)
670 }
671 },
672 )
673 .buffer_unordered(25)
674 .boxed()
675 }
676}
677
678#[cfg(test)]
679#[cfg(feature = "postgres")]
680mod test {
681 use torn_api::executor::{BulkExecutorExt, ExecutorExt};
682
683 use crate::postgres;
684
685 use super::*;
686
687 #[sqlx::test]
688 fn name(pool: sqlx::PgPool) {
689 let (storage, _) = postgres::test::setup(pool).await;
690
691 let pool = PoolBuilder::new(storage)
692 .use_default_hooks()
693 .comment("test_runner")
694 .build();
695
696 pool.torn_api(postgres::test::Domain::All)
697 .faction()
698 .basic(|b| b)
699 .await
700 .unwrap();
701 }
702
703 #[sqlx::test]
704 fn bulk(pool: sqlx::PgPool) {
705 let (storage, _) = postgres::test::setup(pool).await;
706
707 let pool = PoolBuilder::new(storage)
708 .use_default_hooks()
709 .comment("test_runner")
710 .build();
711
712 let responses = pool
713 .torn_api(postgres::test::Domain::All)
714 .faction_bulk()
715 .basic_for_id(vec![19.into(), 89.into()], |b| b);
716 let mut responses: Vec<_> = StreamExt::collect(responses).await;
717
718 let (_id1, basic1) = responses.pop().unwrap();
719 basic1.unwrap();
720
721 let (_id2, basic2) = responses.pop().unwrap();
722 basic2.unwrap();
723 }
724
725 #[sqlx::test]
726 fn bulk_trottled(pool: sqlx::PgPool) {
727 let (storage, _) = postgres::test::setup(pool).await;
728
729 let pool = PoolBuilder::new(storage)
730 .use_default_hooks()
731 .comment("test_runner")
732 .build();
733
734 let responses = pool
735 .throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500))
736 .faction_bulk()
737 .basic_for_id(vec![19.into(), 89.into()], |b| b);
738 let mut responses: Vec<_> = StreamExt::collect(responses).await;
739
740 let (_id1, basic1) = responses.pop().unwrap();
741 basic1.unwrap();
742
743 let (_id2, basic2) = responses.pop().unwrap();
744 basic2.unwrap();
745 }
746}