torn_key_pool/
postgres.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use indoc::indoc;
5use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
6use thiserror::Error;
7
8use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector};
9
10pub trait PgKeyDomain:
11    KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
12{
13}
14
15impl<T> PgKeyDomain for T where
16    T: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
17{
18}
19
20#[derive(Debug, Error, Clone)]
21pub enum PgStorageError<D>
22where
23    D: PgKeyDomain,
24{
25    #[error(transparent)]
26    Pg(Arc<sqlx::Error>),
27
28    #[error("No key avalaible for domain {0:?}")]
29    Unavailable(KeySelector<PgKey<D>, D>),
30
31    #[error("Key not found: '{0:?}'")]
32    KeyNotFound(KeySelector<PgKey<D>, D>),
33}
34
35impl<D> From<sqlx::Error> for PgStorageError<D>
36where
37    D: PgKeyDomain,
38{
39    fn from(value: sqlx::Error) -> Self {
40        Self::Pg(Arc::new(value))
41    }
42}
43
44#[derive(Debug, Clone, FromRow)]
45pub struct PgKey<D>
46where
47    D: PgKeyDomain,
48{
49    pub id: i32,
50    pub user_id: i32,
51    pub key: String,
52    pub uses: i16,
53    pub domains: sqlx::types::Json<Vec<D>>,
54}
55
56#[inline(always)]
57fn build_predicate<'b, D>(
58    builder: &mut QueryBuilder<'b, Postgres>,
59    selector: &'b KeySelector<PgKey<D>, D>,
60) where
61    D: PgKeyDomain,
62{
63    match selector {
64        KeySelector::Id(id) => builder.push("id=").push_bind(id),
65        KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
66        KeySelector::Key(key) => builder.push("key=").push_bind(key),
67        KeySelector::Has(domains) => builder
68            .push("domains @> ")
69            .push_bind(sqlx::types::Json(domains)),
70        KeySelector::OneOf(domains) => {
71            if domains.is_empty() {
72                builder.push("false");
73                return;
74            }
75
76            for (idx, domain) in domains.iter().enumerate() {
77                if idx == 0 {
78                    builder.push("(");
79                } else {
80                    builder.push(" or ");
81                }
82                builder
83                    .push("domains @> ")
84                    .push_bind(sqlx::types::Json(vec![domain]));
85            }
86            builder.push(")")
87        }
88    };
89}
90
91#[derive(Debug, Clone, FromRow)]
92pub struct PgKeyPoolStorage<D>
93where
94    D: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
95{
96    pool: PgPool,
97    limit: i16,
98    _phantom: std::marker::PhantomData<D>,
99}
100
101impl<D> ApiKey for PgKey<D>
102where
103    D: PgKeyDomain,
104{
105    type IdType = i32;
106
107    #[inline(always)]
108    fn value(&self) -> &str {
109        &self.key
110    }
111
112    #[inline(always)]
113    fn id(&self) -> Self::IdType {
114        self.id
115    }
116}
117
118impl<D> PgKeyPoolStorage<D>
119where
120    D: PgKeyDomain,
121{
122    pub fn new(pool: PgPool, limit: i16) -> Self {
123        Self {
124            pool,
125            limit,
126            _phantom: Default::default(),
127        }
128    }
129
130    pub async fn initialise(&self) -> Result<(), PgStorageError<D>> {
131        sqlx::query(indoc! {r#"
132            CREATE TABLE IF NOT EXISTS api_keys (
133                id serial primary key,
134                user_id int4 not null,
135                key char(16) not null,
136                uses int2 not null default 0,
137                domains jsonb not null default '{}'::jsonb,
138                last_used timestamptz not null default now(),
139                flag int2,
140                cooldown timestamptz,
141                constraint "uq:api_keys.key" UNIQUE(key)
142            )"#
143        })
144        .execute(&self.pool)
145        .await?;
146
147        sqlx::query(indoc! {r#"
148            CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops)
149        "#})
150        .execute(&self.pool)
151        .await?;
152
153        sqlx::query(indoc! {r#"
154            CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id)
155        "#})
156        .execute(&self.pool)
157        .await?;
158
159        sqlx::query(indoc! {r#"
160            create or replace function __unique_jsonb_array(jsonb) returns jsonb
161                AS $$
162                    select jsonb_agg(d::jsonb) from (
163                        select distinct jsonb_array_elements_text($1) as d
164                    ) t
165                $$ language sql;
166        "#})
167        .execute(&self.pool)
168        .await?;
169
170        sqlx::query(indoc! {r#"
171            create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb
172                AS $$
173                    select jsonb_agg(d::jsonb) from (
174                        select distinct jsonb_array_elements_text($1) as d
175                    ) t where d<>$2::text
176                $$ language sql;
177        "#})
178        .execute(&self.pool)
179        .await?;
180
181        Ok(())
182    }
183}
184
185#[cfg(feature = "tokio-runtime")]
186async fn random_sleep() {
187    use rand::{thread_rng, Rng};
188    let dur = tokio::time::Duration::from_millis(thread_rng().gen_range(1..50));
189    tokio::time::sleep(dur).await;
190}
191
192#[cfg(all(not(feature = "tokio-runtime"), feature = "actix-runtime"))]
193async fn random_sleep() {
194    use rand::{thread_rng, Rng};
195    let dur = std::time::Duration::from_millis(thread_rng().gen_range(1..50));
196    actix_rt::time::sleep(dur).await;
197}
198
199#[async_trait]
200impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
201where
202    D: PgKeyDomain,
203{
204    type Key = PgKey<D>;
205    type Domain = D;
206
207    type Error = PgStorageError<D>;
208
209    async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
210    where
211        S: IntoSelector<Self::Key, Self::Domain>,
212    {
213        let selector = selector.into_selector();
214        loop {
215            let attempt = async {
216                let mut tx = self.pool.begin().await?;
217
218                sqlx::query("set transaction isolation level repeatable read")
219                    .execute(&mut *tx)
220                    .await?;
221
222                let mut qb = QueryBuilder::new(indoc::indoc! {
223                    r#"
224                    with key as (
225                        select 
226                            id,
227                            0::int2 as uses
228                        from api_keys where last_used < date_trunc('minute', now()) 
229                            and (cooldown is null or now() >= cooldown)
230                            and "#
231                });
232
233                build_predicate(&mut qb, &selector);
234
235                qb.push(indoc::indoc! {
236                    "
237                    \n    union (
238                            select id, uses from api_keys 
239                            where last_used >= date_trunc('minute', now()) 
240                                and (cooldown is null or now() >= cooldown) 
241                                and "
242                });
243
244                build_predicate(&mut qb, &selector);
245
246                qb.push(indoc::indoc! {
247                    "
248                    \n        order by uses asc limit 1
249                        )
250                        order by uses asc limit 1
251                    )
252                    update api_keys set
253                        uses = key.uses + 1,
254                        cooldown = null,
255                        flag = null,
256                        last_used = now()
257                    from key where 
258                        api_keys.id=key.id and key.uses < "
259                });
260
261                qb.push_bind(self.limit);
262
263                qb.push(indoc::indoc! { "
264                    \nreturning
265                        api_keys.id,
266                        api_keys.user_id,
267                        api_keys.key,
268                        api_keys.uses,
269                        api_keys.domains"
270                });
271
272                let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
273
274                tx.commit().await?;
275
276                Result::<Option<Self::Key>, sqlx::Error>::Ok(key)
277            }
278            .await;
279
280            match attempt {
281                Ok(Some(result)) => return Ok(result),
282                Ok(None) => {
283                    return self
284                        .acquire_key(
285                            selector
286                                .fallback()
287                                .ok_or_else(|| PgStorageError::Unavailable(selector))?,
288                        )
289                        .await
290                }
291                Err(error) => {
292                    if let Some(db_error) = error.as_database_error() {
293                        let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
294                        if pg_error.code() == "40001" {
295                            random_sleep().await;
296                        } else {
297                            return Err(error.into());
298                        }
299                    } else {
300                        return Err(error.into());
301                    }
302                }
303            }
304        }
305    }
306
307    async fn acquire_many_keys<S>(
308        &self,
309        selector: S,
310        number: i64,
311    ) -> Result<Vec<Self::Key>, Self::Error>
312    where
313        S: IntoSelector<Self::Key, Self::Domain>,
314    {
315        let selector = selector.into_selector();
316        loop {
317            let attempt = async {
318                let mut tx = self.pool.begin().await?;
319
320                sqlx::query("set transaction isolation level repeatable read")
321                    .execute(&mut *tx)
322                    .await?;
323
324                let mut qb = QueryBuilder::new(indoc::indoc! {
325                    r#"select
326                        id,
327                        user_id,
328                        key,
329                        0::int2 as uses,
330                        domains
331                    from api_keys where last_used < date_trunc('minute', now())
332                        and (cooldown is null or now() >= cooldown)
333                        and "#
334                });
335                build_predicate(&mut qb, &selector);
336                qb.push(indoc::indoc! {
337                    "
338                    \nunion
339                    select
340                        id,
341                        user_id,
342                        key,
343                        uses,
344                        domains
345                    from api_keys where last_used >= date_trunc('minute', now())
346                        and (cooldown is null or now() >= cooldown)
347                        and "
348                });
349                build_predicate(&mut qb, &selector);
350                qb.push("\norder by uses limit ");
351                qb.push_bind(self.limit);
352
353                let mut keys: Vec<Self::Key> = qb.build_query_as().fetch_all(&mut *tx).await?;
354
355                if keys.is_empty() {
356                    tx.commit().await?;
357                    return Ok(None);
358                }
359
360                keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses));
361
362                let mut result = Vec::with_capacity(number as usize);
363                let (max, rest) = keys.split_last_mut().unwrap();
364                for key in rest {
365                    let available = max.uses - key.uses;
366                    let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
367                    key.uses += using;
368                    result.extend(std::iter::repeat(key.clone()).take(using as usize));
369
370                    if result.len() == number as usize {
371                        break;
372                    }
373                }
374
375                while result.len() < (number as usize) {
376                    if keys[0].uses == self.limit {
377                        break;
378                    }
379
380                    let take = std::cmp::min(keys.len(), (number as usize) - result.len());
381                    let slice = &mut keys[0..take];
382                    slice.iter_mut().for_each(|k| k.uses += 1);
383                    result.extend_from_slice(slice);
384                }
385
386                sqlx::query(indoc! {r#"
387                    update api_keys set
388                        uses = tmp.uses,
389                        cooldown = null,
390                        flag = null,
391                        last_used = now()
392                    from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
393                    where api_keys.id = tmp.id
394                "#})
395                .bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
396                .bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
397                .execute(&mut *tx)
398                .await?;
399
400                tx.commit().await?;
401
402                Result::<Option<Vec<Self::Key>>, sqlx::Error>::Ok(Some(result))
403            }
404            .await;
405
406            match attempt {
407                Ok(Some(result)) => return Ok(result),
408                Ok(None) => {
409                    return self
410                        .acquire_many_keys(
411                            selector
412                                .fallback()
413                                .ok_or_else(|| Self::Error::Unavailable(selector))?,
414                            number,
415                        )
416                        .await
417                }
418                Err(error) => {
419                    if let Some(db_error) = error.as_database_error() {
420                        let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
421                        if pg_error.code() == "40001" {
422                            random_sleep().await;
423                        } else {
424                            return Err(error.into());
425                        }
426                    } else {
427                        return Err(error.into());
428                    }
429                }
430            }
431        }
432    }
433
434    async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
435        match code {
436            2 | 10 | 13 => {
437                // invalid key, owner fedded or owner inactive
438                sqlx::query(
439                    "update api_keys set cooldown='infinity'::timestamptz, flag=$1 where id=$2",
440                )
441                .bind(code as i16)
442                .bind(key.id)
443                .execute(&self.pool)
444                .await?;
445                Ok(true)
446            }
447            5 => {
448                // too many requests
449                sqlx::query(
450                    "update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', \
451                     flag=5 where id=$1",
452                )
453                .bind(key.id)
454                .execute(&self.pool)
455                .await?;
456                Ok(true)
457            }
458            8 => {
459                // IP block
460                sqlx::query("update api_keys set cooldown=now() + interval '5 min', flag=8")
461                    .execute(&self.pool)
462                    .await?;
463                Ok(false)
464            }
465            9 => {
466                // API disabled
467                sqlx::query("update api_keys set cooldown=now() + interval '1 min', flag=9")
468                    .execute(&self.pool)
469                    .await?;
470                Ok(false)
471            }
472            14 => {
473                // daily read limit reached
474                sqlx::query(
475                    "update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', \
476                     flag=14 where id=$1",
477                )
478                .bind(key.id)
479                .execute(&self.pool)
480                .await?;
481                Ok(true)
482            }
483            _ => Ok(false),
484        }
485    }
486
487    async fn store_key(
488        &self,
489        user_id: i32,
490        key: String,
491        domains: Vec<D>,
492    ) -> Result<Self::Key, Self::Error> {
493        sqlx::query_as(
494            "insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \
495             constraint \"uq:api_keys.key\" do update set domains = \
496             __unique_jsonb_array(excluded.domains || api_keys.domains) returning *",
497        )
498        .bind(user_id)
499        .bind(&key)
500        .bind(sqlx::types::Json(domains))
501        .fetch_one(&self.pool)
502        .await
503        .map_err(Into::into)
504    }
505
506    async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
507    where
508        S: IntoSelector<Self::Key, Self::Domain>,
509    {
510        let selector = selector.into_selector();
511
512        let mut qb = QueryBuilder::new("select * from api_keys where ");
513        build_predicate(&mut qb, &selector);
514
515        qb.build_query_as()
516            .fetch_optional(&self.pool)
517            .await
518            .map_err(Into::into)
519    }
520
521    async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
522    where
523        S: IntoSelector<Self::Key, Self::Domain>,
524    {
525        let selector = selector.into_selector();
526
527        let mut qb = QueryBuilder::new("select * from api_keys where ");
528        build_predicate(&mut qb, &selector);
529
530        qb.build_query_as()
531            .fetch_all(&self.pool)
532            .await
533            .map_err(Into::into)
534    }
535
536    async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
537    where
538        S: IntoSelector<Self::Key, Self::Domain>,
539    {
540        let selector = selector.into_selector();
541
542        let mut qb = QueryBuilder::new("delete from api_keys where ");
543        build_predicate(&mut qb, &selector);
544        qb.push(" returning *");
545
546        qb.build_query_as()
547            .fetch_optional(&self.pool)
548            .await?
549            .ok_or_else(|| PgStorageError::KeyNotFound(selector))
550    }
551
552    async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
553    where
554        S: IntoSelector<Self::Key, Self::Domain>,
555    {
556        let selector = selector.into_selector();
557
558        let mut qb = QueryBuilder::new(
559            "update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(",
560        );
561        qb.push_bind(sqlx::types::Json(domain));
562        qb.push(")) where ");
563        build_predicate(&mut qb, &selector);
564        qb.push(" returning *");
565
566        qb.build_query_as()
567            .fetch_optional(&self.pool)
568            .await?
569            .ok_or_else(|| PgStorageError::KeyNotFound(selector))
570    }
571
572    async fn remove_domain_from_key<S>(
573        &self,
574        selector: S,
575        domain: D,
576    ) -> Result<Self::Key, Self::Error>
577    where
578        S: IntoSelector<Self::Key, Self::Domain>,
579    {
580        let selector = selector.into_selector();
581
582        let mut qb = QueryBuilder::new(
583            "update api_keys set domains = coalesce(__filter_jsonb_array(domains, ",
584        );
585        qb.push_bind(sqlx::types::Json(domain));
586        qb.push("), '[]'::jsonb) where ");
587        build_predicate(&mut qb, &selector);
588        qb.push(" returning *");
589
590        qb.build_query_as()
591            .fetch_optional(&self.pool)
592            .await?
593            .ok_or_else(|| PgStorageError::KeyNotFound(selector))
594    }
595
596    async fn set_domains_for_key<S>(
597        &self,
598        selector: S,
599        domains: Vec<D>,
600    ) -> Result<Self::Key, Self::Error>
601    where
602        S: IntoSelector<Self::Key, Self::Domain>,
603    {
604        let selector = selector.into_selector();
605
606        let mut qb = QueryBuilder::new("update api_keys set domains = ");
607        qb.push_bind(sqlx::types::Json(domains));
608        qb.push(" where ");
609        build_predicate(&mut qb, &selector);
610        qb.push(" returning *");
611
612        qb.build_query_as()
613            .fetch_optional(&self.pool)
614            .await?
615            .ok_or_else(|| PgStorageError::KeyNotFound(selector))
616    }
617}
618
619#[cfg(test)]
620pub(crate) mod test {
621    use std::sync::Arc;
622
623    use sqlx::Row;
624
625    use super::*;
626
627    #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
628    #[serde(tag = "type", rename_all = "snake_case")]
629    pub(crate) enum Domain {
630        All,
631        Guild { id: i64 },
632        User { id: i32 },
633        Faction { id: i32 },
634    }
635
636    impl KeyDomain for Domain {
637        fn fallback(&self) -> Option<Self> {
638            match self {
639                Self::Guild { id: _ } => Some(Self::All),
640                _ => None,
641            }
642        }
643    }
644
645    pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
646        sqlx::query("DROP TABLE IF EXISTS api_keys")
647            .execute(&pool)
648            .await
649            .unwrap();
650
651        let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
652        storage.initialise().await.unwrap();
653
654        let key = storage
655            .store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All])
656            .await
657            .unwrap();
658
659        (storage, key)
660    }
661
662    #[sqlx::test]
663    async fn test_initialise(pool: PgPool) {
664        let (storage, _) = setup(pool).await;
665
666        if let Err(e) = storage.initialise().await {
667            panic!("Initialising key storage failed: {:?}", e);
668        }
669    }
670
671    #[sqlx::test]
672    async fn test_store_duplicate_key(pool: PgPool) {
673        let (storage, key) = setup(pool).await;
674        let key = storage
675            .store_key(1, key.key, vec![Domain::User { id: 1 }])
676            .await
677            .unwrap();
678
679        assert_eq!(key.domains.0.len(), 2);
680    }
681
682    #[sqlx::test]
683    async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) {
684        let (storage, key) = setup(pool).await;
685        let key = storage
686            .store_key(1, key.key, vec![Domain::All])
687            .await
688            .unwrap();
689
690        assert_eq!(key.domains.0.len(), 1);
691    }
692
693    #[sqlx::test]
694    async fn test_add_domain(pool: PgPool) {
695        let (storage, key) = setup(pool).await;
696        let key = storage
697            .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
698            .await
699            .unwrap();
700
701        assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
702    }
703
704    #[sqlx::test]
705    async fn test_add_domain_id(pool: PgPool) {
706        let (storage, key) = setup(pool).await;
707        let key = storage
708            .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
709            .await
710            .unwrap();
711
712        assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
713    }
714
715    #[sqlx::test]
716    async fn test_add_duplicate_domain(pool: PgPool) {
717        let (storage, key) = setup(pool).await;
718        let key = storage
719            .add_domain_to_key(KeySelector::Key(key.key), Domain::All)
720            .await
721            .unwrap();
722        assert_eq!(
723            key.domains
724                .0
725                .into_iter()
726                .filter(|d| *d == Domain::All)
727                .count(),
728            1
729        );
730    }
731
732    #[sqlx::test]
733    async fn test_remove_domain(pool: PgPool) {
734        let (storage, key) = setup(pool).await;
735        storage
736            .add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
737            .await
738            .unwrap();
739        let key = storage
740            .remove_domain_from_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
741            .await
742            .unwrap();
743
744        assert_eq!(key.domains.0, vec![Domain::All]);
745    }
746
747    #[sqlx::test]
748    async fn test_remove_domain_id(pool: PgPool) {
749        let (storage, key) = setup(pool).await;
750        storage
751            .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
752            .await
753            .unwrap();
754        let key = storage
755            .remove_domain_from_key(KeySelector::Id(key.id), Domain::User { id: 1 })
756            .await
757            .unwrap();
758
759        assert_eq!(key.domains.0, vec![Domain::All]);
760    }
761
762    #[sqlx::test]
763    async fn test_remove_last_domain(pool: PgPool) {
764        let (storage, key) = setup(pool).await;
765        let key = storage
766            .remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
767            .await
768            .unwrap();
769
770        assert!(key.domains.0.is_empty());
771    }
772
773    #[sqlx::test]
774    async fn test_store_key(pool: PgPool) {
775        let (storage, _) = setup(pool).await;
776        let key = storage
777            .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
778            .await
779            .unwrap();
780        assert_eq!(key.value(), "ABCDABCDABCDABCD");
781    }
782
783    #[sqlx::test]
784    async fn test_read_user_keys(pool: PgPool) {
785        let (storage, _) = setup(pool).await;
786
787        let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
788        assert_eq!(keys.len(), 1);
789    }
790
791    #[sqlx::test]
792    async fn acquire_one(pool: PgPool) {
793        let (storage, _) = setup(pool).await;
794
795        if let Err(e) = storage.acquire_key(Domain::All).await {
796            panic!("Acquiring key failed: {:?}", e);
797        }
798    }
799
800    #[sqlx::test]
801    async fn uses_spread(pool: PgPool) {
802        let (storage, _) = setup(pool).await;
803        storage
804            .store_key(1, "ABC".to_owned(), vec![Domain::All])
805            .await
806            .unwrap();
807
808        for _ in 0..10 {
809            _ = storage.acquire_key(Domain::All).await.unwrap();
810        }
811
812        let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
813        assert_eq!(keys.len(), 2);
814        for key in keys {
815            assert_eq!(key.uses, 5);
816        }
817    }
818
819    #[sqlx::test]
820    async fn test_flag_key_one(pool: PgPool) {
821        let (storage, key) = setup(pool).await;
822
823        assert!(storage.flag_key(key, 2).await.unwrap());
824
825        match storage.acquire_key(Domain::All).await.unwrap_err() {
826            PgStorageError::Unavailable(KeySelector::Has(domains)) => {
827                assert_eq!(domains, vec![Domain::All])
828            }
829            why => panic!("Expected domain unavailable error but found '{why}'"),
830        }
831    }
832
833    #[sqlx::test]
834    async fn test_flag_key_many(pool: PgPool) {
835        let (storage, key) = setup(pool).await;
836
837        assert!(storage.flag_key(key, 2).await.unwrap());
838
839        match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
840            PgStorageError::Unavailable(KeySelector::Has(domains)) => {
841                assert_eq!(domains, vec![Domain::All])
842            }
843            why => panic!("Expected domain unavailable error but found '{why}'"),
844        }
845    }
846
847    #[sqlx::test]
848    async fn acquire_many(pool: PgPool) {
849        let (storage, _) = setup(pool).await;
850
851        match storage.acquire_many_keys(Domain::All, 30).await {
852            Err(e) => panic!("Acquiring key failed: {:?}", e),
853            Ok(keys) => assert_eq!(keys.len(), 30),
854        }
855    }
856
857    // HACK: this test is time sensitive and will fail if runs at the top of the minute
858    #[sqlx::test]
859    async fn test_concurrent(pool: PgPool) {
860        let storage = Arc::new(setup(pool).await.0);
861
862        for _ in 0..10 {
863            let mut set = tokio::task::JoinSet::new();
864
865            for _ in 0..100 {
866                let storage = storage.clone();
867                set.spawn(async move {
868                    storage.acquire_key(Domain::All).await.unwrap();
869                });
870            }
871
872            for _ in 0..100 {
873                set.join_next().await.unwrap().unwrap();
874            }
875
876            let uses: i16 = sqlx::query("select uses from api_keys")
877                .fetch_one(&storage.pool)
878                .await
879                .unwrap()
880                .get("uses");
881
882            assert_eq!(uses, 100);
883
884            sqlx::query("update api_keys set uses=0")
885                .execute(&storage.pool)
886                .await
887                .unwrap();
888        }
889    }
890
891    #[sqlx::test]
892    async fn test_concurrent_spread(pool: PgPool) {
893        let storage = Arc::new(setup(pool).await.0);
894
895        for i in 0..24 {
896            storage
897                .store_key(1, format!("{}", i), vec![Domain::All])
898                .await
899                .unwrap();
900        }
901
902        for _ in 0..10 {
903            let mut set = tokio::task::JoinSet::new();
904
905            for _ in 0..50 {
906                let storage = storage.clone();
907                set.spawn(async move {
908                    storage.acquire_key(Domain::All).await.unwrap();
909                });
910            }
911
912            for _ in 0..50 {
913                set.join_next().await.unwrap().unwrap();
914            }
915
916            let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
917
918            assert_eq!(keys.len(), 25);
919
920            for key in keys {
921                assert_eq!(key.uses, 2);
922            }
923
924            sqlx::query("update api_keys set uses=0")
925                .execute(&storage.pool)
926                .await
927                .unwrap();
928        }
929    }
930
931    // HACK: this test is time sensitive and will fail if runs at the top of the minute
932    #[sqlx::test]
933    async fn test_concurrent_many(pool: PgPool) {
934        let storage = Arc::new(setup(pool).await.0);
935        for _ in 0..10 {
936            let mut set = tokio::task::JoinSet::new();
937
938            for _ in 0..100 {
939                let storage = storage.clone();
940                set.spawn(async move {
941                    storage.acquire_many_keys(Domain::All, 5).await.unwrap();
942                });
943            }
944
945            for _ in 0..100 {
946                set.join_next().await.unwrap().unwrap();
947            }
948
949            let uses: i16 = sqlx::query("select uses from api_keys")
950                .fetch_one(&storage.pool)
951                .await
952                .unwrap()
953                .get("uses");
954
955            assert_eq!(uses, 500);
956
957            sqlx::query("update api_keys set uses=0")
958                .execute(&storage.pool)
959                .await
960                .unwrap();
961        }
962    }
963
964    #[sqlx::test]
965    async fn read_key(pool: PgPool) {
966        let (storage, key) = setup(pool).await;
967
968        let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
969        assert!(key.is_some());
970    }
971
972    #[sqlx::test]
973    async fn read_key_id(pool: PgPool) {
974        let (storage, key) = setup(pool).await;
975
976        let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
977        assert!(key.is_some());
978    }
979
980    #[sqlx::test]
981    async fn read_nonexistent_key(pool: PgPool) {
982        let (storage, _) = setup(pool).await;
983
984        let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
985        assert!(key.is_none());
986    }
987
988    #[sqlx::test]
989    async fn query_key(pool: PgPool) {
990        let (storage, _) = setup(pool).await;
991
992        let key = storage.read_key(Domain::All).await.unwrap();
993        assert!(key.is_some());
994    }
995
996    #[sqlx::test]
997    async fn query_nonexistent_key(pool: PgPool) {
998        let (storage, _) = setup(pool).await;
999
1000        let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
1001        assert!(key.is_none());
1002    }
1003
1004    #[sqlx::test]
1005    async fn query_all(pool: PgPool) {
1006        let (storage, _) = setup(pool).await;
1007
1008        let keys = storage.read_keys(Domain::All).await.unwrap();
1009        assert!(keys.len() == 1);
1010    }
1011
1012    #[sqlx::test]
1013    async fn query_by_id(pool: PgPool) {
1014        let (storage, _) = setup(pool).await;
1015        let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
1016
1017        assert!(key.is_some());
1018    }
1019
1020    #[sqlx::test]
1021    async fn query_by_key(pool: PgPool) {
1022        let (storage, key) = setup(pool).await;
1023        let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
1024
1025        assert!(key.is_some());
1026    }
1027
1028    #[sqlx::test]
1029    async fn query_by_set(pool: PgPool) {
1030        let (storage, _key) = setup(pool).await;
1031        let key = storage
1032            .read_key(KeySelector::OneOf(vec![
1033                Domain::All,
1034                Domain::Guild { id: 0 },
1035                Domain::Faction { id: 0 },
1036            ]))
1037            .await
1038            .unwrap();
1039
1040        assert!(key.is_some());
1041    }
1042
1043    #[sqlx::test]
1044    async fn all_selector(pool: PgPool) {
1045        let (storage, key) = setup(pool).await;
1046
1047        storage
1048            .add_domain_to_key(key.selector(), Domain::Faction { id: 1 })
1049            .await
1050            .unwrap();
1051
1052        let key = storage
1053            .read_key(KeySelector::Has(vec![
1054                Domain::Faction { id: 1 },
1055                Domain::All,
1056            ]))
1057            .await
1058            .unwrap();
1059
1060        assert!(key.is_some());
1061
1062        let key = storage
1063            .read_key(KeySelector::Has(vec![
1064                Domain::All,
1065                Domain::Faction { id: 1 },
1066            ]))
1067            .await
1068            .unwrap();
1069
1070        assert!(key.is_some());
1071
1072        let key = storage
1073            .read_key(KeySelector::Has(vec![
1074                Domain::All,
1075                Domain::Faction { id: 2 },
1076                Domain::Faction { id: 1 },
1077            ]))
1078            .await
1079            .unwrap();
1080
1081        assert!(key.is_none());
1082    }
1083}