Skip to main content

scythe_codegen/backends/
typescript_mysql2.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::backends::typescript_common::{TsRowType, generate_zod_enum, generate_zod_row_struct};
16use crate::singularize;
17
18const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-mysql2.toml");
19
20pub struct TypescriptMysql2Backend {
21    manifest: BackendManifest,
22    row_type: TsRowType,
23}
24
25impl TypescriptMysql2Backend {
26    pub fn new(engine: &str) -> Result<Self, ScytheError> {
27        match engine {
28            "mysql" | "mariadb" => {}
29            _ => {
30                return Err(ScytheError::new(
31                    ErrorCode::InternalError,
32                    format!(
33                        "typescript-mysql2 only supports MySQL/MariaDB, got engine '{}'",
34                        engine
35                    ),
36                ));
37            }
38        }
39        let manifest_path = Path::new("backends/typescript-mysql2/manifest.toml");
40        let manifest = if manifest_path.exists() {
41            load_manifest(manifest_path)
42                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43        } else {
44            toml::from_str(DEFAULT_MANIFEST_TOML)
45                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
46        };
47        Ok(Self {
48            manifest,
49            row_type: TsRowType::default(),
50        })
51    }
52}
53
54impl CodegenBackend for TypescriptMysql2Backend {
55    fn name(&self) -> &str {
56        "typescript-mysql2"
57    }
58
59    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
60        &self.manifest
61    }
62
63    fn supported_engines(&self) -> &[&str] {
64        &["mysql"]
65    }
66
67    fn file_header(&self) -> String {
68        let mut header =
69            "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Pool, RowDataPacket } from \"mysql2/promise\";\n"
70                .to_string();
71        if self.row_type == TsRowType::Zod {
72            header.push_str("import { z } from \"zod\";\n");
73        }
74        header
75    }
76
77    fn generate_row_struct(
78        &self,
79        query_name: &str,
80        columns: &[ResolvedColumn],
81    ) -> Result<String, ScytheError> {
82        let struct_name = row_struct_name(query_name, &self.manifest.naming);
83        if self.row_type == TsRowType::Zod {
84            let mut out = generate_zod_row_struct(&struct_name, query_name, columns);
85            // mysql2 needs a RowDataPacket-compatible interface for query generics
86            let _ = writeln!(out);
87            let _ = writeln!(out);
88            let _ = write!(
89                out,
90                "export interface {struct_name}Packet extends RowDataPacket, {struct_name} {{}}"
91            );
92            return Ok(out);
93        }
94        let mut out = String::new();
95        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
96        let _ = writeln!(
97            out,
98            "export interface {} extends RowDataPacket {{",
99            struct_name
100        );
101        for col in columns {
102            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
103        }
104        let _ = write!(out, "}}");
105        Ok(out)
106    }
107
108    fn generate_model_struct(
109        &self,
110        table_name: &str,
111        columns: &[ResolvedColumn],
112    ) -> Result<String, ScytheError> {
113        let singular = singularize(table_name);
114        let name = to_pascal_case(&singular);
115        self.generate_row_struct(&name, columns)
116    }
117
118    fn generate_query_fn(
119        &self,
120        analyzed: &AnalyzedQuery,
121        struct_name: &str,
122        _columns: &[ResolvedColumn],
123        params: &[ResolvedParam],
124    ) -> Result<String, ScytheError> {
125        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
126        let mut out = String::new();
127
128        let param_list = params
129            .iter()
130            .map(|p| format!("{}: {}", p.field_name, p.full_type))
131            .collect::<Vec<_>>()
132            .join(", ");
133
134        let sql = super::clean_sql_with_optional(
135            &analyzed.sql,
136            &analyzed.optional_params,
137            &analyzed.params,
138        );
139
140        let inline_params = if params.is_empty() {
141            "pool: Pool".to_string()
142        } else {
143            format!("pool: Pool, {}", param_list)
144        };
145
146        // Helper: write function signature, inline or multi-line based on length
147        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
148            let oneliner = format!(
149                "export async function {}({}): {} {{",
150                name, params_inline, ret
151            );
152            if oneliner.len() <= 80 {
153                let _ = writeln!(out, "{}", oneliner);
154            } else {
155                let mut parts = vec!["\tpool: Pool".to_string()];
156                for p in params {
157                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
158                }
159                let _ = writeln!(out, "export async function {}(", name);
160                for part in &parts {
161                    let _ = writeln!(out, "{},", part);
162                }
163                let _ = writeln!(out, "): {} {{", ret);
164            }
165        };
166
167        // Helper: write param array for execute call
168        let param_array = if params.is_empty() {
169            String::new()
170        } else {
171            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
172            format!(", [{}]", args.join(", "))
173        };
174
175        // In Zod mode, mysql2 query generics use the Packet interface
176        let query_type = if self.row_type == TsRowType::Zod {
177            format!("{struct_name}Packet")
178        } else {
179            struct_name.to_string()
180        };
181
182        match &analyzed.command {
183            QueryCommand::One => {
184                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
185                let ret = format!("Promise<{} | null>", struct_name);
186                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
187                let _ = writeln!(
188                    out,
189                    "\tconst [rows] = await pool.execute<{}[]>(",
190                    query_type
191                );
192                let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
193                let _ = writeln!(out, "\t);");
194                let _ = writeln!(out, "\treturn rows[0] ?? null;");
195                let _ = write!(out, "}}");
196            }
197            QueryCommand::Batch => {
198                let batch_fn_name = format!("{}Batch", func_name);
199                if params.len() > 1 {
200                    let params_type_name = format!("{}BatchParams", struct_name);
201                    let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
202                    let _ = writeln!(out, "export interface {} {{", params_type_name);
203                    for p in params {
204                        let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
205                    }
206                    let _ = writeln!(out, "}}");
207                    let _ = writeln!(out);
208                    let _ = writeln!(
209                        out,
210                        "/** Execute {} for each item in the batch. */",
211                        analyzed.name
212                    );
213                    let batch_params = format!("pool: Pool, items: {}[]", params_type_name);
214                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
215                    let _ = writeln!(out, "\tconst conn = await pool.getConnection();");
216                    let _ = writeln!(out, "\ttry {{");
217                    let _ = writeln!(out, "\t\tawait conn.beginTransaction();");
218                    let _ = writeln!(out, "\t\tfor (const item of items) {{");
219                    let args: Vec<String> = params
220                        .iter()
221                        .map(|p| format!("item.{}", p.field_name))
222                        .collect();
223                    let _ = writeln!(out, "\t\t\tawait conn.execute(");
224                    let args_str = args.join(", ");
225                    let _ = writeln!(out, "\t\t\t\t`{}`, [{}],", sql, args_str);
226                    let _ = writeln!(out, "\t\t\t);");
227                    let _ = writeln!(out, "\t\t}}");
228                    let _ = writeln!(out, "\t\tawait conn.commit();");
229                    let _ = writeln!(out, "\t}} catch (error) {{");
230                    let _ = writeln!(out, "\t\tawait conn.rollback();");
231                    let _ = writeln!(out, "\t\tthrow error;");
232                    let _ = writeln!(out, "\t}} finally {{");
233                    let _ = writeln!(out, "\t\tconn.release();");
234                    let _ = writeln!(out, "\t}}");
235                    let _ = write!(out, "}}");
236                } else if params.len() == 1 {
237                    let _ = writeln!(
238                        out,
239                        "/** Execute {} for each item in the batch. */",
240                        analyzed.name
241                    );
242                    let batch_params = format!("pool: Pool, items: {}[]", params[0].full_type);
243                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
244                    let _ = writeln!(out, "\tconst conn = await pool.getConnection();");
245                    let _ = writeln!(out, "\ttry {{");
246                    let _ = writeln!(out, "\t\tawait conn.beginTransaction();");
247                    let _ = writeln!(out, "\t\tfor (const item of items) {{");
248                    let _ = writeln!(out, "\t\t\tawait conn.execute(`{}`, [item]);", sql);
249                    let _ = writeln!(out, "\t\t}}");
250                    let _ = writeln!(out, "\t\tawait conn.commit();");
251                    let _ = writeln!(out, "\t}} catch (error) {{");
252                    let _ = writeln!(out, "\t\tawait conn.rollback();");
253                    let _ = writeln!(out, "\t\tthrow error;");
254                    let _ = writeln!(out, "\t}} finally {{");
255                    let _ = writeln!(out, "\t\tconn.release();");
256                    let _ = writeln!(out, "\t}}");
257                    let _ = write!(out, "}}");
258                } else {
259                    let _ = writeln!(
260                        out,
261                        "/** Execute {} for each item in the batch. */",
262                        analyzed.name
263                    );
264                    write_fn_sig(
265                        &mut out,
266                        &batch_fn_name,
267                        "pool: Pool, count: number",
268                        "Promise<void>",
269                    );
270                    let _ = writeln!(out, "\tconst conn = await pool.getConnection();");
271                    let _ = writeln!(out, "\ttry {{");
272                    let _ = writeln!(out, "\t\tawait conn.beginTransaction();");
273                    let _ = writeln!(out, "\t\tfor (let i = 0; i < count; i++) {{");
274                    let _ = writeln!(out, "\t\t\tawait conn.execute(`{}`);", sql);
275                    let _ = writeln!(out, "\t\t}}");
276                    let _ = writeln!(out, "\t\tawait conn.commit();");
277                    let _ = writeln!(out, "\t}} catch (error) {{");
278                    let _ = writeln!(out, "\t\tawait conn.rollback();");
279                    let _ = writeln!(out, "\t\tthrow error;");
280                    let _ = writeln!(out, "\t}} finally {{");
281                    let _ = writeln!(out, "\t\tconn.release();");
282                    let _ = writeln!(out, "\t}}");
283                    let _ = write!(out, "}}");
284                }
285            }
286            QueryCommand::Many => {
287                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
288                let ret = format!("Promise<{}[]>", struct_name);
289                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
290                let _ = writeln!(
291                    out,
292                    "\tconst [rows] = await pool.execute<{}[]>(",
293                    query_type
294                );
295                let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
296                let _ = writeln!(out, "\t);");
297                let _ = writeln!(out, "\treturn rows;");
298                let _ = write!(out, "}}");
299            }
300            QueryCommand::Exec => {
301                let _ = writeln!(out, "/** Execute a query returning no rows. */");
302                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
303                let _ = writeln!(out, "\tawait pool.execute(");
304                let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
305                let _ = writeln!(out, "\t);");
306                let _ = write!(out, "}}");
307            }
308            QueryCommand::ExecResult | QueryCommand::ExecRows => {
309                let _ = writeln!(
310                    out,
311                    "/** Execute a query and return the number of affected rows. */"
312                );
313                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
314                let _ = writeln!(out, "\tconst [result] = await pool.execute(");
315                let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
316                let _ = writeln!(out, "\t);");
317                let _ = writeln!(
318                    out,
319                    "\treturn (result as {{ affectedRows: number }}).affectedRows;"
320                );
321                let _ = write!(out, "}}");
322            }
323            QueryCommand::Grouped => {
324                unreachable!("Grouped is rewritten to Many before codegen")
325            }
326        }
327
328        Ok(out)
329    }
330
331    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
332        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
333        if self.row_type == TsRowType::Zod {
334            return Ok(generate_zod_enum(&type_name, &enum_info.values));
335        }
336        let mut out = String::new();
337        let _ = writeln!(out, "export enum {} {{", type_name);
338        for value in &enum_info.values {
339            let variant = enum_variant_name(value, &self.manifest.naming);
340            let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
341        }
342        let _ = write!(out, "}}");
343        Ok(out)
344    }
345
346    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
347        let name = to_pascal_case(&composite.sql_name);
348        let mut out = String::new();
349        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
350        let _ = writeln!(out, "export interface {} {{", name);
351        for field in &composite.fields {
352            let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
353                .map(|t| t.into_owned())
354                .map_err(|e| {
355                    ScytheError::new(
356                        ErrorCode::InternalError,
357                        format!("composite field type error: {}", e),
358                    )
359                })?;
360            let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
361        }
362        let _ = write!(out, "}}");
363        Ok(out)
364    }
365
366    fn apply_options(
367        &mut self,
368        options: &std::collections::HashMap<String, String>,
369    ) -> Result<(), ScytheError> {
370        if let Some(value) = options.get("row_type") {
371            self.row_type = TsRowType::from_option(value)?;
372        }
373        Ok(())
374    }
375}