Skip to main content

scythe_codegen/backends/
typescript_mysql2.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-mysql2.toml");
18
19pub struct TypescriptMysql2Backend {
20    manifest: BackendManifest,
21}
22
23impl TypescriptMysql2Backend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        match engine {
26            "mysql" | "mariadb" => {}
27            _ => {
28                return Err(ScytheError::new(
29                    ErrorCode::InternalError,
30                    format!(
31                        "typescript-mysql2 only supports MySQL/MariaDB, got engine '{}'",
32                        engine
33                    ),
34                ));
35            }
36        }
37        let manifest_path = Path::new("backends/typescript-mysql2/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(DEFAULT_MANIFEST_TOML)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self { manifest })
46    }
47}
48
49impl CodegenBackend for TypescriptMysql2Backend {
50    fn name(&self) -> &str {
51        "typescript-mysql2"
52    }
53
54    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
55        &self.manifest
56    }
57
58    fn supported_engines(&self) -> &[&str] {
59        &["mysql"]
60    }
61
62    fn file_header(&self) -> String {
63        "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Pool, RowDataPacket } from \"mysql2/promise\";\n"
64            .to_string()
65    }
66
67    fn generate_row_struct(
68        &self,
69        query_name: &str,
70        columns: &[ResolvedColumn],
71    ) -> Result<String, ScytheError> {
72        let struct_name = row_struct_name(query_name, &self.manifest.naming);
73        let mut out = String::new();
74        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
75        let _ = writeln!(
76            out,
77            "export interface {} extends RowDataPacket {{",
78            struct_name
79        );
80        for col in columns {
81            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
82        }
83        let _ = write!(out, "}}");
84        Ok(out)
85    }
86
87    fn generate_model_struct(
88        &self,
89        table_name: &str,
90        columns: &[ResolvedColumn],
91    ) -> Result<String, ScytheError> {
92        let singular = singularize(table_name);
93        let name = to_pascal_case(&singular);
94        self.generate_row_struct(&name, columns)
95    }
96
97    fn generate_query_fn(
98        &self,
99        analyzed: &AnalyzedQuery,
100        struct_name: &str,
101        _columns: &[ResolvedColumn],
102        params: &[ResolvedParam],
103    ) -> Result<String, ScytheError> {
104        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
105        let mut out = String::new();
106
107        let param_list = params
108            .iter()
109            .map(|p| format!("{}: {}", p.field_name, p.full_type))
110            .collect::<Vec<_>>()
111            .join(", ");
112
113        let sql = super::clean_sql(&analyzed.sql);
114
115        let inline_params = if params.is_empty() {
116            "pool: Pool".to_string()
117        } else {
118            format!("pool: Pool, {}", param_list)
119        };
120
121        // Helper: write function signature, inline or multi-line based on length
122        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
123            let oneliner = format!(
124                "export async function {}({}): {} {{",
125                name, params_inline, ret
126            );
127            if oneliner.len() <= 80 {
128                let _ = writeln!(out, "{}", oneliner);
129            } else {
130                let mut parts = vec!["\tpool: Pool".to_string()];
131                for p in params {
132                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
133                }
134                let _ = writeln!(out, "export async function {}(", name);
135                for part in &parts {
136                    let _ = writeln!(out, "{},", part);
137                }
138                let _ = writeln!(out, "): {} {{", ret);
139            }
140        };
141
142        // Helper: write param array for execute call
143        let param_array = if params.is_empty() {
144            String::new()
145        } else {
146            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
147            format!(", [{}]", args.join(", "))
148        };
149
150        match &analyzed.command {
151            QueryCommand::One => {
152                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
153                let ret = format!("Promise<{} | null>", struct_name);
154                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
155                let _ = writeln!(
156                    out,
157                    "\tconst [rows] = await pool.execute<{}[]>(",
158                    struct_name
159                );
160                let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
161                let _ = writeln!(out, "\t);");
162                let _ = writeln!(out, "\treturn rows[0] ?? null;");
163                let _ = write!(out, "}}");
164            }
165            QueryCommand::Many | QueryCommand::Batch => {
166                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
167                let ret = format!("Promise<{}[]>", struct_name);
168                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
169                let _ = writeln!(
170                    out,
171                    "\tconst [rows] = await pool.execute<{}[]>(",
172                    struct_name
173                );
174                let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
175                let _ = writeln!(out, "\t);");
176                let _ = writeln!(out, "\treturn rows;");
177                let _ = write!(out, "}}");
178            }
179            QueryCommand::Exec => {
180                let _ = writeln!(out, "/** Execute a query returning no rows. */");
181                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
182                let _ = writeln!(out, "\tawait pool.execute(");
183                let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
184                let _ = writeln!(out, "\t);");
185                let _ = write!(out, "}}");
186            }
187            QueryCommand::ExecResult | QueryCommand::ExecRows => {
188                let _ = writeln!(
189                    out,
190                    "/** Execute a query and return the number of affected rows. */"
191                );
192                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
193                let _ = writeln!(out, "\tconst [result] = await pool.execute(");
194                let _ = writeln!(out, "\t\t`{}`{},", sql, param_array);
195                let _ = writeln!(out, "\t);");
196                let _ = writeln!(
197                    out,
198                    "\treturn (result as {{ affectedRows: number }}).affectedRows;"
199                );
200                let _ = write!(out, "}}");
201            }
202        }
203
204        Ok(out)
205    }
206
207    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
208        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
209        let mut out = String::new();
210        let _ = writeln!(out, "export enum {} {{", type_name);
211        for value in &enum_info.values {
212            let variant = enum_variant_name(value, &self.manifest.naming);
213            let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
214        }
215        let _ = write!(out, "}}");
216        Ok(out)
217    }
218
219    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
220        let name = to_pascal_case(&composite.sql_name);
221        let mut out = String::new();
222        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
223        let _ = writeln!(out, "export interface {} {{", name);
224        for field in &composite.fields {
225            let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
226                .map(|t| t.into_owned())
227                .map_err(|e| {
228                    ScytheError::new(
229                        ErrorCode::InternalError,
230                        format!("composite field type error: {}", e),
231                    )
232                })?;
233            let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
234        }
235        let _ = write!(out, "}}");
236        Ok(out)
237    }
238}