1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/java-jdbc.toml");
16
17pub struct JavaJdbcBackend {
18 manifest: BackendManifest,
19}
20
21impl JavaJdbcBackend {
22 pub fn new() -> Result<Self, ScytheError> {
23 let manifest_path = Path::new("backends/java-jdbc/manifest.toml");
24 let manifest = if manifest_path.exists() {
25 load_manifest(manifest_path)
26 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
27 } else {
28 toml::from_str(DEFAULT_MANIFEST_TOML)
29 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30 };
31 Ok(Self { manifest })
32 }
33
34 pub fn manifest(&self) -> &BackendManifest {
35 &self.manifest
36 }
37}
38
39fn pg_to_jdbc_params(sql: &str) -> String {
41 let mut result = String::with_capacity(sql.len());
42 let mut chars = sql.chars().peekable();
43 while let Some(ch) = chars.next() {
44 if ch == '$' {
45 if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
47 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
49 chars.next();
50 }
51 result.push('?');
52 } else {
53 result.push(ch);
54 }
55 } else {
56 result.push(ch);
57 }
58 }
59 result
60}
61
62fn box_primitive(java_type: &str) -> &str {
64 match java_type {
65 "boolean" => "Boolean",
66 "byte" => "Byte",
67 "short" => "Short",
68 "int" => "Integer",
69 "long" => "Long",
70 "float" => "Float",
71 "double" => "Double",
72 "char" => "Character",
73 _ => java_type,
74 }
75}
76
77fn rs_getter(java_type: &str) -> &str {
79 match java_type {
80 "boolean" | "Boolean" => "getBoolean",
81 "byte" | "Byte" => "getByte",
82 "short" | "Short" => "getShort",
83 "int" | "Integer" => "getInt",
84 "long" | "Long" => "getLong",
85 "float" | "Float" => "getFloat",
86 "double" | "Double" => "getDouble",
87 "String" => "getString",
88 "byte[]" => "getBytes",
89 _ if java_type.contains("BigDecimal") => "getBigDecimal",
90 _ if java_type.contains("LocalDate") => "getObject",
91 _ if java_type.contains("LocalTime") => "getObject",
92 _ if java_type.contains("OffsetTime") => "getObject",
93 _ if java_type.contains("LocalDateTime") => "getObject",
94 _ if java_type.contains("OffsetDateTime") => "getObject",
95 _ if java_type.contains("UUID") => "getObject",
96 _ => "getObject",
97 }
98}
99
100fn ps_setter(java_type: &str) -> &str {
102 match java_type {
103 "boolean" | "Boolean" => "setBoolean",
104 "byte" | "Byte" => "setByte",
105 "short" | "Short" => "setShort",
106 "int" | "Integer" => "setInt",
107 "long" | "Long" => "setLong",
108 "float" | "Float" => "setFloat",
109 "double" | "Double" => "setDouble",
110 "String" => "setString",
111 "byte[]" => "setBytes",
112 _ if java_type.contains("BigDecimal") => "setBigDecimal",
113 _ => "setObject",
114 }
115}
116
117fn java_field_type(col: &ResolvedColumn) -> String {
119 if col.nullable {
120 box_primitive(&col.lang_type).to_string()
121 } else {
122 col.full_type.clone()
123 }
124}
125
126fn java_param_type(param: &ResolvedParam) -> String {
128 if param.nullable {
129 box_primitive(¶m.lang_type).to_string()
130 } else {
131 param.full_type.clone()
132 }
133}
134
135impl CodegenBackend for JavaJdbcBackend {
136 fn name(&self) -> &str {
137 "java-jdbc"
138 }
139
140 fn generate_row_struct(
141 &self,
142 query_name: &str,
143 columns: &[ResolvedColumn],
144 ) -> Result<String, ScytheError> {
145 let struct_name = row_struct_name(query_name, &self.manifest.naming);
146 let mut out = String::new();
147
148 let fields = columns
150 .iter()
151 .map(|c| {
152 let field_type = java_field_type(c);
153 if c.nullable {
154 format!(" @Nullable {} {}", field_type, c.field_name)
155 } else {
156 format!(" {} {}", field_type, c.field_name)
157 }
158 })
159 .collect::<Vec<_>>()
160 .join(",\n");
161
162 let _ = writeln!(out, "public record {}(", struct_name);
163 let _ = writeln!(out, "{}", fields);
164 let _ = writeln!(out, ") {{");
165
166 let _ = writeln!(
168 out,
169 " public static {} fromResultSet(ResultSet rs) throws SQLException {{",
170 struct_name
171 );
172 let _ = writeln!(out, " return new {}(", struct_name);
173 for (i, col) in columns.iter().enumerate() {
174 let getter = rs_getter(&col.lang_type);
175 let sep = if i + 1 < columns.len() { "," } else { "" };
176 let _ = writeln!(out, " rs.{}(\"{}\"){}", getter, col.name, sep);
177 }
178 let _ = writeln!(out, " );");
179 let _ = writeln!(out, " }}");
180 let _ = write!(out, "}}");
181 Ok(out)
182 }
183
184 fn generate_model_struct(
185 &self,
186 table_name: &str,
187 columns: &[ResolvedColumn],
188 ) -> Result<String, ScytheError> {
189 let name = to_pascal_case(table_name);
190 self.generate_row_struct(&name, columns)
191 }
192
193 fn generate_query_fn(
194 &self,
195 analyzed: &AnalyzedQuery,
196 struct_name: &str,
197 _columns: &[ResolvedColumn],
198 params: &[ResolvedParam],
199 ) -> Result<String, ScytheError> {
200 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
201 let sql = pg_to_jdbc_params(&super::clean_sql_oneline(&analyzed.sql));
202
203 let param_list = params
204 .iter()
205 .map(|p| {
206 let param_type = java_param_type(p);
207 format!("{} {}", param_type, p.field_name)
208 })
209 .collect::<Vec<_>>()
210 .join(", ");
211 let sep = if param_list.is_empty() { "" } else { ", " };
212
213 let mut out = String::new();
214
215 match &analyzed.command {
216 QueryCommand::Exec => {
217 let _ = writeln!(
218 out,
219 "public static void {}(Connection conn{}{}) throws SQLException {{",
220 func_name, sep, param_list
221 );
222 let _ = writeln!(
223 out,
224 " try (var ps = conn.prepareStatement(\"{}\")) {{",
225 sql
226 );
227 for (i, param) in params.iter().enumerate() {
228 let setter = ps_setter(¶m.lang_type);
229 let _ = writeln!(
230 out,
231 " ps.{}({}, {});",
232 setter,
233 i + 1,
234 param.field_name
235 );
236 }
237 let _ = writeln!(out, " ps.executeUpdate();");
238 let _ = writeln!(out, " }}");
239 let _ = write!(out, "}}");
240 }
241 QueryCommand::ExecResult | QueryCommand::ExecRows => {
242 let _ = writeln!(
243 out,
244 "public static int {}(Connection conn{}{}) throws SQLException {{",
245 func_name, sep, param_list
246 );
247 let _ = writeln!(
248 out,
249 " try (var ps = conn.prepareStatement(\"{}\")) {{",
250 sql
251 );
252 for (i, param) in params.iter().enumerate() {
253 let setter = ps_setter(¶m.lang_type);
254 let _ = writeln!(
255 out,
256 " ps.{}({}, {});",
257 setter,
258 i + 1,
259 param.field_name
260 );
261 }
262 let _ = writeln!(out, " return ps.executeUpdate();");
263 let _ = writeln!(out, " }}");
264 let _ = write!(out, "}}");
265 }
266 QueryCommand::One => {
267 let _ = writeln!(
268 out,
269 "public static {} {}(Connection conn{}{}) throws SQLException {{",
270 struct_name, func_name, sep, param_list
271 );
272 let _ = writeln!(
273 out,
274 " try (var ps = conn.prepareStatement(\"{}\")) {{",
275 sql
276 );
277 for (i, param) in params.iter().enumerate() {
278 let setter = ps_setter(¶m.lang_type);
279 let _ = writeln!(
280 out,
281 " ps.{}({}, {});",
282 setter,
283 i + 1,
284 param.field_name
285 );
286 }
287 let _ = writeln!(out, " try (ResultSet rs = ps.executeQuery()) {{");
288 let _ = writeln!(out, " if (rs.next()) {{");
289 let _ = writeln!(
290 out,
291 " return {}.fromResultSet(rs);",
292 struct_name
293 );
294 let _ = writeln!(out, " }}");
295 let _ = writeln!(out, " return null;");
296 let _ = writeln!(out, " }}");
297 let _ = writeln!(out, " }}");
298 let _ = write!(out, "}}");
299 }
300 QueryCommand::Many | QueryCommand::Batch => {
301 let _ = writeln!(
302 out,
303 "public static java.util.List<{}> {}(Connection conn{}{}) throws SQLException {{",
304 struct_name, func_name, sep, param_list
305 );
306 let _ = writeln!(
307 out,
308 " try (var ps = conn.prepareStatement(\"{}\")) {{",
309 sql
310 );
311 for (i, param) in params.iter().enumerate() {
312 let setter = ps_setter(¶m.lang_type);
313 let _ = writeln!(
314 out,
315 " ps.{}({}, {});",
316 setter,
317 i + 1,
318 param.field_name
319 );
320 }
321 let _ = writeln!(out, " try (ResultSet rs = ps.executeQuery()) {{");
322 let _ = writeln!(
323 out,
324 " java.util.List<{}> result = new java.util.ArrayList<>();",
325 struct_name
326 );
327 let _ = writeln!(out, " while (rs.next()) {{");
328 let _ = writeln!(
329 out,
330 " result.add({}.fromResultSet(rs));",
331 struct_name
332 );
333 let _ = writeln!(out, " }}");
334 let _ = writeln!(out, " return result;");
335 let _ = writeln!(out, " }}");
336 let _ = writeln!(out, " }}");
337 let _ = write!(out, "}}");
338 }
339 }
340
341 Ok(out)
342 }
343
344 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
345 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
346 let mut out = String::new();
347 let _ = writeln!(out, "public enum {} {{", type_name);
348 for (i, value) in enum_info.values.iter().enumerate() {
349 let variant = enum_variant_name(value, &self.manifest.naming);
350 let sep = if i + 1 < enum_info.values.len() {
351 ","
352 } else {
353 ";"
354 };
355 let _ = writeln!(out, " {}(\"{}\"){}", variant, value, sep);
356 }
357 let _ = writeln!(out);
358 let _ = writeln!(out, " private final String value;");
359 let _ = writeln!(
360 out,
361 " {}(String value) {{ this.value = value; }}",
362 type_name
363 );
364 let _ = writeln!(out, " public String getValue() {{ return value; }}");
365 let _ = write!(out, "}}");
366 Ok(out)
367 }
368
369 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
370 let name = to_pascal_case(&composite.sql_name);
371 let mut out = String::new();
372 if composite.fields.is_empty() {
373 let _ = writeln!(out, "public record {}() {{}}", name);
374 } else {
375 let fields = composite
376 .fields
377 .iter()
378 .map(|f| format!("Object {}", to_camel_case(&f.name)))
379 .collect::<Vec<_>>()
380 .join(", ");
381 let _ = writeln!(out, "public record {}({}) {{}}", name, fields);
382 }
383 Ok(out)
384 }
385}