Skip to main content

scythe_codegen/backends/
csharp_npgsql.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str =
17    include_str!("../../manifests/csharp-npgsql.toml");
18
19pub struct CsharpNpgsqlBackend {
20    manifest: BackendManifest,
21}
22
23impl CsharpNpgsqlBackend {
24    pub fn new() -> Result<Self, ScytheError> {
25        let manifest_path = Path::new("backends/csharp-npgsql/manifest.toml");
26        let manifest = if manifest_path.exists() {
27            load_manifest(manifest_path)
28                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
29        } else {
30            toml::from_str(DEFAULT_MANIFEST_TOML)
31                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
32        };
33        Ok(Self { manifest })
34    }
35
36    pub fn manifest(&self) -> &BackendManifest {
37        &self.manifest
38    }
39}
40
41/// Map a neutral type to an Npgsql reader method.
42fn reader_method(neutral_type: &str) -> &'static str {
43    match neutral_type {
44        "bool" => "GetBoolean",
45        "int16" => "GetInt16",
46        "int32" => "GetInt32",
47        "int64" => "GetInt64",
48        "float32" => "GetFloat",
49        "float64" => "GetDouble",
50        "string" | "json" | "inet" | "interval" => "GetString",
51        "uuid" => "GetGuid",
52        "decimal" => "GetDecimal",
53        "date" => "GetFieldValue<DateOnly>",
54        "time" | "time_tz" => "GetFieldValue<TimeOnly>",
55        "datetime" => "GetDateTime",
56        "datetime_tz" => "GetFieldValue<DateTimeOffset>",
57        _ => "GetValue",
58    }
59}
60
61/// Rewrite $1, $2, ... to @p1, @p2, ...
62fn rewrite_params(sql: &str) -> String {
63    let mut result = sql.to_string();
64    // Replace from highest number down to avoid $1 matching inside $10
65    for i in (1..=99).rev() {
66        let from = format!("${}", i);
67        let to = format!("@p{}", i);
68        result = result.replace(&from, &to);
69    }
70    result
71}
72
73impl CodegenBackend for CsharpNpgsqlBackend {
74    fn name(&self) -> &str {
75        "csharp-npgsql"
76    }
77
78    fn generate_row_struct(
79        &self,
80        query_name: &str,
81        columns: &[ResolvedColumn],
82    ) -> Result<String, ScytheError> {
83        let struct_name = row_struct_name(query_name, &self.manifest.naming);
84        let mut out = String::new();
85        let _ = writeln!(out, "public record {}(", struct_name);
86        for (i, c) in columns.iter().enumerate() {
87            let field = to_pascal_case(&c.field_name);
88            let sep = if i + 1 < columns.len() { "," } else { "" };
89            let _ = writeln!(out, "    {} {}{}", c.full_type, field, sep);
90        }
91        let _ = write!(out, ");");
92        Ok(out)
93    }
94
95    fn generate_model_struct(
96        &self,
97        table_name: &str,
98        columns: &[ResolvedColumn],
99    ) -> Result<String, ScytheError> {
100        let name = to_pascal_case(table_name);
101        self.generate_row_struct(&name, columns)
102    }
103
104    fn generate_query_fn(
105        &self,
106        analyzed: &AnalyzedQuery,
107        struct_name: &str,
108        columns: &[ResolvedColumn],
109        params: &[ResolvedParam],
110    ) -> Result<String, ScytheError> {
111        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
112        let sql = rewrite_params(&super::clean_sql(&analyzed.sql));
113        let mut out = String::new();
114
115        // Build C# parameter list
116        let param_list = params
117            .iter()
118            .map(|p| format!("{} {}", p.full_type, p.field_name))
119            .collect::<Vec<_>>()
120            .join(", ");
121        let sep = if param_list.is_empty() { "" } else { ", " };
122
123        // Return type depends on command
124        let return_type = match &analyzed.command {
125            QueryCommand::One => format!("{}?", struct_name),
126            QueryCommand::Many | QueryCommand::Batch => {
127                format!("List<{}>", struct_name)
128            }
129            QueryCommand::Exec => "void".to_string(),
130            QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
131        };
132
133        let is_async_void = return_type == "void";
134        let task_type = if is_async_void {
135            "Task".to_string()
136        } else {
137            format!("Task<{}>", return_type)
138        };
139
140        let _ = writeln!(
141            out,
142            "public static async {} {}(NpgsqlConnection conn{}{}) {{",
143            task_type, func_name, sep, param_list
144        );
145
146        // Command setup
147        let _ = writeln!(
148            out,
149            "    await using var cmd = new NpgsqlCommand(\"{}\", conn);",
150            sql
151        );
152        for (i, p) in params.iter().enumerate() {
153            let _ = writeln!(
154                out,
155                "    cmd.Parameters.AddWithValue(\"p{}\", {});",
156                i + 1,
157                p.field_name
158            );
159        }
160
161        match &analyzed.command {
162            QueryCommand::One => {
163                let _ = writeln!(
164                    out,
165                    "    await using var reader = await cmd.ExecuteReaderAsync();"
166                );
167                let _ = writeln!(out, "    if (!await reader.ReadAsync()) return null;");
168                let _ = writeln!(out, "    return new {}(", struct_name);
169                for (i, col) in columns.iter().enumerate() {
170                    let method = reader_method(&col.neutral_type);
171                    let sep = if i + 1 < columns.len() { "," } else { "" };
172                    if col.nullable {
173                        let _ = writeln!(
174                            out,
175                            "        reader.IsDBNull({i}) ? null : reader.{method}({i}){sep}"
176                        );
177                    } else {
178                        let _ = writeln!(out, "        reader.{method}({i}){sep}");
179                    }
180                }
181                let _ = writeln!(out, "    );");
182            }
183            QueryCommand::Many | QueryCommand::Batch => {
184                let _ = writeln!(
185                    out,
186                    "    await using var reader = await cmd.ExecuteReaderAsync();"
187                );
188                let _ = writeln!(out, "    var results = new List<{}>();", struct_name);
189                let _ = writeln!(out, "    while (await reader.ReadAsync()) {{");
190                let _ = writeln!(out, "        results.Add(new {}(", struct_name);
191                for (i, col) in columns.iter().enumerate() {
192                    let method = reader_method(&col.neutral_type);
193                    let sep = if i + 1 < columns.len() { "," } else { "" };
194                    if col.nullable {
195                        let _ = writeln!(
196                            out,
197                            "            reader.IsDBNull({i}) ? null : reader.{method}({i}){sep}"
198                        );
199                    } else {
200                        let _ = writeln!(out, "            reader.{method}({i}){sep}");
201                    }
202                }
203                let _ = writeln!(out, "        ));");
204                let _ = writeln!(out, "    }}");
205                let _ = writeln!(out, "    return results;");
206            }
207            QueryCommand::Exec => {
208                let _ = writeln!(out, "    await cmd.ExecuteNonQueryAsync();");
209            }
210            QueryCommand::ExecResult | QueryCommand::ExecRows => {
211                let _ = writeln!(out, "    return await cmd.ExecuteNonQueryAsync();");
212            }
213        }
214
215        let _ = write!(out, "}}");
216        Ok(out)
217    }
218
219    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
220        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
221        let mut out = String::new();
222        let _ = writeln!(out, "public enum {} {{", type_name);
223        for value in &enum_info.values {
224            let variant = enum_variant_name(value, &self.manifest.naming);
225            let _ = writeln!(out, "    {},", variant);
226        }
227        let _ = write!(out, "}}");
228        Ok(out)
229    }
230
231    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
232        let name = to_pascal_case(&composite.sql_name);
233        let mut out = String::new();
234        if composite.fields.is_empty() {
235            let _ = writeln!(out, "public record {}();", name);
236        } else {
237            let _ = writeln!(out, "public record {}(", name);
238            for (i, field) in composite.fields.iter().enumerate() {
239                let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
240                    .map(|t| t.into_owned())
241                    .unwrap_or_else(|_| "object".to_string());
242                let field_name = to_pascal_case(&field.name);
243                let sep = if i + 1 < composite.fields.len() {
244                    ","
245                } else {
246                    ""
247                };
248                let _ = writeln!(out, "    {} {}{}", cs_type, field_name, sep);
249            }
250            let _ = write!(out, ");");
251        }
252        Ok(out)
253    }
254}