torn_key_pool/
postgres.rs

1use std::sync::Arc;
2
3use futures::future::BoxFuture;
4use indoc::formatdoc;
5use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
6use thiserror::Error;
7
8use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector};
9
10pub trait PgKeyDomain:
11    KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
12{
13}
14
15impl<T> PgKeyDomain for T where
16    T: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
17{
18}
19
20#[derive(Debug, Error)]
21pub enum PgKeyPoolError<D>
22where
23    D: PgKeyDomain,
24{
25    #[error("Databank: {0}")]
26    Pg(#[from] sqlx::Error),
27
28    #[error("Network: {0}")]
29    Network(#[from] reqwest::Error),
30
31    #[error("Parsing: {0}")]
32    Parsing(#[from] serde_json::Error),
33
34    #[error("Api: {0}")]
35    Api(#[from] torn_api::ApiError),
36
37    #[error("No key avalaible for domain {0:?}")]
38    Unavailable(KeySelector<PgKey<D>, D>),
39
40    #[error("Key not found: '{0:?}'")]
41    KeyNotFound(KeySelector<PgKey<D>, D>),
42
43    #[error("Failed to acquire keys in bulk: {0}")]
44    Bulk(#[from] Arc<Self>),
45}
46
47#[derive(Debug, Clone, FromRow)]
48pub struct PgKey<D>
49where
50    D: PgKeyDomain,
51{
52    pub id: i32,
53    pub user_id: i32,
54    pub key: String,
55    pub uses: i16,
56    pub domains: sqlx::types::Json<Vec<D>>,
57}
58
59#[inline(always)]
60fn build_predicate<'b, D>(
61    builder: &mut QueryBuilder<'b, Postgres>,
62    selector: &'b KeySelector<PgKey<D>, D>,
63) where
64    D: PgKeyDomain,
65{
66    match selector {
67        KeySelector::Id(id) => builder.push("id=").push_bind(id),
68        KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
69        KeySelector::Key(key) => builder.push("key=").push_bind(key),
70        KeySelector::Has(domains) => builder
71            .push("domains @> ")
72            .push_bind(sqlx::types::Json(domains)),
73        KeySelector::OneOf(domains) => {
74            if domains.is_empty() {
75                builder.push("false");
76                return;
77            }
78
79            for (idx, domain) in domains.iter().enumerate() {
80                if idx == 0 {
81                    builder.push("(");
82                } else {
83                    builder.push(" or ");
84                }
85                builder
86                    .push("domains @> ")
87                    .push_bind(sqlx::types::Json(vec![domain]));
88            }
89            builder.push(")")
90        }
91    };
92}
93
94#[derive(Debug, Clone, FromRow)]
95pub struct PgKeyPoolStorage<D>
96where
97    D: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
98{
99    pool: PgPool,
100    limit: i16,
101    schema: Option<String>,
102    _phantom: std::marker::PhantomData<D>,
103}
104
105impl<D> ApiKey for PgKey<D>
106where
107    D: PgKeyDomain,
108{
109    type IdType = i32;
110
111    #[inline(always)]
112    fn value(&self) -> &str {
113        &self.key
114    }
115
116    #[inline(always)]
117    fn id(&self) -> Self::IdType {
118        self.id
119    }
120}
121
122impl<D> PgKeyPoolStorage<D>
123where
124    D: PgKeyDomain,
125{
126    pub fn new(pool: PgPool, limit: i16, schema: Option<String>) -> Self {
127        Self {
128            pool,
129            limit,
130            schema,
131            _phantom: Default::default(),
132        }
133    }
134
135    fn table_name(&self) -> String {
136        match self.schema.as_ref() {
137            Some(schema) => format!("{schema}.api_keys"),
138            None => "api_keys".to_owned(),
139        }
140    }
141
142    fn unique_array_fn(&self) -> String {
143        match self.schema.as_ref() {
144            Some(schema) => format!("{schema}.__unique_jsonb_array"),
145            None => "__unique_jsonb_array".to_owned(),
146        }
147    }
148
149    fn filter_array_fn(&self) -> String {
150        match self.schema.as_ref() {
151            Some(schema) => format!("{schema}.__filter_jsonb_array"),
152            None => "__filter_jsonb_array".to_owned(),
153        }
154    }
155
156    pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
157        if let Some(schema) = self.schema.as_ref() {
158            sqlx::query(&format!("create schema if not exists {schema}"))
159                .execute(&self.pool)
160                .await?;
161        }
162
163        sqlx::query(&formatdoc! {r#"
164            CREATE TABLE IF NOT EXISTS {} (
165                id serial primary key,
166                user_id int4 not null,
167                key char(16) not null,
168                uses int2 not null default 0,
169                domains jsonb not null default '{{}}'::jsonb,
170                last_used timestamptz not null default now(),
171                flag int2,
172                cooldown timestamptz,
173                constraint "uq:api_keys.key" UNIQUE(key)
174            )"#,
175            self.table_name()
176        })
177        .execute(&self.pool)
178        .await?;
179
180        sqlx::query(&formatdoc! {r#"
181            CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON {} USING GIN(domains jsonb_path_ops)
182        "#, self.table_name()})
183        .execute(&self.pool)
184        .await?;
185
186        sqlx::query(&formatdoc! {r#"
187            CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON {} USING BTREE(user_id)
188        "#, self.table_name()})
189        .execute(&self.pool)
190        .await?;
191
192        sqlx::query(&formatdoc! {r#"
193            create or replace function {}(jsonb) returns jsonb
194                AS $$
195                    select jsonb_agg(d::jsonb) from (
196                        select distinct jsonb_array_elements_text($1) as d
197                    ) t
198                $$ language sql;
199        "#, self.unique_array_fn()})
200        .execute(&self.pool)
201        .await?;
202
203        sqlx::query(&formatdoc! {r#"
204            create or replace function {}(jsonb, jsonb) returns jsonb
205                AS $$
206                    select jsonb_agg(d::jsonb) from (
207                        select distinct jsonb_array_elements_text($1) as d
208                    ) t where d<>$2::text
209                $$ language sql;
210        "#, self.filter_array_fn()})
211        .execute(&self.pool)
212        .await?;
213
214        Ok(())
215    }
216}
217
218#[cfg(feature = "tokio-runtime")]
219async fn random_sleep() {
220    use rand::{rng, Rng};
221    let dur = tokio::time::Duration::from_millis(rng().random_range(1..50));
222    tokio::time::sleep(dur).await;
223}
224
225impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
226where
227    D: PgKeyDomain,
228{
229    type Key = PgKey<D>;
230    type Domain = D;
231
232    type Error = PgKeyPoolError<D>;
233
234    async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
235    where
236        S: IntoSelector<Self::Key, Self::Domain>,
237    {
238        let selector = selector.into_selector();
239        loop {
240            let attempt = async {
241                let mut tx = self.pool.begin().await?;
242
243                sqlx::query("set transaction isolation level repeatable read")
244                    .execute(&mut *tx)
245                    .await?;
246
247                let mut qb = QueryBuilder::new(&formatdoc! {
248                    r#"
249                    with key as (
250                        select
251                            id,
252                            0::int2 as uses
253                        from {} where last_used < date_trunc('minute', now())
254                            and (cooldown is null or now() >= cooldown)
255                            and "#,
256                    self.table_name()
257                });
258
259                build_predicate(&mut qb, &selector);
260
261                qb.push(formatdoc! {
262                    "
263                    \n    union (
264                            select id, uses from {}
265                            where last_used >= date_trunc('minute', now())
266                                and (cooldown is null or now() >= cooldown)
267                                and ",
268                    self.table_name()
269                });
270
271                build_predicate(&mut qb, &selector);
272
273                qb.push(formatdoc! {
274                    "
275                    \n        order by uses asc limit 1
276                        )
277                        order by uses asc limit 1
278                    )
279                    update {} as keys set
280                        uses = key.uses + 1,
281                        cooldown = null,
282                        flag = null,
283                        last_used = now()
284                    from key where
285                        keys.id=key.id and key.uses < ",
286                    self.table_name()
287                });
288
289                qb.push_bind(self.limit);
290
291                qb.push(indoc::indoc! { "
292                    \nreturning keys.id, keys.user_id, keys.key, keys.uses, keys.domains"
293                });
294
295                let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
296
297                tx.commit().await?;
298
299                Result::<Option<Self::Key>, sqlx::Error>::Ok(key)
300            }
301            .await;
302
303            match attempt {
304                Ok(Some(result)) => return Ok(result),
305                Ok(None) => {
306                    fn recurse<D>(
307                        storage: &PgKeyPoolStorage<D>,
308                        selector: KeySelector<PgKey<D>, D>,
309                    ) -> BoxFuture<'_, Result<PgKey<D>, PgKeyPoolError<D>>>
310                    where
311                        D: PgKeyDomain,
312                    {
313                        Box::pin(storage.acquire_key(selector))
314                    }
315
316                    return recurse(
317                        self,
318                        selector
319                            .fallback()
320                            .ok_or_else(|| PgKeyPoolError::Unavailable(selector))?,
321                    )
322                    .await;
323                }
324                Err(error) => {
325                    if let Some(db_error) = error.as_database_error() {
326                        let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
327                        if pg_error.code() == "40001" {
328                            random_sleep().await;
329                        } else {
330                            return Err(error.into());
331                        }
332                    } else {
333                        return Err(error.into());
334                    }
335                }
336            }
337        }
338    }
339
340    async fn acquire_many_keys<S>(
341        &self,
342        selector: S,
343        number: i64,
344    ) -> Result<Vec<Self::Key>, Self::Error>
345    where
346        S: IntoSelector<Self::Key, Self::Domain>,
347    {
348        let selector = selector.into_selector();
349        loop {
350            let attempt = async {
351                let mut tx = self.pool.begin().await?;
352
353                sqlx::query("set transaction isolation level repeatable read")
354                    .execute(&mut *tx)
355                    .await?;
356
357                let mut qb = QueryBuilder::new(&formatdoc! {
358                    r#"select
359                        id,
360                        user_id,
361                        key,
362                        0::int2 as uses,
363                        domains
364                    from {} where last_used < date_trunc('minute', now())
365                        and (cooldown is null or now() >= cooldown)
366                        and "#,
367                    self.table_name()
368                });
369                build_predicate(&mut qb, &selector);
370                qb.push(formatdoc! {
371                    "
372                    \nunion
373                    select
374                        id,
375                        user_id,
376                        key,
377                        uses,
378                        domains
379                    from {} where last_used >= date_trunc('minute', now())
380                        and (cooldown is null or now() >= cooldown)
381                        and ",
382                    self.table_name()
383                });
384                build_predicate(&mut qb, &selector);
385                qb.push("\norder by uses limit ");
386                qb.push_bind(self.limit);
387
388                let mut keys: Vec<Self::Key> = qb.build_query_as().fetch_all(&mut *tx).await?;
389
390                if keys.is_empty() {
391                    tx.commit().await?;
392                    return Ok(None);
393                }
394
395                keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses));
396
397                let mut result = Vec::with_capacity(number as usize);
398                let (max, rest) = keys.split_last_mut().unwrap();
399                for key in rest {
400                    let available = max.uses - key.uses;
401                    let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
402                    key.uses += using;
403                    result.extend(std::iter::repeat_n(key.clone(), using as usize));
404
405                    if result.len() == number as usize {
406                        break;
407                    }
408                }
409
410                while result.len() < (number as usize) {
411                    if keys[0].uses == self.limit {
412                        break;
413                    }
414
415                    let take = std::cmp::min(keys.len(), (number as usize) - result.len());
416                    let slice = &mut keys[0..take];
417                    slice.iter_mut().for_each(|k| k.uses += 1);
418                    result.extend_from_slice(slice);
419                }
420
421                sqlx::query(&formatdoc! {r#"
422                    update {} keys set
423                        uses = tmp.uses,
424                        cooldown = null,
425                        flag = null,
426                        last_used = now()
427                    from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
428                    where keys.id = tmp.id
429                "#, self.table_name()})
430                .bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
431                .bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
432                .execute(&mut *tx)
433                .await?;
434
435                tx.commit().await?;
436
437                Result::<Option<Vec<Self::Key>>, sqlx::Error>::Ok(Some(result))
438            }
439            .await;
440
441            match attempt {
442                Ok(Some(result)) => return Ok(result),
443                Ok(None) => {
444                    fn recurse<D>(
445                        storage: &PgKeyPoolStorage<D>,
446                        selector: KeySelector<PgKey<D>, D>,
447                        number: i64,
448                    ) -> BoxFuture<'_, Result<Vec<PgKey<D>>, PgKeyPoolError<D>>>
449                    where
450                        D: PgKeyDomain,
451                    {
452                        Box::pin(storage.acquire_many_keys(selector, number))
453                    }
454
455                    return recurse(
456                        self,
457                        selector
458                            .fallback()
459                            .ok_or_else(|| Self::Error::Unavailable(selector))?,
460                        number,
461                    )
462                    .await;
463                }
464                Err(error) => {
465                    if let Some(db_error) = error.as_database_error() {
466                        let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
467                        if pg_error.code() == "40001" {
468                            random_sleep().await;
469                        } else {
470                            return Err(error.into());
471                        }
472                    } else {
473                        return Err(error.into());
474                    }
475                }
476            }
477        }
478    }
479
480    async fn timeout_key<S>(
481        &self,
482        selector: S,
483        duration: std::time::Duration,
484    ) -> Result<(), Self::Error>
485    where
486        S: IntoSelector<Self::Key, Self::Domain>,
487    {
488        let selector = selector.into_selector();
489
490        let mut qb = QueryBuilder::new(format!(
491            "update {} set cooldown=now() + ",
492            self.table_name()
493        ));
494        qb.push_bind(duration);
495        qb.push(" where ");
496        build_predicate(&mut qb, &selector);
497
498        qb.build().fetch_optional(&self.pool).await?;
499
500        Ok(())
501    }
502
503    async fn store_key(
504        &self,
505        user_id: i32,
506        key: String,
507        domains: Vec<D>,
508    ) -> Result<Self::Key, Self::Error> {
509        sqlx::query_as(&dbg!(formatdoc!(
510            "insert into {} as api_keys(user_id, key, domains) values ($1, $2, $3) 
511            on conflict(key) do update 
512            set domains = {}(excluded.domains || api_keys.domains) returning *",
513            self.table_name(),
514            self.unique_array_fn()
515        )))
516        .bind(user_id)
517        .bind(&key)
518        .bind(sqlx::types::Json(domains))
519        .fetch_one(&self.pool)
520        .await
521        .map_err(Into::into)
522    }
523
524    async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
525    where
526        S: IntoSelector<Self::Key, Self::Domain>,
527    {
528        let selector = selector.into_selector();
529
530        let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
531        build_predicate(&mut qb, &selector);
532
533        qb.build_query_as()
534            .fetch_optional(&self.pool)
535            .await
536            .map_err(Into::into)
537    }
538
539    async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
540    where
541        S: IntoSelector<Self::Key, Self::Domain>,
542    {
543        let selector = selector.into_selector();
544
545        let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
546        build_predicate(&mut qb, &selector);
547
548        qb.build_query_as()
549            .fetch_all(&self.pool)
550            .await
551            .map_err(Into::into)
552    }
553
554    async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
555    where
556        S: IntoSelector<Self::Key, Self::Domain>,
557    {
558        let selector = selector.into_selector();
559
560        let mut qb = QueryBuilder::new(format!("delete from {} where ", self.table_name()));
561        build_predicate(&mut qb, &selector);
562        qb.push(" returning *");
563
564        qb.build_query_as()
565            .fetch_optional(&self.pool)
566            .await?
567            .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
568    }
569
570    async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
571    where
572        S: IntoSelector<Self::Key, Self::Domain>,
573    {
574        let selector = selector.into_selector();
575
576        let mut qb = QueryBuilder::new(format!(
577            "update {} set domains = {}(domains || jsonb_build_array(",
578            self.table_name(),
579            self.unique_array_fn()
580        ));
581        qb.push_bind(sqlx::types::Json(domain));
582        qb.push(")) where ");
583        build_predicate(&mut qb, &selector);
584        qb.push(" returning *");
585
586        qb.build_query_as()
587            .fetch_optional(&self.pool)
588            .await?
589            .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
590    }
591
592    async fn remove_domain_from_key<S>(
593        &self,
594        selector: S,
595        domain: D,
596    ) -> Result<Self::Key, Self::Error>
597    where
598        S: IntoSelector<Self::Key, Self::Domain>,
599    {
600        let selector = selector.into_selector();
601
602        let mut qb = QueryBuilder::new(format!(
603            "update {} set domains = coalesce({}(domains, ",
604            self.table_name(),
605            self.filter_array_fn()
606        ));
607        qb.push_bind(sqlx::types::Json(domain));
608        qb.push("), '[]'::jsonb) where ");
609        build_predicate(&mut qb, &selector);
610        qb.push(" returning *");
611
612        qb.build_query_as()
613            .fetch_optional(&self.pool)
614            .await?
615            .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
616    }
617
618    async fn set_domains_for_key<S>(
619        &self,
620        selector: S,
621        domains: Vec<D>,
622    ) -> Result<Self::Key, Self::Error>
623    where
624        S: IntoSelector<Self::Key, Self::Domain>,
625    {
626        let selector = selector.into_selector();
627
628        let mut qb = QueryBuilder::new("update api_keys set domains = ");
629        qb.push_bind(sqlx::types::Json(domains));
630        qb.push(" where ");
631        build_predicate(&mut qb, &selector);
632        qb.push(" returning *");
633
634        qb.build_query_as()
635            .fetch_optional(&self.pool)
636            .await?
637            .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
638    }
639}
640
641#[cfg(test)]
642pub(crate) mod test {
643    use std::{sync::Arc, time::Duration};
644
645    use sqlx::Row;
646
647    use super::*;
648
649    #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
650    #[serde(tag = "type", rename_all = "snake_case")]
651    pub(crate) enum Domain {
652        All,
653        Guild { id: i64 },
654        User { id: i32 },
655        Faction { id: i32 },
656    }
657
658    impl KeyDomain for Domain {
659        fn fallback(&self) -> Option<Self> {
660            match self {
661                Self::Guild { id: _ } => Some(Self::All),
662                _ => None,
663            }
664        }
665    }
666
667    pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
668        sqlx::query("DROP TABLE IF EXISTS api_keys")
669            .execute(&pool)
670            .await
671            .unwrap();
672
673        let storage = PgKeyPoolStorage::new(pool.clone(), 1000, Some("test".to_owned()));
674        storage.initialise().await.unwrap();
675
676        let key = storage
677            .store_key(1, std::env::var("API_KEY").unwrap(), vec![Domain::All])
678            .await
679            .unwrap();
680
681        (storage, key)
682    }
683
684    #[sqlx::test]
685    async fn test_initialise(pool: PgPool) {
686        let (storage, _) = setup(pool).await;
687
688        if let Err(e) = storage.initialise().await {
689            panic!("Initialising key storage failed: {e:?}");
690        }
691    }
692
693    #[sqlx::test]
694    async fn test_store_duplicate_key(pool: PgPool) {
695        let (storage, key) = setup(pool).await;
696        let key = storage
697            .store_key(1, key.key, vec![Domain::User { id: 1 }])
698            .await
699            .unwrap();
700
701        assert_eq!(key.domains.0.len(), 2);
702    }
703
704    #[sqlx::test]
705    async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) {
706        let (storage, key) = setup(pool).await;
707        let key = storage
708            .store_key(1, key.key, vec![Domain::All])
709            .await
710            .unwrap();
711
712        assert_eq!(key.domains.0.len(), 1);
713    }
714
715    #[sqlx::test]
716    async fn test_add_domain(pool: PgPool) {
717        let (storage, key) = setup(pool).await;
718        let key = storage
719            .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
720            .await
721            .unwrap();
722
723        assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
724    }
725
726    #[sqlx::test]
727    async fn test_add_domain_id(pool: PgPool) {
728        let (storage, key) = setup(pool).await;
729        let key = storage
730            .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
731            .await
732            .unwrap();
733
734        assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
735    }
736
737    #[sqlx::test]
738    async fn test_add_duplicate_domain(pool: PgPool) {
739        let (storage, key) = setup(pool).await;
740        let key = storage
741            .add_domain_to_key(KeySelector::Key(key.key), Domain::All)
742            .await
743            .unwrap();
744        assert_eq!(
745            key.domains
746                .0
747                .into_iter()
748                .filter(|d| *d == Domain::All)
749                .count(),
750            1
751        );
752    }
753
754    #[sqlx::test]
755    async fn test_remove_domain(pool: PgPool) {
756        let (storage, key) = setup(pool).await;
757        storage
758            .add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
759            .await
760            .unwrap();
761        let key = storage
762            .remove_domain_from_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
763            .await
764            .unwrap();
765
766        assert_eq!(key.domains.0, vec![Domain::All]);
767    }
768
769    #[sqlx::test]
770    async fn test_remove_domain_id(pool: PgPool) {
771        let (storage, key) = setup(pool).await;
772        storage
773            .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
774            .await
775            .unwrap();
776        let key = storage
777            .remove_domain_from_key(KeySelector::Id(key.id), Domain::User { id: 1 })
778            .await
779            .unwrap();
780
781        assert_eq!(key.domains.0, vec![Domain::All]);
782    }
783
784    #[sqlx::test]
785    async fn test_remove_last_domain(pool: PgPool) {
786        let (storage, key) = setup(pool).await;
787        let key = storage
788            .remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
789            .await
790            .unwrap();
791
792        assert!(key.domains.0.is_empty());
793    }
794
795    #[sqlx::test]
796    async fn test_store_key(pool: PgPool) {
797        let (storage, _) = setup(pool).await;
798        let key = storage
799            .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
800            .await
801            .unwrap();
802        assert_eq!(key.value(), "ABCDABCDABCDABCD");
803    }
804
805    #[sqlx::test]
806    async fn test_read_user_keys(pool: PgPool) {
807        let (storage, _) = setup(pool).await;
808
809        let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
810        assert_eq!(keys.len(), 1);
811    }
812
813    #[sqlx::test]
814    async fn acquire_one(pool: PgPool) {
815        let (storage, _) = setup(pool).await;
816
817        if let Err(e) = storage.acquire_key(Domain::All).await {
818            panic!("Acquiring key failed: {e:?}");
819        }
820    }
821
822    #[sqlx::test]
823    async fn uses_spread(pool: PgPool) {
824        let (storage, _) = setup(pool).await;
825        storage
826            .store_key(1, "ABC".to_owned(), vec![Domain::All])
827            .await
828            .unwrap();
829
830        for _ in 0..10 {
831            _ = storage.acquire_key(Domain::All).await.unwrap();
832        }
833
834        let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
835        assert_eq!(keys.len(), 2);
836        for key in keys {
837            assert_eq!(key.uses, 5);
838        }
839    }
840
841    #[sqlx::test]
842    async fn acquire_many(pool: PgPool) {
843        let (storage, _) = setup(pool).await;
844
845        match storage.acquire_many_keys(Domain::All, 30).await {
846            Err(e) => panic!("Acquiring key failed: {e:?}"),
847            Ok(keys) => assert_eq!(keys.len(), 30),
848        }
849    }
850
851    // HACK: this test is time sensitive and will fail if runs at the top of the minute
852    #[sqlx::test]
853    async fn test_concurrent(pool: PgPool) {
854        let storage = Arc::new(setup(pool).await.0);
855
856        for _ in 0..10 {
857            let mut set = tokio::task::JoinSet::new();
858
859            for _ in 0..100 {
860                let storage = storage.clone();
861                set.spawn(async move {
862                    storage.acquire_key(Domain::All).await.unwrap();
863                });
864            }
865
866            for _ in 0..100 {
867                set.join_next().await.unwrap().unwrap();
868            }
869
870            let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
871                .fetch_one(&storage.pool)
872                .await
873                .unwrap()
874                .get("uses");
875
876            assert_eq!(uses, 100);
877
878            sqlx::query(&format!("update {} set uses=0", storage.table_name()))
879                .execute(&storage.pool)
880                .await
881                .unwrap();
882        }
883    }
884
885    #[sqlx::test]
886    async fn test_concurrent_spread(pool: PgPool) {
887        let storage = Arc::new(setup(pool).await.0);
888
889        for i in 0..24 {
890            storage
891                .store_key(1, format!("{i}"), vec![Domain::All])
892                .await
893                .unwrap();
894        }
895
896        for _ in 0..10 {
897            let mut set = tokio::task::JoinSet::new();
898
899            for _ in 0..50 {
900                let storage = storage.clone();
901                set.spawn(async move {
902                    storage.acquire_key(Domain::All).await.unwrap();
903                });
904            }
905
906            for _ in 0..50 {
907                set.join_next().await.unwrap().unwrap();
908            }
909
910            let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
911
912            assert_eq!(keys.len(), 25);
913
914            for key in keys {
915                assert_eq!(key.uses, 2);
916            }
917
918            sqlx::query(&format!("update {} set uses=0", storage.table_name()))
919                .execute(&storage.pool)
920                .await
921                .unwrap();
922        }
923    }
924
925    // HACK: this test is time sensitive and will fail if runs at the top of the minute
926    #[sqlx::test]
927    async fn test_concurrent_many(pool: PgPool) {
928        let storage = Arc::new(setup(pool).await.0);
929        for _ in 0..10 {
930            let mut set = tokio::task::JoinSet::new();
931
932            for _ in 0..100 {
933                let storage = storage.clone();
934                set.spawn(async move {
935                    storage.acquire_many_keys(Domain::All, 5).await.unwrap();
936                });
937            }
938
939            for _ in 0..100 {
940                set.join_next().await.unwrap().unwrap();
941            }
942
943            let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
944                .fetch_one(&storage.pool)
945                .await
946                .unwrap()
947                .get("uses");
948
949            assert_eq!(uses, 500);
950
951            sqlx::query(&format!("update {} set uses=0", storage.table_name()))
952                .execute(&storage.pool)
953                .await
954                .unwrap();
955        }
956    }
957
958    #[sqlx::test]
959    async fn read_key(pool: PgPool) {
960        let (storage, key) = setup(pool).await;
961
962        let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
963        assert!(key.is_some());
964    }
965
966    #[sqlx::test]
967    async fn read_key_id(pool: PgPool) {
968        let (storage, key) = setup(pool).await;
969
970        let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
971        assert!(key.is_some());
972    }
973
974    #[sqlx::test]
975    async fn read_nonexistent_key(pool: PgPool) {
976        let (storage, _) = setup(pool).await;
977
978        let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
979        assert!(key.is_none());
980    }
981
982    #[sqlx::test]
983    async fn query_key(pool: PgPool) {
984        let (storage, _) = setup(pool).await;
985
986        let key = storage.read_key(Domain::All).await.unwrap();
987        assert!(key.is_some());
988    }
989
990    #[sqlx::test]
991    async fn query_nonexistent_key(pool: PgPool) {
992        let (storage, _) = setup(pool).await;
993
994        let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
995        assert!(key.is_none());
996    }
997
998    #[sqlx::test]
999    async fn query_all(pool: PgPool) {
1000        let (storage, _) = setup(pool).await;
1001
1002        let keys = storage.read_keys(Domain::All).await.unwrap();
1003        assert!(keys.len() == 1);
1004    }
1005
1006    #[sqlx::test]
1007    async fn query_by_id(pool: PgPool) {
1008        let (storage, _) = setup(pool).await;
1009        let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
1010
1011        assert!(key.is_some());
1012    }
1013
1014    #[sqlx::test]
1015    async fn query_by_key(pool: PgPool) {
1016        let (storage, key) = setup(pool).await;
1017        let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
1018
1019        assert!(key.is_some());
1020    }
1021
1022    #[sqlx::test]
1023    async fn timeout(pool: PgPool) {
1024        let (storage, key) = setup(pool).await;
1025
1026        storage
1027            .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
1028            .await
1029            .unwrap();
1030    }
1031
1032    #[sqlx::test]
1033    async fn query_by_set(pool: PgPool) {
1034        let (storage, _key) = setup(pool).await;
1035        let key = storage
1036            .read_key(KeySelector::OneOf(vec![
1037                Domain::All,
1038                Domain::Guild { id: 0 },
1039                Domain::Faction { id: 0 },
1040            ]))
1041            .await
1042            .unwrap();
1043
1044        assert!(key.is_some());
1045    }
1046
1047    #[sqlx::test]
1048    async fn all_selector(pool: PgPool) {
1049        let (storage, key) = setup(pool).await;
1050
1051        storage
1052            .add_domain_to_key(key.selector(), Domain::Faction { id: 1 })
1053            .await
1054            .unwrap();
1055
1056        let key = storage
1057            .read_key(KeySelector::Has(vec![
1058                Domain::Faction { id: 1 },
1059                Domain::All,
1060            ]))
1061            .await
1062            .unwrap();
1063
1064        assert!(key.is_some());
1065
1066        let key = storage
1067            .read_key(KeySelector::Has(vec![
1068                Domain::All,
1069                Domain::Faction { id: 1 },
1070            ]))
1071            .await
1072            .unwrap();
1073
1074        assert!(key.is_some());
1075
1076        let key = storage
1077            .read_key(KeySelector::Has(vec![
1078                Domain::All,
1079                Domain::Faction { id: 2 },
1080                Domain::Faction { id: 1 },
1081            ]))
1082            .await
1083            .unwrap();
1084
1085        assert!(key.is_none());
1086    }
1087}