1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14use crate::singularize;
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");
18
19pub struct SqlxBackend {
21 manifest: BackendManifest,
22}
23
24impl SqlxBackend {
25 pub fn new() -> Result<Self, ScytheError> {
26 let manifest = load_sqlx_manifest()?;
27 Ok(Self { manifest })
28 }
29
30 pub fn manifest(&self) -> &BackendManifest {
32 &self.manifest
33 }
34}
35
36fn load_sqlx_manifest() -> Result<BackendManifest, ScytheError> {
37 let manifest_path = Path::new("backends/rust-sqlx/manifest.toml");
38 if manifest_path.exists() {
39 load_manifest(manifest_path).map_err(|e| {
40 ScytheError::new(
41 ErrorCode::InternalError,
42 format!("failed to load manifest: {e}"),
43 )
44 })
45 } else {
46 toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
47 ScytheError::new(
48 ErrorCode::InternalError,
49 format!("failed to parse embedded manifest: {e}"),
50 )
51 })
52 }
53}
54
55impl CodegenBackend for SqlxBackend {
56 fn name(&self) -> &str {
57 "rust-sqlx"
58 }
59
60 fn generate_row_struct(
61 &self,
62 query_name: &str,
63 columns: &[ResolvedColumn],
64 ) -> Result<String, ScytheError> {
65 let struct_name = row_struct_name(query_name, &self.manifest.naming);
66 let mut out = String::new();
67
68 let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
69 let _ = writeln!(out, "pub struct {} {{", struct_name);
70
71 for col in columns {
72 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
73 }
74
75 let _ = write!(out, "}}");
76 Ok(out)
77 }
78
79 fn generate_model_struct(
80 &self,
81 table_name: &str,
82 columns: &[ResolvedColumn],
83 ) -> Result<String, ScytheError> {
84 let singular = singularize(table_name);
85 let struct_name = to_pascal_case(&singular).into_owned();
86 let mut out = String::new();
87
88 let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
89 let _ = writeln!(out, "pub struct {} {{", struct_name);
90
91 for col in columns {
92 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
93 }
94
95 let _ = write!(out, "}}");
96 Ok(out)
97 }
98
99 fn generate_query_fn(
100 &self,
101 analyzed: &AnalyzedQuery,
102 struct_name: &str,
103 _columns: &[ResolvedColumn],
104 params: &[ResolvedParam],
105 ) -> Result<String, ScytheError> {
106 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
107 let mut out = String::new();
108
109 if let Some(ref msg) = analyzed.deprecated {
111 let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
112 }
113
114 let mut param_parts: Vec<String> = vec!["pool: &sqlx::PgPool".to_string()];
116 for param in params {
117 param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
118 }
119
120 let return_type = match &analyzed.command {
122 QueryCommand::One => struct_name.to_string(),
123 QueryCommand::Many => format!("Vec<{}>", struct_name),
124 QueryCommand::Exec => "()".to_string(),
125 QueryCommand::ExecResult => "sqlx::postgres::PgQueryResult".to_string(),
126 QueryCommand::ExecRows => "u64".to_string(),
127 QueryCommand::Batch => format!("Vec<{}>", struct_name),
128 };
129
130 let _ = writeln!(
132 out,
133 "pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
134 func_name,
135 param_parts.join(", "),
136 return_type
137 );
138
139 let sql_raw = super::clean_sql(&analyzed.sql);
141 let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);
142
143 let has_row_struct = matches!(
145 analyzed.command,
146 QueryCommand::One | QueryCommand::Many | QueryCommand::Batch
147 );
148
149 let bind_params: String = analyzed
151 .params
152 .iter()
153 .map(|p| {
154 let param_name = to_snake_case(&p.name);
155 if p.neutral_type.starts_with("enum::") {
156 let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
157 let rust_type = enum_type_name(enum_name, &self.manifest.naming);
158 format!(", {} as &{}", param_name, rust_type)
159 } else {
160 format!(", {}", param_name)
161 }
162 })
163 .collect();
164
165 let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);
166
167 if is_exec_rows {
168 if has_row_struct && !analyzed.columns.is_empty() {
169 let _ = write!(
170 out,
171 " let result = sqlx::query_as!({}, \"{}\"{})",
172 struct_name, sql, bind_params
173 );
174 } else {
175 let _ = write!(
176 out,
177 " let result = sqlx::query!(\"{}\"{})",
178 sql, bind_params
179 );
180 }
181 } else if has_row_struct && !analyzed.columns.is_empty() {
182 let _ = write!(
183 out,
184 " sqlx::query_as!({}, \"{}\"{})",
185 struct_name, sql, bind_params
186 );
187 } else {
188 let _ = write!(out, " sqlx::query!(\"{}\"{})", sql, bind_params);
189 }
190
191 let _ = writeln!(out);
192
193 let fetch_method = match &analyzed.command {
195 QueryCommand::One => ".fetch_one(pool)",
196 QueryCommand::Many => ".fetch_all(pool)",
197 QueryCommand::Exec => ".execute(pool)",
198 QueryCommand::ExecResult => ".execute(pool)",
199 QueryCommand::ExecRows => ".execute(pool)",
200 QueryCommand::Batch => ".fetch_all(pool)",
201 };
202
203 let _ = write!(out, " {}", fetch_method);
204 let _ = writeln!(out);
205
206 match &analyzed.command {
208 QueryCommand::Exec => {
209 let _ = writeln!(out, " .await?;");
210 let _ = writeln!(out, " Ok(())");
211 }
212 QueryCommand::ExecRows => {
213 let _ = writeln!(out, " .await?;");
214 let _ = writeln!(out, " Ok(result.rows_affected())");
215 }
216 _ => {
217 let _ = writeln!(out, " .await");
218 }
219 }
220
221 let _ = write!(out, "}}");
222 Ok(out)
223 }
224
225 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
226 let mut out = String::with_capacity(256);
227 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
228
229 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
230 let _ = writeln!(
231 out,
232 "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
233 enum_info.sql_name
234 );
235 let _ = writeln!(out, "pub enum {type_name} {{");
236
237 for value in &enum_info.values {
238 let variant = enum_variant_name(value, &self.manifest.naming);
239 let _ = writeln!(out, " {variant},");
240 }
241
242 let _ = write!(out, "}}");
243 Ok(out)
244 }
245
246 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
247 use scythe_backend::types::resolve_type;
248
249 let struct_name = to_pascal_case(&composite.sql_name).into_owned();
250 let mut out = String::new();
251
252 let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
253 let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
254 let _ = writeln!(out, "pub struct {} {{", struct_name);
255 for field in &composite.fields {
256 let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
257 .map(|t| t.into_owned())
258 .map_err(|e| {
259 ScytheError::new(
260 ErrorCode::InternalError,
261 format!("composite field type error: {}", e),
262 )
263 })?;
264 let _ = writeln!(
265 out,
266 " pub {}: {},",
267 to_snake_case(&field.name),
268 rust_type
269 );
270 }
271 let _ = write!(out, "}}");
272 Ok(out)
273 }
274}
275
276fn rewrite_sql_for_enums(
282 sql: &str,
283 columns: &[AnalyzedColumn],
284 manifest: &BackendManifest,
285) -> String {
286 let enum_cols: Vec<(&str, String)> = columns
287 .iter()
288 .filter_map(|col| {
289 if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
290 let rust_type = enum_type_name(enum_name, &manifest.naming);
291 let annotation = if col.nullable {
292 format!("Option<{}>", rust_type)
293 } else {
294 rust_type
295 };
296 Some((col.name.as_str(), annotation))
297 } else {
298 None
299 }
300 })
301 .collect();
302
303 if enum_cols.is_empty() {
304 return sql.to_string();
305 }
306
307 let mut result = sql.to_string();
308 for (col_name, annotation) in &enum_cols {
309 let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
310 if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
311 let select_part = &result[..from_pos];
312 let rest = &result[from_pos..];
313 let new_select = replace_column_in_select(select_part, col_name, &alias);
314 result = format!("{}{}", new_select, rest);
315 }
316 }
317 result
318}
319
320fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
321 let mut result = select.to_string();
322 let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
323 for pattern in &patterns {
324 if let Some(pos) = result.rfind(pattern.as_str()) {
325 let after = pos + pattern.len();
326 let next_char = result[after..].chars().next();
327 if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
328 let prefix = &result[..pos + pattern.len() - col_name.len()];
329 let suffix = &result[after..];
330 result = format!("{}{}{}", prefix, replacement, suffix);
331 break;
332 }
333 }
334 }
335 result
336}