1#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)]
2
3#[cfg(feature = "postgres")]
4pub mod postgres;
5
6use std::{collections::HashMap, future::Future, ops::Deref, sync::Arc, time::Duration};
7
8use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use serde::Deserialize;
11use tokio_stream::StreamExt as TokioStreamExt;
12use torn_api::{
13 executor::{BulkExecutor, Executor},
14 request::{ApiRequest, ApiResponse},
15 ApiError,
16};
17
18pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
19
20impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
21
22pub trait ApiKey: Send + Sync + Clone + 'static {
23 type IdType: ApiKeyId;
24
25 fn value(&self) -> &str;
26
27 fn id(&self) -> Self::IdType;
28
29 fn selector<D>(&self) -> KeySelector<Self, D>
30 where
31 D: KeyDomain,
32 {
33 KeySelector::Id(self.id())
34 }
35}
36
37pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
38 fn fallback(&self) -> Option<Self> {
39 None
40 }
41}
42
43#[derive(Debug, Clone)]
44pub enum KeySelector<K, D>
45where
46 K: ApiKey,
47 D: KeyDomain,
48{
49 Key(String),
50 Id(K::IdType),
51 UserId(i32),
52 Has(Vec<D>),
53 OneOf(Vec<D>),
54}
55
56impl<K, D> KeySelector<K, D>
57where
58 K: ApiKey,
59 D: KeyDomain,
60{
61 pub(crate) fn fallback(&self) -> Option<Self> {
62 match self {
63 Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
64 Self::Has(domains) => {
65 let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
66 if fallbacks.is_empty() {
67 None
68 } else {
69 Some(Self::Has(fallbacks))
70 }
71 }
72 Self::OneOf(domains) => {
73 let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
74 if fallbacks.is_empty() {
75 None
76 } else {
77 Some(Self::OneOf(fallbacks))
78 }
79 }
80 }
81 }
82}
83
84impl<K, D> From<&str> for KeySelector<K, D>
85where
86 K: ApiKey,
87 D: KeyDomain,
88{
89 fn from(value: &str) -> Self {
90 Self::Key(value.to_owned())
91 }
92}
93
94impl<K, D> From<D> for KeySelector<K, D>
95where
96 K: ApiKey,
97 D: KeyDomain,
98{
99 fn from(value: D) -> Self {
100 Self::Has(vec![value])
101 }
102}
103
104impl<K, D> From<&[D]> for KeySelector<K, D>
105where
106 K: ApiKey,
107 D: KeyDomain,
108{
109 fn from(value: &[D]) -> Self {
110 Self::Has(value.to_vec())
111 }
112}
113
114impl<K, D> From<Vec<D>> for KeySelector<K, D>
115where
116 K: ApiKey,
117 D: KeyDomain,
118{
119 fn from(value: Vec<D>) -> Self {
120 Self::Has(value)
121 }
122}
123
124pub trait IntoSelector<K, D>: Send
125where
126 K: ApiKey,
127 D: KeyDomain,
128{
129 fn into_selector(self) -> KeySelector<K, D>;
130}
131
132impl<K, D, T> IntoSelector<K, D> for T
133where
134 K: ApiKey,
135 D: KeyDomain,
136 T: Into<KeySelector<K, D>> + Send,
137{
138 fn into_selector(self) -> KeySelector<K, D> {
139 self.into()
140 }
141}
142
143pub trait KeyPoolError:
144 From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + From<Arc<Self>> + Send
145{
146}
147
148impl<T> KeyPoolError for T where
149 T: From<reqwest::Error>
150 + From<serde_json::Error>
151 + From<torn_api::ApiError>
152 + From<Arc<Self>>
153 + Send
154{
155}
156
157pub trait KeyPoolStorage: Send + Sync {
158 type Key: ApiKey;
159 type Domain: KeyDomain;
160 type Error: KeyPoolError;
161
162 fn acquire_key<S>(
163 &self,
164 selector: S,
165 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
166 where
167 S: IntoSelector<Self::Key, Self::Domain>;
168
169 fn acquire_many_keys<S>(
170 &self,
171 selector: S,
172 number: i64,
173 ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
174 where
175 S: IntoSelector<Self::Key, Self::Domain>;
176
177 fn store_key(
178 &self,
179 user_id: i32,
180 key: String,
181 domains: Vec<Self::Domain>,
182 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
183
184 fn read_key<S>(
185 &self,
186 selector: S,
187 ) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
188 where
189 S: IntoSelector<Self::Key, Self::Domain>;
190
191 fn read_keys<S>(
192 &self,
193 selector: S,
194 ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
195 where
196 S: IntoSelector<Self::Key, Self::Domain>;
197
198 fn remove_key<S>(
199 &self,
200 selector: S,
201 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
202 where
203 S: IntoSelector<Self::Key, Self::Domain>;
204
205 fn add_domain_to_key<S>(
206 &self,
207 selector: S,
208 domain: Self::Domain,
209 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
210 where
211 S: IntoSelector<Self::Key, Self::Domain>;
212
213 fn remove_domain_from_key<S>(
214 &self,
215 selector: S,
216 domain: Self::Domain,
217 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
218 where
219 S: IntoSelector<Self::Key, Self::Domain>;
220
221 fn set_domains_for_key<S>(
222 &self,
223 selector: S,
224 domains: Vec<Self::Domain>,
225 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
226 where
227 S: IntoSelector<Self::Key, Self::Domain>;
228
229 fn timeout_key<S>(
230 &self,
231 selector: S,
232 duration: Duration,
233 ) -> impl Future<Output = Result<(), Self::Error>> + Send
234 where
235 S: IntoSelector<Self::Key, Self::Domain>;
236}
237
238#[derive(Default)]
239pub struct PoolOptions<S>
240where
241 S: KeyPoolStorage,
242{
243 comment: Option<String>,
244 #[allow(clippy::type_complexity)]
245 error_hooks: HashMap<
246 u16,
247 Box<
248 dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
249 + Send
250 + Sync,
251 >,
252 >,
253}
254
255pub struct PoolBuilder<S>
256where
257 S: KeyPoolStorage,
258{
259 client: reqwest::Client,
260 storage: S,
261 options: crate::PoolOptions<S>,
262}
263
264impl<S> PoolBuilder<S>
265where
266 S: KeyPoolStorage,
267{
268 pub fn new(storage: S) -> Self {
269 Self {
270 client: reqwest::Client::builder()
271 .brotli(true)
272 .http2_keep_alive_timeout(Duration::from_secs(60))
273 .http2_keep_alive_interval(Duration::from_secs(5))
274 .https_only(true)
275 .build()
276 .unwrap(),
277 storage,
278 options: PoolOptions {
279 comment: None,
280 error_hooks: Default::default(),
281 },
282 }
283 }
284
285 pub fn comment(mut self, c: impl ToString) -> Self {
286 self.options.comment = Some(c.to_string());
287 self
288 }
289
290 pub fn error_hook<F>(mut self, code: u16, handler: F) -> Self
291 where
292 F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
293 + Send
294 + Sync
295 + 'static,
296 {
297 self.options.error_hooks.insert(code, Box::new(handler));
298
299 self
300 }
301
302 pub fn use_default_hooks(self) -> Self {
303 self.error_hook(2, |storage, key| {
304 async move {
305 storage.remove_key(KeySelector::Id(key.id())).await?;
306 Ok(true)
307 }
308 .boxed()
309 })
310 .error_hook(5, |storage, key| {
311 async move {
312 storage
313 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
314 .await?;
315 Ok(true)
316 }
317 .boxed()
318 })
319 .error_hook(10, |storage, key| {
320 async move {
321 storage.remove_key(KeySelector::Id(key.id())).await?;
322 Ok(true)
323 }
324 .boxed()
325 })
326 .error_hook(13, |storage, key| {
327 async move {
328 storage
329 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
330 .await?;
331 Ok(true)
332 }
333 .boxed()
334 })
335 .error_hook(18, |storage, key| {
336 async move {
337 storage
338 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
339 .await?;
340 Ok(true)
341 }
342 .boxed()
343 })
344 }
345
346 pub fn build(self) -> KeyPool<S> {
347 KeyPool {
348 inner: Arc::new(KeyPoolInner {
349 client: self.client,
350 storage: self.storage,
351 options: self.options,
352 }),
353 }
354 }
355}
356
357pub struct KeyPoolInner<S>
358where
359 S: KeyPoolStorage,
360{
361 pub client: reqwest::Client,
362 pub storage: S,
363 pub options: PoolOptions<S>,
364}
365
366impl<S> KeyPoolInner<S>
367where
368 S: KeyPoolStorage,
369{
370 async fn execute_with_key(
371 &self,
372 key: &S::Key,
373 request: &ApiRequest,
374 ) -> Result<RequestResult, S::Error> {
375 let mut headers = HeaderMap::with_capacity(1);
376 headers.insert(
377 AUTHORIZATION,
378 HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
379 );
380
381 let resp = self
382 .client
383 .get(request.url())
384 .headers(headers)
385 .send()
386 .await?;
387
388 let status = resp.status();
389
390 let bytes = resp.bytes().await?;
391
392 if let Some(err) = decode_error(&bytes)? {
393 if let Some(handler) = self.options.error_hooks.get(&err.code()) {
394 let retry = (*handler)(&self.storage, key).await?;
395
396 if retry {
397 return Ok(RequestResult::Retry);
398 }
399 }
400 Err(err.into())
401 } else {
402 Ok(RequestResult::Response(ApiResponse {
403 body: Some(bytes),
404 status,
405 }))
406 }
407 }
408
409 async fn execute_request(
410 &self,
411 selector: KeySelector<S::Key, S::Domain>,
412 request: ApiRequest,
413 ) -> Result<ApiResponse, S::Error> {
414 loop {
415 let key = self.storage.acquire_key(selector.clone()).await?;
416 match self.execute_with_key(&key, &request).await {
417 Ok(RequestResult::Response(resp)) => return Ok(resp),
418 Ok(RequestResult::Retry) => (),
419 Err(why) => return Err(why),
420 }
421 }
422 }
423
424 async fn execute_bulk_requests<D, T: IntoIterator<Item = (D, ApiRequest)>>(
425 &self,
426 selector: KeySelector<S::Key, S::Domain>,
427 requests: T,
428 ) -> impl Stream<Item = (D, Result<ApiResponse, S::Error>)> + use<'_, D, S, T> {
429 let requests: Vec<_> = requests.into_iter().collect();
430
431 let keys: Vec<_> = match self
432 .storage
433 .acquire_many_keys(selector.clone(), requests.len() as i64)
434 .await
435 {
436 Ok(keys) => keys.into_iter().map(Ok).collect(),
437 Err(why) => {
438 let why = Arc::new(why);
439 std::iter::repeat_n(why, requests.len())
440 .map(|e| Err(S::Error::from(e)))
441 .collect()
442 }
443 };
444
445 StreamExt::map(
446 futures::stream::iter(std::iter::zip(requests, keys)),
447 move |((discriminant, request), mut maybe_key)| {
448 let selector = selector.clone();
449 async move {
450 loop {
451 let key = match maybe_key {
452 Ok(key) => key,
453 Err(why) => return (discriminant, Err(why)),
454 };
455 match self.execute_with_key(&key, &request).await {
456 Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)),
457 Ok(RequestResult::Retry) => (),
458 Err(why) => return (discriminant, Err(why)),
459 }
460 maybe_key = self.storage.acquire_key(selector.clone()).await;
461 }
462 }
463 },
464 )
465 .buffer_unordered(25)
466 }
467}
468
469pub struct KeyPool<S>
470where
471 S: KeyPoolStorage,
472{
473 inner: Arc<KeyPoolInner<S>>,
474}
475
476impl<S> Deref for KeyPool<S>
477where
478 S: KeyPoolStorage,
479{
480 type Target = KeyPoolInner<S>;
481 fn deref(&self) -> &Self::Target {
482 &self.inner
483 }
484}
485
486enum RequestResult {
487 Response(ApiResponse),
488 Retry,
489}
490
491impl<S> KeyPool<S>
492where
493 S: KeyPoolStorage + Send + Sync + 'static,
494{
495 pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<S>
496 where
497 I: IntoSelector<S::Key, S::Domain>,
498 {
499 KeyPoolExecutor::new(self, selector.into_selector())
500 }
501
502 pub fn throttled_torn_api<I>(
503 &self,
504 selector: I,
505 distance: Duration,
506 ) -> ThrottledKeyPoolExecutor<S>
507 where
508 I: IntoSelector<S::Key, S::Domain>,
509 {
510 ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance)
511 }
512}
513
514fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
515 if buf.starts_with(br#"{"error":{"#) {
516 #[derive(Deserialize)]
517 struct ErrorBody<'a> {
518 code: u16,
519 error: &'a str,
520 }
521 #[derive(Deserialize)]
522 struct ErrorContainer<'a> {
523 #[serde(borrow)]
524 error: ErrorBody<'a>,
525 }
526
527 let error: ErrorContainer = serde_json::from_slice(buf)?;
528 Ok(Some(crate::ApiError::new(
529 error.error.code,
530 error.error.error,
531 )))
532 } else {
533 Ok(None)
534 }
535}
536
537pub struct KeyPoolExecutor<'p, S>
538where
539 S: KeyPoolStorage,
540{
541 pool: &'p KeyPoolInner<S>,
542 selector: KeySelector<S::Key, S::Domain>,
543}
544
545impl<'p, S> KeyPoolExecutor<'p, S>
546where
547 S: KeyPoolStorage,
548{
549 pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
550 Self {
551 pool: &pool.inner,
552 selector,
553 }
554 }
555}
556
557impl<S> Executor for KeyPoolExecutor<'_, S>
558where
559 S: KeyPoolStorage + 'static,
560{
561 type Error = S::Error;
562
563 async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
564 where
565 R: torn_api::request::IntoRequest,
566 {
567 let (d, request) = request.into_request();
568
569 (d, self.pool.execute_request(self.selector, request).await)
570 }
571}
572
573impl<S> BulkExecutor for KeyPoolExecutor<'_, S>
574where
575 S: KeyPoolStorage + 'static,
576{
577 type Error = S::Error;
578
579 fn execute<R>(
580 self,
581 requests: impl IntoIterator<Item = R>,
582 ) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
583 where
584 R: torn_api::request::IntoRequest,
585 {
586 let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
587 self.pool
588 .execute_bulk_requests(self.selector.clone(), requests)
589 .into_stream()
590 .flatten()
591 .boxed()
592 }
593}
594
595pub struct ThrottledKeyPoolExecutor<'p, S>
596where
597 S: KeyPoolStorage,
598{
599 pool: &'p KeyPoolInner<S>,
600 selector: KeySelector<S::Key, S::Domain>,
601 distance: Duration,
602}
603
604impl<S> Clone for ThrottledKeyPoolExecutor<'_, S>
605where
606 S: KeyPoolStorage,
607{
608 fn clone(&self) -> Self {
609 Self {
610 pool: self.pool,
611 selector: self.selector.clone(),
612 distance: self.distance,
613 }
614 }
615}
616
617impl<S> ThrottledKeyPoolExecutor<'_, S>
618where
619 S: KeyPoolStorage,
620{
621 async fn execute_request(self, request: ApiRequest) -> Result<ApiResponse, S::Error> {
622 self.pool.execute_request(self.selector, request).await
623 }
624}
625
626impl<'p, S> ThrottledKeyPoolExecutor<'p, S>
627where
628 S: KeyPoolStorage,
629{
630 pub fn new(
631 pool: &'p KeyPool<S>,
632 selector: KeySelector<S::Key, S::Domain>,
633 distance: Duration,
634 ) -> Self {
635 Self {
636 pool: &pool.inner,
637 selector,
638 distance,
639 }
640 }
641}
642
643impl<S> BulkExecutor for ThrottledKeyPoolExecutor<'_, S>
644where
645 S: KeyPoolStorage + 'static,
646{
647 type Error = S::Error;
648
649 fn execute<R>(
650 self,
651 requests: impl IntoIterator<Item = R>,
652 ) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
653 where
654 R: torn_api::request::IntoRequest,
655 {
656 let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
657 StreamExt::map(
658 futures::stream::iter(requests).throttle(self.distance),
659 move |(d, request)| {
660 let this = self.clone();
661 async move {
662 let result = this.execute_request(request).await;
663 (d, result)
664 }
665 },
666 )
667 .buffer_unordered(25)
668 .boxed()
669 }
670}
671
672#[cfg(test)]
673#[cfg(feature = "postgres")]
674mod test {
675 use torn_api::executor::{BulkExecutorExt, ExecutorExt};
676
677 use crate::postgres;
678
679 use super::*;
680
681 #[sqlx::test]
682 fn name(pool: sqlx::PgPool) {
683 let (storage, _) = postgres::test::setup(pool).await;
684
685 let pool = PoolBuilder::new(storage)
686 .use_default_hooks()
687 .comment("test_runner")
688 .build();
689
690 pool.torn_api(postgres::test::Domain::All)
691 .faction()
692 .basic(|b| b)
693 .await
694 .unwrap();
695 }
696
697 #[sqlx::test]
698 fn bulk(pool: sqlx::PgPool) {
699 let (storage, _) = postgres::test::setup(pool).await;
700
701 let pool = PoolBuilder::new(storage)
702 .use_default_hooks()
703 .comment("test_runner")
704 .build();
705
706 let responses = pool
707 .torn_api(postgres::test::Domain::All)
708 .faction_bulk()
709 .basic_for_id(vec![19.into(), 89.into()], |b| b);
710 let mut responses: Vec<_> = StreamExt::collect(responses).await;
711
712 let (_id1, basic1) = responses.pop().unwrap();
713 basic1.unwrap();
714
715 let (_id2, basic2) = responses.pop().unwrap();
716 basic2.unwrap();
717 }
718
719 #[sqlx::test]
720 fn bulk_trottled(pool: sqlx::PgPool) {
721 let (storage, _) = postgres::test::setup(pool).await;
722
723 let pool = PoolBuilder::new(storage)
724 .use_default_hooks()
725 .comment("test_runner")
726 .build();
727
728 let responses = pool
729 .throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500))
730 .faction_bulk()
731 .basic_for_id(vec![19.into(), 89.into()], |b| b);
732 let mut responses: Vec<_> = StreamExt::collect(responses).await;
733
734 let (_id1, basic1) = responses.pop().unwrap();
735 basic1.unwrap();
736
737 let (_id2, basic2) = responses.pop().unwrap();
738 basic2.unwrap();
739 }
740}