Skip to main content

rustauth_core/db/
transform.rs

1use indexmap::IndexMap;
2
3use super::{
4    AdapterCapabilities, Count, Create, DbField, DbFieldType, DbRecord, DbSchema, DbTable, DbValue,
5    Delete, DeleteMany, FindMany, FindOne, JoinConfig, JoinOption, JoinRelation, JoinResolution,
6    Sort, Update, UpdateMany, Where,
7};
8use crate::error::RustAuthError;
9
10pub fn transform_create_query(schema: &DbSchema, query: Create) -> Result<Create, RustAuthError> {
11    transform_create_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
12}
13
14pub fn transform_create_query_with_capabilities(
15    schema: &DbSchema,
16    capabilities: &AdapterCapabilities,
17    query: Create,
18) -> Result<Create, RustAuthError> {
19    let model = schema.table_name(&query.model)?.to_owned();
20    let data = transform_record(schema, capabilities, &query.model, query.data)?;
21    let select = transform_select(schema, &query.model, query.select)?;
22
23    Ok(Create {
24        model,
25        data,
26        select,
27        force_allow_id: query.force_allow_id,
28    })
29}
30
31pub fn transform_find_one_query(
32    schema: &DbSchema,
33    query: FindOne,
34) -> Result<FindOne, RustAuthError> {
35    transform_find_one_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
36}
37
38pub fn transform_find_one_query_with_capabilities(
39    schema: &DbSchema,
40    capabilities: &AdapterCapabilities,
41    query: FindOne,
42) -> Result<FindOne, RustAuthError> {
43    let model = schema.table_name(&query.model)?.to_owned();
44    let where_clauses =
45        transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
46    let select = transform_select(schema, &query.model, query.select)?;
47
48    Ok(FindOne {
49        model,
50        where_clauses,
51        select,
52        joins: query.joins,
53    })
54}
55
56pub fn transform_find_many_query(
57    schema: &DbSchema,
58    query: FindMany,
59) -> Result<FindMany, RustAuthError> {
60    transform_find_many_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
61}
62
63pub fn transform_find_many_query_with_capabilities(
64    schema: &DbSchema,
65    capabilities: &AdapterCapabilities,
66    query: FindMany,
67) -> Result<FindMany, RustAuthError> {
68    let model = schema.table_name(&query.model)?.to_owned();
69    let where_clauses =
70        transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
71    let sort_by = query
72        .sort_by
73        .map(|sort| transform_sort(schema, &query.model, sort))
74        .transpose()?;
75    let select = transform_select(schema, &query.model, query.select)?;
76
77    Ok(FindMany {
78        model,
79        where_clauses,
80        limit: query.limit,
81        offset: query.offset,
82        sort_by,
83        select,
84        joins: query.joins,
85    })
86}
87
88pub fn transform_count_query(schema: &DbSchema, query: Count) -> Result<Count, RustAuthError> {
89    transform_count_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
90}
91
92pub fn transform_count_query_with_capabilities(
93    schema: &DbSchema,
94    capabilities: &AdapterCapabilities,
95    query: Count,
96) -> Result<Count, RustAuthError> {
97    let model = schema.table_name(&query.model)?.to_owned();
98    let where_clauses =
99        transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
100
101    Ok(Count {
102        model,
103        where_clauses,
104    })
105}
106
107pub fn transform_update_query(schema: &DbSchema, query: Update) -> Result<Update, RustAuthError> {
108    transform_update_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
109}
110
111pub fn transform_update_query_with_capabilities(
112    schema: &DbSchema,
113    capabilities: &AdapterCapabilities,
114    query: Update,
115) -> Result<Update, RustAuthError> {
116    let model = schema.table_name(&query.model)?.to_owned();
117    let where_clauses =
118        transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
119    let data = transform_record(schema, capabilities, &query.model, query.data)?;
120
121    Ok(Update {
122        model,
123        where_clauses,
124        data,
125    })
126}
127
128pub fn transform_update_many_query(
129    schema: &DbSchema,
130    query: UpdateMany,
131) -> Result<UpdateMany, RustAuthError> {
132    transform_update_many_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
133}
134
135pub fn transform_update_many_query_with_capabilities(
136    schema: &DbSchema,
137    capabilities: &AdapterCapabilities,
138    query: UpdateMany,
139) -> Result<UpdateMany, RustAuthError> {
140    let model = schema.table_name(&query.model)?.to_owned();
141    let where_clauses =
142        transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
143    let data = transform_record(schema, capabilities, &query.model, query.data)?;
144
145    Ok(UpdateMany {
146        model,
147        where_clauses,
148        data,
149    })
150}
151
152pub fn transform_delete_query(schema: &DbSchema, query: Delete) -> Result<Delete, RustAuthError> {
153    transform_delete_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
154}
155
156pub fn transform_delete_query_with_capabilities(
157    schema: &DbSchema,
158    capabilities: &AdapterCapabilities,
159    query: Delete,
160) -> Result<Delete, RustAuthError> {
161    let model = schema.table_name(&query.model)?.to_owned();
162    let where_clauses =
163        transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
164
165    Ok(Delete {
166        model,
167        where_clauses,
168    })
169}
170
171pub fn transform_delete_many_query(
172    schema: &DbSchema,
173    query: DeleteMany,
174) -> Result<DeleteMany, RustAuthError> {
175    transform_delete_many_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
176}
177
178pub fn transform_delete_many_query_with_capabilities(
179    schema: &DbSchema,
180    capabilities: &AdapterCapabilities,
181    query: DeleteMany,
182) -> Result<DeleteMany, RustAuthError> {
183    let model = schema.table_name(&query.model)?.to_owned();
184    let where_clauses =
185        transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
186
187    Ok(DeleteMany {
188        model,
189        where_clauses,
190    })
191}
192
193pub fn resolve_join_options(
194    schema: &DbSchema,
195    base_model: &str,
196    joins: IndexMap<String, JoinOption>,
197    select: Vec<String>,
198    default_limit: usize,
199) -> Result<JoinResolution, RustAuthError> {
200    let base_table = schema
201        .table(base_model)
202        .ok_or_else(|| RustAuthError::TableNotFound {
203            table: base_model.to_owned(),
204        })?;
205    let mut resolution = JoinResolution::new(select);
206
207    for (join_model, option) in joins {
208        if !option.enabled {
209            continue;
210        }
211
212        let join_table = schema
213            .table(&join_model)
214            .ok_or_else(|| RustAuthError::TableNotFound {
215                table: join_model.clone(),
216            })?;
217        let resolved = resolve_join_config(
218            schema,
219            base_model,
220            base_table,
221            &join_model,
222            join_table,
223            option,
224            default_limit,
225        )?;
226
227        if !resolution.select.is_empty() && !resolution.select.contains(&resolved.required_select) {
228            resolution.select.push(resolved.required_select);
229        }
230        resolution
231            .joins
232            .insert(join_table.name.clone(), resolved.config);
233    }
234
235    Ok(resolution)
236}
237
238struct ResolvedJoinConfig {
239    config: JoinConfig,
240    required_select: String,
241}
242
243fn resolve_join_config(
244    schema: &DbSchema,
245    base_model: &str,
246    base_table: &DbTable,
247    join_model: &str,
248    join_table: &DbTable,
249    option: JoinOption,
250    default_limit: usize,
251) -> Result<ResolvedJoinConfig, RustAuthError> {
252    let mut foreign_keys = foreign_keys_to_table(join_table, &base_table.name);
253    let is_forward_join = !foreign_keys.is_empty();
254
255    if foreign_keys.is_empty() {
256        foreign_keys = foreign_keys_to_table(base_table, &join_table.name);
257    }
258
259    let [(foreign_key, field)] =
260        foreign_keys
261            .as_slice()
262            .try_into()
263            .map_err(|_| match foreign_keys.len() {
264                0 => RustAuthError::JoinForeignKeyNotFound {
265                    base_model: base_model.to_owned(),
266                    join_model: join_model.to_owned(),
267                },
268                _ => RustAuthError::JoinForeignKeyAmbiguous {
269                    base_model: base_model.to_owned(),
270                    join_model: join_model.to_owned(),
271                },
272            })?;
273    let reference =
274        field
275            .foreign_key
276            .as_ref()
277            .ok_or_else(|| RustAuthError::JoinForeignKeyNotFound {
278                base_model: base_model.to_owned(),
279                join_model: join_model.to_owned(),
280            })?;
281
282    let (from, to, required_select, relation_field) = if is_forward_join {
283        let from = schema.field_name(base_model, &reference.field)?.to_owned();
284        let to = schema.field_name(join_model, foreign_key)?.to_owned();
285        let required_select = from.clone();
286        (from, to, required_select, field)
287    } else {
288        let from = schema.field_name(base_model, foreign_key)?.to_owned();
289        let to = schema.field_name(join_model, &reference.field)?.to_owned();
290        (from.clone(), to, from, field)
291    };
292
293    let is_unique = to == "id" || relation_field.unique;
294    let limit = if is_unique {
295        1
296    } else {
297        option.limit.unwrap_or(default_limit)
298    };
299    let relation = if is_unique {
300        JoinRelation::OneToOne
301    } else {
302        JoinRelation::OneToMany
303    };
304
305    Ok(ResolvedJoinConfig {
306        config: JoinConfig::new(from, to).limit(limit).relation(relation),
307        required_select,
308    })
309}
310
311fn foreign_keys_to_table<'a>(
312    table: &'a DbTable,
313    target_table: &str,
314) -> Vec<(&'a str, &'a DbField)> {
315    table
316        .fields
317        .iter()
318        .filter_map(|(logical_name, field)| {
319            field
320                .foreign_key
321                .as_ref()
322                .filter(|foreign_key| foreign_key.table == target_table)
323                .map(|_| (logical_name.as_str(), field))
324        })
325        .collect()
326}
327
328fn transform_record(
329    schema: &DbSchema,
330    capabilities: &AdapterCapabilities,
331    model: &str,
332    record: DbRecord,
333) -> Result<DbRecord, RustAuthError> {
334    record
335        .into_iter()
336        .map(|(field, value)| {
337            let field_metadata = schema.field(model, &field)?;
338            let value = transform_value(capabilities, field_metadata, value);
339            Ok((field_metadata.name.clone(), value))
340        })
341        .collect::<Result<IndexMap<_, _>, _>>()
342}
343
344fn transform_select(
345    schema: &DbSchema,
346    model: &str,
347    select: Vec<String>,
348) -> Result<Vec<String>, RustAuthError> {
349    select
350        .into_iter()
351        .map(|field| {
352            schema
353                .field_name(model, &field)
354                .map(|field_name| field_name.to_owned())
355        })
356        .collect()
357}
358
359fn transform_where_clauses(
360    schema: &DbSchema,
361    capabilities: &AdapterCapabilities,
362    model: &str,
363    where_clauses: Vec<Where>,
364) -> Result<Vec<Where>, RustAuthError> {
365    where_clauses
366        .into_iter()
367        .map(|where_clause| transform_where_clause(schema, capabilities, model, where_clause))
368        .collect()
369}
370
371fn transform_where_clause(
372    schema: &DbSchema,
373    capabilities: &AdapterCapabilities,
374    model: &str,
375    where_clause: Where,
376) -> Result<Where, RustAuthError> {
377    let field_metadata = schema.field(model, &where_clause.field)?;
378    let value = transform_value(capabilities, field_metadata, where_clause.value);
379
380    Ok(Where {
381        field: field_metadata.name.clone(),
382        value,
383        operator: where_clause.operator,
384        connector: where_clause.connector,
385        mode: where_clause.mode,
386    })
387}
388
389fn transform_sort(schema: &DbSchema, model: &str, sort: Sort) -> Result<Sort, RustAuthError> {
390    let field = schema.field_name(model, &sort.field)?.to_owned();
391
392    Ok(Sort {
393        field,
394        direction: sort.direction,
395    })
396}
397
398fn transform_value(capabilities: &AdapterCapabilities, field: &DbField, value: DbValue) -> DbValue {
399    match (&field.field_type, value) {
400        (DbFieldType::Boolean, DbValue::String(value)) => {
401            transform_value(capabilities, field, DbValue::Boolean(value == "true"))
402        }
403        (DbFieldType::Boolean, DbValue::Boolean(value)) if !capabilities.supports_booleans => {
404            DbValue::Number(i64::from(value))
405        }
406        (DbFieldType::Number, DbValue::String(value)) => value
407            .parse::<i64>()
408            .map(DbValue::Number)
409            .unwrap_or(DbValue::String(value)),
410        (DbFieldType::Timestamp, DbValue::Timestamp(value)) if !capabilities.supports_dates => {
411            DbValue::String(value.to_string())
412        }
413        (DbFieldType::Json, DbValue::Json(value)) if !capabilities.supports_json => {
414            DbValue::String(value.to_string())
415        }
416        (DbFieldType::StringArray, DbValue::StringArray(value))
417            if !capabilities.supports_arrays =>
418        {
419            let value = value.into_iter().map(serde_json::Value::String).collect();
420            DbValue::String(serde_json::Value::Array(value).to_string())
421        }
422        (DbFieldType::NumberArray, DbValue::NumberArray(value))
423            if !capabilities.supports_arrays =>
424        {
425            let value = value
426                .into_iter()
427                .map(|number| serde_json::Value::Number(number.into()))
428                .collect();
429            DbValue::String(serde_json::Value::Array(value).to_string())
430        }
431        (_, value) => value,
432    }
433}