1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::backends::typescript_common::{TsRowType, generate_zod_enum, generate_zod_row_struct};
16use crate::singularize;
17
18const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-mysql2.toml");
19
20pub struct TypescriptMysql2Backend {
21 manifest: BackendManifest,
22 row_type: TsRowType,
23}
24
25impl TypescriptMysql2Backend {
26 pub fn new(engine: &str) -> Result<Self, ScytheError> {
27 match engine {
28 "mysql" | "mariadb" => {}
29 _ => {
30 return Err(ScytheError::new(
31 ErrorCode::InternalError,
32 format!(
33 "typescript-mysql2 only supports MySQL/MariaDB, got engine '{}'",
34 engine
35 ),
36 ));
37 }
38 }
39 let manifest_path = Path::new("backends/typescript-mysql2/manifest.toml");
40 let manifest = if manifest_path.exists() {
41 load_manifest(manifest_path)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 } else {
44 toml::from_str(DEFAULT_MANIFEST_TOML)
45 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
46 };
47 Ok(Self {
48 manifest,
49 row_type: TsRowType::default(),
50 })
51 }
52}
53
54impl CodegenBackend for TypescriptMysql2Backend {
55 fn name(&self) -> &str {
56 "typescript-mysql2"
57 }
58
59 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
60 &self.manifest
61 }
62
63 fn supported_engines(&self) -> &[&str] {
64 &["mysql"]
65 }
66
67 fn file_header(&self) -> String {
68 let mut header =
69 "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Pool, RowDataPacket } from \"mysql2/promise\";\n"
70 .to_string();
71 if self.row_type == TsRowType::Zod {
72 header.push_str("import { z } from \"zod\";\n");
73 }
74 header
75 }
76
77 fn generate_row_struct(
78 &self,
79 query_name: &str,
80 columns: &[ResolvedColumn],
81 ) -> Result<String, ScytheError> {
82 let struct_name = row_struct_name(query_name, &self.manifest.naming);
83 if self.row_type == TsRowType::Zod {
84 let mut out = generate_zod_row_struct(&struct_name, query_name, columns);
85 let _ = writeln!(out);
87 let _ = writeln!(out);
88 let _ = write!(
89 out,
90 "export interface {struct_name}Packet extends RowDataPacket, {struct_name} {{}}"
91 );
92 return Ok(out);
93 }
94 let mut out = String::new();
95 let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
96 let _ = writeln!(
97 out,
98 "export interface {} extends RowDataPacket {{",
99 struct_name
100 );
101 for col in columns {
102 let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
103 }
104 let _ = write!(out, "}}");
105 Ok(out)
106 }
107
108 fn generate_model_struct(
109 &self,
110 table_name: &str,
111 columns: &[ResolvedColumn],
112 ) -> Result<String, ScytheError> {
113 let singular = singularize(table_name);
114 let name = to_pascal_case(&singular);
115 self.generate_row_struct(&name, columns)
116 }
117
118 fn generate_query_fn(
119 &self,
120 analyzed: &AnalyzedQuery,
121 struct_name: &str,
122 _columns: &[ResolvedColumn],
123 params: &[ResolvedParam],
124 ) -> Result<String, ScytheError> {
125 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
126 let mut out = String::new();
127
128 let param_list = params
129 .iter()
130 .map(|p| format!("{}: {}", p.field_name, p.full_type))
131 .collect::<Vec<_>>()
132 .join(", ");
133
134 let sql = super::clean_sql_with_optional(
135 &analyzed.sql,
136 &analyzed.optional_params,
137 &analyzed.params,
138 );
139
140 let inline_params = if params.is_empty() {
141 "pool: Pool".to_string()
142 } else {
143 format!("pool: Pool, {}", param_list)
144 };
145
146 let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
148 let oneliner = format!(
149 "export async function {}({}): {} {{",
150 name, params_inline, ret
151 );
152 if oneliner.len() <= 80 {
153 let _ = writeln!(out, "{}", oneliner);
154 } else {
155 let mut parts = vec!["\tpool: Pool".to_string()];
156 for p in params {
157 parts.push(format!("\t{}: {}", p.field_name, p.full_type));
158 }
159 let _ = writeln!(out, "export async function {}(", name);
160 for part in &parts {
161 let _ = writeln!(out, "{},", part);
162 }
163 let _ = writeln!(out, "): {} {{", ret);
164 }
165 };
166
167 let param_array = if params.is_empty() {
169 String::new()
170 } else {
171 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
172 format!(", [{}]", args.join(", "))
173 };
174
175 let query_type = if self.row_type == TsRowType::Zod {
177 format!("{struct_name}Packet")
178 } else {
179 struct_name.to_string()
180 };
181
182 match &analyzed.command {
183 QueryCommand::One => {
184 let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
185 let ret = format!("Promise<{} | null>", struct_name);
186 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
187 let _ = writeln!(
188 out,
189 "\tconst [rows] = await pool.execute<{}[]>(",
190 query_type
191 );
192 let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
193 let _ = writeln!(out, "\t);");
194 let _ = writeln!(out, "\treturn rows[0] ?? null;");
195 let _ = write!(out, "}}");
196 }
197 QueryCommand::Batch => {
198 let batch_fn_name = format!("{}Batch", func_name);
199 if params.len() > 1 {
200 let params_type_name = format!("{}BatchParams", struct_name);
201 let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
202 let _ = writeln!(out, "export interface {} {{", params_type_name);
203 for p in params {
204 let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
205 }
206 let _ = writeln!(out, "}}");
207 let _ = writeln!(out);
208 let _ = writeln!(
209 out,
210 "/** Execute {} for each item in the batch. */",
211 analyzed.name
212 );
213 let batch_params = format!("pool: Pool, items: {}[]", params_type_name);
214 write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
215 let _ = writeln!(out, "\tconst conn = await pool.getConnection();");
216 let _ = writeln!(out, "\ttry {{");
217 let _ = writeln!(out, "\t\tawait conn.beginTransaction();");
218 let _ = writeln!(out, "\t\tfor (const item of items) {{");
219 let args: Vec<String> = params
220 .iter()
221 .map(|p| format!("item.{}", p.field_name))
222 .collect();
223 let _ = writeln!(out, "\t\t\tawait conn.execute(");
224 let args_str = args.join(", ");
225 let _ = writeln!(out, "\t\t\t\t`{}`, [{}],", sql, args_str);
226 let _ = writeln!(out, "\t\t\t);");
227 let _ = writeln!(out, "\t\t}}");
228 let _ = writeln!(out, "\t\tawait conn.commit();");
229 let _ = writeln!(out, "\t}} catch (error) {{");
230 let _ = writeln!(out, "\t\tawait conn.rollback();");
231 let _ = writeln!(out, "\t\tthrow error;");
232 let _ = writeln!(out, "\t}} finally {{");
233 let _ = writeln!(out, "\t\tconn.release();");
234 let _ = writeln!(out, "\t}}");
235 let _ = write!(out, "}}");
236 } else if params.len() == 1 {
237 let _ = writeln!(
238 out,
239 "/** Execute {} for each item in the batch. */",
240 analyzed.name
241 );
242 let batch_params = format!("pool: Pool, items: {}[]", params[0].full_type);
243 write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
244 let _ = writeln!(out, "\tconst conn = await pool.getConnection();");
245 let _ = writeln!(out, "\ttry {{");
246 let _ = writeln!(out, "\t\tawait conn.beginTransaction();");
247 let _ = writeln!(out, "\t\tfor (const item of items) {{");
248 let _ = writeln!(out, "\t\t\tawait conn.execute(`{}`, [item]);", sql);
249 let _ = writeln!(out, "\t\t}}");
250 let _ = writeln!(out, "\t\tawait conn.commit();");
251 let _ = writeln!(out, "\t}} catch (error) {{");
252 let _ = writeln!(out, "\t\tawait conn.rollback();");
253 let _ = writeln!(out, "\t\tthrow error;");
254 let _ = writeln!(out, "\t}} finally {{");
255 let _ = writeln!(out, "\t\tconn.release();");
256 let _ = writeln!(out, "\t}}");
257 let _ = write!(out, "}}");
258 } else {
259 let _ = writeln!(
260 out,
261 "/** Execute {} for each item in the batch. */",
262 analyzed.name
263 );
264 write_fn_sig(
265 &mut out,
266 &batch_fn_name,
267 "pool: Pool, count: number",
268 "Promise<void>",
269 );
270 let _ = writeln!(out, "\tconst conn = await pool.getConnection();");
271 let _ = writeln!(out, "\ttry {{");
272 let _ = writeln!(out, "\t\tawait conn.beginTransaction();");
273 let _ = writeln!(out, "\t\tfor (let i = 0; i < count; i++) {{");
274 let _ = writeln!(out, "\t\t\tawait conn.execute(`{}`);", sql);
275 let _ = writeln!(out, "\t\t}}");
276 let _ = writeln!(out, "\t\tawait conn.commit();");
277 let _ = writeln!(out, "\t}} catch (error) {{");
278 let _ = writeln!(out, "\t\tawait conn.rollback();");
279 let _ = writeln!(out, "\t\tthrow error;");
280 let _ = writeln!(out, "\t}} finally {{");
281 let _ = writeln!(out, "\t\tconn.release();");
282 let _ = writeln!(out, "\t}}");
283 let _ = write!(out, "}}");
284 }
285 }
286 QueryCommand::Many => {
287 let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
288 let ret = format!("Promise<{}[]>", struct_name);
289 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
290 let _ = writeln!(
291 out,
292 "\tconst [rows] = await pool.execute<{}[]>(",
293 query_type
294 );
295 let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
296 let _ = writeln!(out, "\t);");
297 let _ = writeln!(out, "\treturn rows;");
298 let _ = write!(out, "}}");
299 }
300 QueryCommand::Exec => {
301 let _ = writeln!(out, "/** Execute a query returning no rows. */");
302 write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
303 let _ = writeln!(out, "\tawait pool.execute(");
304 let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
305 let _ = writeln!(out, "\t);");
306 let _ = write!(out, "}}");
307 }
308 QueryCommand::ExecResult | QueryCommand::ExecRows => {
309 let _ = writeln!(
310 out,
311 "/** Execute a query and return the number of affected rows. */"
312 );
313 write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
314 let _ = writeln!(out, "\tconst [result] = await pool.execute(");
315 let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
316 let _ = writeln!(out, "\t);");
317 let _ = writeln!(
318 out,
319 "\treturn (result as {{ affectedRows: number }}).affectedRows;"
320 );
321 let _ = write!(out, "}}");
322 }
323 QueryCommand::Grouped => {
324 unreachable!("Grouped is rewritten to Many before codegen")
325 }
326 }
327
328 Ok(out)
329 }
330
331 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
332 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
333 if self.row_type == TsRowType::Zod {
334 return Ok(generate_zod_enum(&type_name, &enum_info.values));
335 }
336 let mut out = String::new();
337 let _ = writeln!(out, "export enum {} {{", type_name);
338 for value in &enum_info.values {
339 let variant = enum_variant_name(value, &self.manifest.naming);
340 let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
341 }
342 let _ = write!(out, "}}");
343 Ok(out)
344 }
345
346 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
347 let name = to_pascal_case(&composite.sql_name);
348 let mut out = String::new();
349 let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
350 let _ = writeln!(out, "export interface {} {{", name);
351 for field in &composite.fields {
352 let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
353 .map(|t| t.into_owned())
354 .map_err(|e| {
355 ScytheError::new(
356 ErrorCode::InternalError,
357 format!("composite field type error: {}", e),
358 )
359 })?;
360 let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
361 }
362 let _ = write!(out, "}}");
363 Ok(out)
364 }
365
366 fn apply_options(
367 &mut self,
368 options: &std::collections::HashMap<String, String>,
369 ) -> Result<(), ScytheError> {
370 if let Some(value) = options.get("row_type") {
371 self.row_type = TsRowType::from_option(value)?;
372 }
373 Ok(())
374 }
375}