Skip to main content

scythe_codegen/backends/
java_jdbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/java-jdbc.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/java-jdbc.mysql.toml");
17const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/java-jdbc.sqlite.toml");
18
19pub struct JavaJdbcBackend {
20    manifest: BackendManifest,
21}
22
23impl JavaJdbcBackend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        let default_toml = match engine {
26            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
27            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
28            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
29            _ => {
30                return Err(ScytheError::new(
31                    ErrorCode::InternalError,
32                    format!("unsupported engine '{}' for java-jdbc backend", engine),
33                ));
34            }
35        };
36        let manifest_path = Path::new("backends/java-jdbc/manifest.toml");
37        let manifest = if manifest_path.exists() {
38            load_manifest(manifest_path)
39                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40        } else {
41            toml::from_str(default_toml)
42                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43        };
44        Ok(Self { manifest })
45    }
46}
47
48/// Convert PostgreSQL $1, $2, ... placeholders to JDBC ? placeholders.
49fn pg_to_jdbc_params(sql: &str) -> String {
50    let mut result = String::with_capacity(sql.len());
51    let mut chars = sql.chars().peekable();
52    while let Some(ch) = chars.next() {
53        if ch == '$' {
54            // Check if followed by digits
55            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
56                // Consume all digits
57                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
58                    chars.next();
59                }
60                result.push('?');
61            } else {
62                result.push(ch);
63            }
64        } else {
65            result.push(ch);
66        }
67    }
68    result
69}
70
71/// Convert a Java primitive type to its boxed equivalent for nullable usage.
72fn box_primitive(java_type: &str) -> &str {
73    match java_type {
74        "boolean" => "Boolean",
75        "byte" => "Byte",
76        "short" => "Short",
77        "int" => "Integer",
78        "long" => "Long",
79        "float" => "Float",
80        "double" => "Double",
81        "char" => "Character",
82        _ => java_type,
83    }
84}
85
86/// Get the ResultSet getter method name for a given Java type.
87fn rs_getter(java_type: &str) -> &str {
88    match java_type {
89        "boolean" | "Boolean" => "getBoolean",
90        "byte" | "Byte" => "getByte",
91        "short" | "Short" => "getShort",
92        "int" | "Integer" => "getInt",
93        "long" | "Long" => "getLong",
94        "float" | "Float" => "getFloat",
95        "double" | "Double" => "getDouble",
96        "String" => "getString",
97        "byte[]" => "getBytes",
98        _ if java_type.contains("BigDecimal") => "getBigDecimal",
99        _ if java_type.contains("LocalDate") => "getObject",
100        _ if java_type.contains("LocalTime") => "getObject",
101        _ if java_type.contains("OffsetTime") => "getObject",
102        _ if java_type.contains("LocalDateTime") => "getObject",
103        _ if java_type.contains("OffsetDateTime") => "getObject",
104        _ if java_type.contains("UUID") => "getObject",
105        _ => "getObject",
106    }
107}
108
109/// Get the PreparedStatement setter method name for a given Java type.
110fn ps_setter(java_type: &str) -> &str {
111    match java_type {
112        "boolean" | "Boolean" => "setBoolean",
113        "byte" | "Byte" => "setByte",
114        "short" | "Short" => "setShort",
115        "int" | "Integer" => "setInt",
116        "long" | "Long" => "setLong",
117        "float" | "Float" => "setFloat",
118        "double" | "Double" => "setDouble",
119        "String" => "setString",
120        "byte[]" => "setBytes",
121        _ if java_type.contains("BigDecimal") => "setBigDecimal",
122        _ => "setObject",
123    }
124}
125
126/// Resolve the display type for a Java field, boxing primitives when nullable.
127fn java_field_type(col: &ResolvedColumn) -> String {
128    if col.nullable {
129        box_primitive(&col.lang_type).to_string()
130    } else {
131        col.full_type.clone()
132    }
133}
134
135/// Resolve the display type for a Java param, boxing primitives when nullable.
136fn java_param_type(param: &ResolvedParam) -> String {
137    if param.nullable {
138        box_primitive(&param.lang_type).to_string()
139    } else {
140        param.full_type.clone()
141    }
142}
143
144impl CodegenBackend for JavaJdbcBackend {
145    fn name(&self) -> &str {
146        "java-jdbc"
147    }
148
149    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
150        &self.manifest
151    }
152
153    fn supported_engines(&self) -> &[&str] {
154        &["postgresql", "mysql", "sqlite"]
155    }
156
157    fn file_header(&self) -> String {
158        "// Auto-generated by scythe. Do not edit.\n\
159         import java.math.BigDecimal;\n\
160         import java.sql.*;\n\
161         import java.time.OffsetDateTime;\n\
162         import java.util.ArrayList;\n\
163         import java.util.List;\n\
164         import javax.annotation.Nullable;"
165            .to_string()
166    }
167
168    fn generate_row_struct(
169        &self,
170        query_name: &str,
171        columns: &[ResolvedColumn],
172    ) -> Result<String, ScytheError> {
173        let struct_name = row_struct_name(query_name, &self.manifest.naming);
174        let mut out = String::new();
175
176        // Record declaration with fields
177        let fields = columns
178            .iter()
179            .map(|c| {
180                let field_type = java_field_type(c);
181                if c.nullable {
182                    format!("    @Nullable {} {}", field_type, c.field_name)
183                } else {
184                    format!("    {} {}", field_type, c.field_name)
185                }
186            })
187            .collect::<Vec<_>>()
188            .join(",\n");
189
190        let _ = writeln!(out, "public record {}(", struct_name);
191        let _ = writeln!(out, "{}", fields);
192        let _ = writeln!(out, ") {{");
193
194        // fromResultSet static factory method
195        let _ = writeln!(
196            out,
197            "    public static {} fromResultSet(ResultSet rs) throws SQLException {{",
198            struct_name
199        );
200        let _ = writeln!(out, "        return new {}(", struct_name);
201        for (i, col) in columns.iter().enumerate() {
202            let getter = rs_getter(&col.lang_type);
203            let sep = if i + 1 < columns.len() { "," } else { "" };
204            let _ = writeln!(out, "            rs.{}(\"{}\"){}", getter, col.name, sep);
205        }
206        let _ = writeln!(out, "        );");
207        let _ = writeln!(out, "    }}");
208        let _ = write!(out, "}}");
209        Ok(out)
210    }
211
212    fn generate_model_struct(
213        &self,
214        table_name: &str,
215        columns: &[ResolvedColumn],
216    ) -> Result<String, ScytheError> {
217        let name = to_pascal_case(table_name);
218        self.generate_row_struct(&name, columns)
219    }
220
221    fn generate_query_fn(
222        &self,
223        analyzed: &AnalyzedQuery,
224        struct_name: &str,
225        _columns: &[ResolvedColumn],
226        params: &[ResolvedParam],
227    ) -> Result<String, ScytheError> {
228        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
229        let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
230            &analyzed.sql,
231            &analyzed.optional_params,
232            &analyzed.params,
233        ));
234
235        let param_list = params
236            .iter()
237            .map(|p| {
238                let param_type = java_param_type(p);
239                format!("{} {}", param_type, p.field_name)
240            })
241            .collect::<Vec<_>>()
242            .join(", ");
243        let sep = if param_list.is_empty() { "" } else { ", " };
244
245        let mut out = String::new();
246
247        match &analyzed.command {
248            QueryCommand::Exec => {
249                let _ = writeln!(
250                    out,
251                    "public static void {}(Connection conn{}{}) throws SQLException {{",
252                    func_name, sep, param_list
253                );
254                let _ = writeln!(
255                    out,
256                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
257                    sql
258                );
259                for (i, param) in params.iter().enumerate() {
260                    let setter = ps_setter(&param.lang_type);
261                    let _ = writeln!(
262                        out,
263                        "        ps.{}({}, {});",
264                        setter,
265                        i + 1,
266                        param.field_name
267                    );
268                }
269                let _ = writeln!(out, "        ps.executeUpdate();");
270                let _ = writeln!(out, "    }}");
271                let _ = write!(out, "}}");
272            }
273            QueryCommand::ExecResult | QueryCommand::ExecRows => {
274                let _ = writeln!(
275                    out,
276                    "public static int {}(Connection conn{}{}) throws SQLException {{",
277                    func_name, sep, param_list
278                );
279                let _ = writeln!(
280                    out,
281                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
282                    sql
283                );
284                for (i, param) in params.iter().enumerate() {
285                    let setter = ps_setter(&param.lang_type);
286                    let _ = writeln!(
287                        out,
288                        "        ps.{}({}, {});",
289                        setter,
290                        i + 1,
291                        param.field_name
292                    );
293                }
294                let _ = writeln!(out, "        return ps.executeUpdate();");
295                let _ = writeln!(out, "    }}");
296                let _ = write!(out, "}}");
297            }
298            QueryCommand::One => {
299                let _ = writeln!(
300                    out,
301                    "public static {} {}(Connection conn{}{}) throws SQLException {{",
302                    struct_name, func_name, sep, param_list
303                );
304                let _ = writeln!(
305                    out,
306                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
307                    sql
308                );
309                for (i, param) in params.iter().enumerate() {
310                    let setter = ps_setter(&param.lang_type);
311                    let _ = writeln!(
312                        out,
313                        "        ps.{}({}, {});",
314                        setter,
315                        i + 1,
316                        param.field_name
317                    );
318                }
319                let _ = writeln!(out, "        try (ResultSet rs = ps.executeQuery()) {{");
320                let _ = writeln!(out, "            if (rs.next()) {{");
321                let _ = writeln!(
322                    out,
323                    "                return {}.fromResultSet(rs);",
324                    struct_name
325                );
326                let _ = writeln!(out, "            }}");
327                let _ = writeln!(out, "            return null;");
328                let _ = writeln!(out, "        }}");
329                let _ = writeln!(out, "    }}");
330                let _ = write!(out, "}}");
331            }
332            QueryCommand::Many => {
333                let _ = writeln!(
334                    out,
335                    "public static java.util.List<{}> {}(Connection conn{}{}) throws SQLException {{",
336                    struct_name, func_name, sep, param_list
337                );
338                let _ = writeln!(
339                    out,
340                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
341                    sql
342                );
343                for (i, param) in params.iter().enumerate() {
344                    let setter = ps_setter(&param.lang_type);
345                    let _ = writeln!(
346                        out,
347                        "        ps.{}({}, {});",
348                        setter,
349                        i + 1,
350                        param.field_name
351                    );
352                }
353                let _ = writeln!(out, "        try (ResultSet rs = ps.executeQuery()) {{");
354                let _ = writeln!(
355                    out,
356                    "            java.util.List<{}> result = new java.util.ArrayList<>();",
357                    struct_name
358                );
359                let _ = writeln!(out, "            while (rs.next()) {{");
360                let _ = writeln!(
361                    out,
362                    "                result.add({}.fromResultSet(rs));",
363                    struct_name
364                );
365                let _ = writeln!(out, "            }}");
366                let _ = writeln!(out, "            return result;");
367                let _ = writeln!(out, "        }}");
368                let _ = writeln!(out, "    }}");
369                let _ = write!(out, "}}");
370            }
371            QueryCommand::Batch => {
372                let batch_fn_name = format!("{}Batch", func_name);
373                if params.len() > 1 {
374                    // Generate params record
375                    let params_record_name =
376                        format!("{}BatchParams", to_pascal_case(&analyzed.name));
377                    let record_fields = params
378                        .iter()
379                        .map(|p| format!("{} {}", java_param_type(p), p.field_name))
380                        .collect::<Vec<_>>()
381                        .join(", ");
382                    let _ = writeln!(
383                        out,
384                        "public record {}({}) {{}}",
385                        params_record_name, record_fields
386                    );
387                    let _ = writeln!(out);
388                    let _ = writeln!(
389                        out,
390                        "public static void {}(Connection conn, java.util.List<{}> items) throws SQLException {{",
391                        batch_fn_name, params_record_name
392                    );
393                    let _ = writeln!(out, "    conn.setAutoCommit(false);");
394                    let _ = writeln!(
395                        out,
396                        "    try (var ps = conn.prepareStatement(\"{}\")) {{",
397                        sql
398                    );
399                    let _ = writeln!(out, "        for (var item : items) {{");
400                    for (i, param) in params.iter().enumerate() {
401                        let setter = ps_setter(&param.lang_type);
402                        let _ = writeln!(
403                            out,
404                            "            ps.{}({}, item.{}());",
405                            setter,
406                            i + 1,
407                            param.field_name
408                        );
409                    }
410                    let _ = writeln!(out, "            ps.addBatch();");
411                    let _ = writeln!(out, "        }}");
412                    let _ = writeln!(out, "        ps.executeBatch();");
413                    let _ = writeln!(out, "        conn.commit();");
414                    let _ = writeln!(out, "    }} catch (SQLException e) {{");
415                    let _ = writeln!(out, "        conn.rollback();");
416                    let _ = writeln!(out, "        throw e;");
417                    let _ = writeln!(out, "    }} finally {{");
418                    let _ = writeln!(out, "        conn.setAutoCommit(true);");
419                    let _ = writeln!(out, "    }}");
420                    let _ = write!(out, "}}");
421                } else if params.len() == 1 {
422                    let param = &params[0];
423                    let _ = writeln!(
424                        out,
425                        "public static void {}(Connection conn, java.util.List<{}> items) throws SQLException {{",
426                        batch_fn_name,
427                        java_param_type(param)
428                    );
429                    let _ = writeln!(out, "    conn.setAutoCommit(false);");
430                    let _ = writeln!(
431                        out,
432                        "    try (var ps = conn.prepareStatement(\"{}\")) {{",
433                        sql
434                    );
435                    let _ = writeln!(out, "        for (var item : items) {{");
436                    let setter = ps_setter(&param.lang_type);
437                    let _ = writeln!(out, "            ps.{}(1, item);", setter);
438                    let _ = writeln!(out, "            ps.addBatch();");
439                    let _ = writeln!(out, "        }}");
440                    let _ = writeln!(out, "        ps.executeBatch();");
441                    let _ = writeln!(out, "        conn.commit();");
442                    let _ = writeln!(out, "    }} catch (SQLException e) {{");
443                    let _ = writeln!(out, "        conn.rollback();");
444                    let _ = writeln!(out, "        throw e;");
445                    let _ = writeln!(out, "    }} finally {{");
446                    let _ = writeln!(out, "        conn.setAutoCommit(true);");
447                    let _ = writeln!(out, "    }}");
448                    let _ = write!(out, "}}");
449                } else {
450                    let _ = writeln!(
451                        out,
452                        "public static void {}(Connection conn, int count) throws SQLException {{",
453                        batch_fn_name
454                    );
455                    let _ = writeln!(out, "    conn.setAutoCommit(false);");
456                    let _ = writeln!(
457                        out,
458                        "    try (var ps = conn.prepareStatement(\"{}\")) {{",
459                        sql
460                    );
461                    let _ = writeln!(out, "        for (int i = 0; i < count; i++) {{");
462                    let _ = writeln!(out, "            ps.addBatch();");
463                    let _ = writeln!(out, "        }}");
464                    let _ = writeln!(out, "        ps.executeBatch();");
465                    let _ = writeln!(out, "        conn.commit();");
466                    let _ = writeln!(out, "    }} catch (SQLException e) {{");
467                    let _ = writeln!(out, "        conn.rollback();");
468                    let _ = writeln!(out, "        throw e;");
469                    let _ = writeln!(out, "    }} finally {{");
470                    let _ = writeln!(out, "        conn.setAutoCommit(true);");
471                    let _ = writeln!(out, "    }}");
472                    let _ = write!(out, "}}");
473                }
474            }
475        }
476
477        Ok(out)
478    }
479
480    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
481        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
482        let mut out = String::new();
483        let _ = writeln!(out, "public enum {} {{", type_name);
484        for (i, value) in enum_info.values.iter().enumerate() {
485            let variant = enum_variant_name(value, &self.manifest.naming);
486            let sep = if i + 1 < enum_info.values.len() {
487                ","
488            } else {
489                ";"
490            };
491            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
492        }
493        let _ = writeln!(out);
494        let _ = writeln!(out, "    private final String value;");
495        let _ = writeln!(
496            out,
497            "    {}(String value) {{ this.value = value; }}",
498            type_name
499        );
500        let _ = writeln!(out, "    public String getValue() {{ return value; }}");
501        let _ = write!(out, "}}");
502        Ok(out)
503    }
504
505    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
506        let name = to_pascal_case(&composite.sql_name);
507        let mut out = String::new();
508        if composite.fields.is_empty() {
509            let _ = writeln!(out, "public record {}() {{}}", name);
510        } else {
511            let fields = composite
512                .fields
513                .iter()
514                .map(|f| format!("Object {}", to_camel_case(&f.name)))
515                .collect::<Vec<_>>()
516                .join(", ");
517            let _ = writeln!(out, "public record {}({}) {{}}", name, fields);
518        }
519        Ok(out)
520    }
521}