1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/csharp-microsoft-sqlite.toml");
17
18pub struct CsharpMicrosoftSqliteBackend {
19 manifest: BackendManifest,
20}
21
22impl CsharpMicrosoftSqliteBackend {
23 pub fn new(engine: &str) -> Result<Self, ScytheError> {
24 match engine {
25 "sqlite" | "sqlite3" => {}
26 _ => {
27 return Err(ScytheError::new(
28 ErrorCode::InternalError,
29 format!(
30 "csharp-microsoft-sqlite only supports SQLite, got engine '{}'",
31 engine
32 ),
33 ));
34 }
35 }
36 let manifest_path = Path::new("backends/csharp-microsoft-sqlite/manifest.toml");
37 let manifest = if manifest_path.exists() {
38 load_manifest(manifest_path)
39 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40 } else {
41 toml::from_str(DEFAULT_MANIFEST_TOML)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 };
44 Ok(Self { manifest })
45 }
46}
47
48fn reader_method(neutral_type: &str) -> &'static str {
50 match neutral_type {
51 "bool" => "GetBoolean",
52 "int16" => "GetInt16",
53 "int32" => "GetInt32",
54 "int64" => "GetInt64",
55 "float32" => "GetFloat",
56 "float64" => "GetDouble",
57 "string" | "json" | "inet" | "interval" | "uuid" | "date" | "time" | "time_tz"
58 | "datetime_tz" => "GetString",
59 "decimal" => "GetDouble",
60 "datetime" => "GetDateTime",
61 _ => "GetValue",
62 }
63}
64
65fn rewrite_params(sql: &str) -> String {
67 let mut result = sql.to_string();
68 for i in (1..=99).rev() {
69 let from = format!("${}", i);
70 let to = format!("$p{}", i);
71 result = result.replace(&from, &to);
72 }
73 result
74}
75
76fn column_read_expr(col: &ResolvedColumn, ordinal: usize) -> String {
78 if col.neutral_type.starts_with("enum::") {
79 format!(
80 "Enum.Parse<{}>(reader.GetString({}), true)",
81 col.lang_type, ordinal
82 )
83 } else {
84 let method = reader_method(&col.neutral_type);
85 format!("reader.{}({})", method, ordinal)
86 }
87}
88
89impl CodegenBackend for CsharpMicrosoftSqliteBackend {
90 fn name(&self) -> &str {
91 "csharp-microsoft-sqlite"
92 }
93
94 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
95 &self.manifest
96 }
97
98 fn supported_engines(&self) -> &[&str] {
99 &["sqlite"]
100 }
101
102 fn file_header(&self) -> String {
103 "// Auto-generated by scythe. Do not edit.\nusing Microsoft.Data.Sqlite;\n\npublic static class Queries {"
104 .to_string()
105 }
106
107 fn file_footer(&self) -> String {
108 "}".to_string()
109 }
110
111 fn generate_row_struct(
112 &self,
113 query_name: &str,
114 columns: &[ResolvedColumn],
115 ) -> Result<String, ScytheError> {
116 let struct_name = row_struct_name(query_name, &self.manifest.naming);
117 let mut out = String::new();
118 let _ = writeln!(out, "public record {}(", struct_name);
119 for (i, c) in columns.iter().enumerate() {
120 let field = to_pascal_case(&c.field_name);
121 let sep = if i + 1 < columns.len() { "," } else { "" };
122 let _ = writeln!(out, " {} {}{}", c.full_type, field, sep);
123 }
124 let _ = write!(out, ");");
125 Ok(out)
126 }
127
128 fn generate_model_struct(
129 &self,
130 table_name: &str,
131 columns: &[ResolvedColumn],
132 ) -> Result<String, ScytheError> {
133 let name = to_pascal_case(table_name);
134 self.generate_row_struct(&name, columns)
135 }
136
137 fn generate_query_fn(
138 &self,
139 analyzed: &AnalyzedQuery,
140 struct_name: &str,
141 columns: &[ResolvedColumn],
142 params: &[ResolvedParam],
143 ) -> Result<String, ScytheError> {
144 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
145 let sql = rewrite_params(&super::clean_sql_oneline(&analyzed.sql));
146 let mut out = String::new();
147
148 let param_list = params
149 .iter()
150 .map(|p| format!("{} {}", p.full_type, p.field_name))
151 .collect::<Vec<_>>()
152 .join(", ");
153 let sep = if param_list.is_empty() { "" } else { ", " };
154
155 let return_type = match &analyzed.command {
156 QueryCommand::One => format!("{}?", struct_name),
157 QueryCommand::Many | QueryCommand::Batch => {
158 format!("List<{}>", struct_name)
159 }
160 QueryCommand::Exec => "void".to_string(),
161 QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
162 };
163
164 let _ = writeln!(
166 out,
167 "public static {} {}(SqliteConnection conn{}{}) {{",
168 return_type, func_name, sep, param_list
169 );
170
171 let _ = writeln!(
172 out,
173 " using var cmd = new SqliteCommand(\"{}\", conn);",
174 sql
175 );
176 for (i, p) in params.iter().enumerate() {
177 let value_expr = if p.neutral_type.starts_with("enum::") {
178 format!("{}.ToString().ToLower()", p.field_name)
179 } else {
180 p.field_name.clone()
181 };
182 let _ = writeln!(
183 out,
184 " cmd.Parameters.AddWithValue(\"$p{}\", {});",
185 i + 1,
186 value_expr
187 );
188 }
189
190 match &analyzed.command {
191 QueryCommand::One => {
192 let _ = writeln!(out, " using var reader = cmd.ExecuteReader();");
193 let _ = writeln!(out, " if (!reader.Read()) return null;");
194 let _ = writeln!(out, " return new {}(", struct_name);
195 for (i, col) in columns.iter().enumerate() {
196 let expr = column_read_expr(col, i);
197 let sep = if i + 1 < columns.len() { "," } else { "" };
198 if col.nullable {
199 let _ = writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
200 } else {
201 let _ = writeln!(out, " {expr}{sep}");
202 }
203 }
204 let _ = writeln!(out, " );");
205 }
206 QueryCommand::Many | QueryCommand::Batch => {
207 let _ = writeln!(out, " using var reader = cmd.ExecuteReader();");
208 let _ = writeln!(out, " var results = new List<{}>();", struct_name);
209 let _ = writeln!(out, " while (reader.Read()) {{");
210 let _ = writeln!(out, " results.Add(new {}(", struct_name);
211 for (i, col) in columns.iter().enumerate() {
212 let expr = column_read_expr(col, i);
213 let sep = if i + 1 < columns.len() { "," } else { "" };
214 if col.nullable {
215 let _ =
216 writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
217 } else {
218 let _ = writeln!(out, " {expr}{sep}");
219 }
220 }
221 let _ = writeln!(out, " ));");
222 let _ = writeln!(out, " }}");
223 let _ = writeln!(out, " return results;");
224 }
225 QueryCommand::Exec => {
226 let _ = writeln!(out, " cmd.ExecuteNonQuery();");
227 }
228 QueryCommand::ExecResult | QueryCommand::ExecRows => {
229 let _ = writeln!(out, " return cmd.ExecuteNonQuery();");
230 }
231 }
232
233 let _ = write!(out, "}}");
234 Ok(out)
235 }
236
237 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
238 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
239 let mut out = String::new();
240 let _ = writeln!(out, "public enum {} {{", type_name);
241 for value in &enum_info.values {
242 let variant = enum_variant_name(value, &self.manifest.naming);
243 let _ = writeln!(out, " {},", variant);
244 }
245 let _ = write!(out, "}}");
246 Ok(out)
247 }
248
249 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
250 let name = to_pascal_case(&composite.sql_name);
251 let mut out = String::new();
252 if composite.fields.is_empty() {
253 let _ = writeln!(out, "public record {}();", name);
254 } else {
255 let _ = writeln!(out, "public record {}(", name);
256 for (i, field) in composite.fields.iter().enumerate() {
257 let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
258 .map(|t| t.into_owned())
259 .unwrap_or_else(|_| "object".to_string());
260 let field_name = to_pascal_case(&field.name);
261 let sep = if i + 1 < composite.fields.len() {
262 ","
263 } else {
264 ""
265 };
266 let _ = writeln!(out, " {} {}{}", cs_type, field_name, sep);
267 }
268 let _ = write!(out, ");");
269 }
270 Ok(out)
271 }
272}