1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/java-jdbc.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/java-jdbc.mysql.toml");
17const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/java-jdbc.sqlite.toml");
18const DEFAULT_MANIFEST_DUCKDB: &str = include_str!("../../manifests/java-jdbc.duckdb.toml");
19
20pub struct JavaJdbcBackend {
21 manifest: BackendManifest,
22}
23
24impl JavaJdbcBackend {
25 pub fn new(engine: &str) -> Result<Self, ScytheError> {
26 let default_toml = match engine {
27 "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
28 "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
29 "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
30 "duckdb" => DEFAULT_MANIFEST_DUCKDB,
31 _ => {
32 return Err(ScytheError::new(
33 ErrorCode::InternalError,
34 format!("unsupported engine '{}' for java-jdbc backend", engine),
35 ));
36 }
37 };
38 let manifest_path = Path::new("backends/java-jdbc/manifest.toml");
39 let manifest = if manifest_path.exists() {
40 load_manifest(manifest_path)
41 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
42 } else {
43 toml::from_str(default_toml)
44 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45 };
46 Ok(Self { manifest })
47 }
48}
49
50fn pg_to_jdbc_params(sql: &str) -> String {
52 let mut result = String::with_capacity(sql.len());
53 let mut chars = sql.chars().peekable();
54 while let Some(ch) = chars.next() {
55 if ch == '$' {
56 if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
58 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
60 chars.next();
61 }
62 result.push('?');
63 } else {
64 result.push(ch);
65 }
66 } else {
67 result.push(ch);
68 }
69 }
70 result
71}
72
73fn box_primitive(java_type: &str) -> &str {
75 match java_type {
76 "boolean" => "Boolean",
77 "byte" => "Byte",
78 "short" => "Short",
79 "int" => "Integer",
80 "long" => "Long",
81 "float" => "Float",
82 "double" => "Double",
83 "char" => "Character",
84 _ => java_type,
85 }
86}
87
88fn rs_getter(java_type: &str) -> &str {
90 match java_type {
91 "boolean" | "Boolean" => "getBoolean",
92 "byte" | "Byte" => "getByte",
93 "short" | "Short" => "getShort",
94 "int" | "Integer" => "getInt",
95 "long" | "Long" => "getLong",
96 "float" | "Float" => "getFloat",
97 "double" | "Double" => "getDouble",
98 "String" => "getString",
99 "byte[]" => "getBytes",
100 _ if java_type.contains("BigDecimal") => "getBigDecimal",
101 _ if java_type.contains("LocalDate") => "getObject",
102 _ if java_type.contains("LocalTime") => "getObject",
103 _ if java_type.contains("OffsetTime") => "getObject",
104 _ if java_type.contains("LocalDateTime") => "getObject",
105 _ if java_type.contains("OffsetDateTime") => "getObject",
106 _ if java_type.contains("UUID") => "getObject",
107 _ => "getObject",
108 }
109}
110
111fn ps_setter(java_type: &str) -> &str {
113 match java_type {
114 "boolean" | "Boolean" => "setBoolean",
115 "byte" | "Byte" => "setByte",
116 "short" | "Short" => "setShort",
117 "int" | "Integer" => "setInt",
118 "long" | "Long" => "setLong",
119 "float" | "Float" => "setFloat",
120 "double" | "Double" => "setDouble",
121 "String" => "setString",
122 "byte[]" => "setBytes",
123 _ if java_type.contains("BigDecimal") => "setBigDecimal",
124 _ => "setObject",
125 }
126}
127
128fn java_field_type(col: &ResolvedColumn) -> String {
130 if col.nullable {
131 box_primitive(&col.lang_type).to_string()
132 } else {
133 col.full_type.clone()
134 }
135}
136
137fn java_param_type(param: &ResolvedParam) -> String {
139 if param.nullable {
140 box_primitive(¶m.lang_type).to_string()
141 } else {
142 param.full_type.clone()
143 }
144}
145
146fn is_java_primitive(java_type: &str) -> bool {
148 matches!(
149 java_type,
150 "boolean" | "byte" | "short" | "int" | "long" | "float" | "double" | "char"
151 )
152}
153
154fn java_annotated_param(param: &ResolvedParam) -> String {
156 let param_type = java_param_type(param);
157 if param.nullable {
158 format!("@Nullable {} {}", param_type, param.field_name)
159 } else if !is_java_primitive(¶m.lang_type) {
160 format!("@Nonnull {} {}", param_type, param.field_name)
161 } else {
162 format!("{} {}", param_type, param.field_name)
163 }
164}
165
166impl CodegenBackend for JavaJdbcBackend {
167 fn name(&self) -> &str {
168 "java-jdbc"
169 }
170
171 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
172 &self.manifest
173 }
174
175 fn supported_engines(&self) -> &[&str] {
176 &["postgresql", "mysql", "sqlite", "duckdb"]
177 }
178
179 fn file_header(&self) -> String {
180 "// Auto-generated by scythe. Do not edit.\n\
181 import java.math.BigDecimal;\n\
182 import java.sql.*;\n\
183 import java.time.OffsetDateTime;\n\
184 import java.util.ArrayList;\n\
185 import java.util.List;\n\
186 import javax.annotation.Nonnull;\n\
187 import javax.annotation.Nullable;"
188 .to_string()
189 }
190
191 fn generate_row_struct(
192 &self,
193 query_name: &str,
194 columns: &[ResolvedColumn],
195 ) -> Result<String, ScytheError> {
196 let struct_name = row_struct_name(query_name, &self.manifest.naming);
197 let mut out = String::new();
198
199 let fields = columns
201 .iter()
202 .map(|c| {
203 let field_type = java_field_type(c);
204 if c.nullable {
205 format!(" @Nullable {} {}", field_type, c.field_name)
206 } else {
207 format!(" {} {}", field_type, c.field_name)
208 }
209 })
210 .collect::<Vec<_>>()
211 .join(",\n");
212
213 let _ = writeln!(out, "public record {}(", struct_name);
214 let _ = writeln!(out, "{}", fields);
215 let _ = writeln!(out, ") {{");
216
217 let _ = writeln!(
219 out,
220 " public static {} fromResultSet(ResultSet rs) throws SQLException {{",
221 struct_name
222 );
223 let _ = writeln!(out, " return new {}(", struct_name);
224 for (i, col) in columns.iter().enumerate() {
225 let getter = rs_getter(&col.lang_type);
226 let sep = if i + 1 < columns.len() { "," } else { "" };
227 let _ = writeln!(out, " rs.{}(\"{}\"){}", getter, col.name, sep);
228 }
229 let _ = writeln!(out, " );");
230 let _ = writeln!(out, " }}");
231 let _ = write!(out, "}}");
232 Ok(out)
233 }
234
235 fn generate_model_struct(
236 &self,
237 table_name: &str,
238 columns: &[ResolvedColumn],
239 ) -> Result<String, ScytheError> {
240 let name = to_pascal_case(table_name);
241 self.generate_row_struct(&name, columns)
242 }
243
244 fn generate_query_fn(
245 &self,
246 analyzed: &AnalyzedQuery,
247 struct_name: &str,
248 _columns: &[ResolvedColumn],
249 params: &[ResolvedParam],
250 ) -> Result<String, ScytheError> {
251 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
252 let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
253 &analyzed.sql,
254 &analyzed.optional_params,
255 &analyzed.params,
256 ));
257
258 let param_list = params
259 .iter()
260 .map(java_annotated_param)
261 .collect::<Vec<_>>()
262 .join(", ");
263 let sep = if param_list.is_empty() { "" } else { ", " };
264
265 let mut out = String::new();
266
267 match &analyzed.command {
268 QueryCommand::Exec => {
269 let _ = writeln!(
270 out,
271 "public static void {}(Connection conn{}{}) throws SQLException {{",
272 func_name, sep, param_list
273 );
274 let _ = writeln!(
275 out,
276 " try (var ps = conn.prepareStatement(\"{}\")) {{",
277 sql
278 );
279 for (i, param) in params.iter().enumerate() {
280 let setter = ps_setter(¶m.lang_type);
281 let _ = writeln!(
282 out,
283 " ps.{}({}, {});",
284 setter,
285 i + 1,
286 param.field_name
287 );
288 }
289 let _ = writeln!(out, " ps.executeUpdate();");
290 let _ = writeln!(out, " }}");
291 let _ = write!(out, "}}");
292 }
293 QueryCommand::ExecResult | QueryCommand::ExecRows => {
294 let _ = writeln!(
295 out,
296 "public static int {}(Connection conn{}{}) throws SQLException {{",
297 func_name, sep, param_list
298 );
299 let _ = writeln!(
300 out,
301 " try (var ps = conn.prepareStatement(\"{}\")) {{",
302 sql
303 );
304 for (i, param) in params.iter().enumerate() {
305 let setter = ps_setter(¶m.lang_type);
306 let _ = writeln!(
307 out,
308 " ps.{}({}, {});",
309 setter,
310 i + 1,
311 param.field_name
312 );
313 }
314 let _ = writeln!(out, " return ps.executeUpdate();");
315 let _ = writeln!(out, " }}");
316 let _ = write!(out, "}}");
317 }
318 QueryCommand::One => {
319 let _ = writeln!(
320 out,
321 "public static @Nullable {} {}(Connection conn{}{}) throws SQLException {{",
322 struct_name, func_name, sep, param_list
323 );
324 let _ = writeln!(
325 out,
326 " try (var ps = conn.prepareStatement(\"{}\")) {{",
327 sql
328 );
329 for (i, param) in params.iter().enumerate() {
330 let setter = ps_setter(¶m.lang_type);
331 let _ = writeln!(
332 out,
333 " ps.{}({}, {});",
334 setter,
335 i + 1,
336 param.field_name
337 );
338 }
339 let _ = writeln!(out, " try (ResultSet rs = ps.executeQuery()) {{");
340 let _ = writeln!(out, " if (rs.next()) {{");
341 let _ = writeln!(
342 out,
343 " return {}.fromResultSet(rs);",
344 struct_name
345 );
346 let _ = writeln!(out, " }}");
347 let _ = writeln!(out, " return null;");
348 let _ = writeln!(out, " }}");
349 let _ = writeln!(out, " }}");
350 let _ = write!(out, "}}");
351 }
352 QueryCommand::Many => {
353 let _ = writeln!(
354 out,
355 "public static java.util.List<{}> {}(Connection conn{}{}) throws SQLException {{",
356 struct_name, func_name, sep, param_list
357 );
358 let _ = writeln!(
359 out,
360 " try (var ps = conn.prepareStatement(\"{}\")) {{",
361 sql
362 );
363 for (i, param) in params.iter().enumerate() {
364 let setter = ps_setter(¶m.lang_type);
365 let _ = writeln!(
366 out,
367 " ps.{}({}, {});",
368 setter,
369 i + 1,
370 param.field_name
371 );
372 }
373 let _ = writeln!(out, " try (ResultSet rs = ps.executeQuery()) {{");
374 let _ = writeln!(
375 out,
376 " java.util.List<{}> result = new java.util.ArrayList<>();",
377 struct_name
378 );
379 let _ = writeln!(out, " while (rs.next()) {{");
380 let _ = writeln!(
381 out,
382 " result.add({}.fromResultSet(rs));",
383 struct_name
384 );
385 let _ = writeln!(out, " }}");
386 let _ = writeln!(out, " return result;");
387 let _ = writeln!(out, " }}");
388 let _ = writeln!(out, " }}");
389 let _ = write!(out, "}}");
390 }
391 QueryCommand::Batch => {
392 let batch_fn_name = format!("{}Batch", func_name);
393 if params.len() > 1 {
394 let params_record_name =
396 format!("{}BatchParams", to_pascal_case(&analyzed.name));
397 let record_fields = params
398 .iter()
399 .map(|p| format!("{} {}", java_param_type(p), p.field_name))
400 .collect::<Vec<_>>()
401 .join(", ");
402 let _ = writeln!(
403 out,
404 "public record {}({}) {{}}",
405 params_record_name, record_fields
406 );
407 let _ = writeln!(out);
408 let _ = writeln!(
409 out,
410 "public static void {}(Connection conn, java.util.List<{}> items) throws SQLException {{",
411 batch_fn_name, params_record_name
412 );
413 let _ = writeln!(out, " conn.setAutoCommit(false);");
414 let _ = writeln!(
415 out,
416 " try (var ps = conn.prepareStatement(\"{}\")) {{",
417 sql
418 );
419 let _ = writeln!(out, " for (var item : items) {{");
420 for (i, param) in params.iter().enumerate() {
421 let setter = ps_setter(¶m.lang_type);
422 let _ = writeln!(
423 out,
424 " ps.{}({}, item.{}());",
425 setter,
426 i + 1,
427 param.field_name
428 );
429 }
430 let _ = writeln!(out, " ps.addBatch();");
431 let _ = writeln!(out, " }}");
432 let _ = writeln!(out, " ps.executeBatch();");
433 let _ = writeln!(out, " conn.commit();");
434 let _ = writeln!(out, " }} catch (SQLException e) {{");
435 let _ = writeln!(out, " conn.rollback();");
436 let _ = writeln!(out, " throw e;");
437 let _ = writeln!(out, " }} finally {{");
438 let _ = writeln!(out, " conn.setAutoCommit(true);");
439 let _ = writeln!(out, " }}");
440 let _ = write!(out, "}}");
441 } else if params.len() == 1 {
442 let param = ¶ms[0];
443 let _ = writeln!(
444 out,
445 "public static void {}(Connection conn, java.util.List<{}> items) throws SQLException {{",
446 batch_fn_name,
447 java_param_type(param)
448 );
449 let _ = writeln!(out, " conn.setAutoCommit(false);");
450 let _ = writeln!(
451 out,
452 " try (var ps = conn.prepareStatement(\"{}\")) {{",
453 sql
454 );
455 let _ = writeln!(out, " for (var item : items) {{");
456 let setter = ps_setter(¶m.lang_type);
457 let _ = writeln!(out, " ps.{}(1, item);", setter);
458 let _ = writeln!(out, " ps.addBatch();");
459 let _ = writeln!(out, " }}");
460 let _ = writeln!(out, " ps.executeBatch();");
461 let _ = writeln!(out, " conn.commit();");
462 let _ = writeln!(out, " }} catch (SQLException e) {{");
463 let _ = writeln!(out, " conn.rollback();");
464 let _ = writeln!(out, " throw e;");
465 let _ = writeln!(out, " }} finally {{");
466 let _ = writeln!(out, " conn.setAutoCommit(true);");
467 let _ = writeln!(out, " }}");
468 let _ = write!(out, "}}");
469 } else {
470 let _ = writeln!(
471 out,
472 "public static void {}(Connection conn, int count) throws SQLException {{",
473 batch_fn_name
474 );
475 let _ = writeln!(out, " conn.setAutoCommit(false);");
476 let _ = writeln!(
477 out,
478 " try (var ps = conn.prepareStatement(\"{}\")) {{",
479 sql
480 );
481 let _ = writeln!(out, " for (int i = 0; i < count; i++) {{");
482 let _ = writeln!(out, " ps.addBatch();");
483 let _ = writeln!(out, " }}");
484 let _ = writeln!(out, " ps.executeBatch();");
485 let _ = writeln!(out, " conn.commit();");
486 let _ = writeln!(out, " }} catch (SQLException e) {{");
487 let _ = writeln!(out, " conn.rollback();");
488 let _ = writeln!(out, " throw e;");
489 let _ = writeln!(out, " }} finally {{");
490 let _ = writeln!(out, " conn.setAutoCommit(true);");
491 let _ = writeln!(out, " }}");
492 let _ = write!(out, "}}");
493 }
494 }
495 QueryCommand::Grouped => {
496 return Err(ScytheError::new(
497 ErrorCode::InternalError,
498 "grouped queries are not yet supported for java-jdbc".to_string(),
499 ));
500 }
501 }
502
503 Ok(out)
504 }
505
506 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
507 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
508 let mut out = String::new();
509 let _ = writeln!(out, "public enum {} {{", type_name);
510 for (i, value) in enum_info.values.iter().enumerate() {
511 let variant = enum_variant_name(value, &self.manifest.naming);
512 let sep = if i + 1 < enum_info.values.len() {
513 ","
514 } else {
515 ";"
516 };
517 let _ = writeln!(out, " {}(\"{}\"){}", variant, value, sep);
518 }
519 let _ = writeln!(out);
520 let _ = writeln!(out, " private final String value;");
521 let _ = writeln!(
522 out,
523 " {}(String value) {{ this.value = value; }}",
524 type_name
525 );
526 let _ = writeln!(out, " public String getValue() {{ return value; }}");
527 let _ = write!(out, "}}");
528 Ok(out)
529 }
530
531 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
532 let name = to_pascal_case(&composite.sql_name);
533 let mut out = String::new();
534 if composite.fields.is_empty() {
535 let _ = writeln!(out, "public record {}() {{}}", name);
536 } else {
537 let fields = composite
538 .fields
539 .iter()
540 .map(|f| format!("Object {}", to_camel_case(&f.name)))
541 .collect::<Vec<_>>()
542 .join(", ");
543 let _ = writeln!(out, "public record {}({}) {{}}", name, fields);
544 }
545 Ok(out)
546 }
547}