Skip to main content

scythe_codegen/backends/
java_jdbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/java-jdbc.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/java-jdbc.mysql.toml");
17const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/java-jdbc.sqlite.toml");
18const DEFAULT_MANIFEST_DUCKDB: &str = include_str!("../../manifests/java-jdbc.duckdb.toml");
19
20pub struct JavaJdbcBackend {
21    manifest: BackendManifest,
22}
23
24impl JavaJdbcBackend {
25    pub fn new(engine: &str) -> Result<Self, ScytheError> {
26        let default_toml = match engine {
27            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
28            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
29            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
30            "duckdb" => DEFAULT_MANIFEST_DUCKDB,
31            _ => {
32                return Err(ScytheError::new(
33                    ErrorCode::InternalError,
34                    format!("unsupported engine '{}' for java-jdbc backend", engine),
35                ));
36            }
37        };
38        let manifest_path = Path::new("backends/java-jdbc/manifest.toml");
39        let manifest = if manifest_path.exists() {
40            load_manifest(manifest_path)
41                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
42        } else {
43            toml::from_str(default_toml)
44                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45        };
46        Ok(Self { manifest })
47    }
48}
49
50/// Convert PostgreSQL $1, $2, ... placeholders to JDBC ? placeholders.
51fn pg_to_jdbc_params(sql: &str) -> String {
52    let mut result = String::with_capacity(sql.len());
53    let mut chars = sql.chars().peekable();
54    while let Some(ch) = chars.next() {
55        if ch == '$' {
56            // Check if followed by digits
57            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
58                // Consume all digits
59                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
60                    chars.next();
61                }
62                result.push('?');
63            } else {
64                result.push(ch);
65            }
66        } else {
67            result.push(ch);
68        }
69    }
70    result
71}
72
73/// Convert a Java primitive type to its boxed equivalent for nullable usage.
74fn box_primitive(java_type: &str) -> &str {
75    match java_type {
76        "boolean" => "Boolean",
77        "byte" => "Byte",
78        "short" => "Short",
79        "int" => "Integer",
80        "long" => "Long",
81        "float" => "Float",
82        "double" => "Double",
83        "char" => "Character",
84        _ => java_type,
85    }
86}
87
88/// Get the ResultSet getter method name for a given Java type.
89fn rs_getter(java_type: &str) -> &str {
90    match java_type {
91        "boolean" | "Boolean" => "getBoolean",
92        "byte" | "Byte" => "getByte",
93        "short" | "Short" => "getShort",
94        "int" | "Integer" => "getInt",
95        "long" | "Long" => "getLong",
96        "float" | "Float" => "getFloat",
97        "double" | "Double" => "getDouble",
98        "String" => "getString",
99        "byte[]" => "getBytes",
100        _ if java_type.contains("BigDecimal") => "getBigDecimal",
101        _ if java_type.contains("LocalDate") => "getObject",
102        _ if java_type.contains("LocalTime") => "getObject",
103        _ if java_type.contains("OffsetTime") => "getObject",
104        _ if java_type.contains("LocalDateTime") => "getObject",
105        _ if java_type.contains("OffsetDateTime") => "getObject",
106        _ if java_type.contains("UUID") => "getObject",
107        _ => "getObject",
108    }
109}
110
111/// Get the PreparedStatement setter method name for a given Java type.
112fn ps_setter(java_type: &str) -> &str {
113    match java_type {
114        "boolean" | "Boolean" => "setBoolean",
115        "byte" | "Byte" => "setByte",
116        "short" | "Short" => "setShort",
117        "int" | "Integer" => "setInt",
118        "long" | "Long" => "setLong",
119        "float" | "Float" => "setFloat",
120        "double" | "Double" => "setDouble",
121        "String" => "setString",
122        "byte[]" => "setBytes",
123        _ if java_type.contains("BigDecimal") => "setBigDecimal",
124        _ => "setObject",
125    }
126}
127
128/// Resolve the display type for a Java field, boxing primitives when nullable.
129fn java_field_type(col: &ResolvedColumn) -> String {
130    if col.nullable {
131        box_primitive(&col.lang_type).to_string()
132    } else {
133        col.full_type.clone()
134    }
135}
136
137/// Resolve the display type for a Java param, boxing primitives when nullable.
138fn java_param_type(param: &ResolvedParam) -> String {
139    if param.nullable {
140        box_primitive(&param.lang_type).to_string()
141    } else {
142        param.full_type.clone()
143    }
144}
145
146/// Check whether a Java type is a primitive (not a reference type).
147fn is_java_primitive(java_type: &str) -> bool {
148    matches!(
149        java_type,
150        "boolean" | "byte" | "short" | "int" | "long" | "float" | "double" | "char"
151    )
152}
153
154/// Format a Java parameter with nullability annotation.
155fn java_annotated_param(param: &ResolvedParam) -> String {
156    let param_type = java_param_type(param);
157    if param.nullable {
158        format!("@Nullable {} {}", param_type, param.field_name)
159    } else if !is_java_primitive(&param.lang_type) {
160        format!("@Nonnull {} {}", param_type, param.field_name)
161    } else {
162        format!("{} {}", param_type, param.field_name)
163    }
164}
165
166impl CodegenBackend for JavaJdbcBackend {
167    fn name(&self) -> &str {
168        "java-jdbc"
169    }
170
171    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
172        &self.manifest
173    }
174
175    fn supported_engines(&self) -> &[&str] {
176        &["postgresql", "mysql", "sqlite", "duckdb"]
177    }
178
179    fn file_header(&self) -> String {
180        "// Auto-generated by scythe. Do not edit.\n\
181         import java.math.BigDecimal;\n\
182         import java.sql.*;\n\
183         import java.time.OffsetDateTime;\n\
184         import java.util.ArrayList;\n\
185         import java.util.List;\n\
186         import javax.annotation.Nonnull;\n\
187         import javax.annotation.Nullable;"
188            .to_string()
189    }
190
191    fn generate_row_struct(
192        &self,
193        query_name: &str,
194        columns: &[ResolvedColumn],
195    ) -> Result<String, ScytheError> {
196        let struct_name = row_struct_name(query_name, &self.manifest.naming);
197        let mut out = String::new();
198
199        // Record declaration with fields
200        let fields = columns
201            .iter()
202            .map(|c| {
203                let field_type = java_field_type(c);
204                if c.nullable {
205                    format!("    @Nullable {} {}", field_type, c.field_name)
206                } else {
207                    format!("    {} {}", field_type, c.field_name)
208                }
209            })
210            .collect::<Vec<_>>()
211            .join(",\n");
212
213        let _ = writeln!(out, "public record {}(", struct_name);
214        let _ = writeln!(out, "{}", fields);
215        let _ = writeln!(out, ") {{");
216
217        // fromResultSet static factory method
218        let _ = writeln!(
219            out,
220            "    public static {} fromResultSet(ResultSet rs) throws SQLException {{",
221            struct_name
222        );
223        let _ = writeln!(out, "        return new {}(", struct_name);
224        for (i, col) in columns.iter().enumerate() {
225            let getter = rs_getter(&col.lang_type);
226            let sep = if i + 1 < columns.len() { "," } else { "" };
227            let _ = writeln!(out, "            rs.{}(\"{}\"){}", getter, col.name, sep);
228        }
229        let _ = writeln!(out, "        );");
230        let _ = writeln!(out, "    }}");
231        let _ = write!(out, "}}");
232        Ok(out)
233    }
234
235    fn generate_model_struct(
236        &self,
237        table_name: &str,
238        columns: &[ResolvedColumn],
239    ) -> Result<String, ScytheError> {
240        let name = to_pascal_case(table_name);
241        self.generate_row_struct(&name, columns)
242    }
243
244    fn generate_query_fn(
245        &self,
246        analyzed: &AnalyzedQuery,
247        struct_name: &str,
248        _columns: &[ResolvedColumn],
249        params: &[ResolvedParam],
250    ) -> Result<String, ScytheError> {
251        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
252        let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
253            &analyzed.sql,
254            &analyzed.optional_params,
255            &analyzed.params,
256        ));
257
258        let param_list = params
259            .iter()
260            .map(java_annotated_param)
261            .collect::<Vec<_>>()
262            .join(", ");
263        let sep = if param_list.is_empty() { "" } else { ", " };
264
265        let mut out = String::new();
266
267        match &analyzed.command {
268            QueryCommand::Exec => {
269                let _ = writeln!(
270                    out,
271                    "public static void {}(Connection conn{}{}) throws SQLException {{",
272                    func_name, sep, param_list
273                );
274                let _ = writeln!(
275                    out,
276                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
277                    sql
278                );
279                for (i, param) in params.iter().enumerate() {
280                    let setter = ps_setter(&param.lang_type);
281                    let _ = writeln!(
282                        out,
283                        "        ps.{}({}, {});",
284                        setter,
285                        i + 1,
286                        param.field_name
287                    );
288                }
289                let _ = writeln!(out, "        ps.executeUpdate();");
290                let _ = writeln!(out, "    }}");
291                let _ = write!(out, "}}");
292            }
293            QueryCommand::ExecResult | QueryCommand::ExecRows => {
294                let _ = writeln!(
295                    out,
296                    "public static int {}(Connection conn{}{}) throws SQLException {{",
297                    func_name, sep, param_list
298                );
299                let _ = writeln!(
300                    out,
301                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
302                    sql
303                );
304                for (i, param) in params.iter().enumerate() {
305                    let setter = ps_setter(&param.lang_type);
306                    let _ = writeln!(
307                        out,
308                        "        ps.{}({}, {});",
309                        setter,
310                        i + 1,
311                        param.field_name
312                    );
313                }
314                let _ = writeln!(out, "        return ps.executeUpdate();");
315                let _ = writeln!(out, "    }}");
316                let _ = write!(out, "}}");
317            }
318            QueryCommand::One => {
319                let _ = writeln!(
320                    out,
321                    "public static @Nullable {} {}(Connection conn{}{}) throws SQLException {{",
322                    struct_name, func_name, sep, param_list
323                );
324                let _ = writeln!(
325                    out,
326                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
327                    sql
328                );
329                for (i, param) in params.iter().enumerate() {
330                    let setter = ps_setter(&param.lang_type);
331                    let _ = writeln!(
332                        out,
333                        "        ps.{}({}, {});",
334                        setter,
335                        i + 1,
336                        param.field_name
337                    );
338                }
339                let _ = writeln!(out, "        try (ResultSet rs = ps.executeQuery()) {{");
340                let _ = writeln!(out, "            if (rs.next()) {{");
341                let _ = writeln!(
342                    out,
343                    "                return {}.fromResultSet(rs);",
344                    struct_name
345                );
346                let _ = writeln!(out, "            }}");
347                let _ = writeln!(out, "            return null;");
348                let _ = writeln!(out, "        }}");
349                let _ = writeln!(out, "    }}");
350                let _ = write!(out, "}}");
351            }
352            QueryCommand::Many => {
353                let _ = writeln!(
354                    out,
355                    "public static java.util.List<{}> {}(Connection conn{}{}) throws SQLException {{",
356                    struct_name, func_name, sep, param_list
357                );
358                let _ = writeln!(
359                    out,
360                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
361                    sql
362                );
363                for (i, param) in params.iter().enumerate() {
364                    let setter = ps_setter(&param.lang_type);
365                    let _ = writeln!(
366                        out,
367                        "        ps.{}({}, {});",
368                        setter,
369                        i + 1,
370                        param.field_name
371                    );
372                }
373                let _ = writeln!(out, "        try (ResultSet rs = ps.executeQuery()) {{");
374                let _ = writeln!(
375                    out,
376                    "            java.util.List<{}> result = new java.util.ArrayList<>();",
377                    struct_name
378                );
379                let _ = writeln!(out, "            while (rs.next()) {{");
380                let _ = writeln!(
381                    out,
382                    "                result.add({}.fromResultSet(rs));",
383                    struct_name
384                );
385                let _ = writeln!(out, "            }}");
386                let _ = writeln!(out, "            return result;");
387                let _ = writeln!(out, "        }}");
388                let _ = writeln!(out, "    }}");
389                let _ = write!(out, "}}");
390            }
391            QueryCommand::Batch => {
392                let batch_fn_name = format!("{}Batch", func_name);
393                if params.len() > 1 {
394                    // Generate params record
395                    let params_record_name =
396                        format!("{}BatchParams", to_pascal_case(&analyzed.name));
397                    let record_fields = params
398                        .iter()
399                        .map(|p| format!("{} {}", java_param_type(p), p.field_name))
400                        .collect::<Vec<_>>()
401                        .join(", ");
402                    let _ = writeln!(
403                        out,
404                        "public record {}({}) {{}}",
405                        params_record_name, record_fields
406                    );
407                    let _ = writeln!(out);
408                    let _ = writeln!(
409                        out,
410                        "public static void {}(Connection conn, java.util.List<{}> items) throws SQLException {{",
411                        batch_fn_name, params_record_name
412                    );
413                    let _ = writeln!(out, "    conn.setAutoCommit(false);");
414                    let _ = writeln!(
415                        out,
416                        "    try (var ps = conn.prepareStatement(\"{}\")) {{",
417                        sql
418                    );
419                    let _ = writeln!(out, "        for (var item : items) {{");
420                    for (i, param) in params.iter().enumerate() {
421                        let setter = ps_setter(&param.lang_type);
422                        let _ = writeln!(
423                            out,
424                            "            ps.{}({}, item.{}());",
425                            setter,
426                            i + 1,
427                            param.field_name
428                        );
429                    }
430                    let _ = writeln!(out, "            ps.addBatch();");
431                    let _ = writeln!(out, "        }}");
432                    let _ = writeln!(out, "        ps.executeBatch();");
433                    let _ = writeln!(out, "        conn.commit();");
434                    let _ = writeln!(out, "    }} catch (SQLException e) {{");
435                    let _ = writeln!(out, "        conn.rollback();");
436                    let _ = writeln!(out, "        throw e;");
437                    let _ = writeln!(out, "    }} finally {{");
438                    let _ = writeln!(out, "        conn.setAutoCommit(true);");
439                    let _ = writeln!(out, "    }}");
440                    let _ = write!(out, "}}");
441                } else if params.len() == 1 {
442                    let param = &params[0];
443                    let _ = writeln!(
444                        out,
445                        "public static void {}(Connection conn, java.util.List<{}> items) throws SQLException {{",
446                        batch_fn_name,
447                        java_param_type(param)
448                    );
449                    let _ = writeln!(out, "    conn.setAutoCommit(false);");
450                    let _ = writeln!(
451                        out,
452                        "    try (var ps = conn.prepareStatement(\"{}\")) {{",
453                        sql
454                    );
455                    let _ = writeln!(out, "        for (var item : items) {{");
456                    let setter = ps_setter(&param.lang_type);
457                    let _ = writeln!(out, "            ps.{}(1, item);", setter);
458                    let _ = writeln!(out, "            ps.addBatch();");
459                    let _ = writeln!(out, "        }}");
460                    let _ = writeln!(out, "        ps.executeBatch();");
461                    let _ = writeln!(out, "        conn.commit();");
462                    let _ = writeln!(out, "    }} catch (SQLException e) {{");
463                    let _ = writeln!(out, "        conn.rollback();");
464                    let _ = writeln!(out, "        throw e;");
465                    let _ = writeln!(out, "    }} finally {{");
466                    let _ = writeln!(out, "        conn.setAutoCommit(true);");
467                    let _ = writeln!(out, "    }}");
468                    let _ = write!(out, "}}");
469                } else {
470                    let _ = writeln!(
471                        out,
472                        "public static void {}(Connection conn, int count) throws SQLException {{",
473                        batch_fn_name
474                    );
475                    let _ = writeln!(out, "    conn.setAutoCommit(false);");
476                    let _ = writeln!(
477                        out,
478                        "    try (var ps = conn.prepareStatement(\"{}\")) {{",
479                        sql
480                    );
481                    let _ = writeln!(out, "        for (int i = 0; i < count; i++) {{");
482                    let _ = writeln!(out, "            ps.addBatch();");
483                    let _ = writeln!(out, "        }}");
484                    let _ = writeln!(out, "        ps.executeBatch();");
485                    let _ = writeln!(out, "        conn.commit();");
486                    let _ = writeln!(out, "    }} catch (SQLException e) {{");
487                    let _ = writeln!(out, "        conn.rollback();");
488                    let _ = writeln!(out, "        throw e;");
489                    let _ = writeln!(out, "    }} finally {{");
490                    let _ = writeln!(out, "        conn.setAutoCommit(true);");
491                    let _ = writeln!(out, "    }}");
492                    let _ = write!(out, "}}");
493                }
494            }
495            QueryCommand::Grouped => {
496                return Err(ScytheError::new(
497                    ErrorCode::InternalError,
498                    "grouped queries are not yet supported for java-jdbc".to_string(),
499                ));
500            }
501        }
502
503        Ok(out)
504    }
505
506    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
507        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
508        let mut out = String::new();
509        let _ = writeln!(out, "public enum {} {{", type_name);
510        for (i, value) in enum_info.values.iter().enumerate() {
511            let variant = enum_variant_name(value, &self.manifest.naming);
512            let sep = if i + 1 < enum_info.values.len() {
513                ","
514            } else {
515                ";"
516            };
517            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
518        }
519        let _ = writeln!(out);
520        let _ = writeln!(out, "    private final String value;");
521        let _ = writeln!(
522            out,
523            "    {}(String value) {{ this.value = value; }}",
524            type_name
525        );
526        let _ = writeln!(out, "    public String getValue() {{ return value; }}");
527        let _ = write!(out, "}}");
528        Ok(out)
529    }
530
531    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
532        let name = to_pascal_case(&composite.sql_name);
533        let mut out = String::new();
534        if composite.fields.is_empty() {
535            let _ = writeln!(out, "public record {}() {{}}", name);
536        } else {
537            let fields = composite
538                .fields
539                .iter()
540                .map(|f| format!("Object {}", to_camel_case(&f.name)))
541                .collect::<Vec<_>>()
542                .join(", ");
543            let _ = writeln!(out, "public record {}({}) {{}}", name, fields);
544        }
545        Ok(out)
546    }
547}