1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/java-jdbc.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/java-jdbc.mysql.toml");
17const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/java-jdbc.sqlite.toml");
18
19pub struct JavaJdbcBackend {
20 manifest: BackendManifest,
21}
22
23impl JavaJdbcBackend {
24 pub fn new(engine: &str) -> Result<Self, ScytheError> {
25 let default_toml = match engine {
26 "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
27 "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
28 "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
29 _ => {
30 return Err(ScytheError::new(
31 ErrorCode::InternalError,
32 format!("unsupported engine '{}' for java-jdbc backend", engine),
33 ));
34 }
35 };
36 let manifest_path = Path::new("backends/java-jdbc/manifest.toml");
37 let manifest = if manifest_path.exists() {
38 load_manifest(manifest_path)
39 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40 } else {
41 toml::from_str(default_toml)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 };
44 Ok(Self { manifest })
45 }
46}
47
48fn pg_to_jdbc_params(sql: &str) -> String {
50 let mut result = String::with_capacity(sql.len());
51 let mut chars = sql.chars().peekable();
52 while let Some(ch) = chars.next() {
53 if ch == '$' {
54 if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
56 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
58 chars.next();
59 }
60 result.push('?');
61 } else {
62 result.push(ch);
63 }
64 } else {
65 result.push(ch);
66 }
67 }
68 result
69}
70
71fn box_primitive(java_type: &str) -> &str {
73 match java_type {
74 "boolean" => "Boolean",
75 "byte" => "Byte",
76 "short" => "Short",
77 "int" => "Integer",
78 "long" => "Long",
79 "float" => "Float",
80 "double" => "Double",
81 "char" => "Character",
82 _ => java_type,
83 }
84}
85
86fn rs_getter(java_type: &str) -> &str {
88 match java_type {
89 "boolean" | "Boolean" => "getBoolean",
90 "byte" | "Byte" => "getByte",
91 "short" | "Short" => "getShort",
92 "int" | "Integer" => "getInt",
93 "long" | "Long" => "getLong",
94 "float" | "Float" => "getFloat",
95 "double" | "Double" => "getDouble",
96 "String" => "getString",
97 "byte[]" => "getBytes",
98 _ if java_type.contains("BigDecimal") => "getBigDecimal",
99 _ if java_type.contains("LocalDate") => "getObject",
100 _ if java_type.contains("LocalTime") => "getObject",
101 _ if java_type.contains("OffsetTime") => "getObject",
102 _ if java_type.contains("LocalDateTime") => "getObject",
103 _ if java_type.contains("OffsetDateTime") => "getObject",
104 _ if java_type.contains("UUID") => "getObject",
105 _ => "getObject",
106 }
107}
108
109fn ps_setter(java_type: &str) -> &str {
111 match java_type {
112 "boolean" | "Boolean" => "setBoolean",
113 "byte" | "Byte" => "setByte",
114 "short" | "Short" => "setShort",
115 "int" | "Integer" => "setInt",
116 "long" | "Long" => "setLong",
117 "float" | "Float" => "setFloat",
118 "double" | "Double" => "setDouble",
119 "String" => "setString",
120 "byte[]" => "setBytes",
121 _ if java_type.contains("BigDecimal") => "setBigDecimal",
122 _ => "setObject",
123 }
124}
125
126fn java_field_type(col: &ResolvedColumn) -> String {
128 if col.nullable {
129 box_primitive(&col.lang_type).to_string()
130 } else {
131 col.full_type.clone()
132 }
133}
134
135fn java_param_type(param: &ResolvedParam) -> String {
137 if param.nullable {
138 box_primitive(¶m.lang_type).to_string()
139 } else {
140 param.full_type.clone()
141 }
142}
143
144impl CodegenBackend for JavaJdbcBackend {
145 fn name(&self) -> &str {
146 "java-jdbc"
147 }
148
149 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
150 &self.manifest
151 }
152
153 fn supported_engines(&self) -> &[&str] {
154 &["postgresql", "mysql", "sqlite"]
155 }
156
157 fn file_header(&self) -> String {
158 "// Auto-generated by scythe. Do not edit.\n\
159 import java.math.BigDecimal;\n\
160 import java.sql.*;\n\
161 import java.time.OffsetDateTime;\n\
162 import java.util.ArrayList;\n\
163 import java.util.List;\n\
164 import javax.annotation.Nullable;"
165 .to_string()
166 }
167
168 fn generate_row_struct(
169 &self,
170 query_name: &str,
171 columns: &[ResolvedColumn],
172 ) -> Result<String, ScytheError> {
173 let struct_name = row_struct_name(query_name, &self.manifest.naming);
174 let mut out = String::new();
175
176 let fields = columns
178 .iter()
179 .map(|c| {
180 let field_type = java_field_type(c);
181 if c.nullable {
182 format!(" @Nullable {} {}", field_type, c.field_name)
183 } else {
184 format!(" {} {}", field_type, c.field_name)
185 }
186 })
187 .collect::<Vec<_>>()
188 .join(",\n");
189
190 let _ = writeln!(out, "public record {}(", struct_name);
191 let _ = writeln!(out, "{}", fields);
192 let _ = writeln!(out, ") {{");
193
194 let _ = writeln!(
196 out,
197 " public static {} fromResultSet(ResultSet rs) throws SQLException {{",
198 struct_name
199 );
200 let _ = writeln!(out, " return new {}(", struct_name);
201 for (i, col) in columns.iter().enumerate() {
202 let getter = rs_getter(&col.lang_type);
203 let sep = if i + 1 < columns.len() { "," } else { "" };
204 let _ = writeln!(out, " rs.{}(\"{}\"){}", getter, col.name, sep);
205 }
206 let _ = writeln!(out, " );");
207 let _ = writeln!(out, " }}");
208 let _ = write!(out, "}}");
209 Ok(out)
210 }
211
212 fn generate_model_struct(
213 &self,
214 table_name: &str,
215 columns: &[ResolvedColumn],
216 ) -> Result<String, ScytheError> {
217 let name = to_pascal_case(table_name);
218 self.generate_row_struct(&name, columns)
219 }
220
221 fn generate_query_fn(
222 &self,
223 analyzed: &AnalyzedQuery,
224 struct_name: &str,
225 _columns: &[ResolvedColumn],
226 params: &[ResolvedParam],
227 ) -> Result<String, ScytheError> {
228 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
229 let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
230 &analyzed.sql,
231 &analyzed.optional_params,
232 &analyzed.params,
233 ));
234
235 let param_list = params
236 .iter()
237 .map(|p| {
238 let param_type = java_param_type(p);
239 format!("{} {}", param_type, p.field_name)
240 })
241 .collect::<Vec<_>>()
242 .join(", ");
243 let sep = if param_list.is_empty() { "" } else { ", " };
244
245 let mut out = String::new();
246
247 match &analyzed.command {
248 QueryCommand::Exec => {
249 let _ = writeln!(
250 out,
251 "public static void {}(Connection conn{}{}) throws SQLException {{",
252 func_name, sep, param_list
253 );
254 let _ = writeln!(
255 out,
256 " try (var ps = conn.prepareStatement(\"{}\")) {{",
257 sql
258 );
259 for (i, param) in params.iter().enumerate() {
260 let setter = ps_setter(¶m.lang_type);
261 let _ = writeln!(
262 out,
263 " ps.{}({}, {});",
264 setter,
265 i + 1,
266 param.field_name
267 );
268 }
269 let _ = writeln!(out, " ps.executeUpdate();");
270 let _ = writeln!(out, " }}");
271 let _ = write!(out, "}}");
272 }
273 QueryCommand::ExecResult | QueryCommand::ExecRows => {
274 let _ = writeln!(
275 out,
276 "public static int {}(Connection conn{}{}) throws SQLException {{",
277 func_name, sep, param_list
278 );
279 let _ = writeln!(
280 out,
281 " try (var ps = conn.prepareStatement(\"{}\")) {{",
282 sql
283 );
284 for (i, param) in params.iter().enumerate() {
285 let setter = ps_setter(¶m.lang_type);
286 let _ = writeln!(
287 out,
288 " ps.{}({}, {});",
289 setter,
290 i + 1,
291 param.field_name
292 );
293 }
294 let _ = writeln!(out, " return ps.executeUpdate();");
295 let _ = writeln!(out, " }}");
296 let _ = write!(out, "}}");
297 }
298 QueryCommand::One => {
299 let _ = writeln!(
300 out,
301 "public static {} {}(Connection conn{}{}) throws SQLException {{",
302 struct_name, func_name, sep, param_list
303 );
304 let _ = writeln!(
305 out,
306 " try (var ps = conn.prepareStatement(\"{}\")) {{",
307 sql
308 );
309 for (i, param) in params.iter().enumerate() {
310 let setter = ps_setter(¶m.lang_type);
311 let _ = writeln!(
312 out,
313 " ps.{}({}, {});",
314 setter,
315 i + 1,
316 param.field_name
317 );
318 }
319 let _ = writeln!(out, " try (ResultSet rs = ps.executeQuery()) {{");
320 let _ = writeln!(out, " if (rs.next()) {{");
321 let _ = writeln!(
322 out,
323 " return {}.fromResultSet(rs);",
324 struct_name
325 );
326 let _ = writeln!(out, " }}");
327 let _ = writeln!(out, " return null;");
328 let _ = writeln!(out, " }}");
329 let _ = writeln!(out, " }}");
330 let _ = write!(out, "}}");
331 }
332 QueryCommand::Many => {
333 let _ = writeln!(
334 out,
335 "public static java.util.List<{}> {}(Connection conn{}{}) throws SQLException {{",
336 struct_name, func_name, sep, param_list
337 );
338 let _ = writeln!(
339 out,
340 " try (var ps = conn.prepareStatement(\"{}\")) {{",
341 sql
342 );
343 for (i, param) in params.iter().enumerate() {
344 let setter = ps_setter(¶m.lang_type);
345 let _ = writeln!(
346 out,
347 " ps.{}({}, {});",
348 setter,
349 i + 1,
350 param.field_name
351 );
352 }
353 let _ = writeln!(out, " try (ResultSet rs = ps.executeQuery()) {{");
354 let _ = writeln!(
355 out,
356 " java.util.List<{}> result = new java.util.ArrayList<>();",
357 struct_name
358 );
359 let _ = writeln!(out, " while (rs.next()) {{");
360 let _ = writeln!(
361 out,
362 " result.add({}.fromResultSet(rs));",
363 struct_name
364 );
365 let _ = writeln!(out, " }}");
366 let _ = writeln!(out, " return result;");
367 let _ = writeln!(out, " }}");
368 let _ = writeln!(out, " }}");
369 let _ = write!(out, "}}");
370 }
371 QueryCommand::Batch => {
372 let batch_fn_name = format!("{}Batch", func_name);
373 if params.len() > 1 {
374 let params_record_name =
376 format!("{}BatchParams", to_pascal_case(&analyzed.name));
377 let record_fields = params
378 .iter()
379 .map(|p| format!("{} {}", java_param_type(p), p.field_name))
380 .collect::<Vec<_>>()
381 .join(", ");
382 let _ = writeln!(
383 out,
384 "public record {}({}) {{}}",
385 params_record_name, record_fields
386 );
387 let _ = writeln!(out);
388 let _ = writeln!(
389 out,
390 "public static void {}(Connection conn, java.util.List<{}> items) throws SQLException {{",
391 batch_fn_name, params_record_name
392 );
393 let _ = writeln!(out, " conn.setAutoCommit(false);");
394 let _ = writeln!(
395 out,
396 " try (var ps = conn.prepareStatement(\"{}\")) {{",
397 sql
398 );
399 let _ = writeln!(out, " for (var item : items) {{");
400 for (i, param) in params.iter().enumerate() {
401 let setter = ps_setter(¶m.lang_type);
402 let _ = writeln!(
403 out,
404 " ps.{}({}, item.{}());",
405 setter,
406 i + 1,
407 param.field_name
408 );
409 }
410 let _ = writeln!(out, " ps.addBatch();");
411 let _ = writeln!(out, " }}");
412 let _ = writeln!(out, " ps.executeBatch();");
413 let _ = writeln!(out, " conn.commit();");
414 let _ = writeln!(out, " }} catch (SQLException e) {{");
415 let _ = writeln!(out, " conn.rollback();");
416 let _ = writeln!(out, " throw e;");
417 let _ = writeln!(out, " }} finally {{");
418 let _ = writeln!(out, " conn.setAutoCommit(true);");
419 let _ = writeln!(out, " }}");
420 let _ = write!(out, "}}");
421 } else if params.len() == 1 {
422 let param = ¶ms[0];
423 let _ = writeln!(
424 out,
425 "public static void {}(Connection conn, java.util.List<{}> items) throws SQLException {{",
426 batch_fn_name,
427 java_param_type(param)
428 );
429 let _ = writeln!(out, " conn.setAutoCommit(false);");
430 let _ = writeln!(
431 out,
432 " try (var ps = conn.prepareStatement(\"{}\")) {{",
433 sql
434 );
435 let _ = writeln!(out, " for (var item : items) {{");
436 let setter = ps_setter(¶m.lang_type);
437 let _ = writeln!(out, " ps.{}(1, item);", setter);
438 let _ = writeln!(out, " ps.addBatch();");
439 let _ = writeln!(out, " }}");
440 let _ = writeln!(out, " ps.executeBatch();");
441 let _ = writeln!(out, " conn.commit();");
442 let _ = writeln!(out, " }} catch (SQLException e) {{");
443 let _ = writeln!(out, " conn.rollback();");
444 let _ = writeln!(out, " throw e;");
445 let _ = writeln!(out, " }} finally {{");
446 let _ = writeln!(out, " conn.setAutoCommit(true);");
447 let _ = writeln!(out, " }}");
448 let _ = write!(out, "}}");
449 } else {
450 let _ = writeln!(
451 out,
452 "public static void {}(Connection conn, int count) throws SQLException {{",
453 batch_fn_name
454 );
455 let _ = writeln!(out, " conn.setAutoCommit(false);");
456 let _ = writeln!(
457 out,
458 " try (var ps = conn.prepareStatement(\"{}\")) {{",
459 sql
460 );
461 let _ = writeln!(out, " for (int i = 0; i < count; i++) {{");
462 let _ = writeln!(out, " ps.addBatch();");
463 let _ = writeln!(out, " }}");
464 let _ = writeln!(out, " ps.executeBatch();");
465 let _ = writeln!(out, " conn.commit();");
466 let _ = writeln!(out, " }} catch (SQLException e) {{");
467 let _ = writeln!(out, " conn.rollback();");
468 let _ = writeln!(out, " throw e;");
469 let _ = writeln!(out, " }} finally {{");
470 let _ = writeln!(out, " conn.setAutoCommit(true);");
471 let _ = writeln!(out, " }}");
472 let _ = write!(out, "}}");
473 }
474 }
475 }
476
477 Ok(out)
478 }
479
480 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
481 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
482 let mut out = String::new();
483 let _ = writeln!(out, "public enum {} {{", type_name);
484 for (i, value) in enum_info.values.iter().enumerate() {
485 let variant = enum_variant_name(value, &self.manifest.naming);
486 let sep = if i + 1 < enum_info.values.len() {
487 ","
488 } else {
489 ";"
490 };
491 let _ = writeln!(out, " {}(\"{}\"){}", variant, value, sep);
492 }
493 let _ = writeln!(out);
494 let _ = writeln!(out, " private final String value;");
495 let _ = writeln!(
496 out,
497 " {}(String value) {{ this.value = value; }}",
498 type_name
499 );
500 let _ = writeln!(out, " public String getValue() {{ return value; }}");
501 let _ = write!(out, "}}");
502 Ok(out)
503 }
504
505 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
506 let name = to_pascal_case(&composite.sql_name);
507 let mut out = String::new();
508 if composite.fields.is_empty() {
509 let _ = writeln!(out, "public record {}() {{}}", name);
510 } else {
511 let fields = composite
512 .fields
513 .iter()
514 .map(|f| format!("Object {}", to_camel_case(&f.name)))
515 .collect::<Vec<_>>()
516 .join(", ");
517 let _ = writeln!(out, "public record {}({}) {{}}", name, fields);
518 }
519 Ok(out)
520 }
521}