rustauth_deadpool_postgres/
rate_limit.rs1use std::sync::Arc;
2
3use deadpool_postgres::Pool;
4use rustauth_core::db::{validate_rate_limit_rule, SqlRateLimitNames};
5use rustauth_core::error::RustAuthError;
6use rustauth_core::options::{
7 RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitStore,
8};
9use rustauth_tokio_postgres::driver::{
10 consume_postgres_rate_limit_in_tx, postgres_error, postgres_rate_limit_plan,
11};
12use tokio::sync::Mutex;
13
14use crate::adapter::DeadpoolPostgresAdapter;
15use crate::config::{deadpool_error, pg_client};
16use crate::tx_guard::PooledClientRollbackGuard;
17
18#[derive(Clone)]
20pub struct DeadpoolPostgresRateLimitStore {
21 pub(crate) pool: Pool,
22 pub(crate) names: SqlRateLimitNames,
23}
24
25impl std::fmt::Debug for DeadpoolPostgresRateLimitStore {
26 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 formatter
28 .debug_struct("DeadpoolPostgresRateLimitStore")
29 .field("names", &self.names)
30 .finish_non_exhaustive()
31 }
32}
33
34impl DeadpoolPostgresRateLimitStore {
35 pub fn new(pool: Pool) -> Self {
36 Self::with_table(pool, "rate_limits")
37 }
38
39 pub fn with_table(pool: Pool, table: impl Into<String>) -> Self {
40 Self::with_names(pool, SqlRateLimitNames::new(table))
41 }
42
43 pub fn with_names(pool: Pool, names: SqlRateLimitNames) -> Self {
44 Self { pool, names }
45 }
46}
47
48impl From<&DeadpoolPostgresAdapter> for DeadpoolPostgresRateLimitStore {
49 fn from(adapter: &DeadpoolPostgresAdapter) -> Self {
50 Self {
51 pool: adapter.pool.clone(),
52 names: SqlRateLimitNames::from_schema(&adapter.schema),
53 }
54 }
55}
56
57impl RateLimitStore for DeadpoolPostgresRateLimitStore {
58 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
59 Box::pin(async move { consume_deadpool_rate_limit(self, input).await })
60 }
61}
62
63async fn consume_deadpool_rate_limit(
64 store: &DeadpoolPostgresRateLimitStore,
65 input: RateLimitConsumeInput,
66) -> Result<RateLimitDecision, RustAuthError> {
67 validate_rate_limit_rule(&input.rule)?;
68 let plan = postgres_rate_limit_plan(
69 &store.names.table,
70 &store.names.key,
71 &store.names.count,
72 &store.names.last_request,
73 )?;
74 let client = store.pool.get().await.map_err(deadpool_error)?;
75 client
76 .batch_execute("BEGIN")
77 .await
78 .map_err(postgres_error)?;
79 let client = Arc::new(Mutex::new(client));
80 let mut guard = PooledClientRollbackGuard::new(Arc::clone(&client));
81 let locked = client.lock().await;
82 let result = consume_postgres_rate_limit_in_tx(pg_client(&locked), &plan, input).await;
83 match result {
84 Ok(decision) => {
85 if let Err(error) = locked.batch_execute("COMMIT").await {
86 let _rollback_result = locked.batch_execute("ROLLBACK").await;
87 guard.disarm();
88 return Err(postgres_error(error));
89 }
90 guard.disarm();
91 Ok(decision)
92 }
93 Err(error) => {
94 let _rollback_result = locked.batch_execute("ROLLBACK").await;
95 guard.disarm();
96 Err(error)
97 }
98 }
99}