Skip to main content

scythe_codegen/backends/
java_jdbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/java-jdbc.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/java-jdbc.mysql.toml");
17const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/java-jdbc.sqlite.toml");
18
19pub struct JavaJdbcBackend {
20    manifest: BackendManifest,
21}
22
23impl JavaJdbcBackend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        let default_toml = match engine {
26            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
27            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
28            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
29            _ => {
30                return Err(ScytheError::new(
31                    ErrorCode::InternalError,
32                    format!("unsupported engine '{}' for java-jdbc backend", engine),
33                ));
34            }
35        };
36        let manifest_path = Path::new("backends/java-jdbc/manifest.toml");
37        let manifest = if manifest_path.exists() {
38            load_manifest(manifest_path)
39                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40        } else {
41            toml::from_str(default_toml)
42                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43        };
44        Ok(Self { manifest })
45    }
46}
47
48/// Convert PostgreSQL $1, $2, ... placeholders to JDBC ? placeholders.
49fn pg_to_jdbc_params(sql: &str) -> String {
50    let mut result = String::with_capacity(sql.len());
51    let mut chars = sql.chars().peekable();
52    while let Some(ch) = chars.next() {
53        if ch == '$' {
54            // Check if followed by digits
55            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
56                // Consume all digits
57                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
58                    chars.next();
59                }
60                result.push('?');
61            } else {
62                result.push(ch);
63            }
64        } else {
65            result.push(ch);
66        }
67    }
68    result
69}
70
71/// Convert a Java primitive type to its boxed equivalent for nullable usage.
72fn box_primitive(java_type: &str) -> &str {
73    match java_type {
74        "boolean" => "Boolean",
75        "byte" => "Byte",
76        "short" => "Short",
77        "int" => "Integer",
78        "long" => "Long",
79        "float" => "Float",
80        "double" => "Double",
81        "char" => "Character",
82        _ => java_type,
83    }
84}
85
86/// Get the ResultSet getter method name for a given Java type.
87fn rs_getter(java_type: &str) -> &str {
88    match java_type {
89        "boolean" | "Boolean" => "getBoolean",
90        "byte" | "Byte" => "getByte",
91        "short" | "Short" => "getShort",
92        "int" | "Integer" => "getInt",
93        "long" | "Long" => "getLong",
94        "float" | "Float" => "getFloat",
95        "double" | "Double" => "getDouble",
96        "String" => "getString",
97        "byte[]" => "getBytes",
98        _ if java_type.contains("BigDecimal") => "getBigDecimal",
99        _ if java_type.contains("LocalDate") => "getObject",
100        _ if java_type.contains("LocalTime") => "getObject",
101        _ if java_type.contains("OffsetTime") => "getObject",
102        _ if java_type.contains("LocalDateTime") => "getObject",
103        _ if java_type.contains("OffsetDateTime") => "getObject",
104        _ if java_type.contains("UUID") => "getObject",
105        _ => "getObject",
106    }
107}
108
109/// Get the PreparedStatement setter method name for a given Java type.
110fn ps_setter(java_type: &str) -> &str {
111    match java_type {
112        "boolean" | "Boolean" => "setBoolean",
113        "byte" | "Byte" => "setByte",
114        "short" | "Short" => "setShort",
115        "int" | "Integer" => "setInt",
116        "long" | "Long" => "setLong",
117        "float" | "Float" => "setFloat",
118        "double" | "Double" => "setDouble",
119        "String" => "setString",
120        "byte[]" => "setBytes",
121        _ if java_type.contains("BigDecimal") => "setBigDecimal",
122        _ => "setObject",
123    }
124}
125
126/// Resolve the display type for a Java field, boxing primitives when nullable.
127fn java_field_type(col: &ResolvedColumn) -> String {
128    if col.nullable {
129        box_primitive(&col.lang_type).to_string()
130    } else {
131        col.full_type.clone()
132    }
133}
134
135/// Resolve the display type for a Java param, boxing primitives when nullable.
136fn java_param_type(param: &ResolvedParam) -> String {
137    if param.nullable {
138        box_primitive(&param.lang_type).to_string()
139    } else {
140        param.full_type.clone()
141    }
142}
143
144impl CodegenBackend for JavaJdbcBackend {
145    fn name(&self) -> &str {
146        "java-jdbc"
147    }
148
149    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
150        &self.manifest
151    }
152
153    fn supported_engines(&self) -> &[&str] {
154        &["postgresql", "mysql", "sqlite"]
155    }
156
157    fn file_header(&self) -> String {
158        "// Auto-generated by scythe. Do not edit.\n\
159         import java.math.BigDecimal;\n\
160         import java.sql.*;\n\
161         import java.time.OffsetDateTime;\n\
162         import java.util.ArrayList;\n\
163         import java.util.List;\n\
164         import javax.annotation.Nullable;"
165            .to_string()
166    }
167
168    fn generate_row_struct(
169        &self,
170        query_name: &str,
171        columns: &[ResolvedColumn],
172    ) -> Result<String, ScytheError> {
173        let struct_name = row_struct_name(query_name, &self.manifest.naming);
174        let mut out = String::new();
175
176        // Record declaration with fields
177        let fields = columns
178            .iter()
179            .map(|c| {
180                let field_type = java_field_type(c);
181                if c.nullable {
182                    format!("    @Nullable {} {}", field_type, c.field_name)
183                } else {
184                    format!("    {} {}", field_type, c.field_name)
185                }
186            })
187            .collect::<Vec<_>>()
188            .join(",\n");
189
190        let _ = writeln!(out, "public record {}(", struct_name);
191        let _ = writeln!(out, "{}", fields);
192        let _ = writeln!(out, ") {{");
193
194        // fromResultSet static factory method
195        let _ = writeln!(
196            out,
197            "    public static {} fromResultSet(ResultSet rs) throws SQLException {{",
198            struct_name
199        );
200        let _ = writeln!(out, "        return new {}(", struct_name);
201        for (i, col) in columns.iter().enumerate() {
202            let getter = rs_getter(&col.lang_type);
203            let sep = if i + 1 < columns.len() { "," } else { "" };
204            let _ = writeln!(out, "            rs.{}(\"{}\"){}", getter, col.name, sep);
205        }
206        let _ = writeln!(out, "        );");
207        let _ = writeln!(out, "    }}");
208        let _ = write!(out, "}}");
209        Ok(out)
210    }
211
212    fn generate_model_struct(
213        &self,
214        table_name: &str,
215        columns: &[ResolvedColumn],
216    ) -> Result<String, ScytheError> {
217        let name = to_pascal_case(table_name);
218        self.generate_row_struct(&name, columns)
219    }
220
221    fn generate_query_fn(
222        &self,
223        analyzed: &AnalyzedQuery,
224        struct_name: &str,
225        _columns: &[ResolvedColumn],
226        params: &[ResolvedParam],
227    ) -> Result<String, ScytheError> {
228        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
229        let sql = pg_to_jdbc_params(&super::clean_sql_oneline(&analyzed.sql));
230
231        let param_list = params
232            .iter()
233            .map(|p| {
234                let param_type = java_param_type(p);
235                format!("{} {}", param_type, p.field_name)
236            })
237            .collect::<Vec<_>>()
238            .join(", ");
239        let sep = if param_list.is_empty() { "" } else { ", " };
240
241        let mut out = String::new();
242
243        match &analyzed.command {
244            QueryCommand::Exec => {
245                let _ = writeln!(
246                    out,
247                    "public static void {}(Connection conn{}{}) throws SQLException {{",
248                    func_name, sep, param_list
249                );
250                let _ = writeln!(
251                    out,
252                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
253                    sql
254                );
255                for (i, param) in params.iter().enumerate() {
256                    let setter = ps_setter(&param.lang_type);
257                    let _ = writeln!(
258                        out,
259                        "        ps.{}({}, {});",
260                        setter,
261                        i + 1,
262                        param.field_name
263                    );
264                }
265                let _ = writeln!(out, "        ps.executeUpdate();");
266                let _ = writeln!(out, "    }}");
267                let _ = write!(out, "}}");
268            }
269            QueryCommand::ExecResult | QueryCommand::ExecRows => {
270                let _ = writeln!(
271                    out,
272                    "public static int {}(Connection conn{}{}) throws SQLException {{",
273                    func_name, sep, param_list
274                );
275                let _ = writeln!(
276                    out,
277                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
278                    sql
279                );
280                for (i, param) in params.iter().enumerate() {
281                    let setter = ps_setter(&param.lang_type);
282                    let _ = writeln!(
283                        out,
284                        "        ps.{}({}, {});",
285                        setter,
286                        i + 1,
287                        param.field_name
288                    );
289                }
290                let _ = writeln!(out, "        return ps.executeUpdate();");
291                let _ = writeln!(out, "    }}");
292                let _ = write!(out, "}}");
293            }
294            QueryCommand::One => {
295                let _ = writeln!(
296                    out,
297                    "public static {} {}(Connection conn{}{}) throws SQLException {{",
298                    struct_name, func_name, sep, param_list
299                );
300                let _ = writeln!(
301                    out,
302                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
303                    sql
304                );
305                for (i, param) in params.iter().enumerate() {
306                    let setter = ps_setter(&param.lang_type);
307                    let _ = writeln!(
308                        out,
309                        "        ps.{}({}, {});",
310                        setter,
311                        i + 1,
312                        param.field_name
313                    );
314                }
315                let _ = writeln!(out, "        try (ResultSet rs = ps.executeQuery()) {{");
316                let _ = writeln!(out, "            if (rs.next()) {{");
317                let _ = writeln!(
318                    out,
319                    "                return {}.fromResultSet(rs);",
320                    struct_name
321                );
322                let _ = writeln!(out, "            }}");
323                let _ = writeln!(out, "            return null;");
324                let _ = writeln!(out, "        }}");
325                let _ = writeln!(out, "    }}");
326                let _ = write!(out, "}}");
327            }
328            QueryCommand::Many | QueryCommand::Batch => {
329                let _ = writeln!(
330                    out,
331                    "public static java.util.List<{}> {}(Connection conn{}{}) throws SQLException {{",
332                    struct_name, func_name, sep, param_list
333                );
334                let _ = writeln!(
335                    out,
336                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
337                    sql
338                );
339                for (i, param) in params.iter().enumerate() {
340                    let setter = ps_setter(&param.lang_type);
341                    let _ = writeln!(
342                        out,
343                        "        ps.{}({}, {});",
344                        setter,
345                        i + 1,
346                        param.field_name
347                    );
348                }
349                let _ = writeln!(out, "        try (ResultSet rs = ps.executeQuery()) {{");
350                let _ = writeln!(
351                    out,
352                    "            java.util.List<{}> result = new java.util.ArrayList<>();",
353                    struct_name
354                );
355                let _ = writeln!(out, "            while (rs.next()) {{");
356                let _ = writeln!(
357                    out,
358                    "                result.add({}.fromResultSet(rs));",
359                    struct_name
360                );
361                let _ = writeln!(out, "            }}");
362                let _ = writeln!(out, "            return result;");
363                let _ = writeln!(out, "        }}");
364                let _ = writeln!(out, "    }}");
365                let _ = write!(out, "}}");
366            }
367        }
368
369        Ok(out)
370    }
371
372    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
373        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
374        let mut out = String::new();
375        let _ = writeln!(out, "public enum {} {{", type_name);
376        for (i, value) in enum_info.values.iter().enumerate() {
377            let variant = enum_variant_name(value, &self.manifest.naming);
378            let sep = if i + 1 < enum_info.values.len() {
379                ","
380            } else {
381                ";"
382            };
383            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
384        }
385        let _ = writeln!(out);
386        let _ = writeln!(out, "    private final String value;");
387        let _ = writeln!(
388            out,
389            "    {}(String value) {{ this.value = value; }}",
390            type_name
391        );
392        let _ = writeln!(out, "    public String getValue() {{ return value; }}");
393        let _ = write!(out, "}}");
394        Ok(out)
395    }
396
397    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
398        let name = to_pascal_case(&composite.sql_name);
399        let mut out = String::new();
400        if composite.fields.is_empty() {
401            let _ = writeln!(out, "public record {}() {{}}", name);
402        } else {
403            let fields = composite
404                .fields
405                .iter()
406                .map(|f| format!("Object {}", to_camel_case(&f.name)))
407                .collect::<Vec<_>>()
408                .join(", ");
409            let _ = writeln!(out, "public record {}({}) {{}}", name, fields);
410        }
411        Ok(out)
412    }
413}