1use std::sync::Arc;
2
3use async_trait::async_trait;
4use indoc::indoc;
5use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
6use thiserror::Error;
7
8use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector};
9
10pub trait PgKeyDomain:
11 KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
12{
13}
14
15impl<T> PgKeyDomain for T where
16 T: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
17{
18}
19
20#[derive(Debug, Error, Clone)]
21pub enum PgStorageError<D>
22where
23 D: PgKeyDomain,
24{
25 #[error(transparent)]
26 Pg(Arc<sqlx::Error>),
27
28 #[error("No key avalaible for domain {0:?}")]
29 Unavailable(KeySelector<PgKey<D>, D>),
30
31 #[error("Key not found: '{0:?}'")]
32 KeyNotFound(KeySelector<PgKey<D>, D>),
33}
34
35impl<D> From<sqlx::Error> for PgStorageError<D>
36where
37 D: PgKeyDomain,
38{
39 fn from(value: sqlx::Error) -> Self {
40 Self::Pg(Arc::new(value))
41 }
42}
43
44#[derive(Debug, Clone, FromRow)]
45pub struct PgKey<D>
46where
47 D: PgKeyDomain,
48{
49 pub id: i32,
50 pub user_id: i32,
51 pub key: String,
52 pub uses: i16,
53 pub domains: sqlx::types::Json<Vec<D>>,
54}
55
56#[inline(always)]
57fn build_predicate<'b, D>(
58 builder: &mut QueryBuilder<'b, Postgres>,
59 selector: &'b KeySelector<PgKey<D>, D>,
60) where
61 D: PgKeyDomain,
62{
63 match selector {
64 KeySelector::Id(id) => builder.push("id=").push_bind(id),
65 KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
66 KeySelector::Key(key) => builder.push("key=").push_bind(key),
67 KeySelector::Has(domains) => builder
68 .push("domains @> ")
69 .push_bind(sqlx::types::Json(domains)),
70 KeySelector::OneOf(domains) => {
71 if domains.is_empty() {
72 builder.push("false");
73 return;
74 }
75
76 for (idx, domain) in domains.iter().enumerate() {
77 if idx == 0 {
78 builder.push("(");
79 } else {
80 builder.push(" or ");
81 }
82 builder
83 .push("domains @> ")
84 .push_bind(sqlx::types::Json(vec![domain]));
85 }
86 builder.push(")")
87 }
88 };
89}
90
91#[derive(Debug, Clone, FromRow)]
92pub struct PgKeyPoolStorage<D>
93where
94 D: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
95{
96 pool: PgPool,
97 limit: i16,
98 _phantom: std::marker::PhantomData<D>,
99}
100
101impl<D> ApiKey for PgKey<D>
102where
103 D: PgKeyDomain,
104{
105 type IdType = i32;
106
107 #[inline(always)]
108 fn value(&self) -> &str {
109 &self.key
110 }
111
112 #[inline(always)]
113 fn id(&self) -> Self::IdType {
114 self.id
115 }
116}
117
118impl<D> PgKeyPoolStorage<D>
119where
120 D: PgKeyDomain,
121{
122 pub fn new(pool: PgPool, limit: i16) -> Self {
123 Self {
124 pool,
125 limit,
126 _phantom: Default::default(),
127 }
128 }
129
130 pub async fn initialise(&self) -> Result<(), PgStorageError<D>> {
131 sqlx::query(indoc! {r#"
132 CREATE TABLE IF NOT EXISTS api_keys (
133 id serial primary key,
134 user_id int4 not null,
135 key char(16) not null,
136 uses int2 not null default 0,
137 domains jsonb not null default '{}'::jsonb,
138 last_used timestamptz not null default now(),
139 flag int2,
140 cooldown timestamptz,
141 constraint "uq:api_keys.key" UNIQUE(key)
142 )"#
143 })
144 .execute(&self.pool)
145 .await?;
146
147 sqlx::query(indoc! {r#"
148 CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops)
149 "#})
150 .execute(&self.pool)
151 .await?;
152
153 sqlx::query(indoc! {r#"
154 CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id)
155 "#})
156 .execute(&self.pool)
157 .await?;
158
159 sqlx::query(indoc! {r#"
160 create or replace function __unique_jsonb_array(jsonb) returns jsonb
161 AS $$
162 select jsonb_agg(d::jsonb) from (
163 select distinct jsonb_array_elements_text($1) as d
164 ) t
165 $$ language sql;
166 "#})
167 .execute(&self.pool)
168 .await?;
169
170 sqlx::query(indoc! {r#"
171 create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb
172 AS $$
173 select jsonb_agg(d::jsonb) from (
174 select distinct jsonb_array_elements_text($1) as d
175 ) t where d<>$2::text
176 $$ language sql;
177 "#})
178 .execute(&self.pool)
179 .await?;
180
181 Ok(())
182 }
183}
184
185#[cfg(feature = "tokio-runtime")]
186async fn random_sleep() {
187 use rand::{thread_rng, Rng};
188 let dur = tokio::time::Duration::from_millis(thread_rng().gen_range(1..50));
189 tokio::time::sleep(dur).await;
190}
191
192#[cfg(all(not(feature = "tokio-runtime"), feature = "actix-runtime"))]
193async fn random_sleep() {
194 use rand::{thread_rng, Rng};
195 let dur = std::time::Duration::from_millis(thread_rng().gen_range(1..50));
196 actix_rt::time::sleep(dur).await;
197}
198
199#[async_trait]
200impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
201where
202 D: PgKeyDomain,
203{
204 type Key = PgKey<D>;
205 type Domain = D;
206
207 type Error = PgStorageError<D>;
208
209 async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
210 where
211 S: IntoSelector<Self::Key, Self::Domain>,
212 {
213 let selector = selector.into_selector();
214 loop {
215 let attempt = async {
216 let mut tx = self.pool.begin().await?;
217
218 sqlx::query("set transaction isolation level repeatable read")
219 .execute(&mut *tx)
220 .await?;
221
222 let mut qb = QueryBuilder::new(indoc::indoc! {
223 r#"
224 with key as (
225 select
226 id,
227 0::int2 as uses
228 from api_keys where last_used < date_trunc('minute', now())
229 and (cooldown is null or now() >= cooldown)
230 and "#
231 });
232
233 build_predicate(&mut qb, &selector);
234
235 qb.push(indoc::indoc! {
236 "
237 \n union (
238 select id, uses from api_keys
239 where last_used >= date_trunc('minute', now())
240 and (cooldown is null or now() >= cooldown)
241 and "
242 });
243
244 build_predicate(&mut qb, &selector);
245
246 qb.push(indoc::indoc! {
247 "
248 \n order by uses asc limit 1
249 )
250 order by uses asc limit 1
251 )
252 update api_keys set
253 uses = key.uses + 1,
254 cooldown = null,
255 flag = null,
256 last_used = now()
257 from key where
258 api_keys.id=key.id and key.uses < "
259 });
260
261 qb.push_bind(self.limit);
262
263 qb.push(indoc::indoc! { "
264 \nreturning
265 api_keys.id,
266 api_keys.user_id,
267 api_keys.key,
268 api_keys.uses,
269 api_keys.domains"
270 });
271
272 let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
273
274 tx.commit().await?;
275
276 Result::<Option<Self::Key>, sqlx::Error>::Ok(key)
277 }
278 .await;
279
280 match attempt {
281 Ok(Some(result)) => return Ok(result),
282 Ok(None) => {
283 return self
284 .acquire_key(
285 selector
286 .fallback()
287 .ok_or_else(|| PgStorageError::Unavailable(selector))?,
288 )
289 .await
290 }
291 Err(error) => {
292 if let Some(db_error) = error.as_database_error() {
293 let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
294 if pg_error.code() == "40001" {
295 random_sleep().await;
296 } else {
297 return Err(error.into());
298 }
299 } else {
300 return Err(error.into());
301 }
302 }
303 }
304 }
305 }
306
307 async fn acquire_many_keys<S>(
308 &self,
309 selector: S,
310 number: i64,
311 ) -> Result<Vec<Self::Key>, Self::Error>
312 where
313 S: IntoSelector<Self::Key, Self::Domain>,
314 {
315 let selector = selector.into_selector();
316 loop {
317 let attempt = async {
318 let mut tx = self.pool.begin().await?;
319
320 sqlx::query("set transaction isolation level repeatable read")
321 .execute(&mut *tx)
322 .await?;
323
324 let mut qb = QueryBuilder::new(indoc::indoc! {
325 r#"select
326 id,
327 user_id,
328 key,
329 0::int2 as uses,
330 domains
331 from api_keys where last_used < date_trunc('minute', now())
332 and (cooldown is null or now() >= cooldown)
333 and "#
334 });
335 build_predicate(&mut qb, &selector);
336 qb.push(indoc::indoc! {
337 "
338 \nunion
339 select
340 id,
341 user_id,
342 key,
343 uses,
344 domains
345 from api_keys where last_used >= date_trunc('minute', now())
346 and (cooldown is null or now() >= cooldown)
347 and "
348 });
349 build_predicate(&mut qb, &selector);
350 qb.push("\norder by uses limit ");
351 qb.push_bind(self.limit);
352
353 let mut keys: Vec<Self::Key> = qb.build_query_as().fetch_all(&mut *tx).await?;
354
355 if keys.is_empty() {
356 tx.commit().await?;
357 return Ok(None);
358 }
359
360 keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses));
361
362 let mut result = Vec::with_capacity(number as usize);
363 let (max, rest) = keys.split_last_mut().unwrap();
364 for key in rest {
365 let available = max.uses - key.uses;
366 let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
367 key.uses += using;
368 result.extend(std::iter::repeat(key.clone()).take(using as usize));
369
370 if result.len() == number as usize {
371 break;
372 }
373 }
374
375 while result.len() < (number as usize) {
376 if keys[0].uses == self.limit {
377 break;
378 }
379
380 let take = std::cmp::min(keys.len(), (number as usize) - result.len());
381 let slice = &mut keys[0..take];
382 slice.iter_mut().for_each(|k| k.uses += 1);
383 result.extend_from_slice(slice);
384 }
385
386 sqlx::query(indoc! {r#"
387 update api_keys set
388 uses = tmp.uses,
389 cooldown = null,
390 flag = null,
391 last_used = now()
392 from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
393 where api_keys.id = tmp.id
394 "#})
395 .bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
396 .bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
397 .execute(&mut *tx)
398 .await?;
399
400 tx.commit().await?;
401
402 Result::<Option<Vec<Self::Key>>, sqlx::Error>::Ok(Some(result))
403 }
404 .await;
405
406 match attempt {
407 Ok(Some(result)) => return Ok(result),
408 Ok(None) => {
409 return self
410 .acquire_many_keys(
411 selector
412 .fallback()
413 .ok_or_else(|| Self::Error::Unavailable(selector))?,
414 number,
415 )
416 .await
417 }
418 Err(error) => {
419 if let Some(db_error) = error.as_database_error() {
420 let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
421 if pg_error.code() == "40001" {
422 random_sleep().await;
423 } else {
424 return Err(error.into());
425 }
426 } else {
427 return Err(error.into());
428 }
429 }
430 }
431 }
432 }
433
434 async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
435 match code {
436 2 | 10 | 13 => {
437 sqlx::query(
439 "update api_keys set cooldown='infinity'::timestamptz, flag=$1 where id=$2",
440 )
441 .bind(code as i16)
442 .bind(key.id)
443 .execute(&self.pool)
444 .await?;
445 Ok(true)
446 }
447 5 => {
448 sqlx::query(
450 "update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', \
451 flag=5 where id=$1",
452 )
453 .bind(key.id)
454 .execute(&self.pool)
455 .await?;
456 Ok(true)
457 }
458 8 => {
459 sqlx::query("update api_keys set cooldown=now() + interval '5 min', flag=8")
461 .execute(&self.pool)
462 .await?;
463 Ok(false)
464 }
465 9 => {
466 sqlx::query("update api_keys set cooldown=now() + interval '1 min', flag=9")
468 .execute(&self.pool)
469 .await?;
470 Ok(false)
471 }
472 14 => {
473 sqlx::query(
475 "update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', \
476 flag=14 where id=$1",
477 )
478 .bind(key.id)
479 .execute(&self.pool)
480 .await?;
481 Ok(true)
482 }
483 _ => Ok(false),
484 }
485 }
486
487 async fn store_key(
488 &self,
489 user_id: i32,
490 key: String,
491 domains: Vec<D>,
492 ) -> Result<Self::Key, Self::Error> {
493 sqlx::query_as(
494 "insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \
495 constraint \"uq:api_keys.key\" do update set domains = \
496 __unique_jsonb_array(excluded.domains || api_keys.domains) returning *",
497 )
498 .bind(user_id)
499 .bind(&key)
500 .bind(sqlx::types::Json(domains))
501 .fetch_one(&self.pool)
502 .await
503 .map_err(Into::into)
504 }
505
506 async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
507 where
508 S: IntoSelector<Self::Key, Self::Domain>,
509 {
510 let selector = selector.into_selector();
511
512 let mut qb = QueryBuilder::new("select * from api_keys where ");
513 build_predicate(&mut qb, &selector);
514
515 qb.build_query_as()
516 .fetch_optional(&self.pool)
517 .await
518 .map_err(Into::into)
519 }
520
521 async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
522 where
523 S: IntoSelector<Self::Key, Self::Domain>,
524 {
525 let selector = selector.into_selector();
526
527 let mut qb = QueryBuilder::new("select * from api_keys where ");
528 build_predicate(&mut qb, &selector);
529
530 qb.build_query_as()
531 .fetch_all(&self.pool)
532 .await
533 .map_err(Into::into)
534 }
535
536 async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
537 where
538 S: IntoSelector<Self::Key, Self::Domain>,
539 {
540 let selector = selector.into_selector();
541
542 let mut qb = QueryBuilder::new("delete from api_keys where ");
543 build_predicate(&mut qb, &selector);
544 qb.push(" returning *");
545
546 qb.build_query_as()
547 .fetch_optional(&self.pool)
548 .await?
549 .ok_or_else(|| PgStorageError::KeyNotFound(selector))
550 }
551
552 async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
553 where
554 S: IntoSelector<Self::Key, Self::Domain>,
555 {
556 let selector = selector.into_selector();
557
558 let mut qb = QueryBuilder::new(
559 "update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(",
560 );
561 qb.push_bind(sqlx::types::Json(domain));
562 qb.push(")) where ");
563 build_predicate(&mut qb, &selector);
564 qb.push(" returning *");
565
566 qb.build_query_as()
567 .fetch_optional(&self.pool)
568 .await?
569 .ok_or_else(|| PgStorageError::KeyNotFound(selector))
570 }
571
572 async fn remove_domain_from_key<S>(
573 &self,
574 selector: S,
575 domain: D,
576 ) -> Result<Self::Key, Self::Error>
577 where
578 S: IntoSelector<Self::Key, Self::Domain>,
579 {
580 let selector = selector.into_selector();
581
582 let mut qb = QueryBuilder::new(
583 "update api_keys set domains = coalesce(__filter_jsonb_array(domains, ",
584 );
585 qb.push_bind(sqlx::types::Json(domain));
586 qb.push("), '[]'::jsonb) where ");
587 build_predicate(&mut qb, &selector);
588 qb.push(" returning *");
589
590 qb.build_query_as()
591 .fetch_optional(&self.pool)
592 .await?
593 .ok_or_else(|| PgStorageError::KeyNotFound(selector))
594 }
595
596 async fn set_domains_for_key<S>(
597 &self,
598 selector: S,
599 domains: Vec<D>,
600 ) -> Result<Self::Key, Self::Error>
601 where
602 S: IntoSelector<Self::Key, Self::Domain>,
603 {
604 let selector = selector.into_selector();
605
606 let mut qb = QueryBuilder::new("update api_keys set domains = ");
607 qb.push_bind(sqlx::types::Json(domains));
608 qb.push(" where ");
609 build_predicate(&mut qb, &selector);
610 qb.push(" returning *");
611
612 qb.build_query_as()
613 .fetch_optional(&self.pool)
614 .await?
615 .ok_or_else(|| PgStorageError::KeyNotFound(selector))
616 }
617}
618
619#[cfg(test)]
620pub(crate) mod test {
621 use std::sync::Arc;
622
623 use sqlx::Row;
624
625 use super::*;
626
627 #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
628 #[serde(tag = "type", rename_all = "snake_case")]
629 pub(crate) enum Domain {
630 All,
631 Guild { id: i64 },
632 User { id: i32 },
633 Faction { id: i32 },
634 }
635
636 impl KeyDomain for Domain {
637 fn fallback(&self) -> Option<Self> {
638 match self {
639 Self::Guild { id: _ } => Some(Self::All),
640 _ => None,
641 }
642 }
643 }
644
645 pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
646 sqlx::query("DROP TABLE IF EXISTS api_keys")
647 .execute(&pool)
648 .await
649 .unwrap();
650
651 let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
652 storage.initialise().await.unwrap();
653
654 let key = storage
655 .store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All])
656 .await
657 .unwrap();
658
659 (storage, key)
660 }
661
662 #[sqlx::test]
663 async fn test_initialise(pool: PgPool) {
664 let (storage, _) = setup(pool).await;
665
666 if let Err(e) = storage.initialise().await {
667 panic!("Initialising key storage failed: {:?}", e);
668 }
669 }
670
671 #[sqlx::test]
672 async fn test_store_duplicate_key(pool: PgPool) {
673 let (storage, key) = setup(pool).await;
674 let key = storage
675 .store_key(1, key.key, vec![Domain::User { id: 1 }])
676 .await
677 .unwrap();
678
679 assert_eq!(key.domains.0.len(), 2);
680 }
681
682 #[sqlx::test]
683 async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) {
684 let (storage, key) = setup(pool).await;
685 let key = storage
686 .store_key(1, key.key, vec![Domain::All])
687 .await
688 .unwrap();
689
690 assert_eq!(key.domains.0.len(), 1);
691 }
692
693 #[sqlx::test]
694 async fn test_add_domain(pool: PgPool) {
695 let (storage, key) = setup(pool).await;
696 let key = storage
697 .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
698 .await
699 .unwrap();
700
701 assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
702 }
703
704 #[sqlx::test]
705 async fn test_add_domain_id(pool: PgPool) {
706 let (storage, key) = setup(pool).await;
707 let key = storage
708 .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
709 .await
710 .unwrap();
711
712 assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
713 }
714
715 #[sqlx::test]
716 async fn test_add_duplicate_domain(pool: PgPool) {
717 let (storage, key) = setup(pool).await;
718 let key = storage
719 .add_domain_to_key(KeySelector::Key(key.key), Domain::All)
720 .await
721 .unwrap();
722 assert_eq!(
723 key.domains
724 .0
725 .into_iter()
726 .filter(|d| *d == Domain::All)
727 .count(),
728 1
729 );
730 }
731
732 #[sqlx::test]
733 async fn test_remove_domain(pool: PgPool) {
734 let (storage, key) = setup(pool).await;
735 storage
736 .add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
737 .await
738 .unwrap();
739 let key = storage
740 .remove_domain_from_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
741 .await
742 .unwrap();
743
744 assert_eq!(key.domains.0, vec![Domain::All]);
745 }
746
747 #[sqlx::test]
748 async fn test_remove_domain_id(pool: PgPool) {
749 let (storage, key) = setup(pool).await;
750 storage
751 .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
752 .await
753 .unwrap();
754 let key = storage
755 .remove_domain_from_key(KeySelector::Id(key.id), Domain::User { id: 1 })
756 .await
757 .unwrap();
758
759 assert_eq!(key.domains.0, vec![Domain::All]);
760 }
761
762 #[sqlx::test]
763 async fn test_remove_last_domain(pool: PgPool) {
764 let (storage, key) = setup(pool).await;
765 let key = storage
766 .remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
767 .await
768 .unwrap();
769
770 assert!(key.domains.0.is_empty());
771 }
772
773 #[sqlx::test]
774 async fn test_store_key(pool: PgPool) {
775 let (storage, _) = setup(pool).await;
776 let key = storage
777 .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
778 .await
779 .unwrap();
780 assert_eq!(key.value(), "ABCDABCDABCDABCD");
781 }
782
783 #[sqlx::test]
784 async fn test_read_user_keys(pool: PgPool) {
785 let (storage, _) = setup(pool).await;
786
787 let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
788 assert_eq!(keys.len(), 1);
789 }
790
791 #[sqlx::test]
792 async fn acquire_one(pool: PgPool) {
793 let (storage, _) = setup(pool).await;
794
795 if let Err(e) = storage.acquire_key(Domain::All).await {
796 panic!("Acquiring key failed: {:?}", e);
797 }
798 }
799
800 #[sqlx::test]
801 async fn uses_spread(pool: PgPool) {
802 let (storage, _) = setup(pool).await;
803 storage
804 .store_key(1, "ABC".to_owned(), vec![Domain::All])
805 .await
806 .unwrap();
807
808 for _ in 0..10 {
809 _ = storage.acquire_key(Domain::All).await.unwrap();
810 }
811
812 let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
813 assert_eq!(keys.len(), 2);
814 for key in keys {
815 assert_eq!(key.uses, 5);
816 }
817 }
818
819 #[sqlx::test]
820 async fn test_flag_key_one(pool: PgPool) {
821 let (storage, key) = setup(pool).await;
822
823 assert!(storage.flag_key(key, 2).await.unwrap());
824
825 match storage.acquire_key(Domain::All).await.unwrap_err() {
826 PgStorageError::Unavailable(KeySelector::Has(domains)) => {
827 assert_eq!(domains, vec![Domain::All])
828 }
829 why => panic!("Expected domain unavailable error but found '{why}'"),
830 }
831 }
832
833 #[sqlx::test]
834 async fn test_flag_key_many(pool: PgPool) {
835 let (storage, key) = setup(pool).await;
836
837 assert!(storage.flag_key(key, 2).await.unwrap());
838
839 match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
840 PgStorageError::Unavailable(KeySelector::Has(domains)) => {
841 assert_eq!(domains, vec![Domain::All])
842 }
843 why => panic!("Expected domain unavailable error but found '{why}'"),
844 }
845 }
846
847 #[sqlx::test]
848 async fn acquire_many(pool: PgPool) {
849 let (storage, _) = setup(pool).await;
850
851 match storage.acquire_many_keys(Domain::All, 30).await {
852 Err(e) => panic!("Acquiring key failed: {:?}", e),
853 Ok(keys) => assert_eq!(keys.len(), 30),
854 }
855 }
856
857 #[sqlx::test]
859 async fn test_concurrent(pool: PgPool) {
860 let storage = Arc::new(setup(pool).await.0);
861
862 for _ in 0..10 {
863 let mut set = tokio::task::JoinSet::new();
864
865 for _ in 0..100 {
866 let storage = storage.clone();
867 set.spawn(async move {
868 storage.acquire_key(Domain::All).await.unwrap();
869 });
870 }
871
872 for _ in 0..100 {
873 set.join_next().await.unwrap().unwrap();
874 }
875
876 let uses: i16 = sqlx::query("select uses from api_keys")
877 .fetch_one(&storage.pool)
878 .await
879 .unwrap()
880 .get("uses");
881
882 assert_eq!(uses, 100);
883
884 sqlx::query("update api_keys set uses=0")
885 .execute(&storage.pool)
886 .await
887 .unwrap();
888 }
889 }
890
891 #[sqlx::test]
892 async fn test_concurrent_spread(pool: PgPool) {
893 let storage = Arc::new(setup(pool).await.0);
894
895 for i in 0..24 {
896 storage
897 .store_key(1, format!("{}", i), vec![Domain::All])
898 .await
899 .unwrap();
900 }
901
902 for _ in 0..10 {
903 let mut set = tokio::task::JoinSet::new();
904
905 for _ in 0..50 {
906 let storage = storage.clone();
907 set.spawn(async move {
908 storage.acquire_key(Domain::All).await.unwrap();
909 });
910 }
911
912 for _ in 0..50 {
913 set.join_next().await.unwrap().unwrap();
914 }
915
916 let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
917
918 assert_eq!(keys.len(), 25);
919
920 for key in keys {
921 assert_eq!(key.uses, 2);
922 }
923
924 sqlx::query("update api_keys set uses=0")
925 .execute(&storage.pool)
926 .await
927 .unwrap();
928 }
929 }
930
931 #[sqlx::test]
933 async fn test_concurrent_many(pool: PgPool) {
934 let storage = Arc::new(setup(pool).await.0);
935 for _ in 0..10 {
936 let mut set = tokio::task::JoinSet::new();
937
938 for _ in 0..100 {
939 let storage = storage.clone();
940 set.spawn(async move {
941 storage.acquire_many_keys(Domain::All, 5).await.unwrap();
942 });
943 }
944
945 for _ in 0..100 {
946 set.join_next().await.unwrap().unwrap();
947 }
948
949 let uses: i16 = sqlx::query("select uses from api_keys")
950 .fetch_one(&storage.pool)
951 .await
952 .unwrap()
953 .get("uses");
954
955 assert_eq!(uses, 500);
956
957 sqlx::query("update api_keys set uses=0")
958 .execute(&storage.pool)
959 .await
960 .unwrap();
961 }
962 }
963
964 #[sqlx::test]
965 async fn read_key(pool: PgPool) {
966 let (storage, key) = setup(pool).await;
967
968 let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
969 assert!(key.is_some());
970 }
971
972 #[sqlx::test]
973 async fn read_key_id(pool: PgPool) {
974 let (storage, key) = setup(pool).await;
975
976 let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
977 assert!(key.is_some());
978 }
979
980 #[sqlx::test]
981 async fn read_nonexistent_key(pool: PgPool) {
982 let (storage, _) = setup(pool).await;
983
984 let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
985 assert!(key.is_none());
986 }
987
988 #[sqlx::test]
989 async fn query_key(pool: PgPool) {
990 let (storage, _) = setup(pool).await;
991
992 let key = storage.read_key(Domain::All).await.unwrap();
993 assert!(key.is_some());
994 }
995
996 #[sqlx::test]
997 async fn query_nonexistent_key(pool: PgPool) {
998 let (storage, _) = setup(pool).await;
999
1000 let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
1001 assert!(key.is_none());
1002 }
1003
1004 #[sqlx::test]
1005 async fn query_all(pool: PgPool) {
1006 let (storage, _) = setup(pool).await;
1007
1008 let keys = storage.read_keys(Domain::All).await.unwrap();
1009 assert!(keys.len() == 1);
1010 }
1011
1012 #[sqlx::test]
1013 async fn query_by_id(pool: PgPool) {
1014 let (storage, _) = setup(pool).await;
1015 let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
1016
1017 assert!(key.is_some());
1018 }
1019
1020 #[sqlx::test]
1021 async fn query_by_key(pool: PgPool) {
1022 let (storage, key) = setup(pool).await;
1023 let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
1024
1025 assert!(key.is_some());
1026 }
1027
1028 #[sqlx::test]
1029 async fn query_by_set(pool: PgPool) {
1030 let (storage, _key) = setup(pool).await;
1031 let key = storage
1032 .read_key(KeySelector::OneOf(vec![
1033 Domain::All,
1034 Domain::Guild { id: 0 },
1035 Domain::Faction { id: 0 },
1036 ]))
1037 .await
1038 .unwrap();
1039
1040 assert!(key.is_some());
1041 }
1042
1043 #[sqlx::test]
1044 async fn all_selector(pool: PgPool) {
1045 let (storage, key) = setup(pool).await;
1046
1047 storage
1048 .add_domain_to_key(key.selector(), Domain::Faction { id: 1 })
1049 .await
1050 .unwrap();
1051
1052 let key = storage
1053 .read_key(KeySelector::Has(vec![
1054 Domain::Faction { id: 1 },
1055 Domain::All,
1056 ]))
1057 .await
1058 .unwrap();
1059
1060 assert!(key.is_some());
1061
1062 let key = storage
1063 .read_key(KeySelector::Has(vec![
1064 Domain::All,
1065 Domain::Faction { id: 1 },
1066 ]))
1067 .await
1068 .unwrap();
1069
1070 assert!(key.is_some());
1071
1072 let key = storage
1073 .read_key(KeySelector::Has(vec![
1074 Domain::All,
1075 Domain::Faction { id: 2 },
1076 Domain::Faction { id: 1 },
1077 ]))
1078 .await
1079 .unwrap();
1080
1081 assert!(key.is_none());
1082 }
1083}