1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/java-r2dbc.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/java-r2dbc.mysql.toml");
17const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/java-r2dbc.sqlite.toml");
18
19pub struct JavaR2dbcBackend {
20 manifest: BackendManifest,
21 is_pg: bool,
22}
23
24impl JavaR2dbcBackend {
25 pub fn new(engine: &str) -> Result<Self, ScytheError> {
26 let default_toml = match engine {
27 "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
28 "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
29 "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
30 _ => {
31 return Err(ScytheError::new(
32 ErrorCode::InternalError,
33 format!("unsupported engine '{}' for java-r2dbc backend", engine),
34 ));
35 }
36 };
37 let manifest_path = Path::new("backends/java-r2dbc/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(default_toml)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 let is_pg = matches!(engine, "postgresql" | "postgres" | "pg");
46 Ok(Self { manifest, is_pg })
47 }
48}
49
50fn pg_to_r2dbc_params(sql: &str, is_pg: bool) -> String {
54 if is_pg {
55 return sql.to_string();
56 }
57 let mut result = String::with_capacity(sql.len());
58 let mut chars = sql.chars().peekable();
59 while let Some(ch) = chars.next() {
60 if ch == '$' {
61 if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
62 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
63 chars.next();
64 }
65 result.push('?');
66 } else {
67 result.push(ch);
68 }
69 } else {
70 result.push(ch);
71 }
72 }
73 result
74}
75
76fn box_primitive(java_type: &str) -> &str {
78 match java_type {
79 "boolean" => "Boolean",
80 "byte" => "Byte",
81 "short" => "Short",
82 "int" => "Integer",
83 "long" => "Long",
84 "float" => "Float",
85 "double" => "Double",
86 "char" => "Character",
87 _ => java_type,
88 }
89}
90
91fn r2dbc_row_class(java_type: &str) -> &str {
93 match java_type {
94 "boolean" | "Boolean" => "Boolean.class",
95 "byte" | "Byte" => "Byte.class",
96 "short" | "Short" => "Short.class",
97 "int" | "Integer" => "Integer.class",
98 "long" | "Long" => "Long.class",
99 "float" | "Float" => "Float.class",
100 "double" | "Double" => "Double.class",
101 "String" => "String.class",
102 "byte[]" => "byte[].class",
103 _ if java_type.contains("BigDecimal") => "java.math.BigDecimal.class",
104 _ if java_type.contains("LocalDate") => "java.time.LocalDate.class",
105 _ if java_type.contains("LocalTime") => "java.time.LocalTime.class",
106 _ if java_type.contains("OffsetTime") => "java.time.OffsetTime.class",
107 _ if java_type.contains("LocalDateTime") => "java.time.LocalDateTime.class",
108 _ if java_type.contains("OffsetDateTime") => "java.time.OffsetDateTime.class",
109 _ if java_type.contains("UUID") => "java.util.UUID.class",
110 _ => "Object.class",
111 }
112}
113
114fn java_field_type(col: &ResolvedColumn) -> String {
116 if col.nullable {
117 box_primitive(&col.lang_type).to_string()
118 } else {
119 col.full_type.clone()
120 }
121}
122
123fn java_param_type(param: &ResolvedParam) -> String {
125 if param.nullable {
126 box_primitive(¶m.lang_type).to_string()
127 } else {
128 param.full_type.clone()
129 }
130}
131
132fn is_java_primitive(java_type: &str) -> bool {
134 matches!(
135 java_type,
136 "boolean" | "byte" | "short" | "int" | "long" | "float" | "double" | "char"
137 )
138}
139
140fn java_annotated_param(param: &ResolvedParam) -> String {
142 let param_type = java_param_type(param);
143 if param.nullable {
144 format!("@Nullable {} {}", param_type, param.field_name)
145 } else if !is_java_primitive(¶m.lang_type) {
146 format!("@Nonnull {} {}", param_type, param.field_name)
147 } else {
148 format!("{} {}", param_type, param.field_name)
149 }
150}
151
152impl CodegenBackend for JavaR2dbcBackend {
153 fn name(&self) -> &str {
154 "java-r2dbc"
155 }
156
157 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
158 &self.manifest
159 }
160
161 fn supported_engines(&self) -> &[&str] {
162 &["postgresql", "mysql", "sqlite"]
163 }
164
165 fn file_header(&self) -> String {
166 "// Auto-generated by scythe. Do not edit.\n\
167 import io.r2dbc.spi.ConnectionFactory;\n\
168 import io.r2dbc.spi.Row;\n\
169 import io.r2dbc.spi.RowMetadata;\n\
170 import java.math.BigDecimal;\n\
171 import java.time.LocalDate;\n\
172 import java.time.LocalTime;\n\
173 import java.time.OffsetDateTime;\n\
174 import java.util.UUID;\n\
175 import javax.annotation.Nonnull;\n\
176 import javax.annotation.Nullable;\n\
177 import reactor.core.publisher.Flux;\n\
178 import reactor.core.publisher.Mono;"
179 .to_string()
180 }
181
182 fn generate_row_struct(
183 &self,
184 query_name: &str,
185 columns: &[ResolvedColumn],
186 ) -> Result<String, ScytheError> {
187 let struct_name = row_struct_name(query_name, &self.manifest.naming);
188 let mut out = String::new();
189
190 let fields = columns
191 .iter()
192 .map(|c| {
193 let field_type = java_field_type(c);
194 if c.nullable {
195 format!(" @Nullable {} {}", field_type, c.field_name)
196 } else {
197 format!(" {} {}", field_type, c.field_name)
198 }
199 })
200 .collect::<Vec<_>>()
201 .join(",\n");
202
203 let _ = writeln!(out, "public record {}(", struct_name);
204 let _ = writeln!(out, "{}", fields);
205 let _ = write!(out, ") {{}}");
206 Ok(out)
207 }
208
209 fn generate_model_struct(
210 &self,
211 table_name: &str,
212 columns: &[ResolvedColumn],
213 ) -> Result<String, ScytheError> {
214 let name = to_pascal_case(table_name);
215 self.generate_row_struct(&name, columns)
216 }
217
218 fn generate_query_fn(
219 &self,
220 analyzed: &AnalyzedQuery,
221 struct_name: &str,
222 columns: &[ResolvedColumn],
223 params: &[ResolvedParam],
224 ) -> Result<String, ScytheError> {
225 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
226 let sql = pg_to_r2dbc_params(
227 &super::clean_sql_oneline_with_optional(
228 &analyzed.sql,
229 &analyzed.optional_params,
230 &analyzed.params,
231 ),
232 self.is_pg,
233 );
234
235 let param_list = params
236 .iter()
237 .map(java_annotated_param)
238 .collect::<Vec<_>>()
239 .join(", ");
240 let sep = if param_list.is_empty() { "" } else { ", " };
241
242 let mut out = String::new();
243
244 let write_binds = |out: &mut String, indent: &str| {
246 for (i, param) in params.iter().enumerate() {
247 let _ = writeln!(out, "{}.bind({}, {})", indent, i, param.field_name);
248 }
249 };
250
251 let write_row_map = |out: &mut String, indent: &str| {
253 let _ = writeln!(out, "{}new {}(", indent, struct_name);
254 for (i, col) in columns.iter().enumerate() {
255 let class = r2dbc_row_class(&col.lang_type);
256 let sep = if i + 1 < columns.len() { "," } else { "" };
257 let _ = writeln!(
258 out,
259 "{} row.get(\"{}\", {}){}",
260 indent, col.name, class, sep
261 );
262 }
263 let _ = write!(out, "{})", indent);
264 };
265
266 match &analyzed.command {
267 QueryCommand::Exec => {
268 let _ = writeln!(
269 out,
270 "public static Mono<Void> {}(ConnectionFactory cf{}{}) {{",
271 func_name, sep, param_list
272 );
273 let _ = writeln!(out, " return Mono.usingWhen(");
274 let _ = writeln!(out, " Mono.from(cf.create()),");
275 let _ = writeln!(out, " conn -> {{");
276 let _ = writeln!(
277 out,
278 " var stmt = conn.createStatement(\"{}\");",
279 sql
280 );
281 write_binds(&mut out, " stmt");
282 let _ = writeln!(out, " return Mono.from(stmt.execute())");
283 let _ = writeln!(
284 out,
285 " .flatMap(result -> Mono.from(result.getRowsUpdated()))"
286 );
287 let _ = writeln!(out, " .then();");
288 let _ = writeln!(out, " }},");
289 let _ = writeln!(out, " conn -> Mono.from(conn.close())");
290 let _ = writeln!(out, " );");
291 let _ = write!(out, "}}");
292 }
293 QueryCommand::ExecResult | QueryCommand::ExecRows => {
294 let _ = writeln!(
295 out,
296 "public static Mono<Long> {}(ConnectionFactory cf{}{}) {{",
297 func_name, sep, param_list
298 );
299 let _ = writeln!(out, " return Mono.usingWhen(");
300 let _ = writeln!(out, " Mono.from(cf.create()),");
301 let _ = writeln!(out, " conn -> {{");
302 let _ = writeln!(
303 out,
304 " var stmt = conn.createStatement(\"{}\");",
305 sql
306 );
307 write_binds(&mut out, " stmt");
308 let _ = writeln!(out, " return Mono.from(stmt.execute())");
309 let _ = writeln!(
310 out,
311 " .flatMap(result -> Mono.from(result.getRowsUpdated()));"
312 );
313 let _ = writeln!(out, " }},");
314 let _ = writeln!(out, " conn -> Mono.from(conn.close())");
315 let _ = writeln!(out, " );");
316 let _ = write!(out, "}}");
317 }
318 QueryCommand::One => {
319 let _ = writeln!(
320 out,
321 "public static Mono<{}> {}(ConnectionFactory cf{}{}) {{",
322 struct_name, func_name, sep, param_list
323 );
324 let _ = writeln!(out, " return Mono.usingWhen(");
325 let _ = writeln!(out, " Mono.from(cf.create()),");
326 let _ = writeln!(out, " conn -> {{");
327 let _ = writeln!(
328 out,
329 " var stmt = conn.createStatement(\"{}\");",
330 sql
331 );
332 write_binds(&mut out, " stmt");
333 let _ = writeln!(out, " return Mono.from(stmt.execute())");
334 let _ = writeln!(
335 out,
336 " .flatMap(result -> Mono.from(result.map((row, meta) ->"
337 );
338 write_row_map(&mut out, " ");
339 let _ = writeln!(out, ")));");
340 let _ = writeln!(out, " }},");
341 let _ = writeln!(out, " conn -> Mono.from(conn.close())");
342 let _ = writeln!(out, " );");
343 let _ = write!(out, "}}");
344 }
345 QueryCommand::Many => {
346 let _ = writeln!(
347 out,
348 "public static Flux<{}> {}(ConnectionFactory cf{}{}) {{",
349 struct_name, func_name, sep, param_list
350 );
351 let _ = writeln!(out, " return Flux.usingWhen(");
352 let _ = writeln!(out, " cf.create(),");
353 let _ = writeln!(out, " conn -> {{");
354 let _ = writeln!(
355 out,
356 " var stmt = conn.createStatement(\"{}\");",
357 sql
358 );
359 write_binds(&mut out, " stmt");
360 let _ = writeln!(out, " return Flux.from(stmt.execute())");
361 let _ = writeln!(
362 out,
363 " .flatMap(result -> result.map((row, meta) ->"
364 );
365 write_row_map(&mut out, " ");
366 let _ = writeln!(out, "));");
367 let _ = writeln!(out, " }},");
368 let _ = writeln!(out, " conn -> Mono.from(conn.close())");
369 let _ = writeln!(out, " );");
370 let _ = write!(out, "}}");
371 }
372 QueryCommand::Batch => {
373 let batch_fn_name = format!("{}Batch", func_name);
374 if params.len() > 1 {
375 let params_record_name =
376 format!("{}BatchParams", to_pascal_case(&analyzed.name));
377 let record_fields = params
378 .iter()
379 .map(|p| format!("{} {}", java_param_type(p), p.field_name))
380 .collect::<Vec<_>>()
381 .join(", ");
382 let _ = writeln!(
383 out,
384 "public record {}({}) {{}}",
385 params_record_name, record_fields
386 );
387 let _ = writeln!(out);
388 let _ = writeln!(
389 out,
390 "public static Mono<Void> {}(ConnectionFactory cf, java.util.List<{}> items) {{",
391 batch_fn_name, params_record_name
392 );
393 let _ = writeln!(out, " return Mono.from(cf.create())");
394 let _ = writeln!(out, " .flatMap(conn -> {{");
395 let _ = writeln!(out, " return Mono.from(conn.beginTransaction())");
396 let _ = writeln!(out, " .then(Mono.defer(() -> {{");
397 let _ = writeln!(
398 out,
399 " var stmt = conn.createStatement(\"{}\");",
400 sql
401 );
402 let _ = writeln!(out, " boolean first = true;");
403 let _ = writeln!(out, " for (var item : items) {{");
404 let _ = writeln!(out, " if (!first) stmt.add();");
405 for (i, param) in params.iter().enumerate() {
406 let _ = writeln!(
407 out,
408 " stmt.bind({}, item.{}());",
409 i, param.field_name
410 );
411 }
412 let _ = writeln!(out, " first = false;");
413 let _ = writeln!(out, " }}");
414 let _ = writeln!(
415 out,
416 " return Flux.from(stmt.execute()).then();"
417 );
418 let _ = writeln!(out, " }}))");
419 let _ = writeln!(
420 out,
421 " .then(Mono.from(conn.commitTransaction()))"
422 );
423 let _ = writeln!(
424 out,
425 " .onErrorResume(e -> Mono.from(conn.rollbackTransaction()).then(Mono.error(e)))"
426 );
427 let _ = writeln!(
428 out,
429 " .doFinally(s -> Mono.from(conn.close()).subscribe());"
430 );
431 let _ = writeln!(out, " }});");
432 let _ = write!(out, "}}");
433 } else if params.len() == 1 {
434 let param = ¶ms[0];
435 let _ = writeln!(
436 out,
437 "public static Mono<Void> {}(ConnectionFactory cf, java.util.List<{}> items) {{",
438 batch_fn_name,
439 java_param_type(param)
440 );
441 let _ = writeln!(out, " return Mono.from(cf.create())");
442 let _ = writeln!(out, " .flatMap(conn -> {{");
443 let _ = writeln!(out, " return Mono.from(conn.beginTransaction())");
444 let _ = writeln!(out, " .then(Mono.defer(() -> {{");
445 let _ = writeln!(
446 out,
447 " var stmt = conn.createStatement(\"{}\");",
448 sql
449 );
450 let _ = writeln!(out, " boolean first = true;");
451 let _ = writeln!(out, " for (var item : items) {{");
452 let _ = writeln!(out, " if (!first) stmt.add();");
453 let _ = writeln!(out, " stmt.bind(0, item);");
454 let _ = writeln!(out, " first = false;");
455 let _ = writeln!(out, " }}");
456 let _ = writeln!(
457 out,
458 " return Flux.from(stmt.execute()).then();"
459 );
460 let _ = writeln!(out, " }}))");
461 let _ = writeln!(
462 out,
463 " .then(Mono.from(conn.commitTransaction()))"
464 );
465 let _ = writeln!(
466 out,
467 " .onErrorResume(e -> Mono.from(conn.rollbackTransaction()).then(Mono.error(e)))"
468 );
469 let _ = writeln!(
470 out,
471 " .doFinally(s -> Mono.from(conn.close()).subscribe());"
472 );
473 let _ = writeln!(out, " }});");
474 let _ = write!(out, "}}");
475 } else {
476 let _ = writeln!(
477 out,
478 "public static Mono<Void> {}(ConnectionFactory cf, int count) {{",
479 batch_fn_name
480 );
481 let _ = writeln!(out, " return Mono.from(cf.create())");
482 let _ = writeln!(out, " .flatMap(conn -> {{");
483 let _ = writeln!(out, " return Mono.from(conn.beginTransaction())");
484 let _ = writeln!(out, " .then(Mono.defer(() -> {{");
485 let _ = writeln!(
486 out,
487 " var stmt = conn.createStatement(\"{}\");",
488 sql
489 );
490 let _ = writeln!(
491 out,
492 " for (int i = 1; i < count; i++) {{"
493 );
494 let _ = writeln!(out, " stmt.add();");
495 let _ = writeln!(out, " }}");
496 let _ = writeln!(
497 out,
498 " return Flux.from(stmt.execute()).then();"
499 );
500 let _ = writeln!(out, " }}))");
501 let _ = writeln!(
502 out,
503 " .then(Mono.from(conn.commitTransaction()))"
504 );
505 let _ = writeln!(
506 out,
507 " .onErrorResume(e -> Mono.from(conn.rollbackTransaction()).then(Mono.error(e)))"
508 );
509 let _ = writeln!(
510 out,
511 " .doFinally(s -> Mono.from(conn.close()).subscribe());"
512 );
513 let _ = writeln!(out, " }});");
514 let _ = write!(out, "}}");
515 }
516 }
517 QueryCommand::Grouped => {
518 return Err(ScytheError::new(
520 ErrorCode::InternalError,
521 "grouped queries are not yet supported for java-r2dbc".to_string(),
522 ));
523 }
524 }
525
526 Ok(out)
527 }
528
529 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
530 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
531 let mut out = String::new();
532 let _ = writeln!(out, "public enum {} {{", type_name);
533 for (i, value) in enum_info.values.iter().enumerate() {
534 let variant = enum_variant_name(value, &self.manifest.naming);
535 let sep = if i + 1 < enum_info.values.len() {
536 ","
537 } else {
538 ";"
539 };
540 let _ = writeln!(out, " {}(\"{}\"){}", variant, value, sep);
541 }
542 let _ = writeln!(out);
543 let _ = writeln!(out, " private final String value;");
544 let _ = writeln!(
545 out,
546 " {}(String value) {{ this.value = value; }}",
547 type_name
548 );
549 let _ = writeln!(out, " public String getValue() {{ return value; }}");
550 let _ = write!(out, "}}");
551 Ok(out)
552 }
553
554 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
555 let name = to_pascal_case(&composite.sql_name);
556 let mut out = String::new();
557 if composite.fields.is_empty() {
558 let _ = writeln!(out, "public record {}() {{}}", name);
559 } else {
560 let fields = composite
561 .fields
562 .iter()
563 .map(|f| format!("Object {}", to_camel_case(&f.name)))
564 .collect::<Vec<_>>()
565 .join(", ");
566 let _ = writeln!(out, "public record {}({}) {{}}", name, fields);
567 }
568 Ok(out)
569 }
570}