1use futures::future::BoxFuture;
2use indoc::indoc;
3use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
4use thiserror::Error;
5
6use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector};
7
8pub trait PgKeyDomain:
9 KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
10{
11}
12
13impl<T> PgKeyDomain for T where
14 T: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
15{
16}
17
18#[derive(Debug, Error)]
19pub enum PgKeyPoolError<D>
20where
21 D: PgKeyDomain,
22{
23 #[error("Databank: {0}")]
24 Pg(#[from] sqlx::Error),
25
26 #[error("Network: {0}")]
27 Network(#[from] reqwest::Error),
28
29 #[error("Parsing: {0}")]
30 Parsing(#[from] serde_json::Error),
31
32 #[error("Api: {0}")]
33 Api(#[from] torn_api::ApiError),
34
35 #[error("No key avalaible for domain {0:?}")]
36 Unavailable(KeySelector<PgKey<D>, D>),
37
38 #[error("Key not found: '{0:?}'")]
39 KeyNotFound(KeySelector<PgKey<D>, D>),
40}
41
42#[derive(Debug, Clone, FromRow)]
43pub struct PgKey<D>
44where
45 D: PgKeyDomain,
46{
47 pub id: i32,
48 pub user_id: i32,
49 pub key: String,
50 pub uses: i16,
51 pub domains: sqlx::types::Json<Vec<D>>,
52}
53
54#[inline(always)]
55fn build_predicate<'b, D>(
56 builder: &mut QueryBuilder<'b, Postgres>,
57 selector: &'b KeySelector<PgKey<D>, D>,
58) where
59 D: PgKeyDomain,
60{
61 match selector {
62 KeySelector::Id(id) => builder.push("id=").push_bind(id),
63 KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
64 KeySelector::Key(key) => builder.push("key=").push_bind(key),
65 KeySelector::Has(domains) => builder
66 .push("domains @> ")
67 .push_bind(sqlx::types::Json(domains)),
68 KeySelector::OneOf(domains) => {
69 if domains.is_empty() {
70 builder.push("false");
71 return;
72 }
73
74 for (idx, domain) in domains.iter().enumerate() {
75 if idx == 0 {
76 builder.push("(");
77 } else {
78 builder.push(" or ");
79 }
80 builder
81 .push("domains @> ")
82 .push_bind(sqlx::types::Json(vec![domain]));
83 }
84 builder.push(")")
85 }
86 };
87}
88
89#[derive(Debug, Clone, FromRow)]
90pub struct PgKeyPoolStorage<D>
91where
92 D: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
93{
94 pool: PgPool,
95 limit: i16,
96 _phantom: std::marker::PhantomData<D>,
97}
98
99impl<D> ApiKey for PgKey<D>
100where
101 D: PgKeyDomain,
102{
103 type IdType = i32;
104
105 #[inline(always)]
106 fn value(&self) -> &str {
107 &self.key
108 }
109
110 #[inline(always)]
111 fn id(&self) -> Self::IdType {
112 self.id
113 }
114}
115
116impl<D> PgKeyPoolStorage<D>
117where
118 D: PgKeyDomain,
119{
120 pub fn new(pool: PgPool, limit: i16) -> Self {
121 Self {
122 pool,
123 limit,
124 _phantom: Default::default(),
125 }
126 }
127
128 pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
129 sqlx::query(indoc! {r#"
130 CREATE TABLE IF NOT EXISTS api_keys (
131 id serial primary key,
132 user_id int4 not null,
133 key char(16) not null,
134 uses int2 not null default 0,
135 domains jsonb not null default '{}'::jsonb,
136 last_used timestamptz not null default now(),
137 flag int2,
138 cooldown timestamptz,
139 constraint "uq:api_keys.key" UNIQUE(key)
140 )"#
141 })
142 .execute(&self.pool)
143 .await?;
144
145 sqlx::query(indoc! {r#"
146 CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops)
147 "#})
148 .execute(&self.pool)
149 .await?;
150
151 sqlx::query(indoc! {r#"
152 CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id)
153 "#})
154 .execute(&self.pool)
155 .await?;
156
157 sqlx::query(indoc! {r#"
158 create or replace function __unique_jsonb_array(jsonb) returns jsonb
159 AS $$
160 select jsonb_agg(d::jsonb) from (
161 select distinct jsonb_array_elements_text($1) as d
162 ) t
163 $$ language sql;
164 "#})
165 .execute(&self.pool)
166 .await?;
167
168 sqlx::query(indoc! {r#"
169 create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb
170 AS $$
171 select jsonb_agg(d::jsonb) from (
172 select distinct jsonb_array_elements_text($1) as d
173 ) t where d<>$2::text
174 $$ language sql;
175 "#})
176 .execute(&self.pool)
177 .await?;
178
179 Ok(())
180 }
181}
182
183#[cfg(feature = "tokio-runtime")]
184async fn random_sleep() {
185 use rand::{rng, Rng};
186 let dur = tokio::time::Duration::from_millis(rng().random_range(1..50));
187 tokio::time::sleep(dur).await;
188}
189
190impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
191where
192 D: PgKeyDomain,
193{
194 type Key = PgKey<D>;
195 type Domain = D;
196
197 type Error = PgKeyPoolError<D>;
198
199 async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
200 where
201 S: IntoSelector<Self::Key, Self::Domain>,
202 {
203 let selector = selector.into_selector();
204 loop {
205 let attempt = async {
206 let mut tx = self.pool.begin().await?;
207
208 sqlx::query("set transaction isolation level repeatable read")
209 .execute(&mut *tx)
210 .await?;
211
212 let mut qb = QueryBuilder::new(indoc::indoc! {
213 r#"
214 with key as (
215 select
216 id,
217 0::int2 as uses
218 from api_keys where last_used < date_trunc('minute', now())
219 and (cooldown is null or now() >= cooldown)
220 and "#
221 });
222
223 build_predicate(&mut qb, &selector);
224
225 qb.push(indoc::indoc! {
226 "
227 \n union (
228 select id, uses from api_keys
229 where last_used >= date_trunc('minute', now())
230 and (cooldown is null or now() >= cooldown)
231 and "
232 });
233
234 build_predicate(&mut qb, &selector);
235
236 qb.push(indoc::indoc! {
237 "
238 \n order by uses asc limit 1
239 )
240 order by uses asc limit 1
241 )
242 update api_keys set
243 uses = key.uses + 1,
244 cooldown = null,
245 flag = null,
246 last_used = now()
247 from key where
248 api_keys.id=key.id and key.uses < "
249 });
250
251 qb.push_bind(self.limit);
252
253 qb.push(indoc::indoc! { "
254 \nreturning
255 api_keys.id,
256 api_keys.user_id,
257 api_keys.key,
258 api_keys.uses,
259 api_keys.domains"
260 });
261
262 let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
263
264 tx.commit().await?;
265
266 Result::<Option<Self::Key>, sqlx::Error>::Ok(key)
267 }
268 .await;
269
270 match attempt {
271 Ok(Some(result)) => return Ok(result),
272 Ok(None) => {
273 fn recurse<D>(
274 storage: &PgKeyPoolStorage<D>,
275 selector: KeySelector<PgKey<D>, D>,
276 ) -> BoxFuture<Result<PgKey<D>, PgKeyPoolError<D>>>
277 where
278 D: PgKeyDomain,
279 {
280 Box::pin(storage.acquire_key(selector))
281 }
282
283 return recurse(
284 self,
285 selector
286 .fallback()
287 .ok_or_else(|| PgKeyPoolError::Unavailable(selector))?,
288 )
289 .await;
290 }
291 Err(error) => {
292 if let Some(db_error) = error.as_database_error() {
293 let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
294 if pg_error.code() == "40001" {
295 random_sleep().await;
296 } else {
297 return Err(error.into());
298 }
299 } else {
300 return Err(error.into());
301 }
302 }
303 }
304 }
305 }
306
307 async fn acquire_many_keys<S>(
308 &self,
309 selector: S,
310 number: i64,
311 ) -> Result<Vec<Self::Key>, Self::Error>
312 where
313 S: IntoSelector<Self::Key, Self::Domain>,
314 {
315 let selector = selector.into_selector();
316 loop {
317 let attempt = async {
318 let mut tx = self.pool.begin().await?;
319
320 sqlx::query("set transaction isolation level repeatable read")
321 .execute(&mut *tx)
322 .await?;
323
324 let mut qb = QueryBuilder::new(indoc::indoc! {
325 r#"select
326 id,
327 user_id,
328 key,
329 0::int2 as uses,
330 domains
331 from api_keys where last_used < date_trunc('minute', now())
332 and (cooldown is null or now() >= cooldown)
333 and "#
334 });
335 build_predicate(&mut qb, &selector);
336 qb.push(indoc::indoc! {
337 "
338 \nunion
339 select
340 id,
341 user_id,
342 key,
343 uses,
344 domains
345 from api_keys where last_used >= date_trunc('minute', now())
346 and (cooldown is null or now() >= cooldown)
347 and "
348 });
349 build_predicate(&mut qb, &selector);
350 qb.push("\norder by uses limit ");
351 qb.push_bind(self.limit);
352
353 let mut keys: Vec<Self::Key> = qb.build_query_as().fetch_all(&mut *tx).await?;
354
355 if keys.is_empty() {
356 tx.commit().await?;
357 return Ok(None);
358 }
359
360 keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses));
361
362 let mut result = Vec::with_capacity(number as usize);
363 let (max, rest) = keys.split_last_mut().unwrap();
364 for key in rest {
365 let available = max.uses - key.uses;
366 let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
367 key.uses += using;
368 result.extend(std::iter::repeat_n(key.clone(), using as usize));
369
370 if result.len() == number as usize {
371 break;
372 }
373 }
374
375 while result.len() < (number as usize) {
376 if keys[0].uses == self.limit {
377 break;
378 }
379
380 let take = std::cmp::min(keys.len(), (number as usize) - result.len());
381 let slice = &mut keys[0..take];
382 slice.iter_mut().for_each(|k| k.uses += 1);
383 result.extend_from_slice(slice);
384 }
385
386 sqlx::query(indoc! {r#"
387 update api_keys set
388 uses = tmp.uses,
389 cooldown = null,
390 flag = null,
391 last_used = now()
392 from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
393 where api_keys.id = tmp.id
394 "#})
395 .bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
396 .bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
397 .execute(&mut *tx)
398 .await?;
399
400 tx.commit().await?;
401
402 Result::<Option<Vec<Self::Key>>, sqlx::Error>::Ok(Some(result))
403 }
404 .await;
405
406 match attempt {
407 Ok(Some(result)) => return Ok(result),
408 Ok(None) => {
409 fn recurse<D>(
410 storage: &PgKeyPoolStorage<D>,
411 selector: KeySelector<PgKey<D>, D>,
412 number: i64,
413 ) -> BoxFuture<Result<Vec<PgKey<D>>, PgKeyPoolError<D>>>
414 where
415 D: PgKeyDomain,
416 {
417 Box::pin(storage.acquire_many_keys(selector, number))
418 }
419
420 return recurse(
421 self,
422 selector
423 .fallback()
424 .ok_or_else(|| Self::Error::Unavailable(selector))?,
425 number,
426 )
427 .await;
428 }
429 Err(error) => {
430 if let Some(db_error) = error.as_database_error() {
431 let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
432 if pg_error.code() == "40001" {
433 random_sleep().await;
434 } else {
435 return Err(error.into());
436 }
437 } else {
438 return Err(error.into());
439 }
440 }
441 }
442 }
443 }
444
445 async fn timeout_key<S>(
446 &self,
447 selector: S,
448 duration: std::time::Duration,
449 ) -> Result<(), Self::Error>
450 where
451 S: IntoSelector<Self::Key, Self::Domain>,
452 {
453 let selector = selector.into_selector();
454
455 let mut qb = QueryBuilder::new("update api_keys set cooldown=now() + ");
456 qb.push_bind(duration);
457 qb.push(" where ");
458 build_predicate(&mut qb, &selector);
459
460 qb.build().fetch_optional(&self.pool).await?;
461
462 Ok(())
463 }
464
465 async fn store_key(
466 &self,
467 user_id: i32,
468 key: String,
469 domains: Vec<D>,
470 ) -> Result<Self::Key, Self::Error> {
471 sqlx::query_as(
472 "insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \
473 constraint \"uq:api_keys.key\" do update set domains = \
474 __unique_jsonb_array(excluded.domains || api_keys.domains) returning *",
475 )
476 .bind(user_id)
477 .bind(&key)
478 .bind(sqlx::types::Json(domains))
479 .fetch_one(&self.pool)
480 .await
481 .map_err(Into::into)
482 }
483
484 async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
485 where
486 S: IntoSelector<Self::Key, Self::Domain>,
487 {
488 let selector = selector.into_selector();
489
490 let mut qb = QueryBuilder::new("select * from api_keys where ");
491 build_predicate(&mut qb, &selector);
492
493 qb.build_query_as()
494 .fetch_optional(&self.pool)
495 .await
496 .map_err(Into::into)
497 }
498
499 async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
500 where
501 S: IntoSelector<Self::Key, Self::Domain>,
502 {
503 let selector = selector.into_selector();
504
505 let mut qb = QueryBuilder::new("select * from api_keys where ");
506 build_predicate(&mut qb, &selector);
507
508 qb.build_query_as()
509 .fetch_all(&self.pool)
510 .await
511 .map_err(Into::into)
512 }
513
514 async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
515 where
516 S: IntoSelector<Self::Key, Self::Domain>,
517 {
518 let selector = selector.into_selector();
519
520 let mut qb = QueryBuilder::new("delete from api_keys where ");
521 build_predicate(&mut qb, &selector);
522 qb.push(" returning *");
523
524 qb.build_query_as()
525 .fetch_optional(&self.pool)
526 .await?
527 .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
528 }
529
530 async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
531 where
532 S: IntoSelector<Self::Key, Self::Domain>,
533 {
534 let selector = selector.into_selector();
535
536 let mut qb = QueryBuilder::new(
537 "update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(",
538 );
539 qb.push_bind(sqlx::types::Json(domain));
540 qb.push(")) where ");
541 build_predicate(&mut qb, &selector);
542 qb.push(" returning *");
543
544 qb.build_query_as()
545 .fetch_optional(&self.pool)
546 .await?
547 .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
548 }
549
550 async fn remove_domain_from_key<S>(
551 &self,
552 selector: S,
553 domain: D,
554 ) -> Result<Self::Key, Self::Error>
555 where
556 S: IntoSelector<Self::Key, Self::Domain>,
557 {
558 let selector = selector.into_selector();
559
560 let mut qb = QueryBuilder::new(
561 "update api_keys set domains = coalesce(__filter_jsonb_array(domains, ",
562 );
563 qb.push_bind(sqlx::types::Json(domain));
564 qb.push("), '[]'::jsonb) where ");
565 build_predicate(&mut qb, &selector);
566 qb.push(" returning *");
567
568 qb.build_query_as()
569 .fetch_optional(&self.pool)
570 .await?
571 .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
572 }
573
574 async fn set_domains_for_key<S>(
575 &self,
576 selector: S,
577 domains: Vec<D>,
578 ) -> Result<Self::Key, Self::Error>
579 where
580 S: IntoSelector<Self::Key, Self::Domain>,
581 {
582 let selector = selector.into_selector();
583
584 let mut qb = QueryBuilder::new("update api_keys set domains = ");
585 qb.push_bind(sqlx::types::Json(domains));
586 qb.push(" where ");
587 build_predicate(&mut qb, &selector);
588 qb.push(" returning *");
589
590 qb.build_query_as()
591 .fetch_optional(&self.pool)
592 .await?
593 .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
594 }
595}
596
597#[cfg(test)]
598pub(crate) mod test {
599 use std::{sync::Arc, time::Duration};
600
601 use sqlx::Row;
602
603 use super::*;
604
605 #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
606 #[serde(tag = "type", rename_all = "snake_case")]
607 pub(crate) enum Domain {
608 All,
609 Guild { id: i64 },
610 User { id: i32 },
611 Faction { id: i32 },
612 }
613
614 impl KeyDomain for Domain {
615 fn fallback(&self) -> Option<Self> {
616 match self {
617 Self::Guild { id: _ } => Some(Self::All),
618 _ => None,
619 }
620 }
621 }
622
623 pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
624 sqlx::query("DROP TABLE IF EXISTS api_keys")
625 .execute(&pool)
626 .await
627 .unwrap();
628
629 let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
630 storage.initialise().await.unwrap();
631
632 let key = storage
633 .store_key(1, std::env::var("API_KEY").unwrap(), vec![Domain::All])
634 .await
635 .unwrap();
636
637 (storage, key)
638 }
639
640 #[sqlx::test]
641 async fn test_initialise(pool: PgPool) {
642 let (storage, _) = setup(pool).await;
643
644 if let Err(e) = storage.initialise().await {
645 panic!("Initialising key storage failed: {:?}", e);
646 }
647 }
648
649 #[sqlx::test]
650 async fn test_store_duplicate_key(pool: PgPool) {
651 let (storage, key) = setup(pool).await;
652 let key = storage
653 .store_key(1, key.key, vec![Domain::User { id: 1 }])
654 .await
655 .unwrap();
656
657 assert_eq!(key.domains.0.len(), 2);
658 }
659
660 #[sqlx::test]
661 async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) {
662 let (storage, key) = setup(pool).await;
663 let key = storage
664 .store_key(1, key.key, vec![Domain::All])
665 .await
666 .unwrap();
667
668 assert_eq!(key.domains.0.len(), 1);
669 }
670
671 #[sqlx::test]
672 async fn test_add_domain(pool: PgPool) {
673 let (storage, key) = setup(pool).await;
674 let key = storage
675 .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
676 .await
677 .unwrap();
678
679 assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
680 }
681
682 #[sqlx::test]
683 async fn test_add_domain_id(pool: PgPool) {
684 let (storage, key) = setup(pool).await;
685 let key = storage
686 .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
687 .await
688 .unwrap();
689
690 assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
691 }
692
693 #[sqlx::test]
694 async fn test_add_duplicate_domain(pool: PgPool) {
695 let (storage, key) = setup(pool).await;
696 let key = storage
697 .add_domain_to_key(KeySelector::Key(key.key), Domain::All)
698 .await
699 .unwrap();
700 assert_eq!(
701 key.domains
702 .0
703 .into_iter()
704 .filter(|d| *d == Domain::All)
705 .count(),
706 1
707 );
708 }
709
710 #[sqlx::test]
711 async fn test_remove_domain(pool: PgPool) {
712 let (storage, key) = setup(pool).await;
713 storage
714 .add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
715 .await
716 .unwrap();
717 let key = storage
718 .remove_domain_from_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
719 .await
720 .unwrap();
721
722 assert_eq!(key.domains.0, vec![Domain::All]);
723 }
724
725 #[sqlx::test]
726 async fn test_remove_domain_id(pool: PgPool) {
727 let (storage, key) = setup(pool).await;
728 storage
729 .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
730 .await
731 .unwrap();
732 let key = storage
733 .remove_domain_from_key(KeySelector::Id(key.id), Domain::User { id: 1 })
734 .await
735 .unwrap();
736
737 assert_eq!(key.domains.0, vec![Domain::All]);
738 }
739
740 #[sqlx::test]
741 async fn test_remove_last_domain(pool: PgPool) {
742 let (storage, key) = setup(pool).await;
743 let key = storage
744 .remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
745 .await
746 .unwrap();
747
748 assert!(key.domains.0.is_empty());
749 }
750
751 #[sqlx::test]
752 async fn test_store_key(pool: PgPool) {
753 let (storage, _) = setup(pool).await;
754 let key = storage
755 .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
756 .await
757 .unwrap();
758 assert_eq!(key.value(), "ABCDABCDABCDABCD");
759 }
760
761 #[sqlx::test]
762 async fn test_read_user_keys(pool: PgPool) {
763 let (storage, _) = setup(pool).await;
764
765 let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
766 assert_eq!(keys.len(), 1);
767 }
768
769 #[sqlx::test]
770 async fn acquire_one(pool: PgPool) {
771 let (storage, _) = setup(pool).await;
772
773 if let Err(e) = storage.acquire_key(Domain::All).await {
774 panic!("Acquiring key failed: {:?}", e);
775 }
776 }
777
778 #[sqlx::test]
779 async fn uses_spread(pool: PgPool) {
780 let (storage, _) = setup(pool).await;
781 storage
782 .store_key(1, "ABC".to_owned(), vec![Domain::All])
783 .await
784 .unwrap();
785
786 for _ in 0..10 {
787 _ = storage.acquire_key(Domain::All).await.unwrap();
788 }
789
790 let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
791 assert_eq!(keys.len(), 2);
792 for key in keys {
793 assert_eq!(key.uses, 5);
794 }
795 }
796
797 #[sqlx::test]
798 async fn acquire_many(pool: PgPool) {
799 let (storage, _) = setup(pool).await;
800
801 match storage.acquire_many_keys(Domain::All, 30).await {
802 Err(e) => panic!("Acquiring key failed: {:?}", e),
803 Ok(keys) => assert_eq!(keys.len(), 30),
804 }
805 }
806
807 #[sqlx::test]
809 async fn test_concurrent(pool: PgPool) {
810 let storage = Arc::new(setup(pool).await.0);
811
812 for _ in 0..10 {
813 let mut set = tokio::task::JoinSet::new();
814
815 for _ in 0..100 {
816 let storage = storage.clone();
817 set.spawn(async move {
818 storage.acquire_key(Domain::All).await.unwrap();
819 });
820 }
821
822 for _ in 0..100 {
823 set.join_next().await.unwrap().unwrap();
824 }
825
826 let uses: i16 = sqlx::query("select uses from api_keys")
827 .fetch_one(&storage.pool)
828 .await
829 .unwrap()
830 .get("uses");
831
832 assert_eq!(uses, 100);
833
834 sqlx::query("update api_keys set uses=0")
835 .execute(&storage.pool)
836 .await
837 .unwrap();
838 }
839 }
840
841 #[sqlx::test]
842 async fn test_concurrent_spread(pool: PgPool) {
843 let storage = Arc::new(setup(pool).await.0);
844
845 for i in 0..24 {
846 storage
847 .store_key(1, format!("{}", i), vec![Domain::All])
848 .await
849 .unwrap();
850 }
851
852 for _ in 0..10 {
853 let mut set = tokio::task::JoinSet::new();
854
855 for _ in 0..50 {
856 let storage = storage.clone();
857 set.spawn(async move {
858 storage.acquire_key(Domain::All).await.unwrap();
859 });
860 }
861
862 for _ in 0..50 {
863 set.join_next().await.unwrap().unwrap();
864 }
865
866 let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
867
868 assert_eq!(keys.len(), 25);
869
870 for key in keys {
871 assert_eq!(key.uses, 2);
872 }
873
874 sqlx::query("update api_keys set uses=0")
875 .execute(&storage.pool)
876 .await
877 .unwrap();
878 }
879 }
880
881 #[sqlx::test]
883 async fn test_concurrent_many(pool: PgPool) {
884 let storage = Arc::new(setup(pool).await.0);
885 for _ in 0..10 {
886 let mut set = tokio::task::JoinSet::new();
887
888 for _ in 0..100 {
889 let storage = storage.clone();
890 set.spawn(async move {
891 storage.acquire_many_keys(Domain::All, 5).await.unwrap();
892 });
893 }
894
895 for _ in 0..100 {
896 set.join_next().await.unwrap().unwrap();
897 }
898
899 let uses: i16 = sqlx::query("select uses from api_keys")
900 .fetch_one(&storage.pool)
901 .await
902 .unwrap()
903 .get("uses");
904
905 assert_eq!(uses, 500);
906
907 sqlx::query("update api_keys set uses=0")
908 .execute(&storage.pool)
909 .await
910 .unwrap();
911 }
912 }
913
914 #[sqlx::test]
915 async fn read_key(pool: PgPool) {
916 let (storage, key) = setup(pool).await;
917
918 let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
919 assert!(key.is_some());
920 }
921
922 #[sqlx::test]
923 async fn read_key_id(pool: PgPool) {
924 let (storage, key) = setup(pool).await;
925
926 let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
927 assert!(key.is_some());
928 }
929
930 #[sqlx::test]
931 async fn read_nonexistent_key(pool: PgPool) {
932 let (storage, _) = setup(pool).await;
933
934 let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
935 assert!(key.is_none());
936 }
937
938 #[sqlx::test]
939 async fn query_key(pool: PgPool) {
940 let (storage, _) = setup(pool).await;
941
942 let key = storage.read_key(Domain::All).await.unwrap();
943 assert!(key.is_some());
944 }
945
946 #[sqlx::test]
947 async fn query_nonexistent_key(pool: PgPool) {
948 let (storage, _) = setup(pool).await;
949
950 let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
951 assert!(key.is_none());
952 }
953
954 #[sqlx::test]
955 async fn query_all(pool: PgPool) {
956 let (storage, _) = setup(pool).await;
957
958 let keys = storage.read_keys(Domain::All).await.unwrap();
959 assert!(keys.len() == 1);
960 }
961
962 #[sqlx::test]
963 async fn query_by_id(pool: PgPool) {
964 let (storage, _) = setup(pool).await;
965 let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
966
967 assert!(key.is_some());
968 }
969
970 #[sqlx::test]
971 async fn query_by_key(pool: PgPool) {
972 let (storage, key) = setup(pool).await;
973 let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
974
975 assert!(key.is_some());
976 }
977
978 #[sqlx::test]
979 async fn timeout(pool: PgPool) {
980 let (storage, key) = setup(pool).await;
981
982 storage
983 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
984 .await
985 .unwrap();
986 }
987
988 #[sqlx::test]
989 async fn query_by_set(pool: PgPool) {
990 let (storage, _key) = setup(pool).await;
991 let key = storage
992 .read_key(KeySelector::OneOf(vec![
993 Domain::All,
994 Domain::Guild { id: 0 },
995 Domain::Faction { id: 0 },
996 ]))
997 .await
998 .unwrap();
999
1000 assert!(key.is_some());
1001 }
1002
1003 #[sqlx::test]
1004 async fn all_selector(pool: PgPool) {
1005 let (storage, key) = setup(pool).await;
1006
1007 storage
1008 .add_domain_to_key(key.selector(), Domain::Faction { id: 1 })
1009 .await
1010 .unwrap();
1011
1012 let key = storage
1013 .read_key(KeySelector::Has(vec![
1014 Domain::Faction { id: 1 },
1015 Domain::All,
1016 ]))
1017 .await
1018 .unwrap();
1019
1020 assert!(key.is_some());
1021
1022 let key = storage
1023 .read_key(KeySelector::Has(vec![
1024 Domain::All,
1025 Domain::Faction { id: 1 },
1026 ]))
1027 .await
1028 .unwrap();
1029
1030 assert!(key.is_some());
1031
1032 let key = storage
1033 .read_key(KeySelector::Has(vec![
1034 Domain::All,
1035 Domain::Faction { id: 2 },
1036 Domain::Faction { id: 1 },
1037 ]))
1038 .await
1039 .unwrap();
1040
1041 assert!(key.is_none());
1042 }
1043}