rustauth_tokio_postgres/
rate_limit.rs1use std::fmt;
2use std::sync::Arc;
3
4use rustauth_core::db::{validate_rate_limit_rule, SqlRateLimitNames};
5use rustauth_core::error::RustAuthError;
6use rustauth_core::options::{
7 RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitStore,
8};
9
10use crate::adapter::TokioPostgresAdapter;
11use crate::connection::TokioPostgresConnection;
12use crate::driver::{consume_postgres_rate_limit_in_tx, postgres_error, postgres_rate_limit_plan};
13use crate::tx_guard::SharedClientRollbackGuard;
14
15#[derive(Clone)]
16pub struct TokioPostgresRateLimitStore {
17 connection: TokioPostgresConnection,
18 names: SqlRateLimitNames,
19}
20
21impl fmt::Debug for TokioPostgresRateLimitStore {
22 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
23 formatter
24 .debug_struct("TokioPostgresRateLimitStore")
25 .field("names", &self.names)
26 .finish_non_exhaustive()
27 }
28}
29
30impl TokioPostgresRateLimitStore {
31 pub fn from_connection(connection: &TokioPostgresConnection, table: impl Into<String>) -> Self {
33 Self {
34 connection: connection.clone(),
35 names: SqlRateLimitNames::new(table),
36 }
37 }
38
39 pub async fn connect(
41 database_url: &str,
42 table: impl Into<String>,
43 ) -> Result<Self, RustAuthError> {
44 Ok(Self::from_connection(
45 &TokioPostgresConnection::connect(database_url).await?,
46 table,
47 ))
48 }
49
50 pub fn connection(&self) -> &TokioPostgresConnection {
52 &self.connection
53 }
54}
55
56impl From<&TokioPostgresAdapter> for TokioPostgresRateLimitStore {
57 fn from(adapter: &TokioPostgresAdapter) -> Self {
58 Self {
59 connection: adapter.connection.clone(),
60 names: SqlRateLimitNames::from_schema(&adapter.schema),
61 }
62 }
63}
64
65impl RateLimitStore for TokioPostgresRateLimitStore {
66 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
67 Box::pin(async move { consume_postgres_rate_limit(self, input).await })
68 }
69}
70
71async fn consume_postgres_rate_limit(
72 store: &TokioPostgresRateLimitStore,
73 input: RateLimitConsumeInput,
74) -> Result<RateLimitDecision, RustAuthError> {
75 validate_rate_limit_rule(&input.rule)?;
76 let plan = postgres_rate_limit_plan(
77 &store.names.table,
78 &store.names.key,
79 &store.names.count,
80 &store.names.last_request,
81 )?;
82 let gate = Arc::clone(&store.connection.tx_gate).write_owned().await;
83 store
84 .connection
85 .client
86 .batch_execute("BEGIN")
87 .await
88 .map_err(postgres_error)?;
89 let mut guard = SharedClientRollbackGuard::new(Arc::clone(&store.connection.client), gate);
90 let result =
91 consume_postgres_rate_limit_in_tx(store.connection.client.as_ref(), &plan, input).await;
92 match result {
93 Ok(decision) => {
94 if let Err(error) = store.connection.client.batch_execute("COMMIT").await {
95 let _rollback_result = store.connection.client.batch_execute("ROLLBACK").await;
96 guard.disarm();
97 return Err(postgres_error(error));
98 }
99 guard.disarm();
100 Ok(decision)
101 }
102 Err(error) => {
103 let _rollback_result = store.connection.client.batch_execute("ROLLBACK").await;
104 guard.disarm();
105 Err(error)
106 }
107 }
108}