Skip to main content

rustauth_tokio_postgres/
rate_limit.rs

1use std::fmt;
2use std::sync::Arc;
3
4use rustauth_core::db::{validate_rate_limit_rule, SqlRateLimitNames};
5use rustauth_core::error::RustAuthError;
6use rustauth_core::options::{
7    RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitStore,
8};
9
10use crate::adapter::TokioPostgresAdapter;
11use crate::connection::TokioPostgresConnection;
12use crate::driver::{consume_postgres_rate_limit_in_tx, postgres_error, postgres_rate_limit_plan};
13use crate::tx_guard::SharedClientRollbackGuard;
14
15#[derive(Clone)]
16pub struct TokioPostgresRateLimitStore {
17    connection: TokioPostgresConnection,
18    names: SqlRateLimitNames,
19}
20
21impl fmt::Debug for TokioPostgresRateLimitStore {
22    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
23        formatter
24            .debug_struct("TokioPostgresRateLimitStore")
25            .field("names", &self.names)
26            .finish_non_exhaustive()
27    }
28}
29
30impl TokioPostgresRateLimitStore {
31    /// Builds a rate-limit store from a shared connection bundle.
32    pub fn from_connection(connection: &TokioPostgresConnection, table: impl Into<String>) -> Self {
33        Self {
34            connection: connection.clone(),
35            names: SqlRateLimitNames::new(table),
36        }
37    }
38
39    /// Connects for rate-limit-only usage when no [`TokioPostgresAdapter`] is needed.
40    pub async fn connect(
41        database_url: &str,
42        table: impl Into<String>,
43    ) -> Result<Self, RustAuthError> {
44        Ok(Self::from_connection(
45            &TokioPostgresConnection::connect(database_url).await?,
46            table,
47        ))
48    }
49
50    /// Returns the shared connection used by this store.
51    pub fn connection(&self) -> &TokioPostgresConnection {
52        &self.connection
53    }
54}
55
56impl From<&TokioPostgresAdapter> for TokioPostgresRateLimitStore {
57    fn from(adapter: &TokioPostgresAdapter) -> Self {
58        Self {
59            connection: adapter.connection.clone(),
60            names: SqlRateLimitNames::from_schema(&adapter.schema),
61        }
62    }
63}
64
65impl RateLimitStore for TokioPostgresRateLimitStore {
66    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
67        Box::pin(async move { consume_postgres_rate_limit(self, input).await })
68    }
69}
70
71async fn consume_postgres_rate_limit(
72    store: &TokioPostgresRateLimitStore,
73    input: RateLimitConsumeInput,
74) -> Result<RateLimitDecision, RustAuthError> {
75    validate_rate_limit_rule(&input.rule)?;
76    let plan = postgres_rate_limit_plan(
77        &store.names.table,
78        &store.names.key,
79        &store.names.count,
80        &store.names.last_request,
81    )?;
82    let gate = Arc::clone(&store.connection.tx_gate).write_owned().await;
83    store
84        .connection
85        .client
86        .batch_execute("BEGIN")
87        .await
88        .map_err(postgres_error)?;
89    let mut guard = SharedClientRollbackGuard::new(Arc::clone(&store.connection.client), gate);
90    let result =
91        consume_postgres_rate_limit_in_tx(store.connection.client.as_ref(), &plan, input).await;
92    match result {
93        Ok(decision) => {
94            if let Err(error) = store.connection.client.batch_execute("COMMIT").await {
95                let _rollback_result = store.connection.client.batch_execute("ROLLBACK").await;
96                guard.disarm();
97                return Err(postgres_error(error));
98            }
99            guard.disarm();
100            Ok(decision)
101        }
102        Err(error) => {
103            let _rollback_result = store.connection.client.batch_execute("ROLLBACK").await;
104            guard.disarm();
105            Err(error)
106        }
107    }
108}