Skip to main content

rustauth_core/db/
factory.rs

1mod join_support;
2
3use super::{
4    transform_count_query_with_capabilities, transform_create_query_with_capabilities,
5    transform_delete_many_query_with_capabilities, transform_delete_query_with_capabilities,
6    transform_find_many_query_with_capabilities, transform_find_one_query_with_capabilities,
7    transform_update_many_query_with_capabilities, transform_update_query_with_capabilities,
8    AdapterCapabilities, AdapterFuture, Count, Create, DbAdapter, DbRecord, DbSchema, Delete,
9    DeleteMany, FindMany, FindOne, SchemaCreation, TransactionCallback, Update, UpdateMany,
10};
11use crate::error::RustAuthError;
12use join_support::{
13    attach_joins, extend_select_for_joins, resolve_fallback_joins, trim_joined_record,
14};
15use std::sync::Arc;
16
17/// Adapter wrapper that maps RustAuth logical schema names to database names.
18#[derive(Debug, Clone)]
19pub struct SchemaAdapter<A> {
20    schema: DbSchema,
21    inner: A,
22}
23
24impl<A> SchemaAdapter<A> {
25    pub fn new(schema: DbSchema, inner: A) -> Self {
26        Self { schema, inner }
27    }
28
29    pub fn schema(&self) -> &DbSchema {
30        &self.schema
31    }
32
33    pub fn inner(&self) -> &A {
34        &self.inner
35    }
36}
37
38impl<A> DbAdapter for SchemaAdapter<A>
39where
40    A: DbAdapter,
41{
42    fn id(&self) -> &str {
43        self.inner.id()
44    }
45
46    fn capabilities(&self) -> AdapterCapabilities {
47        self.inner.capabilities()
48    }
49
50    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
51        Box::pin(async move {
52            let capabilities = self.inner.capabilities();
53            let query =
54                transform_create_query_with_capabilities(&self.schema, &capabilities, query)?;
55            self.inner.create(query).await
56        })
57    }
58
59    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
60        Box::pin(async move {
61            let capabilities = self.inner.capabilities();
62            let query =
63                transform_find_one_query_with_capabilities(&self.schema, &capabilities, query)?;
64            self.inner.find_one(query).await
65        })
66    }
67
68    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
69        Box::pin(async move {
70            let capabilities = self.inner.capabilities();
71            let query =
72                transform_find_many_query_with_capabilities(&self.schema, &capabilities, query)?;
73            self.inner.find_many(query).await
74        })
75    }
76
77    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
78        Box::pin(async move {
79            let capabilities = self.inner.capabilities();
80            let query =
81                transform_count_query_with_capabilities(&self.schema, &capabilities, query)?;
82            self.inner.count(query).await
83        })
84    }
85
86    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
87        Box::pin(async move {
88            let capabilities = self.inner.capabilities();
89            let query =
90                transform_update_query_with_capabilities(&self.schema, &capabilities, query)?;
91            self.inner.update(query).await
92        })
93    }
94
95    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
96        Box::pin(async move {
97            let capabilities = self.inner.capabilities();
98            let query =
99                transform_update_many_query_with_capabilities(&self.schema, &capabilities, query)?;
100            self.inner.update_many(query).await
101        })
102    }
103
104    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
105        Box::pin(async move {
106            let capabilities = self.inner.capabilities();
107            let query =
108                transform_delete_query_with_capabilities(&self.schema, &capabilities, query)?;
109            self.inner.delete(query).await
110        })
111    }
112
113    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
114        Box::pin(async move {
115            let capabilities = self.inner.capabilities();
116            let query =
117                transform_delete_many_query_with_capabilities(&self.schema, &capabilities, query)?;
118            self.inner.delete_many(query).await
119        })
120    }
121
122    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
123        let schema = self.schema.clone();
124        self.inner.transaction(Box::new(move |transaction| {
125            let adapter = SchemaAdapter::new(schema, transaction);
126            callback(Box::new(adapter))
127        }))
128    }
129
130    fn create_schema<'a>(
131        &'a self,
132        _schema: &'a DbSchema,
133        file: Option<&'a str>,
134    ) -> AdapterFuture<'a, Option<SchemaCreation>> {
135        self.inner.create_schema(&self.schema, file)
136    }
137
138    fn run_migrations<'a>(&'a self, _schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
139        self.inner.run_migrations(&self.schema)
140    }
141}
142
143/// Adapter wrapper that resolves RustAuth join options at runtime.
144#[derive(Clone)]
145pub struct JoinAdapter<A = Arc<dyn DbAdapter>> {
146    schema: DbSchema,
147    inner: A,
148    experimental_joins: bool,
149}
150
151impl<A> JoinAdapter<A> {
152    pub fn new(schema: DbSchema, inner: A, experimental_joins: bool) -> Self {
153        Self {
154            schema,
155            inner,
156            experimental_joins,
157        }
158    }
159}
160
161impl<A> JoinAdapter<A>
162where
163    A: DbAdapter,
164{
165    fn should_delegate_joins(&self) -> bool {
166        let caps = self.inner.capabilities();
167        if caps.supports_native_joins {
168            return true;
169        }
170        self.experimental_joins && caps.supports_joins
171    }
172
173    async fn fallback_find_one(
174        &self,
175        mut query: FindOne,
176    ) -> Result<Option<DbRecord>, RustAuthError> {
177        let joins = resolve_fallback_joins(&self.schema, &query.model, &query.joins, 100)?;
178        let original_select = query.select.clone();
179        extend_select_for_joins(&mut query.select, &joins);
180        query.joins.clear();
181
182        let Some(mut record) = self.inner.find_one(query).await? else {
183            return Ok(None);
184        };
185        attach_joins(&self.inner, &mut [&mut record], &joins).await?;
186        trim_joined_record(&mut record, &original_select, &joins);
187        Ok(Some(record))
188    }
189
190    async fn fallback_find_many(
191        &self,
192        mut query: FindMany,
193    ) -> Result<Vec<DbRecord>, RustAuthError> {
194        let joins = resolve_fallback_joins(&self.schema, &query.model, &query.joins, 100)?;
195        let original_select = query.select.clone();
196        extend_select_for_joins(&mut query.select, &joins);
197        query.joins.clear();
198
199        let mut records = self.inner.find_many(query).await?;
200        let mut record_refs = records.iter_mut().collect::<Vec<_>>();
201        attach_joins(&self.inner, &mut record_refs, &joins).await?;
202        for record in &mut records {
203            trim_joined_record(record, &original_select, &joins);
204        }
205        Ok(records)
206    }
207}
208
209impl<A> std::fmt::Debug for JoinAdapter<A>
210where
211    A: DbAdapter,
212{
213    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        formatter
215            .debug_struct("JoinAdapter")
216            .field("schema", &self.schema)
217            .field("inner", &self.inner.id())
218            .field("experimental_joins", &self.experimental_joins)
219            .finish()
220    }
221}
222
223impl<A> DbAdapter for JoinAdapter<A>
224where
225    A: DbAdapter,
226{
227    fn id(&self) -> &str {
228        self.inner.id()
229    }
230
231    fn capabilities(&self) -> AdapterCapabilities {
232        self.inner.capabilities()
233    }
234
235    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
236        self.inner.create(query)
237    }
238
239    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
240        Box::pin(async move {
241            if query.joins.is_empty() || self.should_delegate_joins() {
242                self.inner.find_one(query).await
243            } else {
244                self.fallback_find_one(query).await
245            }
246        })
247    }
248
249    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
250        Box::pin(async move {
251            if query.joins.is_empty() || self.should_delegate_joins() {
252                self.inner.find_many(query).await
253            } else {
254                self.fallback_find_many(query).await
255            }
256        })
257    }
258
259    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
260        self.inner.count(query)
261    }
262
263    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
264        self.inner.update(query)
265    }
266
267    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
268        self.inner.update_many(query)
269    }
270
271    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
272        self.inner.delete(query)
273    }
274
275    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
276        self.inner.delete_many(query)
277    }
278
279    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
280        let schema = self.schema.clone();
281        let experimental_joins = self.experimental_joins;
282        self.inner.transaction(Box::new(move |transaction| {
283            let adapter = JoinAdapter::new(schema, transaction, experimental_joins);
284            callback(Box::new(adapter))
285        }))
286    }
287
288    fn create_schema<'a>(
289        &'a self,
290        schema: &'a DbSchema,
291        file: Option<&'a str>,
292    ) -> AdapterFuture<'a, Option<SchemaCreation>> {
293        self.inner.create_schema(schema, file)
294    }
295
296    fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
297        self.inner.run_migrations(schema)
298    }
299}