1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/csharp-mysqlconnector.toml");
17
18pub struct CsharpMysqlConnectorBackend {
19 manifest: BackendManifest,
20}
21
22impl CsharpMysqlConnectorBackend {
23 pub fn new(engine: &str) -> Result<Self, ScytheError> {
24 match engine {
25 "mysql" | "mariadb" => {}
26 _ => {
27 return Err(ScytheError::new(
28 ErrorCode::InternalError,
29 format!(
30 "csharp-mysqlconnector only supports MySQL, got engine '{}'",
31 engine
32 ),
33 ));
34 }
35 }
36 let manifest_path = Path::new("backends/csharp-mysqlconnector/manifest.toml");
37 let manifest = if manifest_path.exists() {
38 load_manifest(manifest_path)
39 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40 } else {
41 toml::from_str(DEFAULT_MANIFEST_TOML)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 };
44 Ok(Self { manifest })
45 }
46}
47
48fn reader_method(neutral_type: &str) -> &'static str {
50 match neutral_type {
51 "bool" => "GetBoolean",
52 "int16" => "GetInt16",
53 "int32" => "GetInt32",
54 "int64" => "GetInt64",
55 "float32" => "GetFloat",
56 "float64" => "GetDouble",
57 "string" | "json" | "inet" | "interval" | "uuid" => "GetString",
58 "decimal" => "GetDecimal",
59 "date" => "GetFieldValue<DateOnly>",
60 "time" | "time_tz" => "GetFieldValue<TimeOnly>",
61 "datetime" => "GetDateTime",
62 "datetime_tz" => "GetFieldValue<DateTimeOffset>",
63 _ => "GetValue",
64 }
65}
66
67fn rewrite_params(sql: &str) -> String {
69 let mut result = sql.to_string();
70 for i in (1..=99).rev() {
71 let from = format!("${}", i);
72 let to = format!("@p{}", i);
73 result = result.replace(&from, &to);
74 }
75 result
76}
77
78fn column_read_expr(col: &ResolvedColumn, ordinal: usize) -> String {
80 if col.neutral_type.starts_with("enum::") {
81 format!(
82 "(Enum.TryParse<{typ}>(reader.GetString({ord}), true, out var enumVal{ord}) ? enumVal{ord} : throw new InvalidOperationException($\"Invalid enum value '{{reader.GetString({ord})}}' for {typ}\"))",
83 typ = col.lang_type,
84 ord = ordinal
85 )
86 } else {
87 let method = reader_method(&col.neutral_type);
88 format!("reader.{}({})", method, ordinal)
89 }
90}
91
92impl CodegenBackend for CsharpMysqlConnectorBackend {
93 fn name(&self) -> &str {
94 "csharp-mysqlconnector"
95 }
96
97 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
98 &self.manifest
99 }
100
101 fn supported_engines(&self) -> &[&str] {
102 &["mysql"]
103 }
104
105 fn file_header(&self) -> String {
106 "// Auto-generated by scythe. Do not edit.\n#nullable enable\n\nusing MySqlConnector;\n\npublic static class Queries {"
107 .to_string()
108 }
109
110 fn file_footer(&self) -> String {
111 "}".to_string()
112 }
113
114 fn generate_row_struct(
115 &self,
116 query_name: &str,
117 columns: &[ResolvedColumn],
118 ) -> Result<String, ScytheError> {
119 let struct_name = row_struct_name(query_name, &self.manifest.naming);
120 let mut out = String::new();
121 let _ = writeln!(out, "public record {}(", struct_name);
122 for (i, c) in columns.iter().enumerate() {
123 let field = to_pascal_case(&c.field_name);
124 let sep = if i + 1 < columns.len() { "," } else { "" };
125 let _ = writeln!(out, " {} {}{}", c.full_type, field, sep);
126 }
127 let _ = write!(out, ");");
128 Ok(out)
129 }
130
131 fn generate_model_struct(
132 &self,
133 table_name: &str,
134 columns: &[ResolvedColumn],
135 ) -> Result<String, ScytheError> {
136 let name = to_pascal_case(table_name);
137 self.generate_row_struct(&name, columns)
138 }
139
140 fn generate_query_fn(
141 &self,
142 analyzed: &AnalyzedQuery,
143 struct_name: &str,
144 columns: &[ResolvedColumn],
145 params: &[ResolvedParam],
146 ) -> Result<String, ScytheError> {
147 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
148 let sql = rewrite_params(&super::clean_sql_oneline_with_optional(
149 &analyzed.sql,
150 &analyzed.optional_params,
151 &analyzed.params,
152 ));
153 let mut out = String::new();
154
155 let param_list = params
156 .iter()
157 .map(|p| format!("{} {}", p.full_type, p.field_name))
158 .collect::<Vec<_>>()
159 .join(", ");
160 let sep = if param_list.is_empty() { "" } else { ", " };
161
162 if matches!(analyzed.command, QueryCommand::Batch) {
164 let batch_fn_name = format!("{}Batch", func_name);
165 if params.len() > 1 {
166 let params_record_name = format!("{}BatchParams", to_pascal_case(&analyzed.name));
167 let _ = writeln!(out, "public record {}(", params_record_name);
168 for (i, p) in params.iter().enumerate() {
169 let field = to_pascal_case(&p.field_name);
170 let sep = if i + 1 < params.len() { "," } else { "" };
171 let _ = writeln!(out, " {} {}{}", p.full_type, field, sep);
172 }
173 let _ = writeln!(out, ");");
174 let _ = writeln!(out);
175 let _ = writeln!(
176 out,
177 "public static async Task {}(MySqlConnection conn, List<{}> items) {{",
178 batch_fn_name, params_record_name
179 );
180 } else if params.len() == 1 {
181 let _ = writeln!(
182 out,
183 "public static async Task {}(MySqlConnection conn, List<{}> items) {{",
184 batch_fn_name, params[0].full_type
185 );
186 } else {
187 let _ = writeln!(
188 out,
189 "public static async Task {}(MySqlConnection conn, int count) {{",
190 batch_fn_name
191 );
192 }
193 let _ = writeln!(
194 out,
195 " await using var tx = await conn.BeginTransactionAsync();"
196 );
197 let _ = writeln!(out, " try {{");
198 if params.is_empty() {
199 let _ = writeln!(out, " for (int i = 0; i < count; i++) {{");
200 } else {
201 let _ = writeln!(out, " foreach (var item in items) {{");
202 }
203 let _ = writeln!(
204 out,
205 " await using var cmd = new MySqlCommand(\"{}\", conn, (MySqlTransaction)tx);",
206 sql
207 );
208 for (i, p) in params.iter().enumerate() {
209 let value_expr = if params.len() > 1 {
210 let field = to_pascal_case(&p.field_name);
211 format!("item.{}", field)
212 } else {
213 "item".to_string()
214 };
215 let _ = writeln!(
216 out,
217 " cmd.Parameters.AddWithValue(\"@p{}\", {});",
218 i + 1,
219 value_expr
220 );
221 }
222 let _ = writeln!(out, " await cmd.ExecuteNonQueryAsync();");
223 let _ = writeln!(out, " }}");
224 let _ = writeln!(out, " await tx.CommitAsync();");
225 let _ = writeln!(out, " }} catch {{");
226 let _ = writeln!(out, " await tx.RollbackAsync();");
227 let _ = writeln!(out, " throw;");
228 let _ = writeln!(out, " }}");
229 let _ = write!(out, "}}");
230 return Ok(out);
231 }
232
233 let return_type = match &analyzed.command {
234 QueryCommand::One => format!("{}?", struct_name),
235 QueryCommand::Many => {
236 format!("List<{}>", struct_name)
237 }
238 QueryCommand::Exec => "void".to_string(),
239 QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
240 QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
241 };
242
243 let is_async_void = return_type == "void";
244 let task_type = if is_async_void {
245 "Task".to_string()
246 } else {
247 format!("Task<{}>", return_type)
248 };
249
250 let _ = writeln!(
251 out,
252 "public static async {} {}(MySqlConnection conn{}{}) {{",
253 task_type, func_name, sep, param_list
254 );
255
256 let _ = writeln!(
257 out,
258 " await using var cmd = new MySqlCommand(\"{}\", conn);",
259 sql
260 );
261 for (i, p) in params.iter().enumerate() {
262 let value_expr = if p.neutral_type.starts_with("enum::") {
263 format!("{}.ToString().ToLower()", p.field_name)
264 } else {
265 p.field_name.clone()
266 };
267 let _ = writeln!(
268 out,
269 " cmd.Parameters.AddWithValue(\"@p{}\", {});",
270 i + 1,
271 value_expr
272 );
273 }
274
275 match &analyzed.command {
276 QueryCommand::One => {
277 let _ = writeln!(
278 out,
279 " await using var reader = await cmd.ExecuteReaderAsync();"
280 );
281 let _ = writeln!(out, " if (!await reader.ReadAsync()) return null;");
282 let _ = writeln!(out, " return new {}(", struct_name);
283 for (i, col) in columns.iter().enumerate() {
284 let expr = column_read_expr(col, i);
285 let sep = if i + 1 < columns.len() { "," } else { "" };
286 if col.nullable {
287 let _ = writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
288 } else {
289 let _ = writeln!(out, " {expr}{sep}");
290 }
291 }
292 let _ = writeln!(out, " );");
293 }
294 QueryCommand::Many => {
295 let _ = writeln!(
296 out,
297 " await using var reader = await cmd.ExecuteReaderAsync();"
298 );
299 let _ = writeln!(out, " var results = new List<{}>();", struct_name);
300 let _ = writeln!(out, " while (await reader.ReadAsync()) {{");
301 let _ = writeln!(out, " results.Add(new {}(", struct_name);
302 for (i, col) in columns.iter().enumerate() {
303 let expr = column_read_expr(col, i);
304 let sep = if i + 1 < columns.len() { "," } else { "" };
305 if col.nullable {
306 let _ =
307 writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
308 } else {
309 let _ = writeln!(out, " {expr}{sep}");
310 }
311 }
312 let _ = writeln!(out, " ));");
313 let _ = writeln!(out, " }}");
314 let _ = writeln!(out, " return results;");
315 }
316 QueryCommand::Exec => {
317 let _ = writeln!(out, " await cmd.ExecuteNonQueryAsync();");
318 }
319 QueryCommand::ExecResult | QueryCommand::ExecRows => {
320 let _ = writeln!(out, " return await cmd.ExecuteNonQueryAsync();");
321 }
322 QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
323 }
324
325 let _ = write!(out, "}}");
326 Ok(out)
327 }
328
329 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
330 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
331 let mut out = String::new();
332 let _ = writeln!(out, "public enum {} {{", type_name);
333 for value in &enum_info.values {
334 let variant = enum_variant_name(value, &self.manifest.naming);
335 let _ = writeln!(out, " {},", variant);
336 }
337 let _ = write!(out, "}}");
338 Ok(out)
339 }
340
341 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
342 let name = to_pascal_case(&composite.sql_name);
343 let mut out = String::new();
344 if composite.fields.is_empty() {
345 let _ = writeln!(out, "public record {}();", name);
346 } else {
347 let _ = writeln!(out, "public record {}(", name);
348 for (i, field) in composite.fields.iter().enumerate() {
349 let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
350 .map(|t| t.into_owned())
351 .unwrap_or_else(|_| "object".to_string());
352 let field_name = to_pascal_case(&field.name);
353 let sep = if i + 1 < composite.fields.len() {
354 ","
355 } else {
356 ""
357 };
358 let _ = writeln!(out, " {} {}{}", cs_type, field_name, sep);
359 }
360 let _ = write!(out, ");");
361 }
362 Ok(out)
363 }
364}