Skip to main content

rustauth_deadpool_postgres/
rate_limit.rs

1use std::sync::Arc;
2
3use deadpool_postgres::Pool;
4use rustauth_core::db::{validate_rate_limit_rule, SqlRateLimitNames};
5use rustauth_core::error::RustAuthError;
6use rustauth_core::options::{
7    RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitStore,
8};
9use rustauth_tokio_postgres::driver::{
10    consume_postgres_rate_limit_in_tx, postgres_error, postgres_rate_limit_plan,
11};
12use tokio::sync::Mutex;
13
14use crate::adapter::DeadpoolPostgresAdapter;
15use crate::config::{deadpool_error, pg_client};
16use crate::tx_guard::PooledClientRollbackGuard;
17
18/// Database-backed rate-limit store backed by a `deadpool-postgres` pool.
19#[derive(Clone)]
20pub struct DeadpoolPostgresRateLimitStore {
21    pub(crate) pool: Pool,
22    pub(crate) names: SqlRateLimitNames,
23}
24
25impl std::fmt::Debug for DeadpoolPostgresRateLimitStore {
26    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        formatter
28            .debug_struct("DeadpoolPostgresRateLimitStore")
29            .field("names", &self.names)
30            .finish_non_exhaustive()
31    }
32}
33
34impl DeadpoolPostgresRateLimitStore {
35    pub fn new(pool: Pool) -> Self {
36        Self::with_table(pool, "rate_limits")
37    }
38
39    pub fn with_table(pool: Pool, table: impl Into<String>) -> Self {
40        Self::with_names(pool, SqlRateLimitNames::new(table))
41    }
42
43    pub fn with_names(pool: Pool, names: SqlRateLimitNames) -> Self {
44        Self { pool, names }
45    }
46}
47
48impl From<&DeadpoolPostgresAdapter> for DeadpoolPostgresRateLimitStore {
49    fn from(adapter: &DeadpoolPostgresAdapter) -> Self {
50        Self {
51            pool: adapter.pool.clone(),
52            names: SqlRateLimitNames::from_schema(&adapter.schema),
53        }
54    }
55}
56
57impl RateLimitStore for DeadpoolPostgresRateLimitStore {
58    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
59        Box::pin(async move { consume_deadpool_rate_limit(self, input).await })
60    }
61}
62
63async fn consume_deadpool_rate_limit(
64    store: &DeadpoolPostgresRateLimitStore,
65    input: RateLimitConsumeInput,
66) -> Result<RateLimitDecision, RustAuthError> {
67    validate_rate_limit_rule(&input.rule)?;
68    let plan = postgres_rate_limit_plan(
69        &store.names.table,
70        &store.names.key,
71        &store.names.count,
72        &store.names.last_request,
73    )?;
74    let client = store.pool.get().await.map_err(deadpool_error)?;
75    client
76        .batch_execute("BEGIN")
77        .await
78        .map_err(postgres_error)?;
79    let client = Arc::new(Mutex::new(client));
80    let mut guard = PooledClientRollbackGuard::new(Arc::clone(&client));
81    let locked = client.lock().await;
82    let result = consume_postgres_rate_limit_in_tx(pg_client(&locked), &plan, input).await;
83    match result {
84        Ok(decision) => {
85            if let Err(error) = locked.batch_execute("COMMIT").await {
86                let _rollback_result = locked.batch_execute("ROLLBACK").await;
87                guard.disarm();
88                return Err(postgres_error(error));
89            }
90            guard.disarm();
91            Ok(decision)
92        }
93        Err(error) => {
94            let _rollback_result = locked.batch_execute("ROLLBACK").await;
95            guard.disarm();
96            Err(error)
97        }
98    }
99}