torn_key_pool/
lib.rs

1#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)]
2
3#[cfg(feature = "postgres")]
4pub mod postgres;
5
6// pub mod local;
7pub mod send;
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use thiserror::Error;
13
14use torn_api::ResponseError;
15
16#[derive(Debug, Error)]
17pub enum KeyPoolError<S, C>
18where
19    S: std::error::Error + Clone,
20    C: std::error::Error,
21{
22    #[error("Key pool storage driver error: {0:?}")]
23    Storage(#[source] S),
24
25    #[error(transparent)]
26    Client(#[from] C),
27
28    #[error(transparent)]
29    Response(ResponseError),
30}
31
32impl<S, C> KeyPoolError<S, C>
33where
34    S: std::error::Error + Clone,
35    C: std::error::Error,
36{
37    #[inline(always)]
38    pub fn api_code(&self) -> Option<u8> {
39        match self {
40            Self::Response(why) => why.api_code(),
41            _ => None,
42        }
43    }
44}
45
46pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone + 'static {
47    type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone;
48
49    fn value(&self) -> &str;
50
51    fn id(&self) -> Self::IdType;
52
53    fn selector<D>(&self) -> KeySelector<Self, D>
54    where
55        D: KeyDomain,
56    {
57        KeySelector::Id(self.id())
58    }
59}
60
61pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
62    fn fallback(&self) -> Option<Self> {
63        None
64    }
65}
66
67#[derive(Debug, Clone)]
68pub enum KeySelector<K, D>
69where
70    K: ApiKey,
71    D: KeyDomain,
72{
73    Key(String),
74    Id(K::IdType),
75    UserId(i32),
76    Has(Vec<D>),
77    OneOf(Vec<D>),
78}
79
80impl<K, D> KeySelector<K, D>
81where
82    K: ApiKey,
83    D: KeyDomain,
84{
85    pub(crate) fn fallback(&self) -> Option<Self> {
86        match self {
87            Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
88            Self::Has(domains) => {
89                let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
90                if fallbacks.is_empty() {
91                    None
92                } else {
93                    Some(Self::Has(fallbacks))
94                }
95            }
96            Self::OneOf(domains) => {
97                let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
98                if fallbacks.is_empty() {
99                    None
100                } else {
101                    Some(Self::OneOf(fallbacks))
102                }
103            }
104        }
105    }
106}
107
108pub trait IntoSelector<K, D>: Send + Sync
109where
110    K: ApiKey,
111    D: KeyDomain,
112{
113    fn into_selector(self) -> KeySelector<K, D>;
114}
115
116impl<K, D> IntoSelector<K, D> for D
117where
118    K: ApiKey,
119    D: KeyDomain,
120{
121    fn into_selector(self) -> KeySelector<K, D> {
122        KeySelector::Has(vec![self])
123    }
124}
125
126impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
127where
128    K: ApiKey,
129    D: KeyDomain,
130{
131    fn into_selector(self) -> KeySelector<K, D> {
132        self
133    }
134}
135
136pub enum KeyAction<D>
137where
138    D: KeyDomain,
139{
140    Delete,
141    RemoveDomain(D),
142    Timeout(chrono::Duration),
143}
144
145#[async_trait]
146pub trait KeyPoolStorage {
147    type Key: ApiKey;
148    type Domain: KeyDomain;
149    type Error: std::error::Error + Sync + Send + Clone;
150
151    async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
152    where
153        S: IntoSelector<Self::Key, Self::Domain>;
154
155    async fn acquire_many_keys<S>(
156        &self,
157        selector: S,
158        number: i64,
159    ) -> Result<Vec<Self::Key>, Self::Error>
160    where
161        S: IntoSelector<Self::Key, Self::Domain>;
162
163    async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
164
165    async fn store_key(
166        &self,
167        user_id: i32,
168        key: String,
169        domains: Vec<Self::Domain>,
170    ) -> Result<Self::Key, Self::Error>;
171
172    async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
173    where
174        S: IntoSelector<Self::Key, Self::Domain>;
175
176    async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
177    where
178        S: IntoSelector<Self::Key, Self::Domain>;
179
180    async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
181    where
182        S: IntoSelector<Self::Key, Self::Domain>;
183
184    async fn add_domain_to_key<S>(
185        &self,
186        selector: S,
187        domain: Self::Domain,
188    ) -> Result<Self::Key, Self::Error>
189    where
190        S: IntoSelector<Self::Key, Self::Domain>;
191
192    async fn remove_domain_from_key<S>(
193        &self,
194        selector: S,
195        domain: Self::Domain,
196    ) -> Result<Self::Key, Self::Error>
197    where
198        S: IntoSelector<Self::Key, Self::Domain>;
199
200    async fn set_domains_for_key<S>(
201        &self,
202        selector: S,
203        domains: Vec<Self::Domain>,
204    ) -> Result<Self::Key, Self::Error>
205    where
206        S: IntoSelector<Self::Key, Self::Domain>;
207}
208
209#[derive(Debug, Default)]
210pub struct PoolOptions {
211    comment: Option<String>,
212    hooks_before: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
213    hooks_after: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
214}
215
216#[derive(Debug, Clone)]
217pub struct KeyPoolExecutor<'a, C, S>
218where
219    S: KeyPoolStorage,
220{
221    storage: &'a S,
222    options: Arc<PoolOptions>,
223    selector: KeySelector<S::Key, S::Domain>,
224    _marker: std::marker::PhantomData<C>,
225}
226
227impl<'a, C, S> KeyPoolExecutor<'a, C, S>
228where
229    S: KeyPoolStorage,
230{
231    pub fn new(
232        storage: &'a S,
233        selector: KeySelector<S::Key, S::Domain>,
234        options: Arc<PoolOptions>,
235    ) -> Self {
236        Self {
237            storage,
238            selector,
239            options,
240            _marker: std::marker::PhantomData,
241        }
242    }
243}
244
245#[cfg(all(test, feature = "postgres"))]
246mod test {}