1use std::sync::Arc;
2
3use futures::future::BoxFuture;
4use indoc::formatdoc;
5use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
6use thiserror::Error;
7
8use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector};
9
10pub trait PgKeyDomain:
11 KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
12{
13}
14
15impl<T> PgKeyDomain for T where
16 T: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
17{
18}
19
20#[derive(Debug, Error)]
21pub enum PgKeyPoolError<D>
22where
23 D: PgKeyDomain,
24{
25 #[error("Databank: {0}")]
26 Pg(#[from] sqlx::Error),
27
28 #[error("Network: {0}")]
29 Network(#[from] reqwest::Error),
30
31 #[error("Parsing: {0}")]
32 Parsing(#[from] serde_json::Error),
33
34 #[error("Api: {0}")]
35 Api(#[from] torn_api::ApiError),
36
37 #[error("No key avalaible for domain {0:?}")]
38 Unavailable(KeySelector<PgKey<D>, D>),
39
40 #[error("Key not found: '{0:?}'")]
41 KeyNotFound(KeySelector<PgKey<D>, D>),
42
43 #[error("Failed to acquire keys in bulk: {0}")]
44 Bulk(#[from] Arc<Self>),
45}
46
47#[derive(Debug, Clone, FromRow)]
48pub struct PgKey<D>
49where
50 D: PgKeyDomain,
51{
52 pub id: i32,
53 pub user_id: i32,
54 pub key: String,
55 pub uses: i16,
56 pub domains: sqlx::types::Json<Vec<D>>,
57}
58
59#[inline(always)]
60fn build_predicate<'b, D>(
61 builder: &mut QueryBuilder<'b, Postgres>,
62 selector: &'b KeySelector<PgKey<D>, D>,
63) where
64 D: PgKeyDomain,
65{
66 match selector {
67 KeySelector::Id(id) => builder.push("id=").push_bind(id),
68 KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
69 KeySelector::Key(key) => builder.push("key=").push_bind(key),
70 KeySelector::Has(domains) => builder
71 .push("domains @> ")
72 .push_bind(sqlx::types::Json(domains)),
73 KeySelector::OneOf(domains) => {
74 if domains.is_empty() {
75 builder.push("false");
76 return;
77 }
78
79 for (idx, domain) in domains.iter().enumerate() {
80 if idx == 0 {
81 builder.push("(");
82 } else {
83 builder.push(" or ");
84 }
85 builder
86 .push("domains @> ")
87 .push_bind(sqlx::types::Json(vec![domain]));
88 }
89 builder.push(")")
90 }
91 };
92}
93
94#[derive(Debug, Clone, FromRow)]
95pub struct PgKeyPoolStorage<D>
96where
97 D: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
98{
99 pool: PgPool,
100 limit: i16,
101 schema: Option<String>,
102 _phantom: std::marker::PhantomData<D>,
103}
104
105impl<D> ApiKey for PgKey<D>
106where
107 D: PgKeyDomain,
108{
109 type IdType = i32;
110
111 #[inline(always)]
112 fn value(&self) -> &str {
113 &self.key
114 }
115
116 #[inline(always)]
117 fn id(&self) -> Self::IdType {
118 self.id
119 }
120}
121
122impl<D> PgKeyPoolStorage<D>
123where
124 D: PgKeyDomain,
125{
126 pub fn new(pool: PgPool, limit: i16, schema: Option<String>) -> Self {
127 Self {
128 pool,
129 limit,
130 schema,
131 _phantom: Default::default(),
132 }
133 }
134
135 fn table_name(&self) -> String {
136 match self.schema.as_ref() {
137 Some(schema) => format!("{schema}.api_keys"),
138 None => "api_keys".to_owned(),
139 }
140 }
141
142 fn unique_array_fn(&self) -> String {
143 match self.schema.as_ref() {
144 Some(schema) => format!("{schema}.__unique_jsonb_array"),
145 None => "__unique_jsonb_array".to_owned(),
146 }
147 }
148
149 fn filter_array_fn(&self) -> String {
150 match self.schema.as_ref() {
151 Some(schema) => format!("{schema}.__filter_jsonb_array"),
152 None => "__filter_jsonb_array".to_owned(),
153 }
154 }
155
156 pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
157 if let Some(schema) = self.schema.as_ref() {
158 sqlx::query(&format!("create schema if not exists {schema}"))
159 .execute(&self.pool)
160 .await?;
161 }
162
163 sqlx::query(&formatdoc! {r#"
164 CREATE TABLE IF NOT EXISTS {} (
165 id serial primary key,
166 user_id int4 not null,
167 key char(16) not null,
168 uses int2 not null default 0,
169 domains jsonb not null default '{{}}'::jsonb,
170 last_used timestamptz not null default now(),
171 flag int2,
172 cooldown timestamptz,
173 constraint "uq:api_keys.key" UNIQUE(key)
174 )"#,
175 self.table_name()
176 })
177 .execute(&self.pool)
178 .await?;
179
180 sqlx::query(&formatdoc! {r#"
181 CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON {} USING GIN(domains jsonb_path_ops)
182 "#, self.table_name()})
183 .execute(&self.pool)
184 .await?;
185
186 sqlx::query(&formatdoc! {r#"
187 CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON {} USING BTREE(user_id)
188 "#, self.table_name()})
189 .execute(&self.pool)
190 .await?;
191
192 sqlx::query(&formatdoc! {r#"
193 create or replace function {}(jsonb) returns jsonb
194 AS $$
195 select jsonb_agg(d::jsonb) from (
196 select distinct jsonb_array_elements_text($1) as d
197 ) t
198 $$ language sql;
199 "#, self.unique_array_fn()})
200 .execute(&self.pool)
201 .await?;
202
203 sqlx::query(&formatdoc! {r#"
204 create or replace function {}(jsonb, jsonb) returns jsonb
205 AS $$
206 select jsonb_agg(d::jsonb) from (
207 select distinct jsonb_array_elements_text($1) as d
208 ) t where d<>$2::text
209 $$ language sql;
210 "#, self.filter_array_fn()})
211 .execute(&self.pool)
212 .await?;
213
214 Ok(())
215 }
216}
217
218#[cfg(feature = "tokio-runtime")]
219async fn random_sleep() {
220 use rand::{rng, Rng};
221 let dur = tokio::time::Duration::from_millis(rng().random_range(1..50));
222 tokio::time::sleep(dur).await;
223}
224
225impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
226where
227 D: PgKeyDomain,
228{
229 type Key = PgKey<D>;
230 type Domain = D;
231
232 type Error = PgKeyPoolError<D>;
233
234 async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
235 where
236 S: IntoSelector<Self::Key, Self::Domain>,
237 {
238 let selector = selector.into_selector();
239 loop {
240 let attempt = async {
241 let mut tx = self.pool.begin().await?;
242
243 sqlx::query("set transaction isolation level repeatable read")
244 .execute(&mut *tx)
245 .await?;
246
247 let mut qb = QueryBuilder::new(&formatdoc! {
248 r#"
249 with key as (
250 select
251 id,
252 0::int2 as uses
253 from {} where last_used < date_trunc('minute', now())
254 and (cooldown is null or now() >= cooldown)
255 and "#,
256 self.table_name()
257 });
258
259 build_predicate(&mut qb, &selector);
260
261 qb.push(formatdoc! {
262 "
263 \n union (
264 select id, uses from {}
265 where last_used >= date_trunc('minute', now())
266 and (cooldown is null or now() >= cooldown)
267 and ",
268 self.table_name()
269 });
270
271 build_predicate(&mut qb, &selector);
272
273 qb.push(formatdoc! {
274 "
275 \n order by uses asc limit 1
276 )
277 order by uses asc limit 1
278 )
279 update {} as keys set
280 uses = key.uses + 1,
281 cooldown = null,
282 flag = null,
283 last_used = now()
284 from key where
285 keys.id=key.id and key.uses < ",
286 self.table_name()
287 });
288
289 qb.push_bind(self.limit);
290
291 qb.push(indoc::indoc! { "
292 \nreturning keys.id, keys.user_id, keys.key, keys.uses, keys.domains"
293 });
294
295 let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
296
297 tx.commit().await?;
298
299 Result::<Option<Self::Key>, sqlx::Error>::Ok(key)
300 }
301 .await;
302
303 match attempt {
304 Ok(Some(result)) => return Ok(result),
305 Ok(None) => {
306 fn recurse<D>(
307 storage: &PgKeyPoolStorage<D>,
308 selector: KeySelector<PgKey<D>, D>,
309 ) -> BoxFuture<'_, Result<PgKey<D>, PgKeyPoolError<D>>>
310 where
311 D: PgKeyDomain,
312 {
313 Box::pin(storage.acquire_key(selector))
314 }
315
316 return recurse(
317 self,
318 selector
319 .fallback()
320 .ok_or_else(|| PgKeyPoolError::Unavailable(selector))?,
321 )
322 .await;
323 }
324 Err(error) => {
325 if let Some(db_error) = error.as_database_error() {
326 let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
327 if pg_error.code() == "40001" {
328 random_sleep().await;
329 } else {
330 return Err(error.into());
331 }
332 } else {
333 return Err(error.into());
334 }
335 }
336 }
337 }
338 }
339
340 async fn acquire_many_keys<S>(
341 &self,
342 selector: S,
343 number: i64,
344 ) -> Result<Vec<Self::Key>, Self::Error>
345 where
346 S: IntoSelector<Self::Key, Self::Domain>,
347 {
348 let selector = selector.into_selector();
349 loop {
350 let attempt = async {
351 let mut tx = self.pool.begin().await?;
352
353 sqlx::query("set transaction isolation level repeatable read")
354 .execute(&mut *tx)
355 .await?;
356
357 let mut qb = QueryBuilder::new(&formatdoc! {
358 r#"select
359 id,
360 user_id,
361 key,
362 0::int2 as uses,
363 domains
364 from {} where last_used < date_trunc('minute', now())
365 and (cooldown is null or now() >= cooldown)
366 and "#,
367 self.table_name()
368 });
369 build_predicate(&mut qb, &selector);
370 qb.push(formatdoc! {
371 "
372 \nunion
373 select
374 id,
375 user_id,
376 key,
377 uses,
378 domains
379 from {} where last_used >= date_trunc('minute', now())
380 and (cooldown is null or now() >= cooldown)
381 and ",
382 self.table_name()
383 });
384 build_predicate(&mut qb, &selector);
385 qb.push("\norder by uses limit ");
386 qb.push_bind(self.limit);
387
388 let mut keys: Vec<Self::Key> = qb.build_query_as().fetch_all(&mut *tx).await?;
389
390 if keys.is_empty() {
391 tx.commit().await?;
392 return Ok(None);
393 }
394
395 keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses));
396
397 let mut result = Vec::with_capacity(number as usize);
398 let (max, rest) = keys.split_last_mut().unwrap();
399 for key in rest {
400 let available = max.uses - key.uses;
401 let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
402 key.uses += using;
403 result.extend(std::iter::repeat_n(key.clone(), using as usize));
404
405 if result.len() == number as usize {
406 break;
407 }
408 }
409
410 while result.len() < (number as usize) {
411 if keys[0].uses == self.limit {
412 break;
413 }
414
415 let take = std::cmp::min(keys.len(), (number as usize) - result.len());
416 let slice = &mut keys[0..take];
417 slice.iter_mut().for_each(|k| k.uses += 1);
418 result.extend_from_slice(slice);
419 }
420
421 sqlx::query(&formatdoc! {r#"
422 update {} keys set
423 uses = tmp.uses,
424 cooldown = null,
425 flag = null,
426 last_used = now()
427 from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
428 where keys.id = tmp.id
429 "#, self.table_name()})
430 .bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
431 .bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
432 .execute(&mut *tx)
433 .await?;
434
435 tx.commit().await?;
436
437 Result::<Option<Vec<Self::Key>>, sqlx::Error>::Ok(Some(result))
438 }
439 .await;
440
441 match attempt {
442 Ok(Some(result)) => return Ok(result),
443 Ok(None) => {
444 fn recurse<D>(
445 storage: &PgKeyPoolStorage<D>,
446 selector: KeySelector<PgKey<D>, D>,
447 number: i64,
448 ) -> BoxFuture<'_, Result<Vec<PgKey<D>>, PgKeyPoolError<D>>>
449 where
450 D: PgKeyDomain,
451 {
452 Box::pin(storage.acquire_many_keys(selector, number))
453 }
454
455 return recurse(
456 self,
457 selector
458 .fallback()
459 .ok_or_else(|| Self::Error::Unavailable(selector))?,
460 number,
461 )
462 .await;
463 }
464 Err(error) => {
465 if let Some(db_error) = error.as_database_error() {
466 let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
467 if pg_error.code() == "40001" {
468 random_sleep().await;
469 } else {
470 return Err(error.into());
471 }
472 } else {
473 return Err(error.into());
474 }
475 }
476 }
477 }
478 }
479
480 async fn timeout_key<S>(
481 &self,
482 selector: S,
483 duration: std::time::Duration,
484 ) -> Result<(), Self::Error>
485 where
486 S: IntoSelector<Self::Key, Self::Domain>,
487 {
488 let selector = selector.into_selector();
489
490 let mut qb = QueryBuilder::new(format!(
491 "update {} set cooldown=now() + ",
492 self.table_name()
493 ));
494 qb.push_bind(duration);
495 qb.push(" where ");
496 build_predicate(&mut qb, &selector);
497
498 qb.build().fetch_optional(&self.pool).await?;
499
500 Ok(())
501 }
502
503 async fn store_key(
504 &self,
505 user_id: i32,
506 key: String,
507 domains: Vec<D>,
508 ) -> Result<Self::Key, Self::Error> {
509 sqlx::query_as(&dbg!(formatdoc!(
510 "insert into {} as api_keys(user_id, key, domains) values ($1, $2, $3)
511 on conflict(key) do update
512 set domains = {}(excluded.domains || api_keys.domains) returning *",
513 self.table_name(),
514 self.unique_array_fn()
515 )))
516 .bind(user_id)
517 .bind(&key)
518 .bind(sqlx::types::Json(domains))
519 .fetch_one(&self.pool)
520 .await
521 .map_err(Into::into)
522 }
523
524 async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
525 where
526 S: IntoSelector<Self::Key, Self::Domain>,
527 {
528 let selector = selector.into_selector();
529
530 let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
531 build_predicate(&mut qb, &selector);
532
533 qb.build_query_as()
534 .fetch_optional(&self.pool)
535 .await
536 .map_err(Into::into)
537 }
538
539 async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
540 where
541 S: IntoSelector<Self::Key, Self::Domain>,
542 {
543 let selector = selector.into_selector();
544
545 let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
546 build_predicate(&mut qb, &selector);
547
548 qb.build_query_as()
549 .fetch_all(&self.pool)
550 .await
551 .map_err(Into::into)
552 }
553
554 async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
555 where
556 S: IntoSelector<Self::Key, Self::Domain>,
557 {
558 let selector = selector.into_selector();
559
560 let mut qb = QueryBuilder::new(format!("delete from {} where ", self.table_name()));
561 build_predicate(&mut qb, &selector);
562 qb.push(" returning *");
563
564 qb.build_query_as()
565 .fetch_optional(&self.pool)
566 .await?
567 .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
568 }
569
570 async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
571 where
572 S: IntoSelector<Self::Key, Self::Domain>,
573 {
574 let selector = selector.into_selector();
575
576 let mut qb = QueryBuilder::new(format!(
577 "update {} set domains = {}(domains || jsonb_build_array(",
578 self.table_name(),
579 self.unique_array_fn()
580 ));
581 qb.push_bind(sqlx::types::Json(domain));
582 qb.push(")) where ");
583 build_predicate(&mut qb, &selector);
584 qb.push(" returning *");
585
586 qb.build_query_as()
587 .fetch_optional(&self.pool)
588 .await?
589 .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
590 }
591
592 async fn remove_domain_from_key<S>(
593 &self,
594 selector: S,
595 domain: D,
596 ) -> Result<Self::Key, Self::Error>
597 where
598 S: IntoSelector<Self::Key, Self::Domain>,
599 {
600 let selector = selector.into_selector();
601
602 let mut qb = QueryBuilder::new(format!(
603 "update {} set domains = coalesce({}(domains, ",
604 self.table_name(),
605 self.filter_array_fn()
606 ));
607 qb.push_bind(sqlx::types::Json(domain));
608 qb.push("), '[]'::jsonb) where ");
609 build_predicate(&mut qb, &selector);
610 qb.push(" returning *");
611
612 qb.build_query_as()
613 .fetch_optional(&self.pool)
614 .await?
615 .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
616 }
617
618 async fn set_domains_for_key<S>(
619 &self,
620 selector: S,
621 domains: Vec<D>,
622 ) -> Result<Self::Key, Self::Error>
623 where
624 S: IntoSelector<Self::Key, Self::Domain>,
625 {
626 let selector = selector.into_selector();
627
628 let mut qb = QueryBuilder::new("update api_keys set domains = ");
629 qb.push_bind(sqlx::types::Json(domains));
630 qb.push(" where ");
631 build_predicate(&mut qb, &selector);
632 qb.push(" returning *");
633
634 qb.build_query_as()
635 .fetch_optional(&self.pool)
636 .await?
637 .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
638 }
639}
640
641#[cfg(test)]
642pub(crate) mod test {
643 use std::{sync::Arc, time::Duration};
644
645 use sqlx::Row;
646
647 use super::*;
648
649 #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
650 #[serde(tag = "type", rename_all = "snake_case")]
651 pub(crate) enum Domain {
652 All,
653 Guild { id: i64 },
654 User { id: i32 },
655 Faction { id: i32 },
656 }
657
658 impl KeyDomain for Domain {
659 fn fallback(&self) -> Option<Self> {
660 match self {
661 Self::Guild { id: _ } => Some(Self::All),
662 _ => None,
663 }
664 }
665 }
666
667 pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
668 sqlx::query("DROP TABLE IF EXISTS api_keys")
669 .execute(&pool)
670 .await
671 .unwrap();
672
673 let storage = PgKeyPoolStorage::new(pool.clone(), 1000, Some("test".to_owned()));
674 storage.initialise().await.unwrap();
675
676 let key = storage
677 .store_key(1, std::env::var("API_KEY").unwrap(), vec![Domain::All])
678 .await
679 .unwrap();
680
681 (storage, key)
682 }
683
684 #[sqlx::test]
685 async fn test_initialise(pool: PgPool) {
686 let (storage, _) = setup(pool).await;
687
688 if let Err(e) = storage.initialise().await {
689 panic!("Initialising key storage failed: {e:?}");
690 }
691 }
692
693 #[sqlx::test]
694 async fn test_store_duplicate_key(pool: PgPool) {
695 let (storage, key) = setup(pool).await;
696 let key = storage
697 .store_key(1, key.key, vec![Domain::User { id: 1 }])
698 .await
699 .unwrap();
700
701 assert_eq!(key.domains.0.len(), 2);
702 }
703
704 #[sqlx::test]
705 async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) {
706 let (storage, key) = setup(pool).await;
707 let key = storage
708 .store_key(1, key.key, vec![Domain::All])
709 .await
710 .unwrap();
711
712 assert_eq!(key.domains.0.len(), 1);
713 }
714
715 #[sqlx::test]
716 async fn test_add_domain(pool: PgPool) {
717 let (storage, key) = setup(pool).await;
718 let key = storage
719 .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
720 .await
721 .unwrap();
722
723 assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
724 }
725
726 #[sqlx::test]
727 async fn test_add_domain_id(pool: PgPool) {
728 let (storage, key) = setup(pool).await;
729 let key = storage
730 .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
731 .await
732 .unwrap();
733
734 assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
735 }
736
737 #[sqlx::test]
738 async fn test_add_duplicate_domain(pool: PgPool) {
739 let (storage, key) = setup(pool).await;
740 let key = storage
741 .add_domain_to_key(KeySelector::Key(key.key), Domain::All)
742 .await
743 .unwrap();
744 assert_eq!(
745 key.domains
746 .0
747 .into_iter()
748 .filter(|d| *d == Domain::All)
749 .count(),
750 1
751 );
752 }
753
754 #[sqlx::test]
755 async fn test_remove_domain(pool: PgPool) {
756 let (storage, key) = setup(pool).await;
757 storage
758 .add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
759 .await
760 .unwrap();
761 let key = storage
762 .remove_domain_from_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
763 .await
764 .unwrap();
765
766 assert_eq!(key.domains.0, vec![Domain::All]);
767 }
768
769 #[sqlx::test]
770 async fn test_remove_domain_id(pool: PgPool) {
771 let (storage, key) = setup(pool).await;
772 storage
773 .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
774 .await
775 .unwrap();
776 let key = storage
777 .remove_domain_from_key(KeySelector::Id(key.id), Domain::User { id: 1 })
778 .await
779 .unwrap();
780
781 assert_eq!(key.domains.0, vec![Domain::All]);
782 }
783
784 #[sqlx::test]
785 async fn test_remove_last_domain(pool: PgPool) {
786 let (storage, key) = setup(pool).await;
787 let key = storage
788 .remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
789 .await
790 .unwrap();
791
792 assert!(key.domains.0.is_empty());
793 }
794
795 #[sqlx::test]
796 async fn test_store_key(pool: PgPool) {
797 let (storage, _) = setup(pool).await;
798 let key = storage
799 .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
800 .await
801 .unwrap();
802 assert_eq!(key.value(), "ABCDABCDABCDABCD");
803 }
804
805 #[sqlx::test]
806 async fn test_read_user_keys(pool: PgPool) {
807 let (storage, _) = setup(pool).await;
808
809 let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
810 assert_eq!(keys.len(), 1);
811 }
812
813 #[sqlx::test]
814 async fn acquire_one(pool: PgPool) {
815 let (storage, _) = setup(pool).await;
816
817 if let Err(e) = storage.acquire_key(Domain::All).await {
818 panic!("Acquiring key failed: {e:?}");
819 }
820 }
821
822 #[sqlx::test]
823 async fn uses_spread(pool: PgPool) {
824 let (storage, _) = setup(pool).await;
825 storage
826 .store_key(1, "ABC".to_owned(), vec![Domain::All])
827 .await
828 .unwrap();
829
830 for _ in 0..10 {
831 _ = storage.acquire_key(Domain::All).await.unwrap();
832 }
833
834 let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
835 assert_eq!(keys.len(), 2);
836 for key in keys {
837 assert_eq!(key.uses, 5);
838 }
839 }
840
841 #[sqlx::test]
842 async fn acquire_many(pool: PgPool) {
843 let (storage, _) = setup(pool).await;
844
845 match storage.acquire_many_keys(Domain::All, 30).await {
846 Err(e) => panic!("Acquiring key failed: {e:?}"),
847 Ok(keys) => assert_eq!(keys.len(), 30),
848 }
849 }
850
851 #[sqlx::test]
853 async fn test_concurrent(pool: PgPool) {
854 let storage = Arc::new(setup(pool).await.0);
855
856 for _ in 0..10 {
857 let mut set = tokio::task::JoinSet::new();
858
859 for _ in 0..100 {
860 let storage = storage.clone();
861 set.spawn(async move {
862 storage.acquire_key(Domain::All).await.unwrap();
863 });
864 }
865
866 for _ in 0..100 {
867 set.join_next().await.unwrap().unwrap();
868 }
869
870 let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
871 .fetch_one(&storage.pool)
872 .await
873 .unwrap()
874 .get("uses");
875
876 assert_eq!(uses, 100);
877
878 sqlx::query(&format!("update {} set uses=0", storage.table_name()))
879 .execute(&storage.pool)
880 .await
881 .unwrap();
882 }
883 }
884
885 #[sqlx::test]
886 async fn test_concurrent_spread(pool: PgPool) {
887 let storage = Arc::new(setup(pool).await.0);
888
889 for i in 0..24 {
890 storage
891 .store_key(1, format!("{i}"), vec![Domain::All])
892 .await
893 .unwrap();
894 }
895
896 for _ in 0..10 {
897 let mut set = tokio::task::JoinSet::new();
898
899 for _ in 0..50 {
900 let storage = storage.clone();
901 set.spawn(async move {
902 storage.acquire_key(Domain::All).await.unwrap();
903 });
904 }
905
906 for _ in 0..50 {
907 set.join_next().await.unwrap().unwrap();
908 }
909
910 let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
911
912 assert_eq!(keys.len(), 25);
913
914 for key in keys {
915 assert_eq!(key.uses, 2);
916 }
917
918 sqlx::query(&format!("update {} set uses=0", storage.table_name()))
919 .execute(&storage.pool)
920 .await
921 .unwrap();
922 }
923 }
924
925 #[sqlx::test]
927 async fn test_concurrent_many(pool: PgPool) {
928 let storage = Arc::new(setup(pool).await.0);
929 for _ in 0..10 {
930 let mut set = tokio::task::JoinSet::new();
931
932 for _ in 0..100 {
933 let storage = storage.clone();
934 set.spawn(async move {
935 storage.acquire_many_keys(Domain::All, 5).await.unwrap();
936 });
937 }
938
939 for _ in 0..100 {
940 set.join_next().await.unwrap().unwrap();
941 }
942
943 let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
944 .fetch_one(&storage.pool)
945 .await
946 .unwrap()
947 .get("uses");
948
949 assert_eq!(uses, 500);
950
951 sqlx::query(&format!("update {} set uses=0", storage.table_name()))
952 .execute(&storage.pool)
953 .await
954 .unwrap();
955 }
956 }
957
958 #[sqlx::test]
959 async fn read_key(pool: PgPool) {
960 let (storage, key) = setup(pool).await;
961
962 let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
963 assert!(key.is_some());
964 }
965
966 #[sqlx::test]
967 async fn read_key_id(pool: PgPool) {
968 let (storage, key) = setup(pool).await;
969
970 let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
971 assert!(key.is_some());
972 }
973
974 #[sqlx::test]
975 async fn read_nonexistent_key(pool: PgPool) {
976 let (storage, _) = setup(pool).await;
977
978 let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
979 assert!(key.is_none());
980 }
981
982 #[sqlx::test]
983 async fn query_key(pool: PgPool) {
984 let (storage, _) = setup(pool).await;
985
986 let key = storage.read_key(Domain::All).await.unwrap();
987 assert!(key.is_some());
988 }
989
990 #[sqlx::test]
991 async fn query_nonexistent_key(pool: PgPool) {
992 let (storage, _) = setup(pool).await;
993
994 let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
995 assert!(key.is_none());
996 }
997
998 #[sqlx::test]
999 async fn query_all(pool: PgPool) {
1000 let (storage, _) = setup(pool).await;
1001
1002 let keys = storage.read_keys(Domain::All).await.unwrap();
1003 assert!(keys.len() == 1);
1004 }
1005
1006 #[sqlx::test]
1007 async fn query_by_id(pool: PgPool) {
1008 let (storage, _) = setup(pool).await;
1009 let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
1010
1011 assert!(key.is_some());
1012 }
1013
1014 #[sqlx::test]
1015 async fn query_by_key(pool: PgPool) {
1016 let (storage, key) = setup(pool).await;
1017 let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
1018
1019 assert!(key.is_some());
1020 }
1021
1022 #[sqlx::test]
1023 async fn timeout(pool: PgPool) {
1024 let (storage, key) = setup(pool).await;
1025
1026 storage
1027 .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
1028 .await
1029 .unwrap();
1030 }
1031
1032 #[sqlx::test]
1033 async fn query_by_set(pool: PgPool) {
1034 let (storage, _key) = setup(pool).await;
1035 let key = storage
1036 .read_key(KeySelector::OneOf(vec![
1037 Domain::All,
1038 Domain::Guild { id: 0 },
1039 Domain::Faction { id: 0 },
1040 ]))
1041 .await
1042 .unwrap();
1043
1044 assert!(key.is_some());
1045 }
1046
1047 #[sqlx::test]
1048 async fn all_selector(pool: PgPool) {
1049 let (storage, key) = setup(pool).await;
1050
1051 storage
1052 .add_domain_to_key(key.selector(), Domain::Faction { id: 1 })
1053 .await
1054 .unwrap();
1055
1056 let key = storage
1057 .read_key(KeySelector::Has(vec![
1058 Domain::Faction { id: 1 },
1059 Domain::All,
1060 ]))
1061 .await
1062 .unwrap();
1063
1064 assert!(key.is_some());
1065
1066 let key = storage
1067 .read_key(KeySelector::Has(vec![
1068 Domain::All,
1069 Domain::Faction { id: 1 },
1070 ]))
1071 .await
1072 .unwrap();
1073
1074 assert!(key.is_some());
1075
1076 let key = storage
1077 .read_key(KeySelector::Has(vec![
1078 Domain::All,
1079 Domain::Faction { id: 2 },
1080 Domain::Faction { id: 1 },
1081 ]))
1082 .await
1083 .unwrap();
1084
1085 assert!(key.is_none());
1086 }
1087}