rustauth_tokio_postgres/
adapter.rs1use std::fmt;
2use std::sync::Arc;
3
4use rustauth_core::db::{
5 auth_schema, AdapterCapabilities, AdapterFuture, AuthSchemaOptions, Count, Create, DbAdapter,
6 DbRecord, DbSchema, Delete, DeleteMany, FindMany, FindOne, SchemaCreation, TransactionCallback,
7 Update, UpdateMany,
8};
9use rustauth_core::error::RustAuthError;
10use tokio_postgres::Client;
11
12use crate::connection::TokioPostgresConnection;
13use crate::driver::PostgresSqlState;
14use crate::errors::postgres_error;
15use crate::rate_limit::TokioPostgresRateLimitStore;
16use crate::schema::{
17 create_schema, execute_migration_plan, plan_migrations as plan_schema_migrations,
18};
19use crate::transaction::TokioPostgresTxAdapter;
20use crate::tx_guard::SharedClientRollbackGuard;
21use rustauth_core::db::SchemaMigrationPlan;
22
23#[derive(Clone)]
24pub struct TokioPostgresAdapter {
25 pub(crate) connection: TokioPostgresConnection,
26 pub(crate) schema: Arc<DbSchema>,
27}
28
29impl fmt::Debug for TokioPostgresAdapter {
30 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
31 formatter
32 .debug_struct("TokioPostgresAdapter")
33 .field("schema", &self.schema)
34 .finish_non_exhaustive()
35 }
36}
37
38impl TokioPostgresAdapter {
39 pub fn new(client: Client) -> Self {
40 Self::with_schema(client, auth_schema(AuthSchemaOptions::default()))
41 }
42
43 pub fn with_schema(client: Client, schema: DbSchema) -> Self {
44 Self::with_connection(TokioPostgresConnection::from_client(client), schema)
45 }
46
47 pub fn with_connection(connection: TokioPostgresConnection, schema: DbSchema) -> Self {
48 Self {
49 connection,
50 schema: Arc::new(schema),
51 }
52 }
53
54 pub fn connection(&self) -> &TokioPostgresConnection {
56 &self.connection
57 }
58
59 pub fn rate_limit_store(&self) -> TokioPostgresRateLimitStore {
62 TokioPostgresRateLimitStore::from(self)
63 }
64
65 pub async fn connect(database_url: &str) -> Result<Self, RustAuthError> {
67 Self::connect_with_schema(database_url, auth_schema(AuthSchemaOptions::default())).await
68 }
69
70 pub async fn connect_with_schema(
75 database_url: &str,
76 schema: DbSchema,
77 ) -> Result<Self, RustAuthError> {
78 Ok(Self::with_connection(
79 TokioPostgresConnection::connect(database_url).await?,
80 schema,
81 ))
82 }
83
84 pub async fn plan_migrations(
85 &self,
86 schema: &DbSchema,
87 ) -> Result<SchemaMigrationPlan, RustAuthError> {
88 let _gate = self.connection.tx_gate.write().await;
89 plan_schema_migrations(self.connection.client.as_ref(), schema).await
90 }
91
92 pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, RustAuthError> {
93 Ok(self.plan_migrations(schema).await?.compile())
94 }
95
96 async fn run_with_state<T>(
97 &self,
98 f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
99 ) -> Result<T, RustAuthError>
100 where
101 T: Send + 'static,
102 {
103 let _gate = self.connection.tx_gate.read().await;
104 f(PostgresSqlState::new(
105 self.schema.as_ref(),
106 self.connection.client.as_ref(),
107 ))
108 .await
109 }
110}
111
112impl DbAdapter for TokioPostgresAdapter {
113 fn id(&self) -> &str {
114 "tokio-postgres"
115 }
116
117 fn capabilities(&self) -> AdapterCapabilities {
118 AdapterCapabilities::new(self.id())
119 .named("tokio-postgres")
120 .with_uuid_ids()
121 .with_json()
122 .with_arrays()
123 .with_native_joins()
124 .with_transactions()
125 }
126
127 fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
128 Box::pin(async move {
129 self.run_with_state(|state| Box::pin(state.create(query)))
130 .await
131 })
132 }
133
134 fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
135 Box::pin(async move {
136 self.run_with_state(|state| Box::pin(state.find_one(query)))
137 .await
138 })
139 }
140
141 fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
142 Box::pin(async move {
143 self.run_with_state(|state| Box::pin(state.find_many(query)))
144 .await
145 })
146 }
147
148 fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
149 Box::pin(async move {
150 self.run_with_state(|state| Box::pin(state.count(query)))
151 .await
152 })
153 }
154
155 fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
156 Box::pin(async move {
157 self.run_with_state(|state| Box::pin(state.update(query)))
158 .await
159 })
160 }
161
162 fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
163 Box::pin(async move {
164 self.run_with_state(|state| Box::pin(state.update_many(query)))
165 .await
166 })
167 }
168
169 fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
170 Box::pin(async move {
171 self.run_with_state(|state| Box::pin(state.delete(query)))
172 .await
173 })
174 }
175
176 fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
177 Box::pin(async move {
178 self.run_with_state(|state| Box::pin(state.delete_many(query)))
179 .await
180 })
181 }
182
183 fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
184 Box::pin(async move {
185 let gate = Arc::clone(&self.connection.tx_gate).write_owned().await;
186 self.connection
187 .client
188 .batch_execute("BEGIN")
189 .await
190 .map_err(postgres_error)?;
191 let mut guard =
192 SharedClientRollbackGuard::new(Arc::clone(&self.connection.client), gate);
193
194 let adapter = TokioPostgresTxAdapter::new(
195 Arc::clone(&self.connection.client),
196 Arc::clone(&self.schema),
197 );
198 let result = callback(Box::new(adapter)).await;
199
200 match result {
201 Ok(()) => {
202 if let Err(error) = self.connection.client.batch_execute("COMMIT").await {
203 let _rollback_result =
204 self.connection.client.batch_execute("ROLLBACK").await;
205 guard.disarm();
206 return Err(postgres_error(error));
207 }
208 guard.disarm();
209 Ok(())
210 }
211 Err(error) => {
212 let _rollback_result = self.connection.client.batch_execute("ROLLBACK").await;
213 guard.disarm();
214 Err(error)
215 }
216 }
217 })
218 }
219
220 fn create_schema<'a>(
221 &'a self,
222 schema: &'a DbSchema,
223 _file: Option<&'a str>,
224 ) -> AdapterFuture<'a, Option<SchemaCreation>> {
225 Box::pin(async move {
226 let _gate = self.connection.tx_gate.write().await;
227 create_schema(self.connection.client.as_ref(), schema).await?;
228 Ok(None)
229 })
230 }
231
232 fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
233 Box::pin(async move {
234 let _gate = self.connection.tx_gate.write().await;
235 execute_migration_plan(self.connection.client.as_ref(), schema).await
236 })
237 }
238}