1#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)]
2
3#[cfg(feature = "postgres")]
4pub mod postgres;
5
6pub mod send;
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use thiserror::Error;
13
14use torn_api::ResponseError;
15
16#[derive(Debug, Error)]
17pub enum KeyPoolError<S, C>
18where
19 S: std::error::Error + Clone,
20 C: std::error::Error,
21{
22 #[error("Key pool storage driver error: {0:?}")]
23 Storage(#[source] S),
24
25 #[error(transparent)]
26 Client(#[from] C),
27
28 #[error(transparent)]
29 Response(ResponseError),
30}
31
32impl<S, C> KeyPoolError<S, C>
33where
34 S: std::error::Error + Clone,
35 C: std::error::Error,
36{
37 #[inline(always)]
38 pub fn api_code(&self) -> Option<u8> {
39 match self {
40 Self::Response(why) => why.api_code(),
41 _ => None,
42 }
43 }
44}
45
46pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone + 'static {
47 type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone;
48
49 fn value(&self) -> &str;
50
51 fn id(&self) -> Self::IdType;
52
53 fn selector<D>(&self) -> KeySelector<Self, D>
54 where
55 D: KeyDomain,
56 {
57 KeySelector::Id(self.id())
58 }
59}
60
61pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
62 fn fallback(&self) -> Option<Self> {
63 None
64 }
65}
66
67#[derive(Debug, Clone)]
68pub enum KeySelector<K, D>
69where
70 K: ApiKey,
71 D: KeyDomain,
72{
73 Key(String),
74 Id(K::IdType),
75 UserId(i32),
76 Has(Vec<D>),
77 OneOf(Vec<D>),
78}
79
80impl<K, D> KeySelector<K, D>
81where
82 K: ApiKey,
83 D: KeyDomain,
84{
85 pub(crate) fn fallback(&self) -> Option<Self> {
86 match self {
87 Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
88 Self::Has(domains) => {
89 let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
90 if fallbacks.is_empty() {
91 None
92 } else {
93 Some(Self::Has(fallbacks))
94 }
95 }
96 Self::OneOf(domains) => {
97 let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
98 if fallbacks.is_empty() {
99 None
100 } else {
101 Some(Self::OneOf(fallbacks))
102 }
103 }
104 }
105 }
106}
107
108pub trait IntoSelector<K, D>: Send + Sync
109where
110 K: ApiKey,
111 D: KeyDomain,
112{
113 fn into_selector(self) -> KeySelector<K, D>;
114}
115
116impl<K, D> IntoSelector<K, D> for D
117where
118 K: ApiKey,
119 D: KeyDomain,
120{
121 fn into_selector(self) -> KeySelector<K, D> {
122 KeySelector::Has(vec![self])
123 }
124}
125
126impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
127where
128 K: ApiKey,
129 D: KeyDomain,
130{
131 fn into_selector(self) -> KeySelector<K, D> {
132 self
133 }
134}
135
136pub enum KeyAction<D>
137where
138 D: KeyDomain,
139{
140 Delete,
141 RemoveDomain(D),
142 Timeout(chrono::Duration),
143}
144
145#[async_trait]
146pub trait KeyPoolStorage {
147 type Key: ApiKey;
148 type Domain: KeyDomain;
149 type Error: std::error::Error + Sync + Send + Clone;
150
151 async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
152 where
153 S: IntoSelector<Self::Key, Self::Domain>;
154
155 async fn acquire_many_keys<S>(
156 &self,
157 selector: S,
158 number: i64,
159 ) -> Result<Vec<Self::Key>, Self::Error>
160 where
161 S: IntoSelector<Self::Key, Self::Domain>;
162
163 async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
164
165 async fn store_key(
166 &self,
167 user_id: i32,
168 key: String,
169 domains: Vec<Self::Domain>,
170 ) -> Result<Self::Key, Self::Error>;
171
172 async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
173 where
174 S: IntoSelector<Self::Key, Self::Domain>;
175
176 async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
177 where
178 S: IntoSelector<Self::Key, Self::Domain>;
179
180 async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
181 where
182 S: IntoSelector<Self::Key, Self::Domain>;
183
184 async fn add_domain_to_key<S>(
185 &self,
186 selector: S,
187 domain: Self::Domain,
188 ) -> Result<Self::Key, Self::Error>
189 where
190 S: IntoSelector<Self::Key, Self::Domain>;
191
192 async fn remove_domain_from_key<S>(
193 &self,
194 selector: S,
195 domain: Self::Domain,
196 ) -> Result<Self::Key, Self::Error>
197 where
198 S: IntoSelector<Self::Key, Self::Domain>;
199
200 async fn set_domains_for_key<S>(
201 &self,
202 selector: S,
203 domains: Vec<Self::Domain>,
204 ) -> Result<Self::Key, Self::Error>
205 where
206 S: IntoSelector<Self::Key, Self::Domain>;
207}
208
209#[derive(Debug, Default)]
210pub struct PoolOptions {
211 comment: Option<String>,
212 hooks_before: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
213 hooks_after: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
214}
215
216#[derive(Debug, Clone)]
217pub struct KeyPoolExecutor<'a, C, S>
218where
219 S: KeyPoolStorage,
220{
221 storage: &'a S,
222 options: Arc<PoolOptions>,
223 selector: KeySelector<S::Key, S::Domain>,
224 _marker: std::marker::PhantomData<C>,
225}
226
227impl<'a, C, S> KeyPoolExecutor<'a, C, S>
228where
229 S: KeyPoolStorage,
230{
231 pub fn new(
232 storage: &'a S,
233 selector: KeySelector<S::Key, S::Domain>,
234 options: Arc<PoolOptions>,
235 ) -> Self {
236 Self {
237 storage,
238 selector,
239 options,
240 _marker: std::marker::PhantomData,
241 }
242 }
243}
244
245#[cfg(all(test, feature = "postgres"))]
246mod test {}