torn_key_pool/
lib.rs

1#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)]
2
3#[cfg(feature = "postgres")]
4pub mod postgres;
5
6use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
7
8use futures::{future::BoxFuture, FutureExt};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use serde::Deserialize;
11use torn_api::{
12    executor::Executor,
13    request::{ApiRequest, ApiResponse},
14    ApiError,
15};
16
17pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
18
19impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
20
21pub trait ApiKey: Send + Sync + Clone + 'static {
22    type IdType: ApiKeyId;
23
24    fn value(&self) -> &str;
25
26    fn id(&self) -> Self::IdType;
27
28    fn selector<D>(&self) -> KeySelector<Self, D>
29    where
30        D: KeyDomain,
31    {
32        KeySelector::Id(self.id())
33    }
34}
35
36pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
37    fn fallback(&self) -> Option<Self> {
38        None
39    }
40}
41
42#[derive(Debug, Clone)]
43pub enum KeySelector<K, D>
44where
45    K: ApiKey,
46    D: KeyDomain,
47{
48    Key(String),
49    Id(K::IdType),
50    UserId(i32),
51    Has(Vec<D>),
52    OneOf(Vec<D>),
53}
54
55impl<K, D> KeySelector<K, D>
56where
57    K: ApiKey,
58    D: KeyDomain,
59{
60    pub(crate) fn fallback(&self) -> Option<Self> {
61        match self {
62            Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
63            Self::Has(domains) => {
64                let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
65                if fallbacks.is_empty() {
66                    None
67                } else {
68                    Some(Self::Has(fallbacks))
69                }
70            }
71            Self::OneOf(domains) => {
72                let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
73                if fallbacks.is_empty() {
74                    None
75                } else {
76                    Some(Self::OneOf(fallbacks))
77                }
78            }
79        }
80    }
81}
82
83pub trait IntoSelector<K, D>: Send
84where
85    K: ApiKey,
86    D: KeyDomain,
87{
88    fn into_selector(self) -> KeySelector<K, D>;
89}
90
91impl<K, D> IntoSelector<K, D> for D
92where
93    K: ApiKey,
94    D: KeyDomain,
95{
96    fn into_selector(self) -> KeySelector<K, D> {
97        KeySelector::Has(vec![self])
98    }
99}
100
101impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
102where
103    K: ApiKey,
104    D: KeyDomain,
105{
106    fn into_selector(self) -> KeySelector<K, D> {
107        self
108    }
109}
110
111pub trait KeyPoolStorage: Send + Sync {
112    type Key: ApiKey;
113    type Domain: KeyDomain;
114    type Error: From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + Send;
115
116    fn acquire_key<S>(
117        &self,
118        selector: S,
119    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
120    where
121        S: IntoSelector<Self::Key, Self::Domain>;
122
123    fn acquire_many_keys<S>(
124        &self,
125        selector: S,
126        number: i64,
127    ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
128    where
129        S: IntoSelector<Self::Key, Self::Domain>;
130
131    fn store_key(
132        &self,
133        user_id: i32,
134        key: String,
135        domains: Vec<Self::Domain>,
136    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
137
138    fn read_key<S>(
139        &self,
140        selector: S,
141    ) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
142    where
143        S: IntoSelector<Self::Key, Self::Domain>;
144
145    fn read_keys<S>(
146        &self,
147        selector: S,
148    ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
149    where
150        S: IntoSelector<Self::Key, Self::Domain>;
151
152    fn remove_key<S>(
153        &self,
154        selector: S,
155    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
156    where
157        S: IntoSelector<Self::Key, Self::Domain>;
158
159    fn add_domain_to_key<S>(
160        &self,
161        selector: S,
162        domain: Self::Domain,
163    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
164    where
165        S: IntoSelector<Self::Key, Self::Domain>;
166
167    fn remove_domain_from_key<S>(
168        &self,
169        selector: S,
170        domain: Self::Domain,
171    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
172    where
173        S: IntoSelector<Self::Key, Self::Domain>;
174
175    fn set_domains_for_key<S>(
176        &self,
177        selector: S,
178        domains: Vec<Self::Domain>,
179    ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
180    where
181        S: IntoSelector<Self::Key, Self::Domain>;
182
183    fn timeout_key<S>(
184        &self,
185        selector: S,
186        duration: Duration,
187    ) -> impl Future<Output = Result<(), Self::Error>> + Send
188    where
189        S: IntoSelector<Self::Key, Self::Domain>;
190}
191
192#[derive(Default)]
193pub struct PoolOptions<S>
194where
195    S: KeyPoolStorage,
196{
197    comment: Option<String>,
198    #[allow(clippy::type_complexity)]
199    error_hooks: HashMap<
200        u16,
201        Box<
202            dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
203                + Send
204                + Sync,
205        >,
206    >,
207}
208
209pub struct KeyPoolExecutor<'p, S>
210where
211    S: KeyPoolStorage,
212{
213    pool: &'p KeyPool<S>,
214    selector: KeySelector<S::Key, S::Domain>,
215}
216
217impl<'p, S> KeyPoolExecutor<'p, S>
218where
219    S: KeyPoolStorage,
220{
221    pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
222        Self { pool, selector }
223    }
224
225    async fn execute_request<D>(&self, request: ApiRequest<D>) -> Result<ApiResponse<D>, S::Error>
226    where
227        D: Send,
228    {
229        let key = self.pool.storage.acquire_key(self.selector.clone()).await?;
230
231        let mut headers = HeaderMap::with_capacity(1);
232        headers.insert(
233            AUTHORIZATION,
234            HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
235        );
236
237        let resp = self
238            .pool
239            .client
240            .get(request.url())
241            .headers(headers)
242            .send()
243            .await?;
244
245        let status = resp.status();
246
247        let bytes = resp.bytes().await?;
248
249        if let Some(err) = decode_error(&bytes)? {
250            if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) {
251                let retry = (*handler)(&self.pool.storage, &key).await?;
252
253                if retry {
254                    return Box::pin(self.execute_request(request)).await;
255                }
256            }
257            Err(err.into())
258        } else {
259            Ok(ApiResponse {
260                discriminant: request.disriminant,
261                body: Some(bytes),
262                status,
263            })
264        }
265    }
266}
267
268pub struct PoolBuilder<S>
269where
270    S: KeyPoolStorage,
271{
272    client: reqwest::Client,
273    storage: S,
274    options: crate::PoolOptions<S>,
275}
276
277impl<S> PoolBuilder<S>
278where
279    S: KeyPoolStorage,
280{
281    pub fn new(storage: S) -> Self {
282        Self {
283            client: reqwest::Client::builder()
284                .brotli(true)
285                .http2_keep_alive_timeout(Duration::from_secs(60))
286                .http2_keep_alive_interval(Duration::from_secs(5))
287                .https_only(true)
288                .build()
289                .unwrap(),
290            storage,
291            options: PoolOptions {
292                comment: None,
293                error_hooks: Default::default(),
294            },
295        }
296    }
297
298    pub fn comment(mut self, c: impl ToString) -> Self {
299        self.options.comment = Some(c.to_string());
300        self
301    }
302
303    pub fn error_hook<F>(mut self, code: u16, handler: F) -> Self
304    where
305        F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
306            + Send
307            + Sync
308            + 'static,
309    {
310        self.options.error_hooks.insert(code, Box::new(handler));
311
312        self
313    }
314
315    pub fn use_default_hooks(self) -> Self {
316        self.error_hook(2, |storage, key| {
317            async move {
318                storage.remove_key(KeySelector::Id(key.id())).await?;
319                Ok(true)
320            }
321            .boxed()
322        })
323        .error_hook(5, |storage, key| {
324            async move {
325                storage
326                    .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
327                    .await?;
328                Ok(true)
329            }
330            .boxed()
331        })
332        .error_hook(10, |storage, key| {
333            async move {
334                storage.remove_key(KeySelector::Id(key.id())).await?;
335                Ok(true)
336            }
337            .boxed()
338        })
339        .error_hook(13, |storage, key| {
340            async move {
341                storage
342                    .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
343                    .await?;
344                Ok(true)
345            }
346            .boxed()
347        })
348        .error_hook(18, |storage, key| {
349            async move {
350                storage
351                    .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
352                    .await?;
353                Ok(true)
354            }
355            .boxed()
356        })
357    }
358
359    pub fn build(self) -> KeyPool<S> {
360        KeyPool {
361            client: self.client,
362            storage: self.storage,
363            options: Arc::new(self.options),
364        }
365    }
366}
367
368pub struct KeyPool<S>
369where
370    S: KeyPoolStorage,
371{
372    pub client: reqwest::Client,
373    pub storage: S,
374    pub options: Arc<PoolOptions<S>>,
375}
376
377impl<S> KeyPool<S>
378where
379    S: KeyPoolStorage + Send + Sync + 'static,
380{
381    pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<S>
382    where
383        I: IntoSelector<S::Key, S::Domain>,
384    {
385        KeyPoolExecutor::new(self, selector.into_selector())
386    }
387}
388
389fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
390    if buf.starts_with(br#"{"error":{"#) {
391        #[derive(Deserialize)]
392        struct ErrorBody<'a> {
393            code: u16,
394            error: &'a str,
395        }
396        #[derive(Deserialize)]
397        struct ErrorContainer<'a> {
398            #[serde(borrow)]
399            error: ErrorBody<'a>,
400        }
401
402        let error: ErrorContainer = serde_json::from_slice(buf)?;
403        Ok(Some(crate::ApiError::new(
404            error.error.code,
405            error.error.error,
406        )))
407    } else {
408        Ok(None)
409    }
410}
411
412impl<S> Executor for KeyPoolExecutor<'_, S>
413where
414    S: KeyPoolStorage,
415{
416    type Error = S::Error;
417
418    async fn execute<R>(
419        &self,
420        request: R,
421    ) -> Result<torn_api::request::ApiResponse<R::Discriminant>, Self::Error>
422    where
423        R: torn_api::request::IntoRequest,
424    {
425        let request = request.into_request();
426
427        self.execute_request(request).await
428    }
429}
430
431#[cfg(test)]
432mod test {
433    use torn_api::executor::ExecutorExt;
434
435    use crate::postgres;
436
437    use super::*;
438
439    #[sqlx::test]
440    fn name(pool: sqlx::PgPool) {
441        let (storage, _) = postgres::test::setup(pool).await;
442
443        let pool = PoolBuilder::new(storage)
444            .use_default_hooks()
445            .comment("test_runner")
446            .build();
447
448        pool.torn_api(postgres::test::Domain::All)
449            .faction()
450            .basic(|b| b)
451            .await
452            .unwrap();
453    }
454}