1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/csharp-mysqlconnector.toml");
17
18pub struct CsharpMysqlConnectorBackend {
19 manifest: BackendManifest,
20}
21
22impl CsharpMysqlConnectorBackend {
23 pub fn new(engine: &str) -> Result<Self, ScytheError> {
24 match engine {
25 "mysql" | "mariadb" => {}
26 _ => {
27 return Err(ScytheError::new(
28 ErrorCode::InternalError,
29 format!(
30 "csharp-mysqlconnector only supports MySQL, got engine '{}'",
31 engine
32 ),
33 ));
34 }
35 }
36 let manifest_path = Path::new("backends/csharp-mysqlconnector/manifest.toml");
37 let manifest = if manifest_path.exists() {
38 load_manifest(manifest_path)
39 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40 } else {
41 toml::from_str(DEFAULT_MANIFEST_TOML)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 };
44 Ok(Self { manifest })
45 }
46}
47
48fn reader_method(neutral_type: &str) -> &'static str {
50 match neutral_type {
51 "bool" => "GetBoolean",
52 "int16" => "GetInt16",
53 "int32" => "GetInt32",
54 "int64" => "GetInt64",
55 "float32" => "GetFloat",
56 "float64" => "GetDouble",
57 "string" | "json" | "inet" | "interval" | "uuid" => "GetString",
58 "decimal" => "GetDecimal",
59 "date" => "GetFieldValue<DateOnly>",
60 "time" | "time_tz" => "GetFieldValue<TimeOnly>",
61 "datetime" => "GetDateTime",
62 "datetime_tz" => "GetFieldValue<DateTimeOffset>",
63 _ => "GetValue",
64 }
65}
66
67fn rewrite_params(sql: &str) -> String {
69 let mut result = sql.to_string();
70 for i in (1..=99).rev() {
71 let from = format!("${}", i);
72 let to = format!("@p{}", i);
73 result = result.replace(&from, &to);
74 }
75 result
76}
77
78fn column_read_expr(col: &ResolvedColumn, ordinal: usize) -> String {
80 if col.neutral_type.starts_with("enum::") {
81 format!(
82 "(Enum.TryParse<{typ}>(reader.GetString({ord}), true, out var enumVal{ord}) ? enumVal{ord} : throw new InvalidOperationException($\"Invalid enum value '{{reader.GetString({ord})}}' for {typ}\"))",
83 typ = col.lang_type,
84 ord = ordinal
85 )
86 } else {
87 let method = reader_method(&col.neutral_type);
88 format!("reader.{}({})", method, ordinal)
89 }
90}
91
92impl CodegenBackend for CsharpMysqlConnectorBackend {
93 fn name(&self) -> &str {
94 "csharp-mysqlconnector"
95 }
96
97 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
98 &self.manifest
99 }
100
101 fn supported_engines(&self) -> &[&str] {
102 &["mysql"]
103 }
104
105 fn file_header(&self) -> String {
106 "// Auto-generated by scythe. Do not edit.\nusing MySqlConnector;\n\npublic static class Queries {"
107 .to_string()
108 }
109
110 fn file_footer(&self) -> String {
111 "}".to_string()
112 }
113
114 fn generate_row_struct(
115 &self,
116 query_name: &str,
117 columns: &[ResolvedColumn],
118 ) -> Result<String, ScytheError> {
119 let struct_name = row_struct_name(query_name, &self.manifest.naming);
120 let mut out = String::new();
121 let _ = writeln!(out, "public record {}(", struct_name);
122 for (i, c) in columns.iter().enumerate() {
123 let field = to_pascal_case(&c.field_name);
124 let sep = if i + 1 < columns.len() { "," } else { "" };
125 let _ = writeln!(out, " {} {}{}", c.full_type, field, sep);
126 }
127 let _ = write!(out, ");");
128 Ok(out)
129 }
130
131 fn generate_model_struct(
132 &self,
133 table_name: &str,
134 columns: &[ResolvedColumn],
135 ) -> Result<String, ScytheError> {
136 let name = to_pascal_case(table_name);
137 self.generate_row_struct(&name, columns)
138 }
139
140 fn generate_query_fn(
141 &self,
142 analyzed: &AnalyzedQuery,
143 struct_name: &str,
144 columns: &[ResolvedColumn],
145 params: &[ResolvedParam],
146 ) -> Result<String, ScytheError> {
147 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
148 let sql = rewrite_params(&super::clean_sql_oneline(&analyzed.sql));
149 let mut out = String::new();
150
151 let param_list = params
152 .iter()
153 .map(|p| format!("{} {}", p.full_type, p.field_name))
154 .collect::<Vec<_>>()
155 .join(", ");
156 let sep = if param_list.is_empty() { "" } else { ", " };
157
158 let return_type = match &analyzed.command {
159 QueryCommand::One => format!("{}?", struct_name),
160 QueryCommand::Many | QueryCommand::Batch => {
161 format!("List<{}>", struct_name)
162 }
163 QueryCommand::Exec => "void".to_string(),
164 QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
165 };
166
167 let is_async_void = return_type == "void";
168 let task_type = if is_async_void {
169 "Task".to_string()
170 } else {
171 format!("Task<{}>", return_type)
172 };
173
174 let _ = writeln!(
175 out,
176 "public static async {} {}(MySqlConnection conn{}{}) {{",
177 task_type, func_name, sep, param_list
178 );
179
180 let _ = writeln!(
181 out,
182 " await using var cmd = new MySqlCommand(\"{}\", conn);",
183 sql
184 );
185 for (i, p) in params.iter().enumerate() {
186 let value_expr = if p.neutral_type.starts_with("enum::") {
187 format!("{}.ToString().ToLower()", p.field_name)
188 } else {
189 p.field_name.clone()
190 };
191 let _ = writeln!(
192 out,
193 " cmd.Parameters.AddWithValue(\"@p{}\", {});",
194 i + 1,
195 value_expr
196 );
197 }
198
199 match &analyzed.command {
200 QueryCommand::One => {
201 let _ = writeln!(
202 out,
203 " await using var reader = await cmd.ExecuteReaderAsync();"
204 );
205 let _ = writeln!(out, " if (!await reader.ReadAsync()) return null;");
206 let _ = writeln!(out, " return new {}(", struct_name);
207 for (i, col) in columns.iter().enumerate() {
208 let expr = column_read_expr(col, i);
209 let sep = if i + 1 < columns.len() { "," } else { "" };
210 if col.nullable {
211 let _ = writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
212 } else {
213 let _ = writeln!(out, " {expr}{sep}");
214 }
215 }
216 let _ = writeln!(out, " );");
217 }
218 QueryCommand::Many | QueryCommand::Batch => {
219 let _ = writeln!(
220 out,
221 " await using var reader = await cmd.ExecuteReaderAsync();"
222 );
223 let _ = writeln!(out, " var results = new List<{}>();", struct_name);
224 let _ = writeln!(out, " while (await reader.ReadAsync()) {{");
225 let _ = writeln!(out, " results.Add(new {}(", struct_name);
226 for (i, col) in columns.iter().enumerate() {
227 let expr = column_read_expr(col, i);
228 let sep = if i + 1 < columns.len() { "," } else { "" };
229 if col.nullable {
230 let _ =
231 writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
232 } else {
233 let _ = writeln!(out, " {expr}{sep}");
234 }
235 }
236 let _ = writeln!(out, " ));");
237 let _ = writeln!(out, " }}");
238 let _ = writeln!(out, " return results;");
239 }
240 QueryCommand::Exec => {
241 let _ = writeln!(out, " await cmd.ExecuteNonQueryAsync();");
242 }
243 QueryCommand::ExecResult | QueryCommand::ExecRows => {
244 let _ = writeln!(out, " return await cmd.ExecuteNonQueryAsync();");
245 }
246 }
247
248 let _ = write!(out, "}}");
249 Ok(out)
250 }
251
252 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
253 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
254 let mut out = String::new();
255 let _ = writeln!(out, "public enum {} {{", type_name);
256 for value in &enum_info.values {
257 let variant = enum_variant_name(value, &self.manifest.naming);
258 let _ = writeln!(out, " {},", variant);
259 }
260 let _ = write!(out, "}}");
261 Ok(out)
262 }
263
264 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
265 let name = to_pascal_case(&composite.sql_name);
266 let mut out = String::new();
267 if composite.fields.is_empty() {
268 let _ = writeln!(out, "public record {}();", name);
269 } else {
270 let _ = writeln!(out, "public record {}(", name);
271 for (i, field) in composite.fields.iter().enumerate() {
272 let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
273 .map(|t| t.into_owned())
274 .unwrap_or_else(|_| "object".to_string());
275 let field_name = to_pascal_case(&field.name);
276 let sep = if i + 1 < composite.fields.len() {
277 ","
278 } else {
279 ""
280 };
281 let _ = writeln!(out, " {} {}{}", cs_type, field_name, sep);
282 }
283 let _ = write!(out, ");");
284 }
285 Ok(out)
286 }
287}