Skip to main content

scythe_codegen/backends/
java_jdbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/java-jdbc.toml");
16
17pub struct JavaJdbcBackend {
18    manifest: BackendManifest,
19}
20
21impl JavaJdbcBackend {
22    pub fn new() -> Result<Self, ScytheError> {
23        let manifest_path = Path::new("backends/java-jdbc/manifest.toml");
24        let manifest = if manifest_path.exists() {
25            load_manifest(manifest_path)
26                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
27        } else {
28            toml::from_str(DEFAULT_MANIFEST_TOML)
29                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30        };
31        Ok(Self { manifest })
32    }
33
34    pub fn manifest(&self) -> &BackendManifest {
35        &self.manifest
36    }
37}
38
39/// Convert PostgreSQL $1, $2, ... placeholders to JDBC ? placeholders.
40fn pg_to_jdbc_params(sql: &str) -> String {
41    let mut result = String::with_capacity(sql.len());
42    let mut chars = sql.chars().peekable();
43    while let Some(ch) = chars.next() {
44        if ch == '$' {
45            // Check if followed by digits
46            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
47                // Consume all digits
48                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
49                    chars.next();
50                }
51                result.push('?');
52            } else {
53                result.push(ch);
54            }
55        } else {
56            result.push(ch);
57        }
58    }
59    result
60}
61
62/// Convert a Java primitive type to its boxed equivalent for nullable usage.
63fn box_primitive(java_type: &str) -> &str {
64    match java_type {
65        "boolean" => "Boolean",
66        "byte" => "Byte",
67        "short" => "Short",
68        "int" => "Integer",
69        "long" => "Long",
70        "float" => "Float",
71        "double" => "Double",
72        "char" => "Character",
73        _ => java_type,
74    }
75}
76
77/// Get the ResultSet getter method name for a given Java type.
78fn rs_getter(java_type: &str) -> &str {
79    match java_type {
80        "boolean" | "Boolean" => "getBoolean",
81        "byte" | "Byte" => "getByte",
82        "short" | "Short" => "getShort",
83        "int" | "Integer" => "getInt",
84        "long" | "Long" => "getLong",
85        "float" | "Float" => "getFloat",
86        "double" | "Double" => "getDouble",
87        "String" => "getString",
88        "byte[]" => "getBytes",
89        _ if java_type.contains("BigDecimal") => "getBigDecimal",
90        _ if java_type.contains("LocalDate") => "getObject",
91        _ if java_type.contains("LocalTime") => "getObject",
92        _ if java_type.contains("OffsetTime") => "getObject",
93        _ if java_type.contains("LocalDateTime") => "getObject",
94        _ if java_type.contains("OffsetDateTime") => "getObject",
95        _ if java_type.contains("UUID") => "getObject",
96        _ => "getObject",
97    }
98}
99
100/// Get the PreparedStatement setter method name for a given Java type.
101fn ps_setter(java_type: &str) -> &str {
102    match java_type {
103        "boolean" | "Boolean" => "setBoolean",
104        "byte" | "Byte" => "setByte",
105        "short" | "Short" => "setShort",
106        "int" | "Integer" => "setInt",
107        "long" | "Long" => "setLong",
108        "float" | "Float" => "setFloat",
109        "double" | "Double" => "setDouble",
110        "String" => "setString",
111        "byte[]" => "setBytes",
112        _ if java_type.contains("BigDecimal") => "setBigDecimal",
113        _ => "setObject",
114    }
115}
116
117/// Resolve the display type for a Java field, boxing primitives when nullable.
118fn java_field_type(col: &ResolvedColumn) -> String {
119    if col.nullable {
120        box_primitive(&col.lang_type).to_string()
121    } else {
122        col.full_type.clone()
123    }
124}
125
126/// Resolve the display type for a Java param, boxing primitives when nullable.
127fn java_param_type(param: &ResolvedParam) -> String {
128    if param.nullable {
129        box_primitive(&param.lang_type).to_string()
130    } else {
131        param.full_type.clone()
132    }
133}
134
135impl CodegenBackend for JavaJdbcBackend {
136    fn name(&self) -> &str {
137        "java-jdbc"
138    }
139
140    fn generate_row_struct(
141        &self,
142        query_name: &str,
143        columns: &[ResolvedColumn],
144    ) -> Result<String, ScytheError> {
145        let struct_name = row_struct_name(query_name, &self.manifest.naming);
146        let mut out = String::new();
147
148        // Record declaration with fields
149        let fields = columns
150            .iter()
151            .map(|c| {
152                let field_type = java_field_type(c);
153                if c.nullable {
154                    format!("    @Nullable {} {}", field_type, c.field_name)
155                } else {
156                    format!("    {} {}", field_type, c.field_name)
157                }
158            })
159            .collect::<Vec<_>>()
160            .join(",\n");
161
162        let _ = writeln!(out, "public record {}(", struct_name);
163        let _ = writeln!(out, "{}", fields);
164        let _ = writeln!(out, ") {{");
165
166        // fromResultSet static factory method
167        let _ = writeln!(
168            out,
169            "    public static {} fromResultSet(ResultSet rs) throws SQLException {{",
170            struct_name
171        );
172        let _ = writeln!(out, "        return new {}(", struct_name);
173        for (i, col) in columns.iter().enumerate() {
174            let getter = rs_getter(&col.lang_type);
175            let sep = if i + 1 < columns.len() { "," } else { "" };
176            let _ = writeln!(out, "            rs.{}(\"{}\"){}", getter, col.name, sep);
177        }
178        let _ = writeln!(out, "        );");
179        let _ = writeln!(out, "    }}");
180        let _ = write!(out, "}}");
181        Ok(out)
182    }
183
184    fn generate_model_struct(
185        &self,
186        table_name: &str,
187        columns: &[ResolvedColumn],
188    ) -> Result<String, ScytheError> {
189        let name = to_pascal_case(table_name);
190        self.generate_row_struct(&name, columns)
191    }
192
193    fn generate_query_fn(
194        &self,
195        analyzed: &AnalyzedQuery,
196        struct_name: &str,
197        _columns: &[ResolvedColumn],
198        params: &[ResolvedParam],
199    ) -> Result<String, ScytheError> {
200        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
201        let sql = pg_to_jdbc_params(&super::clean_sql_oneline(&analyzed.sql));
202
203        let param_list = params
204            .iter()
205            .map(|p| {
206                let param_type = java_param_type(p);
207                format!("{} {}", param_type, p.field_name)
208            })
209            .collect::<Vec<_>>()
210            .join(", ");
211        let sep = if param_list.is_empty() { "" } else { ", " };
212
213        let mut out = String::new();
214
215        match &analyzed.command {
216            QueryCommand::Exec => {
217                let _ = writeln!(
218                    out,
219                    "public static void {}(Connection conn{}{}) throws SQLException {{",
220                    func_name, sep, param_list
221                );
222                let _ = writeln!(
223                    out,
224                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
225                    sql
226                );
227                for (i, param) in params.iter().enumerate() {
228                    let setter = ps_setter(&param.lang_type);
229                    let _ = writeln!(
230                        out,
231                        "        ps.{}({}, {});",
232                        setter,
233                        i + 1,
234                        param.field_name
235                    );
236                }
237                let _ = writeln!(out, "        ps.executeUpdate();");
238                let _ = writeln!(out, "    }}");
239                let _ = write!(out, "}}");
240            }
241            QueryCommand::ExecResult | QueryCommand::ExecRows => {
242                let _ = writeln!(
243                    out,
244                    "public static int {}(Connection conn{}{}) throws SQLException {{",
245                    func_name, sep, param_list
246                );
247                let _ = writeln!(
248                    out,
249                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
250                    sql
251                );
252                for (i, param) in params.iter().enumerate() {
253                    let setter = ps_setter(&param.lang_type);
254                    let _ = writeln!(
255                        out,
256                        "        ps.{}({}, {});",
257                        setter,
258                        i + 1,
259                        param.field_name
260                    );
261                }
262                let _ = writeln!(out, "        return ps.executeUpdate();");
263                let _ = writeln!(out, "    }}");
264                let _ = write!(out, "}}");
265            }
266            QueryCommand::One => {
267                let _ = writeln!(
268                    out,
269                    "public static {} {}(Connection conn{}{}) throws SQLException {{",
270                    struct_name, func_name, sep, param_list
271                );
272                let _ = writeln!(
273                    out,
274                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
275                    sql
276                );
277                for (i, param) in params.iter().enumerate() {
278                    let setter = ps_setter(&param.lang_type);
279                    let _ = writeln!(
280                        out,
281                        "        ps.{}({}, {});",
282                        setter,
283                        i + 1,
284                        param.field_name
285                    );
286                }
287                let _ = writeln!(out, "        try (ResultSet rs = ps.executeQuery()) {{");
288                let _ = writeln!(out, "            if (rs.next()) {{");
289                let _ = writeln!(
290                    out,
291                    "                return {}.fromResultSet(rs);",
292                    struct_name
293                );
294                let _ = writeln!(out, "            }}");
295                let _ = writeln!(out, "            return null;");
296                let _ = writeln!(out, "        }}");
297                let _ = writeln!(out, "    }}");
298                let _ = write!(out, "}}");
299            }
300            QueryCommand::Many | QueryCommand::Batch => {
301                let _ = writeln!(
302                    out,
303                    "public static java.util.List<{}> {}(Connection conn{}{}) throws SQLException {{",
304                    struct_name, func_name, sep, param_list
305                );
306                let _ = writeln!(
307                    out,
308                    "    try (var ps = conn.prepareStatement(\"{}\")) {{",
309                    sql
310                );
311                for (i, param) in params.iter().enumerate() {
312                    let setter = ps_setter(&param.lang_type);
313                    let _ = writeln!(
314                        out,
315                        "        ps.{}({}, {});",
316                        setter,
317                        i + 1,
318                        param.field_name
319                    );
320                }
321                let _ = writeln!(out, "        try (ResultSet rs = ps.executeQuery()) {{");
322                let _ = writeln!(
323                    out,
324                    "            java.util.List<{}> result = new java.util.ArrayList<>();",
325                    struct_name
326                );
327                let _ = writeln!(out, "            while (rs.next()) {{");
328                let _ = writeln!(
329                    out,
330                    "                result.add({}.fromResultSet(rs));",
331                    struct_name
332                );
333                let _ = writeln!(out, "            }}");
334                let _ = writeln!(out, "            return result;");
335                let _ = writeln!(out, "        }}");
336                let _ = writeln!(out, "    }}");
337                let _ = write!(out, "}}");
338            }
339        }
340
341        Ok(out)
342    }
343
344    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
345        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
346        let mut out = String::new();
347        let _ = writeln!(out, "public enum {} {{", type_name);
348        for (i, value) in enum_info.values.iter().enumerate() {
349            let variant = enum_variant_name(value, &self.manifest.naming);
350            let sep = if i + 1 < enum_info.values.len() {
351                ","
352            } else {
353                ";"
354            };
355            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
356        }
357        let _ = writeln!(out);
358        let _ = writeln!(out, "    private final String value;");
359        let _ = writeln!(
360            out,
361            "    {}(String value) {{ this.value = value; }}",
362            type_name
363        );
364        let _ = writeln!(out, "    public String getValue() {{ return value; }}");
365        let _ = write!(out, "}}");
366        Ok(out)
367    }
368
369    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
370        let name = to_pascal_case(&composite.sql_name);
371        let mut out = String::new();
372        if composite.fields.is_empty() {
373            let _ = writeln!(out, "public record {}() {{}}", name);
374        } else {
375            let fields = composite
376                .fields
377                .iter()
378                .map(|f| format!("Object {}", to_camel_case(&f.name)))
379                .collect::<Vec<_>>()
380                .join(", ");
381            let _ = writeln!(out, "public record {}({}) {{}}", name, fields);
382        }
383        Ok(out)
384    }
385}