1use scythe_backend::manifest::BackendManifest;
2use scythe_backend::naming::{
3 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
4};
5use std::fmt::Write;
6
7use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
8use scythe_core::errors::{ErrorCode, ScytheError};
9use scythe_core::parser::QueryCommand;
10
11use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
12use crate::singularize;
13
14const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");
16const DEFAULT_MANIFEST_MARIADB: &str = include_str!("../../manifests/rust-sqlx.mariadb.toml");
17const DEFAULT_MANIFEST_REDSHIFT: &str = include_str!("../../manifests/rust-sqlx.redshift.toml");
18
19pub struct SqlxBackend {
21 manifest: BackendManifest,
22 engine: String,
23 structs_only: bool,
26}
27
28impl SqlxBackend {
29 pub fn new(engine: &str) -> Result<Self, ScytheError> {
30 match engine {
33 "postgresql" | "postgres" | "pg" | "mysql" | "mariadb" | "sqlite" | "sqlite3"
34 | "redshift" => {}
35 _ => {
36 return Err(ScytheError::new(
37 ErrorCode::InternalError,
38 format!("unsupported engine '{}' for rust-sqlx backend", engine),
39 ));
40 }
41 }
42 let manifest = match engine {
43 "mariadb" => super::load_or_default_manifest(
44 "backends/rust-sqlx/manifest.toml",
45 DEFAULT_MANIFEST_MARIADB,
46 )?,
47 "redshift" => super::load_or_default_manifest(
48 "backends/rust-sqlx/manifest.toml",
49 DEFAULT_MANIFEST_REDSHIFT,
50 )?,
51 _ => super::load_or_default_manifest(
52 "backends/rust-sqlx/manifest.toml",
53 DEFAULT_MANIFEST_TOML,
54 )?,
55 };
56 Ok(Self {
57 manifest,
58 engine: engine.to_string(),
59 structs_only: false,
60 })
61 }
62}
63
64impl SqlxBackend {
65 fn uses_inline_enums(&self) -> bool {
77 matches!(
78 self.engine.as_str(),
79 "mysql" | "mariadb" | "sqlite" | "sqlite3"
80 )
81 }
82
83 fn row_field_type<'a>(&self, col: &'a ResolvedColumn) -> &'a str {
89 if self.uses_inline_enums() && col.neutral_type.starts_with("enum::") {
90 if col.nullable {
91 "Option<String>"
92 } else {
93 "String"
94 }
95 } else {
96 &col.full_type
97 }
98 }
99
100 fn pool_type(&self) -> &str {
102 match self.engine.as_str() {
103 "mysql" | "mariadb" => "sqlx::MySqlPool",
104 "sqlite" | "sqlite3" => "sqlx::SqlitePool",
105 _ => "sqlx::PgPool",
106 }
107 }
108
109 fn query_result_type(&self) -> &str {
111 match self.engine.as_str() {
112 "mysql" | "mariadb" => "sqlx::mysql::MySqlQueryResult",
113 "sqlite" | "sqlite3" => "sqlx::sqlite::SqliteQueryResult",
114 _ => "sqlx::postgres::PgQueryResult",
115 }
116 }
117}
118
119impl CodegenBackend for SqlxBackend {
120 fn name(&self) -> &str {
121 "rust-sqlx"
122 }
123
124 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
125 &self.manifest
126 }
127
128 fn supported_engines(&self) -> &[&str] {
129 &["postgresql", "mysql", "mariadb", "sqlite", "redshift"]
130 }
131
132 fn file_header(&self) -> String {
133 "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::needless_question_mark, clippy::redundant_closure)]"
134 .to_string()
135 }
136
137 fn apply_options(
138 &mut self,
139 options: &std::collections::HashMap<String, String>,
140 ) -> Result<(), ScytheError> {
141 if options.get("structs_only").is_some_and(|v| v == "true") {
142 self.structs_only = true;
143 }
144 Ok(())
145 }
146
147 fn generate_row_struct(
148 &self,
149 query_name: &str,
150 columns: &[ResolvedColumn],
151 ) -> Result<String, ScytheError> {
152 let struct_name = row_struct_name(query_name, &self.manifest.naming);
153 let mut out = String::new();
154
155 let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::FromRow)]");
156 let _ = writeln!(out, "pub struct {} {{", struct_name);
157
158 for col in columns {
159 let field_type = self.row_field_type(col);
160 let _ = writeln!(out, " pub {}: {},", col.field_name, field_type);
161 }
162
163 let _ = write!(out, "}}");
164 Ok(out)
165 }
166
167 fn generate_model_struct(
168 &self,
169 table_name: &str,
170 columns: &[ResolvedColumn],
171 ) -> Result<String, ScytheError> {
172 let singular = singularize(table_name);
173 let struct_name = to_pascal_case(&singular).into_owned();
174 let mut out = String::new();
175
176 let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::FromRow)]");
177 let _ = writeln!(out, "pub struct {} {{", struct_name);
178
179 for col in columns {
180 let field_type = self.row_field_type(col);
181 let _ = writeln!(out, " pub {}: {},", col.field_name, field_type);
182 }
183
184 let _ = write!(out, "}}");
185 Ok(out)
186 }
187
188 fn generate_query_fn(
189 &self,
190 analyzed: &AnalyzedQuery,
191 struct_name: &str,
192 _columns: &[ResolvedColumn],
193 params: &[ResolvedParam],
194 ) -> Result<String, ScytheError> {
195 if self.structs_only {
198 return Ok(String::new());
199 }
200
201 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
202 let mut out = String::new();
203
204 if let Some(ref msg) = analyzed.deprecated {
206 let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
207 }
208
209 let pool_type = self.pool_type();
211 let mut param_parts: Vec<String> = vec![format!("pool: &{}", pool_type)];
212 for param in params {
213 param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
214 }
215
216 let sql_raw = super::clean_sql_with_optional(
218 &analyzed.sql,
219 &analyzed.optional_params,
220 &analyzed.params,
221 );
222 let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);
223
224 let bind_params: String = analyzed
226 .params
227 .iter()
228 .map(|p| {
229 let param_name = to_snake_case(&p.name);
230 if p.neutral_type.starts_with("enum::") {
231 let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
232 let rust_type = enum_type_name(enum_name, &self.manifest.naming);
233 format!(", {} as &{}", param_name, rust_type)
234 } else {
235 format!(", {}", param_name)
236 }
237 })
238 .collect();
239
240 if matches!(analyzed.command, QueryCommand::Batch) {
242 let batch_fn_name = format!("{}_batch", func_name);
243
244 if params.len() > 1 {
246 let params_struct_name = format!("{}BatchParams", struct_name);
247 let _ = writeln!(out, "#[derive(Debug, Clone)]");
248 let _ = writeln!(out, "pub struct {} {{", params_struct_name);
249 for param in params {
250 let _ = writeln!(out, " pub {}: {},", param.field_name, param.full_type);
251 }
252 let _ = writeln!(out, "}}");
253 let _ = writeln!(out);
254
255 let _ = writeln!(
257 out,
258 "pub async fn {}(pool: &{}, items: &[{}]) -> Result<(), sqlx::Error> {{",
259 batch_fn_name, pool_type, params_struct_name
260 );
261 let _ = writeln!(out, " let mut tx = pool.begin().await?;");
262 let _ = writeln!(out, " for item in items {{");
263
264 let struct_bind_params: String = params
266 .iter()
267 .map(|p| {
268 if p.neutral_type.starts_with("enum::") {
269 let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
270 let rust_type = enum_type_name(enum_name, &self.manifest.naming);
271 format!(", item.{} as &{}", p.field_name, rust_type)
272 } else {
273 format!(", item.{}", p.field_name)
274 }
275 })
276 .collect();
277
278 let _ = writeln!(
279 out,
280 " sqlx::query!(\"{}\"{})",
281 sql, struct_bind_params
282 );
283 let _ = writeln!(out, " .execute(&mut *tx)");
284 let _ = writeln!(out, " .await?;");
285 let _ = writeln!(out, " }}");
286 let _ = writeln!(out, " tx.commit().await?;");
287 let _ = writeln!(out, " Ok(())");
288 } else if params.len() == 1 {
289 let param = ¶ms[0];
291 let _ = writeln!(
292 out,
293 "pub async fn {}(pool: &{}, items: &[{}]) -> Result<(), sqlx::Error> {{",
294 batch_fn_name, pool_type, param.full_type
295 );
296 let _ = writeln!(out, " let mut tx = pool.begin().await?;");
297 let _ = writeln!(out, " for item in items {{");
298 let _ = writeln!(out, " sqlx::query!(\"{}\", item)", sql);
299 let _ = writeln!(out, " .execute(&mut *tx)");
300 let _ = writeln!(out, " .await?;");
301 let _ = writeln!(out, " }}");
302 let _ = writeln!(out, " tx.commit().await?;");
303 let _ = writeln!(out, " Ok(())");
304 } else {
305 let _ = writeln!(
307 out,
308 "pub async fn {}(pool: &{}, count: usize) -> Result<(), sqlx::Error> {{",
309 batch_fn_name, pool_type
310 );
311 let _ = writeln!(out, " let mut tx = pool.begin().await?;");
312 let _ = writeln!(out, " for _ in 0..count {{");
313 let _ = writeln!(out, " sqlx::query!(\"{}\")", sql);
314 let _ = writeln!(out, " .execute(&mut *tx)");
315 let _ = writeln!(out, " .await?;");
316 let _ = writeln!(out, " }}");
317 let _ = writeln!(out, " tx.commit().await?;");
318 let _ = writeln!(out, " Ok(())");
319 }
320
321 let _ = write!(out, "}}");
322 return Ok(out);
323 }
324
325 let return_type = match &analyzed.command {
327 QueryCommand::One | QueryCommand::Opt => struct_name.to_string(),
328 QueryCommand::Many => format!("Vec<{}>", struct_name),
329 QueryCommand::Exec => "()".to_string(),
330 QueryCommand::ExecResult => self.query_result_type().to_string(),
331 QueryCommand::ExecRows => "u64".to_string(),
332 QueryCommand::Batch => unreachable!(),
333 QueryCommand::Grouped => {
334 return Err(ScytheError::new(
335 ErrorCode::InternalError,
336 "Grouped queries should be rewritten before codegen".to_string(),
337 ));
338 }
339 };
340
341 let _ = writeln!(
343 out,
344 "pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
345 func_name,
346 param_parts.join(", "),
347 return_type
348 );
349
350 let has_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
352
353 let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);
354
355 if is_exec_rows {
356 if has_row_struct && !analyzed.columns.is_empty() {
357 let _ = write!(
358 out,
359 " let result = sqlx::query_as!({}, \"{}\"{})",
360 struct_name, sql, bind_params
361 );
362 } else {
363 let _ = write!(
364 out,
365 " let result = sqlx::query!(\"{}\"{})",
366 sql, bind_params
367 );
368 }
369 } else if has_row_struct && !analyzed.columns.is_empty() {
370 let _ = write!(
371 out,
372 " sqlx::query_as!({}, \"{}\"{})",
373 struct_name, sql, bind_params
374 );
375 } else {
376 let _ = write!(out, " sqlx::query!(\"{}\"{})", sql, bind_params);
377 }
378
379 let _ = writeln!(out);
380
381 let fetch_method = match &analyzed.command {
383 QueryCommand::One | QueryCommand::Opt => ".fetch_one(pool)",
384 QueryCommand::Many => ".fetch_all(pool)",
385 QueryCommand::Exec => ".execute(pool)",
386 QueryCommand::ExecResult => ".execute(pool)",
387 QueryCommand::ExecRows => ".execute(pool)",
388 QueryCommand::Batch => unreachable!(),
389 QueryCommand::Grouped => {
390 return Err(ScytheError::new(
391 ErrorCode::InternalError,
392 "Grouped queries should be rewritten before codegen".to_string(),
393 ));
394 }
395 };
396
397 let _ = write!(out, " {}", fetch_method);
398 let _ = writeln!(out);
399
400 match &analyzed.command {
402 QueryCommand::Exec => {
403 let _ = writeln!(out, " .await?;");
404 let _ = writeln!(out, " Ok(())");
405 }
406 QueryCommand::ExecRows => {
407 let _ = writeln!(out, " .await?;");
408 let _ = writeln!(out, " Ok(result.rows_affected())");
409 }
410 _ => {
411 let _ = writeln!(out, " .await");
412 }
413 }
414
415 let _ = write!(out, "}}");
416 Ok(out)
417 }
418
419 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
420 let mut out = String::with_capacity(256);
421 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
422
423 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
424 match self.engine.as_str() {
428 "mysql" | "mariadb" | "sqlite" | "sqlite3" => {
429 let _ = writeln!(out, "#[sqlx(rename_all = \"snake_case\")]");
430 }
431 _ => {
432 let _ = writeln!(
433 out,
434 "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
435 enum_info.sql_name
436 );
437 }
438 }
439 let _ = writeln!(out, "pub enum {type_name} {{");
440
441 for value in &enum_info.values {
442 let variant = enum_variant_name(value, &self.manifest.naming);
443 let _ = writeln!(out, " {variant},");
444 }
445
446 let _ = write!(out, "}}");
447 Ok(out)
448 }
449
450 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
451 use scythe_backend::types::resolve_type;
452
453 let struct_name = to_pascal_case(&composite.sql_name).into_owned();
454 let mut out = String::new();
455
456 let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
457 let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
458 let _ = writeln!(out, "pub struct {} {{", struct_name);
459 for field in &composite.fields {
460 let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
461 .map(|t| t.into_owned())
462 .map_err(|e| {
463 ScytheError::new(
464 ErrorCode::InternalError,
465 format!("composite field type error: {}", e),
466 )
467 })?;
468 let _ = writeln!(
469 out,
470 " pub {}: {},",
471 to_snake_case(&field.name),
472 rust_type
473 );
474 }
475 let _ = write!(out, "}}");
476 Ok(out)
477 }
478}
479
480fn rewrite_sql_for_enums(
486 sql: &str,
487 columns: &[AnalyzedColumn],
488 manifest: &BackendManifest,
489) -> String {
490 let enum_cols: Vec<(&str, String)> = columns
491 .iter()
492 .filter_map(|col| {
493 if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
494 let rust_type = enum_type_name(enum_name, &manifest.naming);
495 let annotation = if col.nullable {
496 format!("Option<{}>", rust_type)
497 } else {
498 rust_type
499 };
500 Some((col.name.as_str(), annotation))
501 } else {
502 None
503 }
504 })
505 .collect();
506
507 if enum_cols.is_empty() {
508 return sql.to_string();
509 }
510
511 let mut result = sql.to_string();
512 for (col_name, annotation) in &enum_cols {
513 let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
514 if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
515 let select_part = &result[..from_pos];
516 let rest = &result[from_pos..];
517 let new_select = replace_column_in_select(select_part, col_name, &alias);
518 result = format!("{}{}", new_select, rest);
519 }
520 }
521 result
522}
523
524fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
525 let mut result = select.to_string();
526 let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
527 for pattern in &patterns {
528 if let Some(pos) = result.rfind(pattern.as_str()) {
529 let after = pos + pattern.len();
530 let next_char = result[after..].chars().next();
531 if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
532 let prefix = &result[..pos + pattern.len() - col_name.len()];
533 let suffix = &result[after..];
534 result = format!("{}{}{}", prefix, replacement, suffix);
535 break;
536 }
537 }
538 }
539 result
540}