1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/elixir-myxql.toml");
16
17pub struct ElixirMyxqlBackend {
18 manifest: BackendManifest,
19}
20
21impl ElixirMyxqlBackend {
22 pub fn new(engine: &str) -> Result<Self, ScytheError> {
23 match engine {
24 "mysql" | "mariadb" => {}
25 _ => {
26 return Err(ScytheError::new(
27 ErrorCode::InternalError,
28 format!("elixir-myxql only supports MySQL, got engine '{}'", engine),
29 ));
30 }
31 }
32 let manifest_path = Path::new("backends/elixir-myxql/manifest.toml");
33 let manifest = if manifest_path.exists() {
34 load_manifest(manifest_path)
35 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
36 } else {
37 toml::from_str(DEFAULT_MANIFEST_TOML)
38 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
39 };
40 Ok(Self { manifest })
41 }
42}
43
44impl CodegenBackend for ElixirMyxqlBackend {
45 fn name(&self) -> &str {
46 "elixir-myxql"
47 }
48
49 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
50 &self.manifest
51 }
52
53 fn supported_engines(&self) -> &[&str] {
54 &["mysql"]
55 }
56
57 fn generate_row_struct(
58 &self,
59 query_name: &str,
60 columns: &[ResolvedColumn],
61 ) -> Result<String, ScytheError> {
62 let struct_name = row_struct_name(query_name, &self.manifest.naming);
63 let mut out = String::new();
64 let _ = writeln!(out, "defmodule {} do", struct_name);
65 let _ = writeln!(out, " @moduledoc \"Row type for {} queries.\"", query_name);
66 let _ = writeln!(out);
67
68 let _ = writeln!(out, " @type t :: %__MODULE__{{");
70 for (i, c) in columns.iter().enumerate() {
71 let sep = if i + 1 < columns.len() { "," } else { "" };
72 let type_ref = if c.neutral_type.starts_with("enum::") {
73 format!("{}.t()", c.full_type)
74 } else {
75 c.full_type.clone()
76 };
77 let _ = writeln!(out, " {}: {}{}", c.field_name, type_ref, sep);
78 }
79 let _ = writeln!(out, " }}");
80
81 let fields = columns
83 .iter()
84 .map(|c| format!(":{}", c.field_name))
85 .collect::<Vec<_>>()
86 .join(", ");
87 let _ = writeln!(out, " defstruct [{}]", fields);
88 let _ = write!(out, "end");
89 Ok(out)
90 }
91
92 fn generate_model_struct(
93 &self,
94 table_name: &str,
95 columns: &[ResolvedColumn],
96 ) -> Result<String, ScytheError> {
97 let name = to_pascal_case(table_name);
98 self.generate_row_struct(&name, columns)
99 }
100
101 fn generate_query_fn(
102 &self,
103 analyzed: &AnalyzedQuery,
104 struct_name: &str,
105 columns: &[ResolvedColumn],
106 params: &[ResolvedParam],
107 ) -> Result<String, ScytheError> {
108 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
109 let sql = super::clean_sql_with_optional(
110 &analyzed.sql,
111 &analyzed.optional_params,
112 &analyzed.params,
113 );
114 let mut out = String::new();
115
116 let param_list = params
118 .iter()
119 .map(|p| p.field_name.clone())
120 .collect::<Vec<_>>()
121 .join(", ");
122 let sep = if param_list.is_empty() { "" } else { ", " };
123
124 let param_args = if params.is_empty() {
126 "[]".to_string()
127 } else {
128 format!(
129 "[{}]",
130 params
131 .iter()
132 .map(|p| p.field_name.clone())
133 .collect::<Vec<_>>()
134 .join(", ")
135 )
136 };
137
138 let param_specs = if params.is_empty() {
140 String::new()
141 } else {
142 let specs: Vec<String> = params.iter().map(|p| p.full_type.clone()).collect();
143 format!(", {}", specs.join(", "))
144 };
145 match &analyzed.command {
146 QueryCommand::One => {
147 let _ = writeln!(
148 out,
149 "@spec {}(MyXQL.conn(){}) :: {{:ok, %{}{{}}}} | {{:error, term()}}",
150 func_name, param_specs, struct_name
151 );
152 }
153 QueryCommand::Many => {
154 let _ = writeln!(
155 out,
156 "@spec {}(MyXQL.conn(){}) :: {{:ok, [%{}{{}}]}} | {{:error, term()}}",
157 func_name, param_specs, struct_name
158 );
159 }
160 QueryCommand::Batch => {
161 let batch_fn_name = format!("{}_batch", func_name);
162 let _ = writeln!(
163 out,
164 "@spec {}(MyXQL.conn(), list()) :: :ok | {{:error, term()}}",
165 batch_fn_name
166 );
167 let _ = writeln!(out, "def {}(conn, items) do", batch_fn_name);
168 let _ = writeln!(out, " Enum.reduce_while(items, :ok, fn item, :ok ->");
169 if params.len() > 1 {
170 let _ = writeln!(
171 out,
172 " case MyXQL.query(conn, \"{}\", Tuple.to_list(item)) do",
173 sql
174 );
175 } else if params.len() == 1 {
176 let _ = writeln!(out, " case MyXQL.query(conn, \"{}\", [item]) do", sql);
177 } else {
178 let _ = writeln!(out, " case MyXQL.query(conn, \"{}\", []) do", sql);
179 }
180 let _ = writeln!(out, " {{:ok, _}} -> {{:cont, :ok}}");
181 let _ = writeln!(out, " {{:error, err}} -> {{:halt, {{:error, err}}}}");
182 let _ = writeln!(out, " end");
183 let _ = writeln!(out, " end)");
184 let _ = write!(out, "end");
185 return Ok(out);
186 }
187 QueryCommand::Exec => {
188 let _ = writeln!(
189 out,
190 "@spec {}(MyXQL.conn(){}) :: :ok | {{:error, term()}}",
191 func_name, param_specs
192 );
193 }
194 QueryCommand::ExecResult | QueryCommand::ExecRows => {
195 let _ = writeln!(
196 out,
197 "@spec {}(MyXQL.conn(){}) :: {{:ok, non_neg_integer()}} | {{:error, term()}}",
198 func_name, param_specs
199 );
200 }
201 QueryCommand::Grouped => {
202 unreachable!("Grouped is rewritten to Many before codegen")
203 }
204 }
205 let _ = writeln!(out, "def {}(conn{}{}) do", func_name, sep, param_list);
206
207 match &analyzed.command {
208 QueryCommand::One => {
209 let _ = writeln!(
210 out,
211 " case MyXQL.query(conn, \"{}\", {}) do",
212 sql, param_args
213 );
214 let _ = writeln!(out, " {{:ok, %MyXQL.Result{{rows: [row]}}}} ->");
215
216 let field_vars = columns
217 .iter()
218 .map(|c| c.field_name.clone())
219 .collect::<Vec<_>>()
220 .join(", ");
221 let _ = writeln!(out, " [{}] = row", field_vars);
222
223 let struct_fields = columns
224 .iter()
225 .map(|c| format!("{}: {}", c.field_name, c.field_name))
226 .collect::<Vec<_>>()
227 .join(", ");
228 let _ = writeln!(out, " {{:ok, %{}{{{}}}}}", struct_name, struct_fields);
229 let _ = writeln!(
230 out,
231 " {{:ok, %MyXQL.Result{{rows: []}}}} -> {{:error, :not_found}}"
232 );
233 let _ = writeln!(out, " {{:error, err}} -> {{:error, err}}");
234 let _ = writeln!(out, " end");
235 }
236 QueryCommand::Many => {
237 let _ = writeln!(
238 out,
239 " case MyXQL.query(conn, \"{}\", {}) do",
240 sql, param_args
241 );
242 let _ = writeln!(out, " {{:ok, %MyXQL.Result{{rows: rows}}}} ->");
243
244 let field_vars = columns
245 .iter()
246 .map(|c| c.field_name.clone())
247 .collect::<Vec<_>>()
248 .join(", ");
249 let struct_fields = columns
250 .iter()
251 .map(|c| format!("{}: {}", c.field_name, c.field_name))
252 .collect::<Vec<_>>()
253 .join(", ");
254
255 let _ = writeln!(out, " results = Enum.map(rows, fn row ->");
256 let _ = writeln!(out, " [{}] = row", field_vars);
257 let _ = writeln!(out, " %{}{{{}}}", struct_name, struct_fields);
258 let _ = writeln!(out, " end)");
259 let _ = writeln!(out, " {{:ok, results}}");
260 let _ = writeln!(out, " {{:error, err}} -> {{:error, err}}");
261 let _ = writeln!(out, " end");
262 }
263 QueryCommand::Exec => {
264 let _ = writeln!(
265 out,
266 " case MyXQL.query(conn, \"{}\", {}) do",
267 sql, param_args
268 );
269 let _ = writeln!(out, " {{:ok, _}} -> :ok");
270 let _ = writeln!(out, " {{:error, err}} -> {{:error, err}}");
271 let _ = writeln!(out, " end");
272 }
273 QueryCommand::ExecResult | QueryCommand::ExecRows => {
274 let _ = writeln!(
275 out,
276 " case MyXQL.query(conn, \"{}\", {}) do",
277 sql, param_args
278 );
279 let _ = writeln!(
280 out,
281 " {{:ok, %MyXQL.Result{{num_rows: n}}}} -> {{:ok, n}}"
282 );
283 let _ = writeln!(out, " {{:error, err}} -> {{:error, err}}");
284 let _ = writeln!(out, " end");
285 }
286 QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
287 }
288
289 let _ = write!(out, "end");
290 Ok(out)
291 }
292
293 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
294 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
295 let mut out = String::new();
296 let _ = writeln!(out, "defmodule {} do", type_name);
297 let _ = writeln!(
298 out,
299 " @moduledoc \"Enum type for {}.\"",
300 enum_info.sql_name
301 );
302 let _ = writeln!(out);
303 let _ = writeln!(out, " @type t :: String.t()");
304 let _ = writeln!(out);
305 for value in &enum_info.values {
306 let variant = enum_variant_name(value, &self.manifest.naming);
307 let _ = writeln!(out, " @spec {}() :: String.t()", to_snake_case(&variant));
308 let _ = writeln!(
309 out,
310 " def {}(), do: \"{}\"",
311 to_snake_case(&variant),
312 value
313 );
314 }
315 let values_list = enum_info
317 .values
318 .iter()
319 .map(|v| format!("\"{}\"", v))
320 .collect::<Vec<_>>()
321 .join(", ");
322 let _ = writeln!(out, " @spec values() :: [String.t()]");
323 let _ = writeln!(out, " def values, do: [{}]", values_list);
324 let _ = write!(out, "end");
325 Ok(out)
326 }
327
328 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
329 let name = to_pascal_case(&composite.sql_name);
330 let mut out = String::new();
331 let _ = writeln!(out, "defmodule {} do", name);
332 let _ = writeln!(
333 out,
334 " @moduledoc \"Composite type for {}.\"",
335 composite.sql_name
336 );
337 let _ = writeln!(out);
338 if composite.fields.is_empty() {
340 let _ = writeln!(out, " @type t :: %__MODULE__{{}}");
341 } else {
342 let _ = writeln!(out, " @type t :: %__MODULE__{{");
343 for (i, f) in composite.fields.iter().enumerate() {
344 let sep = if i + 1 < composite.fields.len() {
345 ","
346 } else {
347 ""
348 };
349 let _ = writeln!(out, " {}: term(){}", to_snake_case(&f.name), sep);
350 }
351 let _ = writeln!(out, " }}");
352 }
353 let _ = writeln!(out);
354 if composite.fields.is_empty() {
355 let _ = writeln!(out, " defstruct []");
356 } else {
357 let fields = composite
358 .fields
359 .iter()
360 .map(|f| format!(":{}", to_snake_case(&f.name)))
361 .collect::<Vec<_>>()
362 .join(", ");
363 let _ = writeln!(out, " defstruct [{}]", fields);
364 }
365 let _ = write!(out, "end");
366 Ok(out)
367 }
368}