1#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)]
2
3#[cfg(feature = "postgres")]
4pub mod postgres;
5
6use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
7
8use futures::{future::BoxFuture, FutureExt};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use serde::Deserialize;
11use torn_api::{
12 executor::Executor,
13 request::{ApiRequest, ApiResponse},
14 ApiError,
15};
16
17pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
18
19impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
20
21pub trait ApiKey: Send + Sync + Clone + 'static {
22 type IdType: ApiKeyId;
23
24 fn value(&self) -> &str;
25
26 fn id(&self) -> Self::IdType;
27
28 fn selector<D>(&self) -> KeySelector<Self, D>
29 where
30 D: KeyDomain,
31 {
32 KeySelector::Id(self.id())
33 }
34}
35
36pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
37 fn fallback(&self) -> Option<Self> {
38 None
39 }
40}
41
42#[derive(Debug, Clone)]
43pub enum KeySelector<K, D>
44where
45 K: ApiKey,
46 D: KeyDomain,
47{
48 Key(String),
49 Id(K::IdType),
50 UserId(i32),
51 Has(Vec<D>),
52 OneOf(Vec<D>),
53}
54
55impl<K, D> KeySelector<K, D>
56where
57 K: ApiKey,
58 D: KeyDomain,
59{
60 pub(crate) fn fallback(&self) -> Option<Self> {
61 match self {
62 Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
63 Self::Has(domains) => {
64 let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
65 if fallbacks.is_empty() {
66 None
67 } else {
68 Some(Self::Has(fallbacks))
69 }
70 }
71 Self::OneOf(domains) => {
72 let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
73 if fallbacks.is_empty() {
74 None
75 } else {
76 Some(Self::OneOf(fallbacks))
77 }
78 }
79 }
80 }
81}
82
83pub trait IntoSelector<K, D>: Send
84where
85 K: ApiKey,
86 D: KeyDomain,
87{
88 fn into_selector(self) -> KeySelector<K, D>;
89}
90
91impl<K, D> IntoSelector<K, D> for D
92where
93 K: ApiKey,
94 D: KeyDomain,
95{
96 fn into_selector(self) -> KeySelector<K, D> {
97 KeySelector::Has(vec![self])
98 }
99}
100
101impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
102where
103 K: ApiKey,
104 D: KeyDomain,
105{
106 fn into_selector(self) -> KeySelector<K, D> {
107 self
108 }
109}
110
111pub trait KeyPoolStorage: Send + Sync {
112 type Key: ApiKey;
113 type Domain: KeyDomain;
114 type Error: From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + Send;
115
116 fn acquire_key<S>(
117 &self,
118 selector: S,
119 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
120 where
121 S: IntoSelector<Self::Key, Self::Domain>;
122
123 fn acquire_many_keys<S>(
124 &self,
125 selector: S,
126 number: i64,
127 ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
128 where
129 S: IntoSelector<Self::Key, Self::Domain>;
130
131 fn store_key(
132 &self,
133 user_id: i32,
134 key: String,
135 domains: Vec<Self::Domain>,
136 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
137
138 fn read_key<S>(
139 &self,
140 selector: S,
141 ) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
142 where
143 S: IntoSelector<Self::Key, Self::Domain>;
144
145 fn read_keys<S>(
146 &self,
147 selector: S,
148 ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
149 where
150 S: IntoSelector<Self::Key, Self::Domain>;
151
152 fn remove_key<S>(
153 &self,
154 selector: S,
155 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
156 where
157 S: IntoSelector<Self::Key, Self::Domain>;
158
159 fn add_domain_to_key<S>(
160 &self,
161 selector: S,
162 domain: Self::Domain,
163 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
164 where
165 S: IntoSelector<Self::Key, Self::Domain>;
166
167 fn remove_domain_from_key<S>(
168 &self,
169 selector: S,
170 domain: Self::Domain,
171 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
172 where
173 S: IntoSelector<Self::Key, Self::Domain>;
174
175 fn set_domains_for_key<S>(
176 &self,
177 selector: S,
178 domains: Vec<Self::Domain>,
179 ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
180 where
181 S: IntoSelector<Self::Key, Self::Domain>;
182
183 fn timeout_key<S>(
184 &self,
185 selector: S,
186 duration: Duration,
187 ) -> impl Future<Output = Result<(), Self::Error>> + Send
188 where
189 S: IntoSelector<Self::Key, Self::Domain>;
190}
191
192#[derive(Default)]
193pub struct PoolOptions<S>
194where
195 S: KeyPoolStorage,
196{
197 comment: Option<String>,
198 #[allow(clippy::type_complexity)]
199 error_hooks: HashMap<
200 u16,
201 Box<
202 dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
203 + Send
204 + Sync,
205 >,
206 >,
207}
208
209pub struct KeyPoolExecutor<'p, S>
210where
211 S: KeyPoolStorage,
212{
213 pool: &'p KeyPool<S>,
214 selector: KeySelector<S::Key, S::Domain>,
215}
216
217impl<'p, S> KeyPoolExecutor<'p, S>
218where
219 S: KeyPoolStorage,
220{
221 pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
222 Self { pool, selector }
223 }
224
225 async fn execute_request<D>(&self, request: ApiRequest<D>) -> Result<ApiResponse<D>, S::Error>
226 where
227 D: Send,
228 {
229 let key = self.pool.storage.acquire_key(self.selector.clone()).await?;
230
231 let mut headers = HeaderMap::with_capacity(1);
232 headers.insert(
233 AUTHORIZATION,
234 HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
235 );
236
237 let resp = self
238 .pool
239 .client
240 .get(request.url())
241 .headers(headers)
242 .send()
243 .await?;
244
245 let status = resp.status();
246
247 let bytes = resp.bytes().await?;
248
249 if let Some(err) = decode_error(&bytes)? {
250 if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) {
251 let retry = (*handler)(&self.pool.storage, &key).await?;
252
253 if retry {
254 return Box::pin(self.execute_request(request)).await;
255 }
256 }
257 Err(err.into())
258 } else {
259 Ok(ApiResponse {
260 discriminant: request.disriminant,
261 body: Some(bytes),
262 status,
263 })
264 }
265 }
266}
267
268pub struct PoolBuilder<S>
269where
270 S: KeyPoolStorage,
271{
272 client: reqwest::Client,
273 storage: S,
274 options: crate::PoolOptions<S>,
275}
276
277impl<S> PoolBuilder<S>
278where
279 S: KeyPoolStorage,
280{
281 pub fn new(storage: S) -> Self {
282 Self {
283 client: reqwest::Client::builder()
284 .brotli(true)
285 .http2_keep_alive_timeout(Duration::from_secs(60))
286 .http2_keep_alive_interval(Duration::from_secs(5))
287 .https_only(true)
288 .build()
289 .unwrap(),
290 storage,
291 options: PoolOptions {
292 comment: None,
293 error_hooks: Default::default(),
294 },
295 }
296 }
297
298 pub fn comment(mut self, c: impl ToString) -> Self {
299 self.options.comment = Some(c.to_string());
300 self
301 }
302
303 pub fn error_hook<F>(mut self, code: u16, handler: F) -> Self
304 where
305 F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
306 + Send
307 + Sync
308 + 'static,
309 {
310 self.options.error_hooks.insert(code, Box::new(handler));
311
312 self
313 }
314
315 pub fn use_default_hooks(self) -> Self {
316 self.error_hook(2, |storage, key| {
317 async move {
318 storage.remove_key(KeySelector::Id(key.id())).await?;
319 Ok(true)
320 }
321 .boxed()
322 })
323 .error_hook(5, |storage, key| {
324 async move {
325 storage
326 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
327 .await?;
328 Ok(true)
329 }
330 .boxed()
331 })
332 .error_hook(10, |storage, key| {
333 async move {
334 storage.remove_key(KeySelector::Id(key.id())).await?;
335 Ok(true)
336 }
337 .boxed()
338 })
339 .error_hook(13, |storage, key| {
340 async move {
341 storage
342 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
343 .await?;
344 Ok(true)
345 }
346 .boxed()
347 })
348 .error_hook(18, |storage, key| {
349 async move {
350 storage
351 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
352 .await?;
353 Ok(true)
354 }
355 .boxed()
356 })
357 }
358
359 pub fn build(self) -> KeyPool<S> {
360 KeyPool {
361 client: self.client,
362 storage: self.storage,
363 options: Arc::new(self.options),
364 }
365 }
366}
367
368pub struct KeyPool<S>
369where
370 S: KeyPoolStorage,
371{
372 pub client: reqwest::Client,
373 pub storage: S,
374 pub options: Arc<PoolOptions<S>>,
375}
376
377impl<S> KeyPool<S>
378where
379 S: KeyPoolStorage + Send + Sync + 'static,
380{
381 pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<S>
382 where
383 I: IntoSelector<S::Key, S::Domain>,
384 {
385 KeyPoolExecutor::new(self, selector.into_selector())
386 }
387}
388
389fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
390 if buf.starts_with(br#"{"error":{"#) {
391 #[derive(Deserialize)]
392 struct ErrorBody<'a> {
393 code: u16,
394 error: &'a str,
395 }
396 #[derive(Deserialize)]
397 struct ErrorContainer<'a> {
398 #[serde(borrow)]
399 error: ErrorBody<'a>,
400 }
401
402 let error: ErrorContainer = serde_json::from_slice(buf)?;
403 Ok(Some(crate::ApiError::new(
404 error.error.code,
405 error.error.error,
406 )))
407 } else {
408 Ok(None)
409 }
410}
411
412impl<S> Executor for KeyPoolExecutor<'_, S>
413where
414 S: KeyPoolStorage,
415{
416 type Error = S::Error;
417
418 async fn execute<R>(
419 &self,
420 request: R,
421 ) -> Result<torn_api::request::ApiResponse<R::Discriminant>, Self::Error>
422 where
423 R: torn_api::request::IntoRequest,
424 {
425 let request = request.into_request();
426
427 self.execute_request(request).await
428 }
429}
430
431#[cfg(test)]
432mod test {
433 use torn_api::executor::ExecutorExt;
434
435 use crate::postgres;
436
437 use super::*;
438
439 #[sqlx::test]
440 fn name(pool: sqlx::PgPool) {
441 let (storage, _) = postgres::test::setup(pool).await;
442
443 let pool = PoolBuilder::new(storage)
444 .use_default_hooks()
445 .comment("test_runner")
446 .build();
447
448 pool.torn_api(postgres::test::Domain::All)
449 .faction()
450 .basic(|b| b)
451 .await
452 .unwrap();
453 }
454}