Skip to main content

scythe_codegen/backends/
java_r2dbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/java-r2dbc.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/java-r2dbc.mysql.toml");
17const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/java-r2dbc.sqlite.toml");
18
19pub struct JavaR2dbcBackend {
20    manifest: BackendManifest,
21    is_pg: bool,
22}
23
24impl JavaR2dbcBackend {
25    pub fn new(engine: &str) -> Result<Self, ScytheError> {
26        let default_toml = match engine {
27            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
28            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
29            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
30            _ => {
31                return Err(ScytheError::new(
32                    ErrorCode::InternalError,
33                    format!("unsupported engine '{}' for java-r2dbc backend", engine),
34                ));
35            }
36        };
37        let manifest_path = Path::new("backends/java-r2dbc/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(default_toml)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        let is_pg = matches!(engine, "postgresql" | "postgres" | "pg");
46        Ok(Self { manifest, is_pg })
47    }
48}
49
50/// Convert PostgreSQL `$1, $2, ...` placeholders for R2DBC drivers.
51/// PostgreSQL R2DBC uses `$1, $2, ...` natively (no conversion needed).
52/// MySQL/SQLite R2DBC drivers use `?` placeholders.
53fn pg_to_r2dbc_params(sql: &str, is_pg: bool) -> String {
54    if is_pg {
55        return sql.to_string();
56    }
57    let mut result = String::with_capacity(sql.len());
58    let mut chars = sql.chars().peekable();
59    while let Some(ch) = chars.next() {
60        if ch == '$' {
61            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
62                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
63                    chars.next();
64                }
65                result.push('?');
66            } else {
67                result.push(ch);
68            }
69        } else {
70            result.push(ch);
71        }
72    }
73    result
74}
75
76/// Convert a Java primitive type to its boxed equivalent for nullable usage.
77fn box_primitive(java_type: &str) -> &str {
78    match java_type {
79        "boolean" => "Boolean",
80        "byte" => "Byte",
81        "short" => "Short",
82        "int" => "Integer",
83        "long" => "Long",
84        "float" => "Float",
85        "double" => "Double",
86        "char" => "Character",
87        _ => java_type,
88    }
89}
90
91/// Get the R2DBC Row getter class for a given Java type.
92fn r2dbc_row_class(java_type: &str) -> &str {
93    match java_type {
94        "boolean" | "Boolean" => "Boolean.class",
95        "byte" | "Byte" => "Byte.class",
96        "short" | "Short" => "Short.class",
97        "int" | "Integer" => "Integer.class",
98        "long" | "Long" => "Long.class",
99        "float" | "Float" => "Float.class",
100        "double" | "Double" => "Double.class",
101        "String" => "String.class",
102        "byte[]" => "byte[].class",
103        _ if java_type.contains("BigDecimal") => "java.math.BigDecimal.class",
104        _ if java_type.contains("LocalDate") => "java.time.LocalDate.class",
105        _ if java_type.contains("LocalTime") => "java.time.LocalTime.class",
106        _ if java_type.contains("OffsetTime") => "java.time.OffsetTime.class",
107        _ if java_type.contains("LocalDateTime") => "java.time.LocalDateTime.class",
108        _ if java_type.contains("OffsetDateTime") => "java.time.OffsetDateTime.class",
109        _ if java_type.contains("UUID") => "java.util.UUID.class",
110        _ => "Object.class",
111    }
112}
113
114/// Resolve the display type for a Java field, boxing primitives when nullable.
115fn java_field_type(col: &ResolvedColumn) -> String {
116    if col.nullable {
117        box_primitive(&col.lang_type).to_string()
118    } else {
119        col.full_type.clone()
120    }
121}
122
123/// Resolve the display type for a Java param, boxing primitives when nullable.
124fn java_param_type(param: &ResolvedParam) -> String {
125    if param.nullable {
126        box_primitive(&param.lang_type).to_string()
127    } else {
128        param.full_type.clone()
129    }
130}
131
132/// Check whether a Java type is a primitive (not a reference type).
133fn is_java_primitive(java_type: &str) -> bool {
134    matches!(
135        java_type,
136        "boolean" | "byte" | "short" | "int" | "long" | "float" | "double" | "char"
137    )
138}
139
140/// Format a Java parameter with nullability annotation.
141fn java_annotated_param(param: &ResolvedParam) -> String {
142    let param_type = java_param_type(param);
143    if param.nullable {
144        format!("@Nullable {} {}", param_type, param.field_name)
145    } else if !is_java_primitive(&param.lang_type) {
146        format!("@Nonnull {} {}", param_type, param.field_name)
147    } else {
148        format!("{} {}", param_type, param.field_name)
149    }
150}
151
152impl CodegenBackend for JavaR2dbcBackend {
153    fn name(&self) -> &str {
154        "java-r2dbc"
155    }
156
157    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
158        &self.manifest
159    }
160
161    fn supported_engines(&self) -> &[&str] {
162        &["postgresql", "mysql", "sqlite"]
163    }
164
165    fn file_header(&self) -> String {
166        "// Auto-generated by scythe. Do not edit.\n\
167         import io.r2dbc.spi.ConnectionFactory;\n\
168         import io.r2dbc.spi.Row;\n\
169         import io.r2dbc.spi.RowMetadata;\n\
170         import java.math.BigDecimal;\n\
171         import java.time.LocalDate;\n\
172         import java.time.LocalTime;\n\
173         import java.time.OffsetDateTime;\n\
174         import java.util.UUID;\n\
175         import javax.annotation.Nonnull;\n\
176         import javax.annotation.Nullable;\n\
177         import reactor.core.publisher.Flux;\n\
178         import reactor.core.publisher.Mono;"
179            .to_string()
180    }
181
182    fn generate_row_struct(
183        &self,
184        query_name: &str,
185        columns: &[ResolvedColumn],
186    ) -> Result<String, ScytheError> {
187        let struct_name = row_struct_name(query_name, &self.manifest.naming);
188        let mut out = String::new();
189
190        let fields = columns
191            .iter()
192            .map(|c| {
193                let field_type = java_field_type(c);
194                if c.nullable {
195                    format!("    @Nullable {} {}", field_type, c.field_name)
196                } else {
197                    format!("    {} {}", field_type, c.field_name)
198                }
199            })
200            .collect::<Vec<_>>()
201            .join(",\n");
202
203        let _ = writeln!(out, "public record {}(", struct_name);
204        let _ = writeln!(out, "{}", fields);
205        let _ = write!(out, ") {{}}");
206        Ok(out)
207    }
208
209    fn generate_model_struct(
210        &self,
211        table_name: &str,
212        columns: &[ResolvedColumn],
213    ) -> Result<String, ScytheError> {
214        let name = to_pascal_case(table_name);
215        self.generate_row_struct(&name, columns)
216    }
217
218    fn generate_query_fn(
219        &self,
220        analyzed: &AnalyzedQuery,
221        struct_name: &str,
222        columns: &[ResolvedColumn],
223        params: &[ResolvedParam],
224    ) -> Result<String, ScytheError> {
225        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
226        let sql = pg_to_r2dbc_params(
227            &super::clean_sql_oneline_with_optional(
228                &analyzed.sql,
229                &analyzed.optional_params,
230                &analyzed.params,
231            ),
232            self.is_pg,
233        );
234
235        let param_list = params
236            .iter()
237            .map(java_annotated_param)
238            .collect::<Vec<_>>()
239            .join(", ");
240        let sep = if param_list.is_empty() { "" } else { ", " };
241
242        let mut out = String::new();
243
244        // Helper: write .bind() calls for R2DBC (0-based indexing)
245        let write_binds = |out: &mut String, indent: &str| {
246            for (i, param) in params.iter().enumerate() {
247                let _ = writeln!(out, "{}.bind({}, {})", indent, i, param.field_name);
248            }
249        };
250
251        // Helper: write row mapping expression
252        let write_row_map = |out: &mut String, indent: &str| {
253            let _ = writeln!(out, "{}new {}(", indent, struct_name);
254            for (i, col) in columns.iter().enumerate() {
255                let class = r2dbc_row_class(&col.lang_type);
256                let sep = if i + 1 < columns.len() { "," } else { "" };
257                let _ = writeln!(
258                    out,
259                    "{}    row.get(\"{}\", {}){}",
260                    indent, col.name, class, sep
261                );
262            }
263            let _ = write!(out, "{})", indent);
264        };
265
266        match &analyzed.command {
267            QueryCommand::Exec => {
268                let _ = writeln!(
269                    out,
270                    "public static Mono<Void> {}(ConnectionFactory cf{}{}) {{",
271                    func_name, sep, param_list
272                );
273                let _ = writeln!(out, "    return Mono.usingWhen(");
274                let _ = writeln!(out, "        Mono.from(cf.create()),");
275                let _ = writeln!(out, "        conn -> {{");
276                let _ = writeln!(
277                    out,
278                    "            var stmt = conn.createStatement(\"{}\");",
279                    sql
280                );
281                write_binds(&mut out, "            stmt");
282                let _ = writeln!(out, "            return Mono.from(stmt.execute())");
283                let _ = writeln!(
284                    out,
285                    "                .flatMap(result -> Mono.from(result.getRowsUpdated()))"
286                );
287                let _ = writeln!(out, "                .then();");
288                let _ = writeln!(out, "        }},");
289                let _ = writeln!(out, "        conn -> Mono.from(conn.close())");
290                let _ = writeln!(out, "    );");
291                let _ = write!(out, "}}");
292            }
293            QueryCommand::ExecResult | QueryCommand::ExecRows => {
294                let _ = writeln!(
295                    out,
296                    "public static Mono<Long> {}(ConnectionFactory cf{}{}) {{",
297                    func_name, sep, param_list
298                );
299                let _ = writeln!(out, "    return Mono.usingWhen(");
300                let _ = writeln!(out, "        Mono.from(cf.create()),");
301                let _ = writeln!(out, "        conn -> {{");
302                let _ = writeln!(
303                    out,
304                    "            var stmt = conn.createStatement(\"{}\");",
305                    sql
306                );
307                write_binds(&mut out, "            stmt");
308                let _ = writeln!(out, "            return Mono.from(stmt.execute())");
309                let _ = writeln!(
310                    out,
311                    "                .flatMap(result -> Mono.from(result.getRowsUpdated()));"
312                );
313                let _ = writeln!(out, "        }},");
314                let _ = writeln!(out, "        conn -> Mono.from(conn.close())");
315                let _ = writeln!(out, "    );");
316                let _ = write!(out, "}}");
317            }
318            QueryCommand::One => {
319                let _ = writeln!(
320                    out,
321                    "public static Mono<{}> {}(ConnectionFactory cf{}{}) {{",
322                    struct_name, func_name, sep, param_list
323                );
324                let _ = writeln!(out, "    return Mono.usingWhen(");
325                let _ = writeln!(out, "        Mono.from(cf.create()),");
326                let _ = writeln!(out, "        conn -> {{");
327                let _ = writeln!(
328                    out,
329                    "            var stmt = conn.createStatement(\"{}\");",
330                    sql
331                );
332                write_binds(&mut out, "            stmt");
333                let _ = writeln!(out, "            return Mono.from(stmt.execute())");
334                let _ = writeln!(
335                    out,
336                    "                .flatMap(result -> Mono.from(result.map((row, meta) ->"
337                );
338                write_row_map(&mut out, "                    ");
339                let _ = writeln!(out, ")));");
340                let _ = writeln!(out, "        }},");
341                let _ = writeln!(out, "        conn -> Mono.from(conn.close())");
342                let _ = writeln!(out, "    );");
343                let _ = write!(out, "}}");
344            }
345            QueryCommand::Many => {
346                let _ = writeln!(
347                    out,
348                    "public static Flux<{}> {}(ConnectionFactory cf{}{}) {{",
349                    struct_name, func_name, sep, param_list
350                );
351                let _ = writeln!(out, "    return Flux.usingWhen(");
352                let _ = writeln!(out, "        cf.create(),");
353                let _ = writeln!(out, "        conn -> {{");
354                let _ = writeln!(
355                    out,
356                    "            var stmt = conn.createStatement(\"{}\");",
357                    sql
358                );
359                write_binds(&mut out, "            stmt");
360                let _ = writeln!(out, "            return Flux.from(stmt.execute())");
361                let _ = writeln!(
362                    out,
363                    "                .flatMap(result -> result.map((row, meta) ->"
364                );
365                write_row_map(&mut out, "                    ");
366                let _ = writeln!(out, "));");
367                let _ = writeln!(out, "        }},");
368                let _ = writeln!(out, "        conn -> Mono.from(conn.close())");
369                let _ = writeln!(out, "    );");
370                let _ = write!(out, "}}");
371            }
372            QueryCommand::Batch => {
373                let batch_fn_name = format!("{}Batch", func_name);
374                if params.len() > 1 {
375                    let params_record_name =
376                        format!("{}BatchParams", to_pascal_case(&analyzed.name));
377                    let record_fields = params
378                        .iter()
379                        .map(|p| format!("{} {}", java_param_type(p), p.field_name))
380                        .collect::<Vec<_>>()
381                        .join(", ");
382                    let _ = writeln!(
383                        out,
384                        "public record {}({}) {{}}",
385                        params_record_name, record_fields
386                    );
387                    let _ = writeln!(out);
388                    let _ = writeln!(
389                        out,
390                        "public static Mono<Void> {}(ConnectionFactory cf, java.util.List<{}> items) {{",
391                        batch_fn_name, params_record_name
392                    );
393                    let _ = writeln!(out, "    return Mono.from(cf.create())");
394                    let _ = writeln!(out, "        .flatMap(conn -> {{");
395                    let _ = writeln!(out, "            return Mono.from(conn.beginTransaction())");
396                    let _ = writeln!(out, "                .then(Mono.defer(() -> {{");
397                    let _ = writeln!(
398                        out,
399                        "                    var stmt = conn.createStatement(\"{}\");",
400                        sql
401                    );
402                    let _ = writeln!(out, "                    boolean first = true;");
403                    let _ = writeln!(out, "                    for (var item : items) {{");
404                    let _ = writeln!(out, "                        if (!first) stmt.add();");
405                    for (i, param) in params.iter().enumerate() {
406                        let _ = writeln!(
407                            out,
408                            "                        stmt.bind({}, item.{}());",
409                            i, param.field_name
410                        );
411                    }
412                    let _ = writeln!(out, "                        first = false;");
413                    let _ = writeln!(out, "                    }}");
414                    let _ = writeln!(
415                        out,
416                        "                    return Flux.from(stmt.execute()).then();"
417                    );
418                    let _ = writeln!(out, "                }}))");
419                    let _ = writeln!(
420                        out,
421                        "                .then(Mono.from(conn.commitTransaction()))"
422                    );
423                    let _ = writeln!(
424                        out,
425                        "                .onErrorResume(e -> Mono.from(conn.rollbackTransaction()).then(Mono.error(e)))"
426                    );
427                    let _ = writeln!(
428                        out,
429                        "                .doFinally(s -> Mono.from(conn.close()).subscribe());"
430                    );
431                    let _ = writeln!(out, "        }});");
432                    let _ = write!(out, "}}");
433                } else if params.len() == 1 {
434                    let param = &params[0];
435                    let _ = writeln!(
436                        out,
437                        "public static Mono<Void> {}(ConnectionFactory cf, java.util.List<{}> items) {{",
438                        batch_fn_name,
439                        java_param_type(param)
440                    );
441                    let _ = writeln!(out, "    return Mono.from(cf.create())");
442                    let _ = writeln!(out, "        .flatMap(conn -> {{");
443                    let _ = writeln!(out, "            return Mono.from(conn.beginTransaction())");
444                    let _ = writeln!(out, "                .then(Mono.defer(() -> {{");
445                    let _ = writeln!(
446                        out,
447                        "                    var stmt = conn.createStatement(\"{}\");",
448                        sql
449                    );
450                    let _ = writeln!(out, "                    boolean first = true;");
451                    let _ = writeln!(out, "                    for (var item : items) {{");
452                    let _ = writeln!(out, "                        if (!first) stmt.add();");
453                    let _ = writeln!(out, "                        stmt.bind(0, item);");
454                    let _ = writeln!(out, "                        first = false;");
455                    let _ = writeln!(out, "                    }}");
456                    let _ = writeln!(
457                        out,
458                        "                    return Flux.from(stmt.execute()).then();"
459                    );
460                    let _ = writeln!(out, "                }}))");
461                    let _ = writeln!(
462                        out,
463                        "                .then(Mono.from(conn.commitTransaction()))"
464                    );
465                    let _ = writeln!(
466                        out,
467                        "                .onErrorResume(e -> Mono.from(conn.rollbackTransaction()).then(Mono.error(e)))"
468                    );
469                    let _ = writeln!(
470                        out,
471                        "                .doFinally(s -> Mono.from(conn.close()).subscribe());"
472                    );
473                    let _ = writeln!(out, "        }});");
474                    let _ = write!(out, "}}");
475                } else {
476                    let _ = writeln!(
477                        out,
478                        "public static Mono<Void> {}(ConnectionFactory cf, int count) {{",
479                        batch_fn_name
480                    );
481                    let _ = writeln!(out, "    return Mono.from(cf.create())");
482                    let _ = writeln!(out, "        .flatMap(conn -> {{");
483                    let _ = writeln!(out, "            return Mono.from(conn.beginTransaction())");
484                    let _ = writeln!(out, "                .then(Mono.defer(() -> {{");
485                    let _ = writeln!(
486                        out,
487                        "                    var stmt = conn.createStatement(\"{}\");",
488                        sql
489                    );
490                    let _ = writeln!(
491                        out,
492                        "                    for (int i = 1; i < count; i++) {{"
493                    );
494                    let _ = writeln!(out, "                        stmt.add();");
495                    let _ = writeln!(out, "                    }}");
496                    let _ = writeln!(
497                        out,
498                        "                    return Flux.from(stmt.execute()).then();"
499                    );
500                    let _ = writeln!(out, "                }}))");
501                    let _ = writeln!(
502                        out,
503                        "                .then(Mono.from(conn.commitTransaction()))"
504                    );
505                    let _ = writeln!(
506                        out,
507                        "                .onErrorResume(e -> Mono.from(conn.rollbackTransaction()).then(Mono.error(e)))"
508                    );
509                    let _ = writeln!(
510                        out,
511                        "                .doFinally(s -> Mono.from(conn.close()).subscribe());"
512                    );
513                    let _ = writeln!(out, "        }});");
514                    let _ = write!(out, "}}");
515                }
516            }
517            QueryCommand::Grouped => {
518                // Grouped queries are not yet supported for java-r2dbc
519                return Err(ScytheError::new(
520                    ErrorCode::InternalError,
521                    "grouped queries are not yet supported for java-r2dbc".to_string(),
522                ));
523            }
524        }
525
526        Ok(out)
527    }
528
529    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
530        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
531        let mut out = String::new();
532        let _ = writeln!(out, "public enum {} {{", type_name);
533        for (i, value) in enum_info.values.iter().enumerate() {
534            let variant = enum_variant_name(value, &self.manifest.naming);
535            let sep = if i + 1 < enum_info.values.len() {
536                ","
537            } else {
538                ";"
539            };
540            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
541        }
542        let _ = writeln!(out);
543        let _ = writeln!(out, "    private final String value;");
544        let _ = writeln!(
545            out,
546            "    {}(String value) {{ this.value = value; }}",
547            type_name
548        );
549        let _ = writeln!(out, "    public String getValue() {{ return value; }}");
550        let _ = write!(out, "}}");
551        Ok(out)
552    }
553
554    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
555        let name = to_pascal_case(&composite.sql_name);
556        let mut out = String::new();
557        if composite.fields.is_empty() {
558            let _ = writeln!(out, "public record {}() {{}}", name);
559        } else {
560            let fields = composite
561                .fields
562                .iter()
563                .map(|f| format!("Object {}", to_camel_case(&f.name)))
564                .collect::<Vec<_>>()
565                .join(", ");
566            let _ = writeln!(out, "public record {}({}) {{}}", name, fields);
567        }
568        Ok(out)
569    }
570}