torn_key_pool/
lib.rs

1#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)]
2
3#[cfg(feature = "postgres")]
4pub mod postgres;
5
6use std::{collections::HashMap, future::Future, ops::Deref, sync::Arc, time::Duration};
7
8use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use serde::Deserialize;
11use tokio_stream::StreamExt as TokioStreamExt;
12use torn_api::{
13    executor::{BulkExecutor, Executor},
14    request::{ApiRequest, ApiResponse},
15    ApiError,
16};
17
18pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
19
20impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
21
22pub trait ApiKey: Send + Sync + Clone + 'static {
23    type IdType: ApiKeyId;
24
25    fn value(&self) -> &str;
26
27    fn id(&self) -> Self::IdType;
28
29    fn selector<D>(&self) -> KeySelector<Self, D>
30    where
31        D: KeyDomain,
32    {
33        KeySelector::Id(self.id())
34    }
35}
36
37pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
38    fn fallback(&self) -> Option<Self> {
39        None
40    }
41}
42
43#[derive(Debug, Clone)]
44pub enum KeySelector<K, D>
45where
46    K: ApiKey,
47    D: KeyDomain,
48{
49    Key(String),
50    Id(K::IdType),
51    UserId(i32),
52    Has(Vec<D>),
53    OneOf(Vec<D>),
54}
55
56impl<K, D> KeySelector<K, D>
57where
58    K: ApiKey,
59    D: KeyDomain,
60{
61    pub(crate) fn fallback(&self) -> Option<Self> {
62        match self {
63            Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
64            Self::Has(domains) => {
65                let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
66                if fallbacks.is_empty() {
67                    None
68                } else {
69                    Some(Self::Has(fallbacks))
70                }
71            }
72            Self::OneOf(domains) => {
73                let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
74                if fallbacks.is_empty() {
75                    None
76                } else {
77                    Some(Self::OneOf(fallbacks))
78                }
79            }
80        }
81    }
82}
83
84impl<K, D> From<&str> for KeySelector<K, D>
85where
86    K: ApiKey,
87    D: KeyDomain,
88{
89    fn from(value: &str) -> Self {
90        Self::Key(value.to_owned())
91    }
92}
93
94impl<K, D> From<D> for KeySelector<K, D>
95where
96    K: ApiKey,
97    D: KeyDomain,
98{
99    fn from(value: D) -> Self {
100        Self::Has(vec![value])
101    }
102}
103
104impl<K, D> From<&[D]> for KeySelector<K, D>
105where
106    K: ApiKey,
107    D: KeyDomain,
108{
109    fn from(value: &[D]) -> Self {
110        Self::Has(value.to_vec())
111    }
112}
113
114impl<K, D> From<Vec<D>> for KeySelector<K, D>
115where
116    K: ApiKey,
117    D: KeyDomain,
118{
119    fn from(value: Vec<D>) -> Self {
120        Self::Has(value)
121    }
122}
123
124pub trait IntoSelector<K, D>: Send
125where
126    K: ApiKey,
127    D: KeyDomain,
128{
129    fn into_selector(self) -> KeySelector<K, D>;
130}
131
132impl<K, D, T> IntoSelector<K, D> for T
133where
134    K: ApiKey,
135    D: KeyDomain,
136    T: Into<KeySelector<K, D>> + Send,
137{
138    fn into_selector(self) -> KeySelector<K, D> {
139        self.into()
140    }
141}
142
143pub trait KeyPoolError:
144    From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + From<Arc<Self>> + Send
145{
146}
147
148impl<T> KeyPoolError for T where
149    T: From<reqwest::Error>
150        + From<serde_json::Error>
151        + From<torn_api::ApiError>
152        + From<Arc<Self>>
153        + Send
154{
155}
156
157pub trait KeyPoolStorage: Send + Sync {
158    type Key: ApiKey;
159    type Domain: KeyDomain;
160    type Error: KeyPoolError;
161
162    fn acquire_key<S>(
163        &self,
164        selector: S,
165    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
166    where
167        S: IntoSelector<Self::Key, Self::Domain>;
168
169    fn acquire_many_keys<S>(
170        &self,
171        selector: S,
172        number: i64,
173    ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
174    where
175        S: IntoSelector<Self::Key, Self::Domain>;
176
177    fn store_key(
178        &self,
179        user_id: i32,
180        key: String,
181        domains: Vec<Self::Domain>,
182    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
183
184    fn read_key<S>(
185        &self,
186        selector: S,
187    ) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
188    where
189        S: IntoSelector<Self::Key, Self::Domain>;
190
191    fn read_keys<S>(
192        &self,
193        selector: S,
194    ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
195    where
196        S: IntoSelector<Self::Key, Self::Domain>;
197
198    fn remove_key<S>(
199        &self,
200        selector: S,
201    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
202    where
203        S: IntoSelector<Self::Key, Self::Domain>;
204
205    fn add_domain_to_key<S>(
206        &self,
207        selector: S,
208        domain: Self::Domain,
209    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
210    where
211        S: IntoSelector<Self::Key, Self::Domain>;
212
213    fn remove_domain_from_key<S>(
214        &self,
215        selector: S,
216        domain: Self::Domain,
217    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
218    where
219        S: IntoSelector<Self::Key, Self::Domain>;
220
221    fn set_domains_for_key<S>(
222        &self,
223        selector: S,
224        domains: Vec<Self::Domain>,
225    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
226    where
227        S: IntoSelector<Self::Key, Self::Domain>;
228
229    fn timeout_key<S>(
230        &self,
231        selector: S,
232        duration: Duration,
233    ) -> impl Future<Output = Result<(), Self::Error>> + Send
234    where
235        S: IntoSelector<Self::Key, Self::Domain>;
236}
237
238#[derive(Default)]
239pub struct PoolOptions<S>
240where
241    S: KeyPoolStorage,
242{
243    comment: Option<String>,
244    #[allow(clippy::type_complexity)]
245    error_hooks: HashMap<
246        u16,
247        Box<
248            dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
249                + Send
250                + Sync,
251        >,
252    >,
253}
254
255pub struct PoolBuilder<S>
256where
257    S: KeyPoolStorage,
258{
259    client: reqwest::Client,
260    storage: S,
261    options: crate::PoolOptions<S>,
262}
263
264impl<S> PoolBuilder<S>
265where
266    S: KeyPoolStorage,
267{
268    pub fn new(storage: S) -> Self {
269        Self {
270            client: reqwest::Client::builder()
271                .brotli(true)
272                .http2_keep_alive_timeout(Duration::from_secs(60))
273                .http2_keep_alive_interval(Duration::from_secs(5))
274                .https_only(true)
275                .build()
276                .unwrap(),
277            storage,
278            options: PoolOptions {
279                comment: None,
280                error_hooks: Default::default(),
281            },
282        }
283    }
284
285    pub fn comment(mut self, c: impl ToString) -> Self {
286        self.options.comment = Some(c.to_string());
287        self
288    }
289
290    pub fn error_hook<F>(mut self, code: u16, handler: F) -> Self
291    where
292        F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
293            + Send
294            + Sync
295            + 'static,
296    {
297        self.options.error_hooks.insert(code, Box::new(handler));
298
299        self
300    }
301
302    pub fn use_default_hooks(self) -> Self {
303        self.error_hook(2, |storage, key| {
304            async move {
305                storage.remove_key(KeySelector::Id(key.id())).await?;
306                Ok(true)
307            }
308            .boxed()
309        })
310        .error_hook(5, |storage, key| {
311            async move {
312                storage
313                    .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
314                    .await?;
315                Ok(true)
316            }
317            .boxed()
318        })
319        .error_hook(10, |storage, key| {
320            async move {
321                storage.remove_key(KeySelector::Id(key.id())).await?;
322                Ok(true)
323            }
324            .boxed()
325        })
326        .error_hook(13, |storage, key| {
327            async move {
328                storage
329                    .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
330                    .await?;
331                Ok(true)
332            }
333            .boxed()
334        })
335        .error_hook(18, |storage, key| {
336            async move {
337                storage
338                    .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
339                    .await?;
340                Ok(true)
341            }
342            .boxed()
343        })
344    }
345
346    pub fn build(self) -> KeyPool<S> {
347        KeyPool {
348            inner: Arc::new(KeyPoolInner {
349                client: self.client,
350                storage: self.storage,
351                options: self.options,
352            }),
353        }
354    }
355}
356
357pub struct KeyPoolInner<S>
358where
359    S: KeyPoolStorage,
360{
361    pub client: reqwest::Client,
362    pub storage: S,
363    pub options: PoolOptions<S>,
364}
365
366impl<S> KeyPoolInner<S>
367where
368    S: KeyPoolStorage,
369{
370    async fn execute_with_key(
371        &self,
372        key: &S::Key,
373        request: &ApiRequest,
374    ) -> Result<RequestResult, S::Error> {
375        let mut headers = HeaderMap::with_capacity(1);
376        headers.insert(
377            AUTHORIZATION,
378            HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
379        );
380
381        let resp = self
382            .client
383            .get(request.url())
384            .headers(headers)
385            .send()
386            .await?;
387
388        let status = resp.status();
389
390        let bytes = resp.bytes().await?;
391
392        if let Some(err) = decode_error(&bytes)? {
393            if let Some(handler) = self.options.error_hooks.get(&err.code()) {
394                let retry = (*handler)(&self.storage, key).await?;
395
396                if retry {
397                    return Ok(RequestResult::Retry);
398                }
399            }
400            Err(err.into())
401        } else {
402            Ok(RequestResult::Response(ApiResponse {
403                body: Some(bytes),
404                status,
405            }))
406        }
407    }
408
409    async fn execute_request(
410        &self,
411        selector: KeySelector<S::Key, S::Domain>,
412        request: ApiRequest,
413    ) -> Result<ApiResponse, S::Error> {
414        loop {
415            let key = self.storage.acquire_key(selector.clone()).await?;
416            match self.execute_with_key(&key, &request).await {
417                Ok(RequestResult::Response(resp)) => return Ok(resp),
418                Ok(RequestResult::Retry) => (),
419                Err(why) => return Err(why),
420            }
421        }
422    }
423
424    async fn execute_bulk_requests<D, T: IntoIterator<Item = (D, ApiRequest)>>(
425        &self,
426        selector: KeySelector<S::Key, S::Domain>,
427        requests: T,
428    ) -> impl Stream<Item = (D, Result<ApiResponse, S::Error>)> + use<'_, D, S, T> {
429        let requests: Vec<_> = requests.into_iter().collect();
430
431        let keys: Vec<_> = match self
432            .storage
433            .acquire_many_keys(selector.clone(), requests.len() as i64)
434            .await
435        {
436            Ok(keys) => keys.into_iter().map(Ok).collect(),
437            Err(why) => {
438                let why = Arc::new(why);
439                std::iter::repeat_n(why, requests.len())
440                    .map(|e| Err(S::Error::from(e)))
441                    .collect()
442            }
443        };
444
445        StreamExt::map(
446            futures::stream::iter(std::iter::zip(requests, keys)),
447            move |((discriminant, request), mut maybe_key)| {
448                let selector = selector.clone();
449                async move {
450                    loop {
451                        let key = match maybe_key {
452                            Ok(key) => key,
453                            Err(why) => return (discriminant, Err(why)),
454                        };
455                        match self.execute_with_key(&key, &request).await {
456                            Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)),
457                            Ok(RequestResult::Retry) => (),
458                            Err(why) => return (discriminant, Err(why)),
459                        }
460                        maybe_key = self.storage.acquire_key(selector.clone()).await;
461                    }
462                }
463            },
464        )
465        .buffer_unordered(25)
466    }
467}
468
469pub struct KeyPool<S>
470where
471    S: KeyPoolStorage,
472{
473    inner: Arc<KeyPoolInner<S>>,
474}
475
476impl<S> Deref for KeyPool<S>
477where
478    S: KeyPoolStorage,
479{
480    type Target = KeyPoolInner<S>;
481    fn deref(&self) -> &Self::Target {
482        &self.inner
483    }
484}
485
486enum RequestResult {
487    Response(ApiResponse),
488    Retry,
489}
490
491impl<S> KeyPool<S>
492where
493    S: KeyPoolStorage + Send + Sync + 'static,
494{
495    pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<S>
496    where
497        I: IntoSelector<S::Key, S::Domain>,
498    {
499        KeyPoolExecutor::new(self, selector.into_selector())
500    }
501
502    pub fn throttled_torn_api<I>(
503        &self,
504        selector: I,
505        distance: Duration,
506    ) -> ThrottledKeyPoolExecutor<S>
507    where
508        I: IntoSelector<S::Key, S::Domain>,
509    {
510        ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance)
511    }
512}
513
514fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
515    if buf.starts_with(br#"{"error":{"#) {
516        #[derive(Deserialize)]
517        struct ErrorBody<'a> {
518            code: u16,
519            error: &'a str,
520        }
521        #[derive(Deserialize)]
522        struct ErrorContainer<'a> {
523            #[serde(borrow)]
524            error: ErrorBody<'a>,
525        }
526
527        let error: ErrorContainer = serde_json::from_slice(buf)?;
528        Ok(Some(crate::ApiError::new(
529            error.error.code,
530            error.error.error,
531        )))
532    } else {
533        Ok(None)
534    }
535}
536
537pub struct KeyPoolExecutor<'p, S>
538where
539    S: KeyPoolStorage,
540{
541    pool: &'p KeyPoolInner<S>,
542    selector: KeySelector<S::Key, S::Domain>,
543}
544
545impl<'p, S> KeyPoolExecutor<'p, S>
546where
547    S: KeyPoolStorage,
548{
549    pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
550        Self {
551            pool: &pool.inner,
552            selector,
553        }
554    }
555}
556
557impl<S> Executor for KeyPoolExecutor<'_, S>
558where
559    S: KeyPoolStorage + 'static,
560{
561    type Error = S::Error;
562
563    async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
564    where
565        R: torn_api::request::IntoRequest,
566    {
567        let (d, request) = request.into_request();
568
569        (d, self.pool.execute_request(self.selector, request).await)
570    }
571}
572
573impl<S> BulkExecutor for KeyPoolExecutor<'_, S>
574where
575    S: KeyPoolStorage + 'static,
576{
577    type Error = S::Error;
578
579    fn execute<R>(
580        self,
581        requests: impl IntoIterator<Item = R>,
582    ) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
583    where
584        R: torn_api::request::IntoRequest,
585    {
586        let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
587        self.pool
588            .execute_bulk_requests(self.selector.clone(), requests)
589            .into_stream()
590            .flatten()
591            .boxed()
592    }
593}
594
595pub struct ThrottledKeyPoolExecutor<'p, S>
596where
597    S: KeyPoolStorage,
598{
599    pool: &'p KeyPoolInner<S>,
600    selector: KeySelector<S::Key, S::Domain>,
601    distance: Duration,
602}
603
604impl<S> Clone for ThrottledKeyPoolExecutor<'_, S>
605where
606    S: KeyPoolStorage,
607{
608    fn clone(&self) -> Self {
609        Self {
610            pool: self.pool,
611            selector: self.selector.clone(),
612            distance: self.distance,
613        }
614    }
615}
616
617impl<S> ThrottledKeyPoolExecutor<'_, S>
618where
619    S: KeyPoolStorage,
620{
621    async fn execute_request(self, request: ApiRequest) -> Result<ApiResponse, S::Error> {
622        self.pool.execute_request(self.selector, request).await
623    }
624}
625
626impl<'p, S> ThrottledKeyPoolExecutor<'p, S>
627where
628    S: KeyPoolStorage,
629{
630    pub fn new(
631        pool: &'p KeyPool<S>,
632        selector: KeySelector<S::Key, S::Domain>,
633        distance: Duration,
634    ) -> Self {
635        Self {
636            pool: &pool.inner,
637            selector,
638            distance,
639        }
640    }
641}
642
643impl<S> BulkExecutor for ThrottledKeyPoolExecutor<'_, S>
644where
645    S: KeyPoolStorage + 'static,
646{
647    type Error = S::Error;
648
649    fn execute<R>(
650        self,
651        requests: impl IntoIterator<Item = R>,
652    ) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
653    where
654        R: torn_api::request::IntoRequest,
655    {
656        let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
657        StreamExt::map(
658            futures::stream::iter(requests).throttle(self.distance),
659            move |(d, request)| {
660                let this = self.clone();
661                async move {
662                    let result = this.execute_request(request).await;
663                    (d, result)
664                }
665            },
666        )
667        .buffer_unordered(25)
668        .boxed()
669    }
670}
671
672#[cfg(test)]
673#[cfg(feature = "postgres")]
674mod test {
675    use torn_api::executor::{BulkExecutorExt, ExecutorExt};
676
677    use crate::postgres;
678
679    use super::*;
680
681    #[sqlx::test]
682    fn name(pool: sqlx::PgPool) {
683        let (storage, _) = postgres::test::setup(pool).await;
684
685        let pool = PoolBuilder::new(storage)
686            .use_default_hooks()
687            .comment("test_runner")
688            .build();
689
690        pool.torn_api(postgres::test::Domain::All)
691            .faction()
692            .basic(|b| b)
693            .await
694            .unwrap();
695    }
696
697    #[sqlx::test]
698    fn bulk(pool: sqlx::PgPool) {
699        let (storage, _) = postgres::test::setup(pool).await;
700
701        let pool = PoolBuilder::new(storage)
702            .use_default_hooks()
703            .comment("test_runner")
704            .build();
705
706        let responses = pool
707            .torn_api(postgres::test::Domain::All)
708            .faction_bulk()
709            .basic_for_id(vec![19.into(), 89.into()], |b| b);
710        let mut responses: Vec<_> = StreamExt::collect(responses).await;
711
712        let (_id1, basic1) = responses.pop().unwrap();
713        basic1.unwrap();
714
715        let (_id2, basic2) = responses.pop().unwrap();
716        basic2.unwrap();
717    }
718
719    #[sqlx::test]
720    fn bulk_trottled(pool: sqlx::PgPool) {
721        let (storage, _) = postgres::test::setup(pool).await;
722
723        let pool = PoolBuilder::new(storage)
724            .use_default_hooks()
725            .comment("test_runner")
726            .build();
727
728        let responses = pool
729            .throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500))
730            .faction_bulk()
731            .basic_for_id(vec![19.into(), 89.into()], |b| b);
732        let mut responses: Vec<_> = StreamExt::collect(responses).await;
733
734        let (_id1, basic1) = responses.pop().unwrap();
735        basic1.unwrap();
736
737        let (_id2, basic2) = responses.pop().unwrap();
738        basic2.unwrap();
739    }
740}