1use super::*;
2
3pub fn create_statement(
4 dialect: SqlDialect,
5 schema: &DbSchema,
6 query: &Create,
7) -> Result<SqlStatement, RustAuthError> {
8 let table = resolve_table(schema, &query.model)?;
9 let selection = selected_fields(table, &query.select)?;
10 let mut columns = Vec::new();
11 let mut placeholders = Vec::new();
12 let mut params = Vec::new();
13
14 for (field, value) in &query.data {
15 let (_, metadata) = resolve_field(table, field)?;
16 columns.push(dialect.quote_identifier(&metadata.name)?);
17 params.push(SqlParam::new(metadata, value.clone()));
18 placeholders.push(dialect.placeholder(params.len()));
19 }
20
21 let mut sql = if columns.is_empty() {
22 match dialect {
23 SqlDialect::Postgres | SqlDialect::Sqlite => format!(
24 "INSERT INTO {} DEFAULT VALUES",
25 dialect.quote_identifier(&table.name)?
26 ),
27 SqlDialect::MySql => format!(
28 "INSERT INTO {} () VALUES ()",
29 dialect.quote_identifier(&table.name)?
30 ),
31 }
32 } else {
33 format!(
34 "INSERT INTO {} ({}) VALUES ({})",
35 dialect.quote_identifier(&table.name)?,
36 columns.join(", "),
37 placeholders.join(", ")
38 )
39 };
40 if dialect.supports_insert_returning() && table_has_database_generated_id(table) {
41 sql.push_str(" RETURNING ");
42 sql.push_str(
43 &selection
44 .iter()
45 .map(|selected| dialect.quote_identifier(&selected.field.name))
46 .collect::<Result<Vec<_>, _>>()?
47 .join(", "),
48 );
49 }
50
51 Ok(SqlStatement { sql, params })
52}
53
54pub fn create_returning_selection(
55 schema: &DbSchema,
56 query: &Create,
57) -> Result<Vec<SqlSelectedField>, RustAuthError> {
58 selected_fields(resolve_table(schema, &query.model)?, &query.select)
59}
60
61pub fn find_one_statement(
62 dialect: SqlDialect,
63 schema: &DbSchema,
64 query: &FindOne,
65) -> Result<SqlReadStatement, RustAuthError> {
66 let mut find_many = FindMany::new(query.model.clone());
67 find_many.where_clauses = query.where_clauses.clone();
68 find_many.limit = Some(1);
69 find_many.select = query.select.clone();
70 find_many.joins = query.joins.clone();
71 find_many_statement(dialect, schema, &find_many)
72}
73
74pub fn find_many_statement(
75 dialect: SqlDialect,
76 schema: &DbSchema,
77 query: &FindMany,
78) -> Result<SqlReadStatement, RustAuthError> {
79 let table = resolve_table(schema, &query.model)?;
80 let selection = selected_fields(table, &query.select)?;
81 let where_sql = dialect.where_clause(table, &query.where_clauses)?;
82 let sql = format!(
83 "SELECT {} FROM {}{}{}",
84 selection
85 .iter()
86 .map(|selected| dialect.quote_identifier(&selected.field.name))
87 .collect::<Result<Vec<_>, _>>()?
88 .join(", "),
89 dialect.quote_identifier(&table.name)?,
90 where_sql.sql,
91 dialect.order_limit_offset(table, query.sort_by.as_ref(), query.limit, query.offset)?
92 );
93
94 Ok(SqlReadStatement {
95 statement: SqlStatement {
96 sql,
97 params: where_sql.params,
98 },
99 selection,
100 })
101}
102
103pub fn find_many_with_joins_statement<'a>(
104 dialect: SqlDialect,
105 schema: &'a DbSchema,
106 query: &FindMany,
107) -> Result<SqlJoinReadStatement<'a>, RustAuthError> {
108 let (base_logical, table) = resolve_table_with_logical(schema, &query.model)?;
109 let joins = resolve_native_joins(schema, base_logical, table, &query.joins, 100)?;
110 let base_selection = internal_base_selection(table, &query.select, &joins)?;
111 let where_sql = dialect.where_clause(table, &query.where_clauses)?;
112 let base_columns = base_selection
113 .iter()
114 .map(|(_, field)| dialect.quote_identifier(&field.name))
115 .collect::<Result<Vec<_>, _>>()?;
116 let base_sql = format!(
117 "SELECT {} FROM {}{}{}",
118 base_columns.join(", "),
119 dialect.quote_identifier(&table.name)?,
120 where_sql.sql,
121 dialect.order_limit_offset(table, query.sort_by.as_ref(), query.limit, query.offset)?
122 );
123
124 let mut selects = vec![format!(
125 "{}.{} AS {}",
126 dialect.quote_identifier("base")?,
127 dialect.quote_identifier(&resolve_field_from_selection(&base_selection, "id")?.name)?,
128 dialect.quote_identifier("__base_id")?
129 )];
130 for (index, (_, field)) in base_selection.iter().enumerate() {
131 selects.push(format!(
132 "{}.{} AS {}",
133 dialect.quote_identifier("base")?,
134 dialect.quote_identifier(&field.name)?,
135 dialect.quote_identifier(&base_alias(index))?
136 ));
137 }
138 for (join_index, join) in joins.iter().enumerate() {
139 for (field_index, (_, field)) in join.selection.iter().enumerate() {
140 selects.push(format!(
141 "{}.{} AS {}",
142 dialect.quote_identifier(&join_alias(join_index))?,
143 dialect.quote_identifier(&field.name)?,
144 dialect.quote_identifier(&join_field_alias(join_index, field_index))?
145 ));
146 }
147 }
148
149 let mut sql = format!(
150 "SELECT {} FROM ({}) AS {}",
151 selects.join(", "),
152 base_sql,
153 dialect.quote_identifier("base")?
154 );
155 for (index, join) in joins.iter().enumerate() {
156 sql.push_str(" LEFT JOIN ");
157 sql.push_str(&dialect.quote_identifier(&join.table.name)?);
158 sql.push_str(" AS ");
159 sql.push_str(&dialect.quote_identifier(&join_alias(index))?);
160 sql.push_str(" ON ");
161 sql.push_str(&dialect.quote_identifier(&join_alias(index))?);
162 sql.push('.');
163 sql.push_str(&dialect.quote_identifier(&join.to)?);
164 sql.push_str(" = ");
165 sql.push_str(&dialect.quote_identifier("base")?);
166 sql.push('.');
167 sql.push_str(&dialect.quote_identifier(&join.from)?);
168 }
169
170 Ok(SqlJoinReadStatement {
171 statement: SqlStatement {
172 sql,
173 params: where_sql.params,
174 },
175 base_selection,
176 joins,
177 })
178}
179
180pub fn count_statement(
181 dialect: SqlDialect,
182 schema: &DbSchema,
183 query: &Count,
184) -> Result<SqlStatement, RustAuthError> {
185 let table = resolve_table(schema, &query.model)?;
186 let where_sql = dialect.where_clause(table, &query.where_clauses)?;
187 Ok(SqlStatement {
188 sql: format!(
189 "SELECT COUNT(*) FROM {}{}",
190 dialect.quote_identifier(&table.name)?,
191 where_sql.sql
192 ),
193 params: where_sql.params,
194 })
195}
196
197pub fn update_one_plan(
198 dialect: SqlDialect,
199 schema: &DbSchema,
200 query: &Update,
201) -> Result<SqlUpdateOnePlan, RustAuthError> {
202 let table = resolve_table(schema, &query.model)?;
203 let selection = selected_fields(table, &[])?;
204
205 match dialect {
206 SqlDialect::Postgres | SqlDialect::Sqlite => {
207 let assignment = update_assignment(dialect, table, &query.data, 1)?;
208 let where_sql =
209 dialect.where_clause_starting_at(table, &query.where_clauses, assignment.next)?;
210 let row_id = match dialect {
211 SqlDialect::Postgres => "ctid",
212 SqlDialect::Sqlite => "rowid",
213 SqlDialect::MySql => {
214 return Err(RustAuthError::Adapter(
215 "mysql update-one uses a preselect plan".to_owned(),
216 ));
217 }
218 };
219 let mut params = assignment.params;
220 params.extend(where_sql.params);
221 Ok(SqlUpdateOnePlan::Returning(SqlReadStatement {
222 statement: SqlStatement {
223 sql: format!(
224 "UPDATE {} SET {} WHERE {row_id} IN (SELECT {row_id} FROM {}{} LIMIT 1) RETURNING {}",
225 dialect.quote_identifier(&table.name)?,
226 assignment.sql.join(", "),
227 dialect.quote_identifier(&table.name)?,
228 where_sql.sql,
229 selection
230 .iter()
231 .map(|selected| dialect.quote_identifier(&selected.field.name))
232 .collect::<Result<Vec<_>, _>>()?
233 .join(", ")
234 ),
235 params,
236 },
237 selection,
238 }))
239 }
240 SqlDialect::MySql => {
241 let mut select_query = FindMany::new(query.model.clone());
242 select_query.where_clauses = query.where_clauses.clone();
243 select_query.limit = Some(1);
244 let select = find_many_statement(dialect, schema, &select_query)?;
245 let assignment = update_assignment(dialect, table, &query.data, 1)?;
246 let where_sql =
247 dialect.where_clause_starting_at(table, &query.where_clauses, assignment.next)?;
248 let mut params = assignment.params;
249 params.extend(where_sql.params);
250 Ok(SqlUpdateOnePlan::PreselectThenUpdate {
251 select,
252 update: SqlStatement {
253 sql: format!(
254 "UPDATE {} SET {}{} LIMIT 1",
255 dialect.quote_identifier(&table.name)?,
256 assignment.sql.join(", "),
257 where_sql.sql
258 ),
259 params,
260 },
261 data: query.data.clone(),
262 })
263 }
264 }
265}
266
267pub fn update_many_statement(
268 dialect: SqlDialect,
269 schema: &DbSchema,
270 query: &UpdateMany,
271) -> Result<SqlStatement, RustAuthError> {
272 let table = resolve_table(schema, &query.model)?;
273 let assignment = update_assignment(dialect, table, &query.data, 1)?;
274 let where_sql =
275 dialect.where_clause_starting_at(table, &query.where_clauses, assignment.next)?;
276 let mut params = assignment.params;
277 params.extend(where_sql.params);
278 Ok(SqlStatement {
279 sql: format!(
280 "UPDATE {} SET {}{}",
281 dialect.quote_identifier(&table.name)?,
282 assignment.sql.join(", "),
283 where_sql.sql
284 ),
285 params,
286 })
287}
288
289pub fn delete_one_statement(
290 dialect: SqlDialect,
291 schema: &DbSchema,
292 query: &Delete,
293) -> Result<SqlDeleteOnePlan, RustAuthError> {
294 let table = resolve_table(schema, &query.model)?;
295 let where_sql = dialect.where_clause(table, &query.where_clauses)?;
296 let statement = match dialect {
297 SqlDialect::Postgres => SqlStatement {
298 sql: format!(
299 "DELETE FROM {} WHERE ctid IN (SELECT ctid FROM {}{} LIMIT 1)",
300 dialect.quote_identifier(&table.name)?,
301 dialect.quote_identifier(&table.name)?,
302 where_sql.sql
303 ),
304 params: where_sql.params,
305 },
306 SqlDialect::Sqlite => SqlStatement {
307 sql: format!(
308 "DELETE FROM {} WHERE rowid IN (SELECT rowid FROM {}{} LIMIT 1)",
309 dialect.quote_identifier(&table.name)?,
310 dialect.quote_identifier(&table.name)?,
311 where_sql.sql
312 ),
313 params: where_sql.params,
314 },
315 SqlDialect::MySql => SqlStatement {
316 sql: format!(
317 "DELETE FROM {}{} LIMIT 1",
318 dialect.quote_identifier(&table.name)?,
319 where_sql.sql
320 ),
321 params: where_sql.params,
322 },
323 };
324 let strategy = match dialect {
325 SqlDialect::Postgres | SqlDialect::Sqlite => DeleteOneStrategy::NestedId,
326 SqlDialect::MySql => DeleteOneStrategy::Limit,
327 };
328 Ok(SqlDeleteOnePlan {
329 statement,
330 strategy,
331 })
332}
333
334pub fn delete_many_statement(
335 dialect: SqlDialect,
336 schema: &DbSchema,
337 query: &DeleteMany,
338) -> Result<SqlStatement, RustAuthError> {
339 let table = resolve_table(schema, &query.model)?;
340 let where_sql = dialect.where_clause(table, &query.where_clauses)?;
341 Ok(SqlStatement {
342 sql: format!(
343 "DELETE FROM {}{}",
344 dialect.quote_identifier(&table.name)?,
345 where_sql.sql
346 ),
347 params: where_sql.params,
348 })
349}
350
351struct UpdateAssignment {
352 sql: Vec<String>,
353 params: Vec<SqlParam>,
354 next: usize,
355}
356
357fn update_assignment(
358 dialect: SqlDialect,
359 table: &DbTable,
360 data: &DbRecord,
361 first_placeholder: usize,
362) -> Result<UpdateAssignment, RustAuthError> {
363 let mut sql = Vec::new();
364 let mut params = Vec::new();
365 for (field, value) in data {
366 let (_, metadata) = resolve_field(table, field)?;
367 params.push(SqlParam::new(metadata, value.clone()));
368 sql.push(format!(
369 "{} = {}",
370 dialect.quote_identifier(&metadata.name)?,
371 dialect.placeholder(first_placeholder + params.len() - 1)
372 ));
373 }
374 Ok(UpdateAssignment {
375 sql,
376 next: first_placeholder + params.len(),
377 params,
378 })
379}
380
381pub fn table_has_database_generated_id(table: &DbTable) -> bool {
382 table
383 .field("id")
384 .and_then(|field| field.generated_id)
385 .is_some()
386}
387
388impl SqlDialect {
389 pub fn supports_insert_returning(self) -> bool {
390 matches!(self, Self::Postgres | Self::Sqlite)
391 }
392}