1use indexmap::IndexMap;
2
3use super::{
4 AdapterCapabilities, Count, Create, DbField, DbFieldType, DbRecord, DbSchema, DbTable, DbValue,
5 Delete, DeleteMany, FindMany, FindOne, JoinConfig, JoinOption, JoinRelation, JoinResolution,
6 Sort, Update, UpdateMany, Where,
7};
8use crate::error::RustAuthError;
9
10pub fn transform_create_query(schema: &DbSchema, query: Create) -> Result<Create, RustAuthError> {
11 transform_create_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
12}
13
14pub fn transform_create_query_with_capabilities(
15 schema: &DbSchema,
16 capabilities: &AdapterCapabilities,
17 query: Create,
18) -> Result<Create, RustAuthError> {
19 let model = schema.table_name(&query.model)?.to_owned();
20 let data = transform_record(schema, capabilities, &query.model, query.data)?;
21 let select = transform_select(schema, &query.model, query.select)?;
22
23 Ok(Create {
24 model,
25 data,
26 select,
27 force_allow_id: query.force_allow_id,
28 })
29}
30
31pub fn transform_find_one_query(
32 schema: &DbSchema,
33 query: FindOne,
34) -> Result<FindOne, RustAuthError> {
35 transform_find_one_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
36}
37
38pub fn transform_find_one_query_with_capabilities(
39 schema: &DbSchema,
40 capabilities: &AdapterCapabilities,
41 query: FindOne,
42) -> Result<FindOne, RustAuthError> {
43 let model = schema.table_name(&query.model)?.to_owned();
44 let where_clauses =
45 transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
46 let select = transform_select(schema, &query.model, query.select)?;
47
48 Ok(FindOne {
49 model,
50 where_clauses,
51 select,
52 joins: query.joins,
53 })
54}
55
56pub fn transform_find_many_query(
57 schema: &DbSchema,
58 query: FindMany,
59) -> Result<FindMany, RustAuthError> {
60 transform_find_many_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
61}
62
63pub fn transform_find_many_query_with_capabilities(
64 schema: &DbSchema,
65 capabilities: &AdapterCapabilities,
66 query: FindMany,
67) -> Result<FindMany, RustAuthError> {
68 let model = schema.table_name(&query.model)?.to_owned();
69 let where_clauses =
70 transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
71 let sort_by = query
72 .sort_by
73 .map(|sort| transform_sort(schema, &query.model, sort))
74 .transpose()?;
75 let select = transform_select(schema, &query.model, query.select)?;
76
77 Ok(FindMany {
78 model,
79 where_clauses,
80 limit: query.limit,
81 offset: query.offset,
82 sort_by,
83 select,
84 joins: query.joins,
85 })
86}
87
88pub fn transform_count_query(schema: &DbSchema, query: Count) -> Result<Count, RustAuthError> {
89 transform_count_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
90}
91
92pub fn transform_count_query_with_capabilities(
93 schema: &DbSchema,
94 capabilities: &AdapterCapabilities,
95 query: Count,
96) -> Result<Count, RustAuthError> {
97 let model = schema.table_name(&query.model)?.to_owned();
98 let where_clauses =
99 transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
100
101 Ok(Count {
102 model,
103 where_clauses,
104 })
105}
106
107pub fn transform_update_query(schema: &DbSchema, query: Update) -> Result<Update, RustAuthError> {
108 transform_update_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
109}
110
111pub fn transform_update_query_with_capabilities(
112 schema: &DbSchema,
113 capabilities: &AdapterCapabilities,
114 query: Update,
115) -> Result<Update, RustAuthError> {
116 let model = schema.table_name(&query.model)?.to_owned();
117 let where_clauses =
118 transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
119 let data = transform_record(schema, capabilities, &query.model, query.data)?;
120
121 Ok(Update {
122 model,
123 where_clauses,
124 data,
125 })
126}
127
128pub fn transform_update_many_query(
129 schema: &DbSchema,
130 query: UpdateMany,
131) -> Result<UpdateMany, RustAuthError> {
132 transform_update_many_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
133}
134
135pub fn transform_update_many_query_with_capabilities(
136 schema: &DbSchema,
137 capabilities: &AdapterCapabilities,
138 query: UpdateMany,
139) -> Result<UpdateMany, RustAuthError> {
140 let model = schema.table_name(&query.model)?.to_owned();
141 let where_clauses =
142 transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
143 let data = transform_record(schema, capabilities, &query.model, query.data)?;
144
145 Ok(UpdateMany {
146 model,
147 where_clauses,
148 data,
149 })
150}
151
152pub fn transform_delete_query(schema: &DbSchema, query: Delete) -> Result<Delete, RustAuthError> {
153 transform_delete_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
154}
155
156pub fn transform_delete_query_with_capabilities(
157 schema: &DbSchema,
158 capabilities: &AdapterCapabilities,
159 query: Delete,
160) -> Result<Delete, RustAuthError> {
161 let model = schema.table_name(&query.model)?.to_owned();
162 let where_clauses =
163 transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
164
165 Ok(Delete {
166 model,
167 where_clauses,
168 })
169}
170
171pub fn transform_delete_many_query(
172 schema: &DbSchema,
173 query: DeleteMany,
174) -> Result<DeleteMany, RustAuthError> {
175 transform_delete_many_query_with_capabilities(schema, &AdapterCapabilities::new("core"), query)
176}
177
178pub fn transform_delete_many_query_with_capabilities(
179 schema: &DbSchema,
180 capabilities: &AdapterCapabilities,
181 query: DeleteMany,
182) -> Result<DeleteMany, RustAuthError> {
183 let model = schema.table_name(&query.model)?.to_owned();
184 let where_clauses =
185 transform_where_clauses(schema, capabilities, &query.model, query.where_clauses)?;
186
187 Ok(DeleteMany {
188 model,
189 where_clauses,
190 })
191}
192
193pub fn resolve_join_options(
194 schema: &DbSchema,
195 base_model: &str,
196 joins: IndexMap<String, JoinOption>,
197 select: Vec<String>,
198 default_limit: usize,
199) -> Result<JoinResolution, RustAuthError> {
200 let base_table = schema
201 .table(base_model)
202 .ok_or_else(|| RustAuthError::TableNotFound {
203 table: base_model.to_owned(),
204 })?;
205 let mut resolution = JoinResolution::new(select);
206
207 for (join_model, option) in joins {
208 if !option.enabled {
209 continue;
210 }
211
212 let join_table = schema
213 .table(&join_model)
214 .ok_or_else(|| RustAuthError::TableNotFound {
215 table: join_model.clone(),
216 })?;
217 let resolved = resolve_join_config(
218 schema,
219 base_model,
220 base_table,
221 &join_model,
222 join_table,
223 option,
224 default_limit,
225 )?;
226
227 if !resolution.select.is_empty() && !resolution.select.contains(&resolved.required_select) {
228 resolution.select.push(resolved.required_select);
229 }
230 resolution
231 .joins
232 .insert(join_table.name.clone(), resolved.config);
233 }
234
235 Ok(resolution)
236}
237
238struct ResolvedJoinConfig {
239 config: JoinConfig,
240 required_select: String,
241}
242
243fn resolve_join_config(
244 schema: &DbSchema,
245 base_model: &str,
246 base_table: &DbTable,
247 join_model: &str,
248 join_table: &DbTable,
249 option: JoinOption,
250 default_limit: usize,
251) -> Result<ResolvedJoinConfig, RustAuthError> {
252 let mut foreign_keys = foreign_keys_to_table(join_table, &base_table.name);
253 let is_forward_join = !foreign_keys.is_empty();
254
255 if foreign_keys.is_empty() {
256 foreign_keys = foreign_keys_to_table(base_table, &join_table.name);
257 }
258
259 let [(foreign_key, field)] =
260 foreign_keys
261 .as_slice()
262 .try_into()
263 .map_err(|_| match foreign_keys.len() {
264 0 => RustAuthError::JoinForeignKeyNotFound {
265 base_model: base_model.to_owned(),
266 join_model: join_model.to_owned(),
267 },
268 _ => RustAuthError::JoinForeignKeyAmbiguous {
269 base_model: base_model.to_owned(),
270 join_model: join_model.to_owned(),
271 },
272 })?;
273 let reference =
274 field
275 .foreign_key
276 .as_ref()
277 .ok_or_else(|| RustAuthError::JoinForeignKeyNotFound {
278 base_model: base_model.to_owned(),
279 join_model: join_model.to_owned(),
280 })?;
281
282 let (from, to, required_select, relation_field) = if is_forward_join {
283 let from = schema.field_name(base_model, &reference.field)?.to_owned();
284 let to = schema.field_name(join_model, foreign_key)?.to_owned();
285 let required_select = from.clone();
286 (from, to, required_select, field)
287 } else {
288 let from = schema.field_name(base_model, foreign_key)?.to_owned();
289 let to = schema.field_name(join_model, &reference.field)?.to_owned();
290 (from.clone(), to, from, field)
291 };
292
293 let is_unique = to == "id" || relation_field.unique;
294 let limit = if is_unique {
295 1
296 } else {
297 option.limit.unwrap_or(default_limit)
298 };
299 let relation = if is_unique {
300 JoinRelation::OneToOne
301 } else {
302 JoinRelation::OneToMany
303 };
304
305 Ok(ResolvedJoinConfig {
306 config: JoinConfig::new(from, to).limit(limit).relation(relation),
307 required_select,
308 })
309}
310
311fn foreign_keys_to_table<'a>(
312 table: &'a DbTable,
313 target_table: &str,
314) -> Vec<(&'a str, &'a DbField)> {
315 table
316 .fields
317 .iter()
318 .filter_map(|(logical_name, field)| {
319 field
320 .foreign_key
321 .as_ref()
322 .filter(|foreign_key| foreign_key.table == target_table)
323 .map(|_| (logical_name.as_str(), field))
324 })
325 .collect()
326}
327
328fn transform_record(
329 schema: &DbSchema,
330 capabilities: &AdapterCapabilities,
331 model: &str,
332 record: DbRecord,
333) -> Result<DbRecord, RustAuthError> {
334 record
335 .into_iter()
336 .map(|(field, value)| {
337 let field_metadata = schema.field(model, &field)?;
338 let value = transform_value(capabilities, field_metadata, value);
339 Ok((field_metadata.name.clone(), value))
340 })
341 .collect::<Result<IndexMap<_, _>, _>>()
342}
343
344fn transform_select(
345 schema: &DbSchema,
346 model: &str,
347 select: Vec<String>,
348) -> Result<Vec<String>, RustAuthError> {
349 select
350 .into_iter()
351 .map(|field| {
352 schema
353 .field_name(model, &field)
354 .map(|field_name| field_name.to_owned())
355 })
356 .collect()
357}
358
359fn transform_where_clauses(
360 schema: &DbSchema,
361 capabilities: &AdapterCapabilities,
362 model: &str,
363 where_clauses: Vec<Where>,
364) -> Result<Vec<Where>, RustAuthError> {
365 where_clauses
366 .into_iter()
367 .map(|where_clause| transform_where_clause(schema, capabilities, model, where_clause))
368 .collect()
369}
370
371fn transform_where_clause(
372 schema: &DbSchema,
373 capabilities: &AdapterCapabilities,
374 model: &str,
375 where_clause: Where,
376) -> Result<Where, RustAuthError> {
377 let field_metadata = schema.field(model, &where_clause.field)?;
378 let value = transform_value(capabilities, field_metadata, where_clause.value);
379
380 Ok(Where {
381 field: field_metadata.name.clone(),
382 value,
383 operator: where_clause.operator,
384 connector: where_clause.connector,
385 mode: where_clause.mode,
386 })
387}
388
389fn transform_sort(schema: &DbSchema, model: &str, sort: Sort) -> Result<Sort, RustAuthError> {
390 let field = schema.field_name(model, &sort.field)?.to_owned();
391
392 Ok(Sort {
393 field,
394 direction: sort.direction,
395 })
396}
397
398fn transform_value(capabilities: &AdapterCapabilities, field: &DbField, value: DbValue) -> DbValue {
399 match (&field.field_type, value) {
400 (DbFieldType::Boolean, DbValue::String(value)) => {
401 transform_value(capabilities, field, DbValue::Boolean(value == "true"))
402 }
403 (DbFieldType::Boolean, DbValue::Boolean(value)) if !capabilities.supports_booleans => {
404 DbValue::Number(i64::from(value))
405 }
406 (DbFieldType::Number, DbValue::String(value)) => value
407 .parse::<i64>()
408 .map(DbValue::Number)
409 .unwrap_or(DbValue::String(value)),
410 (DbFieldType::Timestamp, DbValue::Timestamp(value)) if !capabilities.supports_dates => {
411 DbValue::String(value.to_string())
412 }
413 (DbFieldType::Json, DbValue::Json(value)) if !capabilities.supports_json => {
414 DbValue::String(value.to_string())
415 }
416 (DbFieldType::StringArray, DbValue::StringArray(value))
417 if !capabilities.supports_arrays =>
418 {
419 let value = value.into_iter().map(serde_json::Value::String).collect();
420 DbValue::String(serde_json::Value::Array(value).to_string())
421 }
422 (DbFieldType::NumberArray, DbValue::NumberArray(value))
423 if !capabilities.supports_arrays =>
424 {
425 let value = value
426 .into_iter()
427 .map(|number| serde_json::Value::Number(number.into()))
428 .collect();
429 DbValue::String(serde_json::Value::Array(value).to_string())
430 }
431 (_, value) => value,
432 }
433}