torn_key_pool/
lib.rs

1#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)]
2
3#[cfg(feature = "postgres")]
4pub mod postgres;
5
6use std::{collections::HashMap, future::Future, ops::Deref, sync::Arc, time::Duration};
7
8use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use serde::Deserialize;
11use tokio_stream::StreamExt as TokioStreamExt;
12use torn_api::{
13    executor::{BulkExecutor, Executor},
14    request::{ApiRequest, ApiResponse},
15    ApiError,
16};
17
18pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
19
20impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
21
22pub trait ApiKey: Send + Sync + Clone + 'static {
23    type IdType: ApiKeyId;
24
25    fn value(&self) -> &str;
26
27    fn id(&self) -> Self::IdType;
28
29    fn selector<D>(&self) -> KeySelector<Self, D>
30    where
31        D: KeyDomain,
32    {
33        KeySelector::Id(self.id())
34    }
35}
36
37pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
38    fn fallback(&self) -> Option<Self> {
39        None
40    }
41}
42
43#[derive(Debug, Clone)]
44pub enum KeySelector<K, D>
45where
46    K: ApiKey,
47    D: KeyDomain,
48{
49    Key(String),
50    Id(K::IdType),
51    UserId(i32),
52    Has(Vec<D>),
53    OneOf(Vec<D>),
54}
55
56impl<K, D> KeySelector<K, D>
57where
58    K: ApiKey,
59    D: KeyDomain,
60{
61    pub(crate) fn fallback(&self) -> Option<Self> {
62        match self {
63            Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
64            Self::Has(domains) => {
65                let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
66                if fallbacks.is_empty() {
67                    None
68                } else {
69                    Some(Self::Has(fallbacks))
70                }
71            }
72            Self::OneOf(domains) => {
73                let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
74                if fallbacks.is_empty() {
75                    None
76                } else {
77                    Some(Self::OneOf(fallbacks))
78                }
79            }
80        }
81    }
82}
83
84impl<K, D> From<&str> for KeySelector<K, D>
85where
86    K: ApiKey,
87    D: KeyDomain,
88{
89    fn from(value: &str) -> Self {
90        Self::Key(value.to_owned())
91    }
92}
93
94impl<K, D> From<D> for KeySelector<K, D>
95where
96    K: ApiKey,
97    D: KeyDomain,
98{
99    fn from(value: D) -> Self {
100        Self::Has(vec![value])
101    }
102}
103
104impl<K, D> From<&[D]> for KeySelector<K, D>
105where
106    K: ApiKey,
107    D: KeyDomain,
108{
109    fn from(value: &[D]) -> Self {
110        Self::Has(value.to_vec())
111    }
112}
113
114impl<K, D> From<Vec<D>> for KeySelector<K, D>
115where
116    K: ApiKey,
117    D: KeyDomain,
118{
119    fn from(value: Vec<D>) -> Self {
120        Self::Has(value)
121    }
122}
123
124pub trait IntoSelector<K, D>: Send
125where
126    K: ApiKey,
127    D: KeyDomain,
128{
129    fn into_selector(self) -> KeySelector<K, D>;
130}
131
132impl<K, D, T> IntoSelector<K, D> for T
133where
134    K: ApiKey,
135    D: KeyDomain,
136    T: Into<KeySelector<K, D>> + Send,
137{
138    fn into_selector(self) -> KeySelector<K, D> {
139        self.into()
140    }
141}
142
143pub trait KeyPoolError:
144    From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + From<Arc<Self>> + Send
145{
146}
147
148impl<T> KeyPoolError for T where
149    T: From<reqwest::Error>
150        + From<serde_json::Error>
151        + From<torn_api::ApiError>
152        + From<Arc<Self>>
153        + Send
154{
155}
156
157pub trait KeyPoolStorage: Send + Sync {
158    type Key: ApiKey;
159    type Domain: KeyDomain;
160    type Error: KeyPoolError;
161
162    fn acquire_key<S>(
163        &self,
164        selector: S,
165    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
166    where
167        S: IntoSelector<Self::Key, Self::Domain>;
168
169    fn acquire_many_keys<S>(
170        &self,
171        selector: S,
172        number: i64,
173    ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
174    where
175        S: IntoSelector<Self::Key, Self::Domain>;
176
177    fn store_key(
178        &self,
179        user_id: i32,
180        key: String,
181        domains: Vec<Self::Domain>,
182    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
183
184    fn read_key<S>(
185        &self,
186        selector: S,
187    ) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
188    where
189        S: IntoSelector<Self::Key, Self::Domain>;
190
191    fn read_keys<S>(
192        &self,
193        selector: S,
194    ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
195    where
196        S: IntoSelector<Self::Key, Self::Domain>;
197
198    fn remove_key<S>(
199        &self,
200        selector: S,
201    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
202    where
203        S: IntoSelector<Self::Key, Self::Domain>;
204
205    fn add_domain_to_key<S>(
206        &self,
207        selector: S,
208        domain: Self::Domain,
209    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
210    where
211        S: IntoSelector<Self::Key, Self::Domain>;
212
213    fn remove_domain_from_key<S>(
214        &self,
215        selector: S,
216        domain: Self::Domain,
217    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
218    where
219        S: IntoSelector<Self::Key, Self::Domain>;
220
221    fn set_domains_for_key<S>(
222        &self,
223        selector: S,
224        domains: Vec<Self::Domain>,
225    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
226    where
227        S: IntoSelector<Self::Key, Self::Domain>;
228
229    fn timeout_key<S>(
230        &self,
231        selector: S,
232        duration: Duration,
233    ) -> impl Future<Output = Result<(), Self::Error>> + Send
234    where
235        S: IntoSelector<Self::Key, Self::Domain>;
236}
237
238#[derive(Default)]
239pub struct PoolOptions<S>
240where
241    S: KeyPoolStorage,
242{
243    comment: Option<String>,
244    #[allow(clippy::type_complexity)]
245    error_hooks: HashMap<
246        u16,
247        Box<
248            dyn for<'a> Fn(
249                    &'a S,
250                    &'a S::Key,
251                    &'a ApiRequest,
252                ) -> BoxFuture<'a, Result<bool, S::Error>>
253                + Send
254                + Sync,
255        >,
256    >,
257}
258
259pub struct PoolBuilder<S>
260where
261    S: KeyPoolStorage,
262{
263    client: reqwest::Client,
264    storage: S,
265    options: crate::PoolOptions<S>,
266}
267
268impl<S> PoolBuilder<S>
269where
270    S: KeyPoolStorage,
271{
272    pub fn new(storage: S) -> Self {
273        Self {
274            client: reqwest::Client::builder()
275                .brotli(true)
276                .http2_keep_alive_timeout(Duration::from_secs(60))
277                .http2_keep_alive_interval(Duration::from_secs(5))
278                .https_only(true)
279                .build()
280                .unwrap(),
281            storage,
282            options: PoolOptions {
283                comment: None,
284                error_hooks: Default::default(),
285            },
286        }
287    }
288
289    pub fn comment(mut self, c: impl ToString) -> Self {
290        self.options.comment = Some(c.to_string());
291        self
292    }
293
294    pub fn error_hook<F>(mut self, error: ApiError, handler: F) -> Self
295    where
296        F: for<'a> Fn(&'a S, &'a S::Key, &'a ApiRequest) -> BoxFuture<'a, Result<bool, S::Error>>
297            + Send
298            + Sync
299            + 'static,
300    {
301        self.options
302            .error_hooks
303            .insert(error.code(), Box::new(handler));
304
305        self
306    }
307
308    pub fn use_default_hooks(self) -> Self {
309        self.error_hook(ApiError::IncorrectKey, |storage, key, _| {
310            async move {
311                storage.remove_key(KeySelector::Id(key.id())).await?;
312                Ok(true)
313            }
314            .boxed()
315        })
316        .error_hook(ApiError::TooManyRequest, |storage, key, _| {
317            async move {
318                storage
319                    .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
320                    .await?;
321                Ok(true)
322            }
323            .boxed()
324        })
325        .error_hook(ApiError::KeyOwnerInFederalJail, |storage, key, _| {
326            async move {
327                storage.remove_key(KeySelector::Id(key.id())).await?;
328                Ok(true)
329            }
330            .boxed()
331        })
332        .error_hook(ApiError::TemporaryInactivity, |storage, key, _| {
333            async move {
334                storage
335                    .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
336                    .await?;
337                Ok(true)
338            }
339            .boxed()
340        })
341        .error_hook(ApiError::Paused, |storage, key, _| {
342            async move {
343                storage
344                    .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
345                    .await?;
346                Ok(true)
347            }
348            .boxed()
349        })
350    }
351
352    pub fn build(self) -> KeyPool<S> {
353        KeyPool {
354            inner: Arc::new(KeyPoolInner {
355                client: self.client,
356                storage: self.storage,
357                options: self.options,
358            }),
359        }
360    }
361}
362
363pub struct KeyPoolInner<S>
364where
365    S: KeyPoolStorage,
366{
367    pub client: reqwest::Client,
368    pub storage: S,
369    pub options: PoolOptions<S>,
370}
371
372impl<S> KeyPoolInner<S>
373where
374    S: KeyPoolStorage,
375{
376    async fn execute_with_key(
377        &self,
378        key: &S::Key,
379        request: &ApiRequest,
380    ) -> Result<RequestResult, S::Error> {
381        let mut headers = HeaderMap::with_capacity(1);
382        headers.insert(
383            AUTHORIZATION,
384            HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
385        );
386
387        let resp = self
388            .client
389            .get(request.url())
390            .headers(headers)
391            .send()
392            .await?;
393
394        let status = resp.status();
395
396        let bytes = resp.bytes().await?;
397
398        if let Some(err) = decode_error(&bytes)? {
399            if let Some(handler) = self.options.error_hooks.get(&err.code()) {
400                let retry = (*handler)(&self.storage, key, request).await?;
401
402                if retry {
403                    return Ok(RequestResult::Retry);
404                }
405            }
406            Err(err.into())
407        } else {
408            Ok(RequestResult::Response(ApiResponse {
409                body: Some(bytes),
410                status,
411            }))
412        }
413    }
414
415    async fn execute_request(
416        &self,
417        selector: KeySelector<S::Key, S::Domain>,
418        request: ApiRequest,
419    ) -> Result<ApiResponse, S::Error> {
420        loop {
421            let key = self.storage.acquire_key(selector.clone()).await?;
422            match self.execute_with_key(&key, &request).await {
423                Ok(RequestResult::Response(resp)) => return Ok(resp),
424                Ok(RequestResult::Retry) => (),
425                Err(why) => return Err(why),
426            }
427        }
428    }
429
430    async fn execute_bulk_requests<D, T: IntoIterator<Item = (D, ApiRequest)>>(
431        &self,
432        selector: KeySelector<S::Key, S::Domain>,
433        requests: T,
434    ) -> impl Stream<Item = (D, Result<ApiResponse, S::Error>)> + use<'_, D, S, T> {
435        let requests: Vec<_> = requests.into_iter().collect();
436
437        let keys: Vec<_> = match self
438            .storage
439            .acquire_many_keys(selector.clone(), requests.len() as i64)
440            .await
441        {
442            Ok(keys) => keys.into_iter().map(Ok).collect(),
443            Err(why) => {
444                let why = Arc::new(why);
445                std::iter::repeat_n(why, requests.len())
446                    .map(|e| Err(S::Error::from(e)))
447                    .collect()
448            }
449        };
450
451        StreamExt::map(
452            futures::stream::iter(std::iter::zip(requests, keys)),
453            move |((discriminant, request), mut maybe_key)| {
454                let selector = selector.clone();
455                async move {
456                    loop {
457                        let key = match maybe_key {
458                            Ok(key) => key,
459                            Err(why) => return (discriminant, Err(why)),
460                        };
461                        match self.execute_with_key(&key, &request).await {
462                            Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)),
463                            Ok(RequestResult::Retry) => (),
464                            Err(why) => return (discriminant, Err(why)),
465                        }
466                        maybe_key = self.storage.acquire_key(selector.clone()).await;
467                    }
468                }
469            },
470        )
471        .buffer_unordered(25)
472    }
473}
474
475pub struct KeyPool<S>
476where
477    S: KeyPoolStorage,
478{
479    inner: Arc<KeyPoolInner<S>>,
480}
481
482impl<S> Deref for KeyPool<S>
483where
484    S: KeyPoolStorage,
485{
486    type Target = KeyPoolInner<S>;
487    fn deref(&self) -> &Self::Target {
488        &self.inner
489    }
490}
491
492enum RequestResult {
493    Response(ApiResponse),
494    Retry,
495}
496
497impl<S> KeyPool<S>
498where
499    S: KeyPoolStorage + Send + Sync + 'static,
500{
501    pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<'_, S>
502    where
503        I: IntoSelector<S::Key, S::Domain>,
504    {
505        KeyPoolExecutor::new(self, selector.into_selector())
506    }
507
508    pub fn throttled_torn_api<I>(
509        &self,
510        selector: I,
511        distance: Duration,
512    ) -> ThrottledKeyPoolExecutor<'_, S>
513    where
514        I: IntoSelector<S::Key, S::Domain>,
515    {
516        ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance)
517    }
518}
519
520fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
521    if buf.starts_with(br#"{"error":{"#) {
522        #[derive(Deserialize)]
523        struct ErrorBody<'a> {
524            code: u16,
525            error: &'a str,
526        }
527        #[derive(Deserialize)]
528        struct ErrorContainer<'a> {
529            #[serde(borrow)]
530            error: ErrorBody<'a>,
531        }
532
533        let error: ErrorContainer = serde_json::from_slice(buf)?;
534        Ok(Some(crate::ApiError::new(
535            error.error.code,
536            error.error.error,
537        )))
538    } else {
539        Ok(None)
540    }
541}
542
543pub struct KeyPoolExecutor<'p, S>
544where
545    S: KeyPoolStorage,
546{
547    pool: &'p KeyPoolInner<S>,
548    selector: KeySelector<S::Key, S::Domain>,
549}
550
551impl<'p, S> KeyPoolExecutor<'p, S>
552where
553    S: KeyPoolStorage,
554{
555    pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
556        Self {
557            pool: &pool.inner,
558            selector,
559        }
560    }
561}
562
563impl<S> Executor for KeyPoolExecutor<'_, S>
564where
565    S: KeyPoolStorage + 'static,
566{
567    type Error = S::Error;
568
569    async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
570    where
571        R: torn_api::request::IntoRequest,
572    {
573        let (d, request) = request.into_request();
574
575        (d, self.pool.execute_request(self.selector, request).await)
576    }
577}
578
579impl<S> BulkExecutor for KeyPoolExecutor<'_, S>
580where
581    S: KeyPoolStorage + 'static,
582{
583    type Error = S::Error;
584
585    fn execute<R>(
586        self,
587        requests: impl IntoIterator<Item = R>,
588    ) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
589    where
590        R: torn_api::request::IntoRequest,
591    {
592        let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
593        self.pool
594            .execute_bulk_requests(self.selector.clone(), requests)
595            .into_stream()
596            .flatten()
597            .boxed()
598    }
599}
600
601pub struct ThrottledKeyPoolExecutor<'p, S>
602where
603    S: KeyPoolStorage,
604{
605    pool: &'p KeyPoolInner<S>,
606    selector: KeySelector<S::Key, S::Domain>,
607    distance: Duration,
608}
609
610impl<S> Clone for ThrottledKeyPoolExecutor<'_, S>
611where
612    S: KeyPoolStorage,
613{
614    fn clone(&self) -> Self {
615        Self {
616            pool: self.pool,
617            selector: self.selector.clone(),
618            distance: self.distance,
619        }
620    }
621}
622
623impl<S> ThrottledKeyPoolExecutor<'_, S>
624where
625    S: KeyPoolStorage,
626{
627    async fn execute_request(self, request: ApiRequest) -> Result<ApiResponse, S::Error> {
628        self.pool.execute_request(self.selector, request).await
629    }
630}
631
632impl<'p, S> ThrottledKeyPoolExecutor<'p, S>
633where
634    S: KeyPoolStorage,
635{
636    pub fn new(
637        pool: &'p KeyPool<S>,
638        selector: KeySelector<S::Key, S::Domain>,
639        distance: Duration,
640    ) -> Self {
641        Self {
642            pool: &pool.inner,
643            selector,
644            distance,
645        }
646    }
647}
648
649impl<S> BulkExecutor for ThrottledKeyPoolExecutor<'_, S>
650where
651    S: KeyPoolStorage + 'static,
652{
653    type Error = S::Error;
654
655    fn execute<R>(
656        self,
657        requests: impl IntoIterator<Item = R>,
658    ) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
659    where
660        R: torn_api::request::IntoRequest,
661    {
662        let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
663        StreamExt::map(
664            futures::stream::iter(requests).throttle(self.distance),
665            move |(d, request)| {
666                let this = self.clone();
667                async move {
668                    let result = this.execute_request(request).await;
669                    (d, result)
670                }
671            },
672        )
673        .buffer_unordered(25)
674        .boxed()
675    }
676}
677
678#[cfg(test)]
679#[cfg(feature = "postgres")]
680mod test {
681    use torn_api::executor::{BulkExecutorExt, ExecutorExt};
682
683    use crate::postgres;
684
685    use super::*;
686
687    #[sqlx::test]
688    fn name(pool: sqlx::PgPool) {
689        let (storage, _) = postgres::test::setup(pool).await;
690
691        let pool = PoolBuilder::new(storage)
692            .use_default_hooks()
693            .comment("test_runner")
694            .build();
695
696        pool.torn_api(postgres::test::Domain::All)
697            .faction()
698            .basic(|b| b)
699            .await
700            .unwrap();
701    }
702
703    #[sqlx::test]
704    fn bulk(pool: sqlx::PgPool) {
705        let (storage, _) = postgres::test::setup(pool).await;
706
707        let pool = PoolBuilder::new(storage)
708            .use_default_hooks()
709            .comment("test_runner")
710            .build();
711
712        let responses = pool
713            .torn_api(postgres::test::Domain::All)
714            .faction_bulk()
715            .basic_for_id(vec![19.into(), 89.into()], |b| b);
716        let mut responses: Vec<_> = StreamExt::collect(responses).await;
717
718        let (_id1, basic1) = responses.pop().unwrap();
719        basic1.unwrap();
720
721        let (_id2, basic2) = responses.pop().unwrap();
722        basic2.unwrap();
723    }
724
725    #[sqlx::test]
726    fn bulk_trottled(pool: sqlx::PgPool) {
727        let (storage, _) = postgres::test::setup(pool).await;
728
729        let pool = PoolBuilder::new(storage)
730            .use_default_hooks()
731            .comment("test_runner")
732            .build();
733
734        let responses = pool
735            .throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500))
736            .faction_bulk()
737            .basic_for_id(vec![19.into(), 89.into()], |b| b);
738        let mut responses: Vec<_> = StreamExt::collect(responses).await;
739
740        let (_id1, basic1) = responses.pop().unwrap();
741        basic1.unwrap();
742
743        let (_id2, basic2) = responses.pop().unwrap();
744        basic2.unwrap();
745    }
746}