1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16pub struct GoDatabaseSqlBackend {
17 manifest: BackendManifest,
18 engine: String,
19}
20
21impl GoDatabaseSqlBackend {
22 pub fn new(engine: &str) -> Result<Self, ScytheError> {
23 let manifest_toml = match engine {
24 "mysql" | "mariadb" => include_str!("../../manifests/go-database-sql.mysql.toml"),
25 "sqlite" | "sqlite3" => include_str!("../../manifests/go-database-sql.sqlite.toml"),
26 "duckdb" => include_str!("../../manifests/go-database-sql.duckdb.toml"),
27 _ => {
28 return Err(ScytheError::new(
29 ErrorCode::InternalError,
30 format!(
31 "go-database-sql supports MySQL, SQLite, and DuckDB, got engine '{}'",
32 engine
33 ),
34 ));
35 }
36 };
37 let manifest_path = Path::new("backends/go-database-sql/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(manifest_toml)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 Ok(Self {
46 manifest,
47 engine: engine.to_string(),
48 })
49 }
50}
51
52impl CodegenBackend for GoDatabaseSqlBackend {
53 fn name(&self) -> &str {
54 "go-database-sql"
55 }
56
57 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
58 &self.manifest
59 }
60
61 fn supported_engines(&self) -> &[&str] {
62 &["mysql", "sqlite", "duckdb"]
63 }
64
65 fn file_header(&self) -> String {
66 let uses_time = matches!(self.engine.as_str(), "mysql" | "mariadb" | "duckdb");
67 let mut header =
68 String::from("package queries\n\nimport (\n\t\"context\"\n\t\"database/sql\"");
69 if uses_time {
70 header.push_str("\n\t\"time\"");
71 }
72 header.push_str("\n)\n");
73 header
74 }
75
76 fn generate_row_struct(
77 &self,
78 query_name: &str,
79 columns: &[ResolvedColumn],
80 ) -> Result<String, ScytheError> {
81 let struct_name = row_struct_name(query_name, &self.manifest.naming);
82 let mut out = String::new();
83 let _ = writeln!(out, "type {} struct {{", struct_name);
84 for col in columns {
85 let field = to_pascal_case(&col.field_name);
86 let json_tag = &col.field_name;
87 let _ = writeln!(out, "\t{} {} `json:\"{}\"`", field, col.full_type, json_tag);
88 }
89 let _ = write!(out, "}}");
90 Ok(out)
91 }
92
93 fn generate_model_struct(
94 &self,
95 table_name: &str,
96 columns: &[ResolvedColumn],
97 ) -> Result<String, ScytheError> {
98 let name = to_pascal_case(table_name);
99 self.generate_row_struct(&name, columns)
100 }
101
102 fn generate_query_fn(
103 &self,
104 analyzed: &AnalyzedQuery,
105 struct_name: &str,
106 columns: &[ResolvedColumn],
107 params: &[ResolvedParam],
108 ) -> Result<String, ScytheError> {
109 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
110 let sql = super::clean_sql_oneline_with_optional(
111 &analyzed.sql,
112 &analyzed.optional_params,
113 &analyzed.params,
114 );
115
116 let param_list = params
117 .iter()
118 .map(|p| {
119 let field = to_pascal_case(&p.field_name);
120 format!("{} {}", field, p.full_type)
121 })
122 .collect::<Vec<_>>()
123 .join(", ");
124 let sep = if param_list.is_empty() { "" } else { ", " };
125
126 let args = params
127 .iter()
128 .map(|p| to_pascal_case(&p.field_name).into_owned())
129 .collect::<Vec<_>>();
130
131 let mut out = String::new();
132
133 match &analyzed.command {
134 QueryCommand::Exec => {
135 let _ = writeln!(
136 out,
137 "func {}(ctx context.Context, db *sql.DB{}{}) error {{",
138 func_name, sep, param_list
139 );
140 let args_str = if args.is_empty() {
141 String::new()
142 } else {
143 format!(", {}", args.join(", "))
144 };
145 let _ = writeln!(
146 out,
147 "\t_, err := db.ExecContext(ctx, \"{}\"{})",
148 sql, args_str
149 );
150 let _ = writeln!(out, "\treturn err");
151 let _ = write!(out, "}}");
152 }
153 QueryCommand::ExecResult | QueryCommand::ExecRows => {
154 let _ = writeln!(
155 out,
156 "func {}(ctx context.Context, db *sql.DB{}{}) (int64, error) {{",
157 func_name, sep, param_list
158 );
159 let args_str = if args.is_empty() {
160 String::new()
161 } else {
162 format!(", {}", args.join(", "))
163 };
164 let _ = writeln!(
165 out,
166 "\tresult, err := db.ExecContext(ctx, \"{}\"{})",
167 sql, args_str
168 );
169 let _ = writeln!(out, "\tif err != nil {{");
170 let _ = writeln!(out, "\t\treturn 0, err");
171 let _ = writeln!(out, "\t}}");
172 let _ = writeln!(out, "\treturn result.RowsAffected()");
173 let _ = write!(out, "}}");
174 }
175 QueryCommand::One => {
176 let _ = writeln!(
177 out,
178 "func {}(ctx context.Context, db *sql.DB{}{}) ({}, error) {{",
179 func_name, sep, param_list, struct_name
180 );
181 let args_str = if args.is_empty() {
182 String::new()
183 } else {
184 format!(", {}", args.join(", "))
185 };
186 let _ = writeln!(
187 out,
188 "\trow := db.QueryRowContext(ctx, \"{}\"{})",
189 sql, args_str
190 );
191 let _ = writeln!(out, "\tvar r {}", struct_name);
192 let scan_fields: Vec<String> = columns
193 .iter()
194 .map(|c| format!("&r.{}", to_pascal_case(&c.field_name)))
195 .collect();
196 let _ = writeln!(out, "\terr := row.Scan({})", scan_fields.join(", "));
197 let _ = writeln!(out, "\treturn r, err");
198 let _ = write!(out, "}}");
199 }
200 QueryCommand::Batch => {
201 let batch_fn_name = format!("{}Batch", func_name);
202 if params.len() > 1 {
203 let params_struct_name = format!("{}BatchParams", func_name);
204 let _ = writeln!(out, "type {} struct {{", params_struct_name);
205 for p in params {
206 let field = to_pascal_case(&p.field_name);
207 let _ = writeln!(out, "\t{} {}", field, p.full_type);
208 }
209 let _ = writeln!(out, "}}");
210 let _ = writeln!(out);
211 let _ = writeln!(
212 out,
213 "func {}(ctx context.Context, db *sql.DB, items []{}) error {{",
214 batch_fn_name, params_struct_name
215 );
216 } else if params.len() == 1 {
217 let _ = writeln!(
218 out,
219 "func {}(ctx context.Context, db *sql.DB, items []{}) error {{",
220 batch_fn_name, params[0].full_type
221 );
222 } else {
223 let _ = writeln!(
224 out,
225 "func {}(ctx context.Context, db *sql.DB, count int) error {{",
226 batch_fn_name
227 );
228 }
229 let _ = writeln!(out, "\ttx, err := db.BeginTx(ctx, nil)");
230 let _ = writeln!(out, "\tif err != nil {{");
231 let _ = writeln!(out, "\t\treturn err");
232 let _ = writeln!(out, "\t}}");
233 let _ = writeln!(out, "\tdefer tx.Rollback()");
234 if params.is_empty() {
235 let _ = writeln!(out, "\tfor i := 0; i < count; i++ {{");
236 let _ = writeln!(out, "\t\t_, err := tx.ExecContext(ctx, \"{}\")", sql);
237 } else {
238 let _ = writeln!(out, "\tfor _, item := range items {{");
239 if params.len() > 1 {
240 let item_args: Vec<String> = params
241 .iter()
242 .map(|p| format!("item.{}", to_pascal_case(&p.field_name)))
243 .collect();
244 let _ = writeln!(
245 out,
246 "\t\t_, err := tx.ExecContext(ctx, \"{}\", {})",
247 sql,
248 item_args.join(", ")
249 );
250 } else {
251 let _ =
252 writeln!(out, "\t\t_, err := tx.ExecContext(ctx, \"{}\", item)", sql);
253 }
254 }
255 let _ = writeln!(out, "\t\tif err != nil {{");
256 let _ = writeln!(out, "\t\t\treturn err");
257 let _ = writeln!(out, "\t\t}}");
258 let _ = writeln!(out, "\t}}");
259 let _ = writeln!(out, "\treturn tx.Commit()");
260 let _ = write!(out, "}}");
261 }
262 QueryCommand::Many => {
263 let _ = writeln!(
264 out,
265 "func {}(ctx context.Context, db *sql.DB{}{}) ([]{}, error) {{",
266 func_name, sep, param_list, struct_name
267 );
268 let args_str = if args.is_empty() {
269 String::new()
270 } else {
271 format!(", {}", args.join(", "))
272 };
273 let _ = writeln!(
274 out,
275 "\trows, err := db.QueryContext(ctx, \"{}\"{})",
276 sql, args_str
277 );
278 let _ = writeln!(out, "\tif err != nil {{");
279 let _ = writeln!(out, "\t\treturn nil, err");
280 let _ = writeln!(out, "\t}}");
281 let _ = writeln!(out, "\tdefer rows.Close()");
282 let _ = writeln!(out, "\tvar result []{}", struct_name);
283 let _ = writeln!(out, "\tfor rows.Next() {{");
284 let _ = writeln!(out, "\t\tvar r {}", struct_name);
285 let scan_fields: Vec<String> = columns
286 .iter()
287 .map(|c| format!("&r.{}", to_pascal_case(&c.field_name)))
288 .collect();
289 let _ = writeln!(
290 out,
291 "\t\tif err := rows.Scan({}); err != nil {{",
292 scan_fields.join(", ")
293 );
294 let _ = writeln!(out, "\t\t\treturn nil, err");
295 let _ = writeln!(out, "\t\t}}");
296 let _ = writeln!(out, "\t\tresult = append(result, r)");
297 let _ = writeln!(out, "\t}}");
298 let _ = writeln!(out, "\treturn result, rows.Err()");
299 let _ = write!(out, "}}");
300 }
301 QueryCommand::Grouped => {
302 unreachable!("Grouped is rewritten to Many before codegen")
303 }
304 }
305
306 Ok(out)
307 }
308
309 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
310 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
311 let mut out = String::new();
312 let _ = writeln!(out, "type {} string", type_name);
313 let _ = writeln!(out);
314 let _ = writeln!(out, "const (");
315 for value in &enum_info.values {
316 let variant = enum_variant_name(value, &self.manifest.naming);
317 let _ = writeln!(
318 out,
319 "\t{}{} {} = \"{}\"",
320 type_name, variant, type_name, value
321 );
322 }
323 let _ = write!(out, ")");
324 Ok(out)
325 }
326
327 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
328 let name = to_pascal_case(&composite.sql_name);
329 let mut out = String::new();
330 let _ = writeln!(out, "type {} struct {{", name);
331 if !composite.fields.is_empty() {
332 for field in &composite.fields {
333 let field_name = to_pascal_case(&field.name);
334 let go_type = resolve_type(&field.neutral_type, &self.manifest, false)
335 .map(|t| t.into_owned())
336 .unwrap_or_else(|_| "interface{}".to_string());
337 let json_tag = &field.name;
338 let _ = writeln!(out, "\t{} {} `json:\"{}\"`", field_name, go_type, json_tag);
339 }
340 }
341 let _ = write!(out, "}}");
342 Ok(out)
343 }
344}