torn_key_pool/
postgres.rs

1use futures::future::BoxFuture;
2use indoc::indoc;
3use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
4use thiserror::Error;
5
6use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector};
7
8pub trait PgKeyDomain:
9    KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
10{
11}
12
13impl<T> PgKeyDomain for T where
14    T: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
15{
16}
17
18#[derive(Debug, Error)]
19pub enum PgKeyPoolError<D>
20where
21    D: PgKeyDomain,
22{
23    #[error("Databank: {0}")]
24    Pg(#[from] sqlx::Error),
25
26    #[error("Network: {0}")]
27    Network(#[from] reqwest::Error),
28
29    #[error("Parsing: {0}")]
30    Parsing(#[from] serde_json::Error),
31
32    #[error("Api: {0}")]
33    Api(#[from] torn_api::ApiError),
34
35    #[error("No key avalaible for domain {0:?}")]
36    Unavailable(KeySelector<PgKey<D>, D>),
37
38    #[error("Key not found: '{0:?}'")]
39    KeyNotFound(KeySelector<PgKey<D>, D>),
40}
41
42#[derive(Debug, Clone, FromRow)]
43pub struct PgKey<D>
44where
45    D: PgKeyDomain,
46{
47    pub id: i32,
48    pub user_id: i32,
49    pub key: String,
50    pub uses: i16,
51    pub domains: sqlx::types::Json<Vec<D>>,
52}
53
54#[inline(always)]
55fn build_predicate<'b, D>(
56    builder: &mut QueryBuilder<'b, Postgres>,
57    selector: &'b KeySelector<PgKey<D>, D>,
58) where
59    D: PgKeyDomain,
60{
61    match selector {
62        KeySelector::Id(id) => builder.push("id=").push_bind(id),
63        KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
64        KeySelector::Key(key) => builder.push("key=").push_bind(key),
65        KeySelector::Has(domains) => builder
66            .push("domains @> ")
67            .push_bind(sqlx::types::Json(domains)),
68        KeySelector::OneOf(domains) => {
69            if domains.is_empty() {
70                builder.push("false");
71                return;
72            }
73
74            for (idx, domain) in domains.iter().enumerate() {
75                if idx == 0 {
76                    builder.push("(");
77                } else {
78                    builder.push(" or ");
79                }
80                builder
81                    .push("domains @> ")
82                    .push_bind(sqlx::types::Json(vec![domain]));
83            }
84            builder.push(")")
85        }
86    };
87}
88
89#[derive(Debug, Clone, FromRow)]
90pub struct PgKeyPoolStorage<D>
91where
92    D: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
93{
94    pool: PgPool,
95    limit: i16,
96    _phantom: std::marker::PhantomData<D>,
97}
98
99impl<D> ApiKey for PgKey<D>
100where
101    D: PgKeyDomain,
102{
103    type IdType = i32;
104
105    #[inline(always)]
106    fn value(&self) -> &str {
107        &self.key
108    }
109
110    #[inline(always)]
111    fn id(&self) -> Self::IdType {
112        self.id
113    }
114}
115
116impl<D> PgKeyPoolStorage<D>
117where
118    D: PgKeyDomain,
119{
120    pub fn new(pool: PgPool, limit: i16) -> Self {
121        Self {
122            pool,
123            limit,
124            _phantom: Default::default(),
125        }
126    }
127
128    pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
129        sqlx::query(indoc! {r#"
130            CREATE TABLE IF NOT EXISTS api_keys (
131                id serial primary key,
132                user_id int4 not null,
133                key char(16) not null,
134                uses int2 not null default 0,
135                domains jsonb not null default '{}'::jsonb,
136                last_used timestamptz not null default now(),
137                flag int2,
138                cooldown timestamptz,
139                constraint "uq:api_keys.key" UNIQUE(key)
140            )"#
141        })
142        .execute(&self.pool)
143        .await?;
144
145        sqlx::query(indoc! {r#"
146            CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops)
147        "#})
148        .execute(&self.pool)
149        .await?;
150
151        sqlx::query(indoc! {r#"
152            CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id)
153        "#})
154        .execute(&self.pool)
155        .await?;
156
157        sqlx::query(indoc! {r#"
158            create or replace function __unique_jsonb_array(jsonb) returns jsonb
159                AS $$
160                    select jsonb_agg(d::jsonb) from (
161                        select distinct jsonb_array_elements_text($1) as d
162                    ) t
163                $$ language sql;
164        "#})
165        .execute(&self.pool)
166        .await?;
167
168        sqlx::query(indoc! {r#"
169            create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb
170                AS $$
171                    select jsonb_agg(d::jsonb) from (
172                        select distinct jsonb_array_elements_text($1) as d
173                    ) t where d<>$2::text
174                $$ language sql;
175        "#})
176        .execute(&self.pool)
177        .await?;
178
179        Ok(())
180    }
181}
182
183#[cfg(feature = "tokio-runtime")]
184async fn random_sleep() {
185    use rand::{rng, Rng};
186    let dur = tokio::time::Duration::from_millis(rng().random_range(1..50));
187    tokio::time::sleep(dur).await;
188}
189
190impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
191where
192    D: PgKeyDomain,
193{
194    type Key = PgKey<D>;
195    type Domain = D;
196
197    type Error = PgKeyPoolError<D>;
198
199    async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
200    where
201        S: IntoSelector<Self::Key, Self::Domain>,
202    {
203        let selector = selector.into_selector();
204        loop {
205            let attempt = async {
206                let mut tx = self.pool.begin().await?;
207
208                sqlx::query("set transaction isolation level repeatable read")
209                    .execute(&mut *tx)
210                    .await?;
211
212                let mut qb = QueryBuilder::new(indoc::indoc! {
213                    r#"
214                    with key as (
215                        select
216                            id,
217                            0::int2 as uses
218                        from api_keys where last_used < date_trunc('minute', now())
219                            and (cooldown is null or now() >= cooldown)
220                            and "#
221                });
222
223                build_predicate(&mut qb, &selector);
224
225                qb.push(indoc::indoc! {
226                    "
227                    \n    union (
228                            select id, uses from api_keys
229                            where last_used >= date_trunc('minute', now())
230                                and (cooldown is null or now() >= cooldown)
231                                and "
232                });
233
234                build_predicate(&mut qb, &selector);
235
236                qb.push(indoc::indoc! {
237                    "
238                    \n        order by uses asc limit 1
239                        )
240                        order by uses asc limit 1
241                    )
242                    update api_keys set
243                        uses = key.uses + 1,
244                        cooldown = null,
245                        flag = null,
246                        last_used = now()
247                    from key where
248                        api_keys.id=key.id and key.uses < "
249                });
250
251                qb.push_bind(self.limit);
252
253                qb.push(indoc::indoc! { "
254                    \nreturning
255                        api_keys.id,
256                        api_keys.user_id,
257                        api_keys.key,
258                        api_keys.uses,
259                        api_keys.domains"
260                });
261
262                let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
263
264                tx.commit().await?;
265
266                Result::<Option<Self::Key>, sqlx::Error>::Ok(key)
267            }
268            .await;
269
270            match attempt {
271                Ok(Some(result)) => return Ok(result),
272                Ok(None) => {
273                    fn recurse<D>(
274                        storage: &PgKeyPoolStorage<D>,
275                        selector: KeySelector<PgKey<D>, D>,
276                    ) -> BoxFuture<Result<PgKey<D>, PgKeyPoolError<D>>>
277                    where
278                        D: PgKeyDomain,
279                    {
280                        Box::pin(storage.acquire_key(selector))
281                    }
282
283                    return recurse(
284                        self,
285                        selector
286                            .fallback()
287                            .ok_or_else(|| PgKeyPoolError::Unavailable(selector))?,
288                    )
289                    .await;
290                }
291                Err(error) => {
292                    if let Some(db_error) = error.as_database_error() {
293                        let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
294                        if pg_error.code() == "40001" {
295                            random_sleep().await;
296                        } else {
297                            return Err(error.into());
298                        }
299                    } else {
300                        return Err(error.into());
301                    }
302                }
303            }
304        }
305    }
306
307    async fn acquire_many_keys<S>(
308        &self,
309        selector: S,
310        number: i64,
311    ) -> Result<Vec<Self::Key>, Self::Error>
312    where
313        S: IntoSelector<Self::Key, Self::Domain>,
314    {
315        let selector = selector.into_selector();
316        loop {
317            let attempt = async {
318                let mut tx = self.pool.begin().await?;
319
320                sqlx::query("set transaction isolation level repeatable read")
321                    .execute(&mut *tx)
322                    .await?;
323
324                let mut qb = QueryBuilder::new(indoc::indoc! {
325                    r#"select
326                        id,
327                        user_id,
328                        key,
329                        0::int2 as uses,
330                        domains
331                    from api_keys where last_used < date_trunc('minute', now())
332                        and (cooldown is null or now() >= cooldown)
333                        and "#
334                });
335                build_predicate(&mut qb, &selector);
336                qb.push(indoc::indoc! {
337                    "
338                    \nunion
339                    select
340                        id,
341                        user_id,
342                        key,
343                        uses,
344                        domains
345                    from api_keys where last_used >= date_trunc('minute', now())
346                        and (cooldown is null or now() >= cooldown)
347                        and "
348                });
349                build_predicate(&mut qb, &selector);
350                qb.push("\norder by uses limit ");
351                qb.push_bind(self.limit);
352
353                let mut keys: Vec<Self::Key> = qb.build_query_as().fetch_all(&mut *tx).await?;
354
355                if keys.is_empty() {
356                    tx.commit().await?;
357                    return Ok(None);
358                }
359
360                keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses));
361
362                let mut result = Vec::with_capacity(number as usize);
363                let (max, rest) = keys.split_last_mut().unwrap();
364                for key in rest {
365                    let available = max.uses - key.uses;
366                    let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
367                    key.uses += using;
368                    result.extend(std::iter::repeat_n(key.clone(), using as usize));
369
370                    if result.len() == number as usize {
371                        break;
372                    }
373                }
374
375                while result.len() < (number as usize) {
376                    if keys[0].uses == self.limit {
377                        break;
378                    }
379
380                    let take = std::cmp::min(keys.len(), (number as usize) - result.len());
381                    let slice = &mut keys[0..take];
382                    slice.iter_mut().for_each(|k| k.uses += 1);
383                    result.extend_from_slice(slice);
384                }
385
386                sqlx::query(indoc! {r#"
387                    update api_keys set
388                        uses = tmp.uses,
389                        cooldown = null,
390                        flag = null,
391                        last_used = now()
392                    from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
393                    where api_keys.id = tmp.id
394                "#})
395                .bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
396                .bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
397                .execute(&mut *tx)
398                .await?;
399
400                tx.commit().await?;
401
402                Result::<Option<Vec<Self::Key>>, sqlx::Error>::Ok(Some(result))
403            }
404            .await;
405
406            match attempt {
407                Ok(Some(result)) => return Ok(result),
408                Ok(None) => {
409                    fn recurse<D>(
410                        storage: &PgKeyPoolStorage<D>,
411                        selector: KeySelector<PgKey<D>, D>,
412                        number: i64,
413                    ) -> BoxFuture<Result<Vec<PgKey<D>>, PgKeyPoolError<D>>>
414                    where
415                        D: PgKeyDomain,
416                    {
417                        Box::pin(storage.acquire_many_keys(selector, number))
418                    }
419
420                    return recurse(
421                        self,
422                        selector
423                            .fallback()
424                            .ok_or_else(|| Self::Error::Unavailable(selector))?,
425                        number,
426                    )
427                    .await;
428                }
429                Err(error) => {
430                    if let Some(db_error) = error.as_database_error() {
431                        let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
432                        if pg_error.code() == "40001" {
433                            random_sleep().await;
434                        } else {
435                            return Err(error.into());
436                        }
437                    } else {
438                        return Err(error.into());
439                    }
440                }
441            }
442        }
443    }
444
445    async fn timeout_key<S>(
446        &self,
447        selector: S,
448        duration: std::time::Duration,
449    ) -> Result<(), Self::Error>
450    where
451        S: IntoSelector<Self::Key, Self::Domain>,
452    {
453        let selector = selector.into_selector();
454
455        let mut qb = QueryBuilder::new("update api_keys set cooldown=now() + ");
456        qb.push_bind(duration);
457        qb.push(" where ");
458        build_predicate(&mut qb, &selector);
459
460        qb.build().fetch_optional(&self.pool).await?;
461
462        Ok(())
463    }
464
465    async fn store_key(
466        &self,
467        user_id: i32,
468        key: String,
469        domains: Vec<D>,
470    ) -> Result<Self::Key, Self::Error> {
471        sqlx::query_as(
472            "insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \
473             constraint \"uq:api_keys.key\" do update set domains = \
474             __unique_jsonb_array(excluded.domains || api_keys.domains) returning *",
475        )
476        .bind(user_id)
477        .bind(&key)
478        .bind(sqlx::types::Json(domains))
479        .fetch_one(&self.pool)
480        .await
481        .map_err(Into::into)
482    }
483
484    async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
485    where
486        S: IntoSelector<Self::Key, Self::Domain>,
487    {
488        let selector = selector.into_selector();
489
490        let mut qb = QueryBuilder::new("select * from api_keys where ");
491        build_predicate(&mut qb, &selector);
492
493        qb.build_query_as()
494            .fetch_optional(&self.pool)
495            .await
496            .map_err(Into::into)
497    }
498
499    async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
500    where
501        S: IntoSelector<Self::Key, Self::Domain>,
502    {
503        let selector = selector.into_selector();
504
505        let mut qb = QueryBuilder::new("select * from api_keys where ");
506        build_predicate(&mut qb, &selector);
507
508        qb.build_query_as()
509            .fetch_all(&self.pool)
510            .await
511            .map_err(Into::into)
512    }
513
514    async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
515    where
516        S: IntoSelector<Self::Key, Self::Domain>,
517    {
518        let selector = selector.into_selector();
519
520        let mut qb = QueryBuilder::new("delete from api_keys where ");
521        build_predicate(&mut qb, &selector);
522        qb.push(" returning *");
523
524        qb.build_query_as()
525            .fetch_optional(&self.pool)
526            .await?
527            .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
528    }
529
530    async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
531    where
532        S: IntoSelector<Self::Key, Self::Domain>,
533    {
534        let selector = selector.into_selector();
535
536        let mut qb = QueryBuilder::new(
537            "update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(",
538        );
539        qb.push_bind(sqlx::types::Json(domain));
540        qb.push(")) where ");
541        build_predicate(&mut qb, &selector);
542        qb.push(" returning *");
543
544        qb.build_query_as()
545            .fetch_optional(&self.pool)
546            .await?
547            .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
548    }
549
550    async fn remove_domain_from_key<S>(
551        &self,
552        selector: S,
553        domain: D,
554    ) -> Result<Self::Key, Self::Error>
555    where
556        S: IntoSelector<Self::Key, Self::Domain>,
557    {
558        let selector = selector.into_selector();
559
560        let mut qb = QueryBuilder::new(
561            "update api_keys set domains = coalesce(__filter_jsonb_array(domains, ",
562        );
563        qb.push_bind(sqlx::types::Json(domain));
564        qb.push("), '[]'::jsonb) where ");
565        build_predicate(&mut qb, &selector);
566        qb.push(" returning *");
567
568        qb.build_query_as()
569            .fetch_optional(&self.pool)
570            .await?
571            .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
572    }
573
574    async fn set_domains_for_key<S>(
575        &self,
576        selector: S,
577        domains: Vec<D>,
578    ) -> Result<Self::Key, Self::Error>
579    where
580        S: IntoSelector<Self::Key, Self::Domain>,
581    {
582        let selector = selector.into_selector();
583
584        let mut qb = QueryBuilder::new("update api_keys set domains = ");
585        qb.push_bind(sqlx::types::Json(domains));
586        qb.push(" where ");
587        build_predicate(&mut qb, &selector);
588        qb.push(" returning *");
589
590        qb.build_query_as()
591            .fetch_optional(&self.pool)
592            .await?
593            .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
594    }
595}
596
597#[cfg(test)]
598pub(crate) mod test {
599    use std::{sync::Arc, time::Duration};
600
601    use sqlx::Row;
602
603    use super::*;
604
605    #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
606    #[serde(tag = "type", rename_all = "snake_case")]
607    pub(crate) enum Domain {
608        All,
609        Guild { id: i64 },
610        User { id: i32 },
611        Faction { id: i32 },
612    }
613
614    impl KeyDomain for Domain {
615        fn fallback(&self) -> Option<Self> {
616            match self {
617                Self::Guild { id: _ } => Some(Self::All),
618                _ => None,
619            }
620        }
621    }
622
623    pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
624        sqlx::query("DROP TABLE IF EXISTS api_keys")
625            .execute(&pool)
626            .await
627            .unwrap();
628
629        let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
630        storage.initialise().await.unwrap();
631
632        let key = storage
633            .store_key(1, std::env::var("API_KEY").unwrap(), vec![Domain::All])
634            .await
635            .unwrap();
636
637        (storage, key)
638    }
639
640    #[sqlx::test]
641    async fn test_initialise(pool: PgPool) {
642        let (storage, _) = setup(pool).await;
643
644        if let Err(e) = storage.initialise().await {
645            panic!("Initialising key storage failed: {:?}", e);
646        }
647    }
648
649    #[sqlx::test]
650    async fn test_store_duplicate_key(pool: PgPool) {
651        let (storage, key) = setup(pool).await;
652        let key = storage
653            .store_key(1, key.key, vec![Domain::User { id: 1 }])
654            .await
655            .unwrap();
656
657        assert_eq!(key.domains.0.len(), 2);
658    }
659
660    #[sqlx::test]
661    async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) {
662        let (storage, key) = setup(pool).await;
663        let key = storage
664            .store_key(1, key.key, vec![Domain::All])
665            .await
666            .unwrap();
667
668        assert_eq!(key.domains.0.len(), 1);
669    }
670
671    #[sqlx::test]
672    async fn test_add_domain(pool: PgPool) {
673        let (storage, key) = setup(pool).await;
674        let key = storage
675            .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
676            .await
677            .unwrap();
678
679        assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
680    }
681
682    #[sqlx::test]
683    async fn test_add_domain_id(pool: PgPool) {
684        let (storage, key) = setup(pool).await;
685        let key = storage
686            .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
687            .await
688            .unwrap();
689
690        assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
691    }
692
693    #[sqlx::test]
694    async fn test_add_duplicate_domain(pool: PgPool) {
695        let (storage, key) = setup(pool).await;
696        let key = storage
697            .add_domain_to_key(KeySelector::Key(key.key), Domain::All)
698            .await
699            .unwrap();
700        assert_eq!(
701            key.domains
702                .0
703                .into_iter()
704                .filter(|d| *d == Domain::All)
705                .count(),
706            1
707        );
708    }
709
710    #[sqlx::test]
711    async fn test_remove_domain(pool: PgPool) {
712        let (storage, key) = setup(pool).await;
713        storage
714            .add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
715            .await
716            .unwrap();
717        let key = storage
718            .remove_domain_from_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
719            .await
720            .unwrap();
721
722        assert_eq!(key.domains.0, vec![Domain::All]);
723    }
724
725    #[sqlx::test]
726    async fn test_remove_domain_id(pool: PgPool) {
727        let (storage, key) = setup(pool).await;
728        storage
729            .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
730            .await
731            .unwrap();
732        let key = storage
733            .remove_domain_from_key(KeySelector::Id(key.id), Domain::User { id: 1 })
734            .await
735            .unwrap();
736
737        assert_eq!(key.domains.0, vec![Domain::All]);
738    }
739
740    #[sqlx::test]
741    async fn test_remove_last_domain(pool: PgPool) {
742        let (storage, key) = setup(pool).await;
743        let key = storage
744            .remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
745            .await
746            .unwrap();
747
748        assert!(key.domains.0.is_empty());
749    }
750
751    #[sqlx::test]
752    async fn test_store_key(pool: PgPool) {
753        let (storage, _) = setup(pool).await;
754        let key = storage
755            .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
756            .await
757            .unwrap();
758        assert_eq!(key.value(), "ABCDABCDABCDABCD");
759    }
760
761    #[sqlx::test]
762    async fn test_read_user_keys(pool: PgPool) {
763        let (storage, _) = setup(pool).await;
764
765        let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
766        assert_eq!(keys.len(), 1);
767    }
768
769    #[sqlx::test]
770    async fn acquire_one(pool: PgPool) {
771        let (storage, _) = setup(pool).await;
772
773        if let Err(e) = storage.acquire_key(Domain::All).await {
774            panic!("Acquiring key failed: {:?}", e);
775        }
776    }
777
778    #[sqlx::test]
779    async fn uses_spread(pool: PgPool) {
780        let (storage, _) = setup(pool).await;
781        storage
782            .store_key(1, "ABC".to_owned(), vec![Domain::All])
783            .await
784            .unwrap();
785
786        for _ in 0..10 {
787            _ = storage.acquire_key(Domain::All).await.unwrap();
788        }
789
790        let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
791        assert_eq!(keys.len(), 2);
792        for key in keys {
793            assert_eq!(key.uses, 5);
794        }
795    }
796
797    #[sqlx::test]
798    async fn acquire_many(pool: PgPool) {
799        let (storage, _) = setup(pool).await;
800
801        match storage.acquire_many_keys(Domain::All, 30).await {
802            Err(e) => panic!("Acquiring key failed: {:?}", e),
803            Ok(keys) => assert_eq!(keys.len(), 30),
804        }
805    }
806
807    // HACK: this test is time sensitive and will fail if runs at the top of the minute
808    #[sqlx::test]
809    async fn test_concurrent(pool: PgPool) {
810        let storage = Arc::new(setup(pool).await.0);
811
812        for _ in 0..10 {
813            let mut set = tokio::task::JoinSet::new();
814
815            for _ in 0..100 {
816                let storage = storage.clone();
817                set.spawn(async move {
818                    storage.acquire_key(Domain::All).await.unwrap();
819                });
820            }
821
822            for _ in 0..100 {
823                set.join_next().await.unwrap().unwrap();
824            }
825
826            let uses: i16 = sqlx::query("select uses from api_keys")
827                .fetch_one(&storage.pool)
828                .await
829                .unwrap()
830                .get("uses");
831
832            assert_eq!(uses, 100);
833
834            sqlx::query("update api_keys set uses=0")
835                .execute(&storage.pool)
836                .await
837                .unwrap();
838        }
839    }
840
841    #[sqlx::test]
842    async fn test_concurrent_spread(pool: PgPool) {
843        let storage = Arc::new(setup(pool).await.0);
844
845        for i in 0..24 {
846            storage
847                .store_key(1, format!("{}", i), vec![Domain::All])
848                .await
849                .unwrap();
850        }
851
852        for _ in 0..10 {
853            let mut set = tokio::task::JoinSet::new();
854
855            for _ in 0..50 {
856                let storage = storage.clone();
857                set.spawn(async move {
858                    storage.acquire_key(Domain::All).await.unwrap();
859                });
860            }
861
862            for _ in 0..50 {
863                set.join_next().await.unwrap().unwrap();
864            }
865
866            let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
867
868            assert_eq!(keys.len(), 25);
869
870            for key in keys {
871                assert_eq!(key.uses, 2);
872            }
873
874            sqlx::query("update api_keys set uses=0")
875                .execute(&storage.pool)
876                .await
877                .unwrap();
878        }
879    }
880
881    // HACK: this test is time sensitive and will fail if runs at the top of the minute
882    #[sqlx::test]
883    async fn test_concurrent_many(pool: PgPool) {
884        let storage = Arc::new(setup(pool).await.0);
885        for _ in 0..10 {
886            let mut set = tokio::task::JoinSet::new();
887
888            for _ in 0..100 {
889                let storage = storage.clone();
890                set.spawn(async move {
891                    storage.acquire_many_keys(Domain::All, 5).await.unwrap();
892                });
893            }
894
895            for _ in 0..100 {
896                set.join_next().await.unwrap().unwrap();
897            }
898
899            let uses: i16 = sqlx::query("select uses from api_keys")
900                .fetch_one(&storage.pool)
901                .await
902                .unwrap()
903                .get("uses");
904
905            assert_eq!(uses, 500);
906
907            sqlx::query("update api_keys set uses=0")
908                .execute(&storage.pool)
909                .await
910                .unwrap();
911        }
912    }
913
914    #[sqlx::test]
915    async fn read_key(pool: PgPool) {
916        let (storage, key) = setup(pool).await;
917
918        let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
919        assert!(key.is_some());
920    }
921
922    #[sqlx::test]
923    async fn read_key_id(pool: PgPool) {
924        let (storage, key) = setup(pool).await;
925
926        let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
927        assert!(key.is_some());
928    }
929
930    #[sqlx::test]
931    async fn read_nonexistent_key(pool: PgPool) {
932        let (storage, _) = setup(pool).await;
933
934        let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
935        assert!(key.is_none());
936    }
937
938    #[sqlx::test]
939    async fn query_key(pool: PgPool) {
940        let (storage, _) = setup(pool).await;
941
942        let key = storage.read_key(Domain::All).await.unwrap();
943        assert!(key.is_some());
944    }
945
946    #[sqlx::test]
947    async fn query_nonexistent_key(pool: PgPool) {
948        let (storage, _) = setup(pool).await;
949
950        let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
951        assert!(key.is_none());
952    }
953
954    #[sqlx::test]
955    async fn query_all(pool: PgPool) {
956        let (storage, _) = setup(pool).await;
957
958        let keys = storage.read_keys(Domain::All).await.unwrap();
959        assert!(keys.len() == 1);
960    }
961
962    #[sqlx::test]
963    async fn query_by_id(pool: PgPool) {
964        let (storage, _) = setup(pool).await;
965        let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
966
967        assert!(key.is_some());
968    }
969
970    #[sqlx::test]
971    async fn query_by_key(pool: PgPool) {
972        let (storage, key) = setup(pool).await;
973        let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
974
975        assert!(key.is_some());
976    }
977
978    #[sqlx::test]
979    async fn timeout(pool: PgPool) {
980        let (storage, key) = setup(pool).await;
981
982        storage
983            .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
984            .await
985            .unwrap();
986    }
987
988    #[sqlx::test]
989    async fn query_by_set(pool: PgPool) {
990        let (storage, _key) = setup(pool).await;
991        let key = storage
992            .read_key(KeySelector::OneOf(vec![
993                Domain::All,
994                Domain::Guild { id: 0 },
995                Domain::Faction { id: 0 },
996            ]))
997            .await
998            .unwrap();
999
1000        assert!(key.is_some());
1001    }
1002
1003    #[sqlx::test]
1004    async fn all_selector(pool: PgPool) {
1005        let (storage, key) = setup(pool).await;
1006
1007        storage
1008            .add_domain_to_key(key.selector(), Domain::Faction { id: 1 })
1009            .await
1010            .unwrap();
1011
1012        let key = storage
1013            .read_key(KeySelector::Has(vec![
1014                Domain::Faction { id: 1 },
1015                Domain::All,
1016            ]))
1017            .await
1018            .unwrap();
1019
1020        assert!(key.is_some());
1021
1022        let key = storage
1023            .read_key(KeySelector::Has(vec![
1024                Domain::All,
1025                Domain::Faction { id: 1 },
1026            ]))
1027            .await
1028            .unwrap();
1029
1030        assert!(key.is_some());
1031
1032        let key = storage
1033            .read_key(KeySelector::Has(vec![
1034                Domain::All,
1035                Domain::Faction { id: 2 },
1036                Domain::Faction { id: 1 },
1037            ]))
1038            .await
1039            .unwrap();
1040
1041        assert!(key.is_none());
1042    }
1043}