1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/kotlin-jdbc.toml");
17
18pub struct KotlinJdbcBackend {
19 manifest: BackendManifest,
20}
21
22impl KotlinJdbcBackend {
23 pub fn new() -> Result<Self, ScytheError> {
24 let manifest_path = Path::new("backends/kotlin-jdbc/manifest.toml");
25 let manifest = if manifest_path.exists() {
26 load_manifest(manifest_path)
27 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
28 } else {
29 toml::from_str(DEFAULT_MANIFEST_TOML)
30 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
31 };
32 Ok(Self { manifest })
33 }
34
35 pub fn manifest(&self) -> &BackendManifest {
36 &self.manifest
37 }
38}
39
40fn pg_to_jdbc_params(sql: &str) -> String {
42 let mut result = String::with_capacity(sql.len());
43 let mut chars = sql.chars().peekable();
44 while let Some(ch) = chars.next() {
45 if ch == '$' {
46 if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
47 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
48 chars.next();
49 }
50 result.push('?');
51 } else {
52 result.push(ch);
53 }
54 } else {
55 result.push(ch);
56 }
57 }
58 result
59}
60
61fn rs_getter(kotlin_type: &str) -> &str {
63 match kotlin_type {
64 "Boolean" => "getBoolean",
65 "Byte" => "getByte",
66 "Short" => "getShort",
67 "Int" => "getInt",
68 "Long" => "getLong",
69 "Float" => "getFloat",
70 "Double" => "getDouble",
71 "String" => "getString",
72 "ByteArray" => "getBytes",
73 _ if kotlin_type.contains("BigDecimal") => "getBigDecimal",
74 _ if kotlin_type.contains("LocalDate") => "getObject",
75 _ if kotlin_type.contains("LocalTime") => "getObject",
76 _ if kotlin_type.contains("OffsetTime") => "getObject",
77 _ if kotlin_type.contains("LocalDateTime") => "getObject",
78 _ if kotlin_type.contains("OffsetDateTime") => "getObject",
79 _ if kotlin_type.contains("UUID") => "getObject",
80 _ => "getObject",
81 }
82}
83
84fn ps_setter(kotlin_type: &str) -> &str {
86 match kotlin_type {
87 "Boolean" => "setBoolean",
88 "Byte" => "setByte",
89 "Short" => "setShort",
90 "Int" => "setInt",
91 "Long" => "setLong",
92 "Float" => "setFloat",
93 "Double" => "setDouble",
94 "String" => "setString",
95 "ByteArray" => "setBytes",
96 _ if kotlin_type.contains("BigDecimal") => "setBigDecimal",
97 _ => "setObject",
98 }
99}
100
101impl CodegenBackend for KotlinJdbcBackend {
102 fn name(&self) -> &str {
103 "kotlin-jdbc"
104 }
105
106 fn file_header(&self) -> String {
107 "import java.sql.Connection\n".to_string()
108 }
109
110 fn generate_row_struct(
111 &self,
112 query_name: &str,
113 columns: &[ResolvedColumn],
114 ) -> Result<String, ScytheError> {
115 let struct_name = row_struct_name(query_name, &self.manifest.naming);
116 let mut out = String::new();
117 let _ = writeln!(out, "data class {}(", struct_name);
118 for col in columns.iter() {
119 let _ = writeln!(out, " val {}: {},", col.field_name, col.full_type);
120 }
121 let _ = writeln!(out, ")");
122 Ok(out)
123 }
124
125 fn generate_model_struct(
126 &self,
127 table_name: &str,
128 columns: &[ResolvedColumn],
129 ) -> Result<String, ScytheError> {
130 let name = to_pascal_case(table_name);
131 self.generate_row_struct(&name, columns)
132 }
133
134 fn generate_query_fn(
135 &self,
136 analyzed: &AnalyzedQuery,
137 struct_name: &str,
138 columns: &[ResolvedColumn],
139 params: &[ResolvedParam],
140 ) -> Result<String, ScytheError> {
141 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
142 let sql = pg_to_jdbc_params(&super::clean_sql_oneline(&analyzed.sql));
143
144 let use_multiline_params = !params.is_empty();
146
147 let mut out = String::new();
148
149 let write_setters = |out: &mut String, params: &[ResolvedParam]| {
151 for (i, param) in params.iter().enumerate() {
152 let setter = ps_setter(¶m.lang_type);
153 let _ = writeln!(
154 out,
155 " ps.{}({}, {})",
156 setter,
157 i + 1,
158 param.field_name
159 );
160 }
161 };
162
163 let write_fn_sig =
165 |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
166 if multiline {
167 let _ = writeln!(out, "fun {}(", name);
168 let _ = writeln!(out, " conn: Connection,");
169 for p in params {
170 let _ = writeln!(out, " {}: {},", p.field_name, p.full_type);
171 }
172 let _ = writeln!(out, "){} {{", ret);
173 } else {
174 let _ = writeln!(out, "fun {}(conn: Connection){} {{", name, ret);
175 }
176 };
177
178 match &analyzed.command {
179 QueryCommand::Exec => {
180 write_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
181 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
182 write_setters(&mut out, params);
183 let _ = writeln!(out, " ps.executeUpdate()");
184 let _ = writeln!(out, " }}");
185 let _ = writeln!(out, "}}");
186 }
187 QueryCommand::ExecResult | QueryCommand::ExecRows => {
188 write_fn_sig(&mut out, &func_name, ": Int", use_multiline_params, params);
189 let _ = writeln!(
190 out,
191 " return conn.prepareStatement(\"{}\").use {{ ps ->",
192 sql
193 );
194 write_setters(&mut out, params);
195 let _ = writeln!(out, " ps.executeUpdate()");
196 let _ = writeln!(out, " }}");
197 let _ = writeln!(out, "}}");
198 }
199 QueryCommand::One => {
200 let ret = format!(": {}?", struct_name);
201 write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
202 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
203 write_setters(&mut out, params);
204 let _ = writeln!(out, " ps.executeQuery().use {{ rs ->");
205 let _ = writeln!(out, " return if (rs.next()) {{");
206 let _ = writeln!(out, " {}(", struct_name);
207 for col in columns.iter() {
208 let getter = rs_getter(&col.lang_type);
209 let _ = writeln!(
210 out,
211 " {} = rs.{}(\"{}\"),",
212 col.field_name, getter, col.name
213 );
214 }
215 let _ = writeln!(out, " )");
216 let _ = writeln!(out, " }} else {{");
217 let _ = writeln!(out, " null");
218 let _ = writeln!(out, " }}");
219 let _ = writeln!(out, " }}");
220 let _ = writeln!(out, " }}");
221 let _ = writeln!(out, "}}");
222 }
223 QueryCommand::Many | QueryCommand::Batch => {
224 let ret = format!(": List<{}>", struct_name);
225 write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
226 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
227 write_setters(&mut out, params);
228 let _ = writeln!(out, " ps.executeQuery().use {{ rs ->");
229 let _ = writeln!(
230 out,
231 " val result = mutableListOf<{}>()",
232 struct_name
233 );
234 let _ = writeln!(out, " while (rs.next()) {{");
235 let _ = writeln!(out, " result.add(");
236 let _ = writeln!(out, " {}(", struct_name);
237 for col in columns.iter() {
238 let getter = rs_getter(&col.lang_type);
239 let _ = writeln!(
240 out,
241 " {} = rs.{}(\"{}\"),",
242 col.field_name, getter, col.name
243 );
244 }
245 let _ = writeln!(out, " ),");
246 let _ = writeln!(out, " )");
247 let _ = writeln!(out, " }}");
248 let _ = writeln!(out, " return result");
249 let _ = writeln!(out, " }}");
250 let _ = writeln!(out, " }}");
251 let _ = writeln!(out, "}}");
252 }
253 }
254
255 Ok(out)
256 }
257
258 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
259 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
260 let mut out = String::new();
261 let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
262 for (i, value) in enum_info.values.iter().enumerate() {
263 let variant = enum_variant_name(value, &self.manifest.naming);
264 let sep = if i + 1 < enum_info.values.len() {
265 ","
266 } else {
267 ";"
268 };
269 let _ = writeln!(out, " {}(\"{}\"){}", variant, value, sep);
270 }
271 let _ = writeln!(out, "}}");
272 Ok(out)
273 }
274
275 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
276 let name = to_pascal_case(&composite.sql_name);
277 let mut out = String::new();
278 let _ = writeln!(out, "data class {}(", name);
279 for field in composite.fields.iter() {
280 let field_name = to_camel_case(&field.name);
281 let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
282 .map(|t| t.into_owned())
283 .unwrap_or_else(|_| "Any".to_string());
284 let _ = writeln!(out, " val {}: {},", field_name, field_type);
285 }
286 let _ = writeln!(out, ")");
287 Ok(out)
288 }
289}