rustauth_deadpool_postgres/
adapter.rs1use std::sync::Arc;
2
3use deadpool_postgres::Pool;
4use rustauth_core::db::SchemaMigrationPlan;
5use rustauth_core::db::{
6 auth_schema, AdapterCapabilities, AdapterFuture, AuthSchemaOptions, Count, Create, DbAdapter,
7 DbRecord, DbSchema, Delete, DeleteMany, FindMany, FindOne, SchemaCreation, TransactionCallback,
8 Update, UpdateMany,
9};
10use rustauth_core::error::RustAuthError;
11use rustauth_tokio_postgres::driver::{postgres_error, PostgresSqlState};
12use tokio::sync::Mutex;
13
14use crate::builder::DeadpoolPostgresBuilder;
15use crate::config::{deadpool_error, pg_client};
16use crate::transaction::DeadpoolPostgresTxAdapter;
17use crate::tx_guard::PooledClientRollbackGuard;
18
19#[derive(Clone)]
21pub struct DeadpoolPostgresAdapter {
22 pub(crate) pool: Pool,
23 pub(crate) schema: Arc<DbSchema>,
24}
25
26impl std::fmt::Debug for DeadpoolPostgresAdapter {
27 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 formatter
29 .debug_struct("DeadpoolPostgresAdapter")
30 .field("schema", &self.schema)
31 .finish_non_exhaustive()
32 }
33}
34
35impl DeadpoolPostgresAdapter {
36 pub fn new(pool: Pool) -> Self {
37 Self::with_schema(pool, auth_schema(AuthSchemaOptions::default()))
38 }
39
40 pub fn with_schema(pool: Pool, schema: DbSchema) -> Self {
41 Self {
42 pool,
43 schema: Arc::new(schema),
44 }
45 }
46
47 pub fn pool(&self) -> &Pool {
48 &self.pool
49 }
50
51 pub fn builder() -> DeadpoolPostgresBuilder {
52 DeadpoolPostgresBuilder::new()
53 }
54
55 pub async fn plan_migrations(
56 &self,
57 schema: &DbSchema,
58 ) -> Result<SchemaMigrationPlan, RustAuthError> {
59 let client = self.pool.get().await.map_err(deadpool_error)?;
60 rustauth_tokio_postgres::driver::plan_migrations(pg_client(&client), schema).await
61 }
62
63 pub async fn validate_connection(&self) -> Result<(), RustAuthError> {
64 let client = self.pool.get().await.map_err(deadpool_error)?;
65 client
66 .simple_query("SELECT 1")
67 .await
68 .map_err(postgres_error)?;
69 Ok(())
70 }
71
72 pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, RustAuthError> {
73 Ok(self.plan_migrations(schema).await?.compile())
74 }
75
76 async fn run_with_state<T>(
77 &self,
78 f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
79 ) -> Result<T, RustAuthError>
80 where
81 T: Send + 'static,
82 {
83 let client = self.pool.get().await.map_err(deadpool_error)?;
84 f(PostgresSqlState::new(
85 self.schema.as_ref(),
86 pg_client(&client),
87 ))
88 .await
89 }
90}
91
92impl DbAdapter for DeadpoolPostgresAdapter {
93 fn id(&self) -> &str {
94 "deadpool-postgres"
95 }
96
97 fn capabilities(&self) -> AdapterCapabilities {
98 AdapterCapabilities::new(self.id())
99 .named("deadpool-postgres")
100 .with_uuid_ids()
101 .with_json()
102 .with_arrays()
103 .with_native_joins()
104 .with_transactions()
105 }
106
107 fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
108 Box::pin(async move {
109 self.run_with_state(|state| Box::pin(state.create(query)))
110 .await
111 })
112 }
113
114 fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
115 Box::pin(async move {
116 self.run_with_state(|state| Box::pin(state.find_one(query)))
117 .await
118 })
119 }
120
121 fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
122 Box::pin(async move {
123 self.run_with_state(|state| Box::pin(state.find_many(query)))
124 .await
125 })
126 }
127
128 fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
129 Box::pin(async move {
130 self.run_with_state(|state| Box::pin(state.count(query)))
131 .await
132 })
133 }
134
135 fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
136 Box::pin(async move {
137 self.run_with_state(|state| Box::pin(state.update(query)))
138 .await
139 })
140 }
141
142 fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
143 Box::pin(async move {
144 self.run_with_state(|state| Box::pin(state.update_many(query)))
145 .await
146 })
147 }
148
149 fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
150 Box::pin(async move {
151 self.run_with_state(|state| Box::pin(state.delete(query)))
152 .await
153 })
154 }
155
156 fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
157 Box::pin(async move {
158 self.run_with_state(|state| Box::pin(state.delete_many(query)))
159 .await
160 })
161 }
162
163 fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
164 Box::pin(async move {
165 let client = self.pool.get().await.map_err(deadpool_error)?;
166 client
167 .batch_execute("BEGIN")
168 .await
169 .map_err(postgres_error)?;
170 let client = Arc::new(Mutex::new(client));
171 let mut guard = PooledClientRollbackGuard::new(Arc::clone(&client));
172 let adapter = DeadpoolPostgresTxAdapter {
173 client: Arc::clone(&client),
174 schema: Arc::clone(&self.schema),
175 };
176 let result = callback(Box::new(adapter)).await;
177
178 let locked = client.lock().await;
179 match result {
180 Ok(()) => {
181 if let Err(error) = locked.batch_execute("COMMIT").await {
182 let _rollback_result = locked.batch_execute("ROLLBACK").await;
183 guard.disarm();
184 return Err(postgres_error(error));
185 }
186 guard.disarm();
187 Ok(())
188 }
189 Err(error) => {
190 let _rollback_result = locked.batch_execute("ROLLBACK").await;
191 guard.disarm();
192 Err(error)
193 }
194 }
195 })
196 }
197
198 fn create_schema<'a>(
199 &'a self,
200 schema: &'a DbSchema,
201 _file: Option<&'a str>,
202 ) -> AdapterFuture<'a, Option<SchemaCreation>> {
203 Box::pin(async move {
204 let client = self.pool.get().await.map_err(deadpool_error)?;
205 rustauth_tokio_postgres::driver::create_schema(pg_client(&client), schema).await?;
206 Ok(None)
207 })
208 }
209
210 fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
211 Box::pin(async move {
212 let client = self.pool.get().await.map_err(deadpool_error)?;
213 rustauth_tokio_postgres::driver::execute_migration_plan(pg_client(&client), schema)
214 .await
215 })
216 }
217}