Skip to main content

scythe_codegen/backends/
elixir_ecto.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/elixir-ecto.toml");
16
17pub struct ElixirEctoBackend {
18    manifest: BackendManifest,
19}
20
21impl ElixirEctoBackend {
22    pub fn new(engine: &str) -> Result<Self, ScytheError> {
23        match engine {
24            "postgresql" | "postgres" | "pg" => {}
25            _ => {
26                return Err(ScytheError::new(
27                    ErrorCode::InternalError,
28                    format!(
29                        "elixir-ecto only supports PostgreSQL, got engine '{}'",
30                        engine
31                    ),
32                ));
33            }
34        }
35        let manifest_path = Path::new("backends/elixir-ecto/manifest.toml");
36        let manifest = if manifest_path.exists() {
37            load_manifest(manifest_path)
38                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
39        } else {
40            toml::from_str(DEFAULT_MANIFEST_TOML)
41                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
42        };
43        Ok(Self { manifest })
44    }
45}
46
47impl CodegenBackend for ElixirEctoBackend {
48    fn name(&self) -> &str {
49        "elixir-ecto"
50    }
51
52    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
53        &self.manifest
54    }
55
56    fn file_header(&self) -> String {
57        "# Auto-generated by scythe. Do not edit.\n\ndefmodule Scythe.Queries do".to_string()
58    }
59
60    fn file_footer(&self) -> String {
61        "end".to_string()
62    }
63
64    fn generate_row_struct(
65        &self,
66        query_name: &str,
67        columns: &[ResolvedColumn],
68    ) -> Result<String, ScytheError> {
69        let struct_name = row_struct_name(query_name, &self.manifest.naming);
70        let mut out = String::new();
71        let _ = writeln!(out, "defmodule {} do", struct_name);
72        let _ = writeln!(out, "  @moduledoc \"Row type for {} queries.\"", query_name);
73        let _ = writeln!(out);
74
75        // Generate typespec
76        let _ = writeln!(out, "  @type t :: %__MODULE__{{");
77        for (i, c) in columns.iter().enumerate() {
78            let sep = if i + 1 < columns.len() { "," } else { "" };
79            let type_ref = if c.neutral_type.starts_with("enum::") {
80                format!("{}.t()", c.full_type)
81            } else {
82                c.full_type.clone()
83            };
84            let _ = writeln!(out, "    {}: {}{}", c.field_name, type_ref, sep);
85        }
86        let _ = writeln!(out, "  }}");
87
88        // Generate defstruct
89        let fields = columns
90            .iter()
91            .map(|c| format!(":{}", c.field_name))
92            .collect::<Vec<_>>()
93            .join(", ");
94        let _ = writeln!(out, "  defstruct [{}]", fields);
95        let _ = write!(out, "end");
96        Ok(out)
97    }
98
99    fn generate_model_struct(
100        &self,
101        table_name: &str,
102        columns: &[ResolvedColumn],
103    ) -> Result<String, ScytheError> {
104        let name = to_pascal_case(table_name);
105        self.generate_row_struct(&name, columns)
106    }
107
108    fn generate_query_fn(
109        &self,
110        analyzed: &AnalyzedQuery,
111        struct_name: &str,
112        columns: &[ResolvedColumn],
113        params: &[ResolvedParam],
114    ) -> Result<String, ScytheError> {
115        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
116        let sql = super::clean_sql_with_optional(
117            &analyzed.sql,
118            &analyzed.optional_params,
119            &analyzed.params,
120        );
121        let mut out = String::new();
122
123        // Parameter list
124        let param_list = params
125            .iter()
126            .map(|p| p.field_name.clone())
127            .collect::<Vec<_>>()
128            .join(", ");
129        let sep = if param_list.is_empty() { "" } else { ", " };
130
131        // Build the params array for Ecto.Adapters.SQL.query
132        let param_args = if params.is_empty() {
133            "[]".to_string()
134        } else {
135            format!(
136                "[{}]",
137                params
138                    .iter()
139                    .map(|p| p.field_name.clone())
140                    .collect::<Vec<_>>()
141                    .join(", ")
142            )
143        };
144
145        // Build @spec
146        let param_specs = if params.is_empty() {
147            String::new()
148        } else {
149            let specs: Vec<String> = params.iter().map(|p| p.full_type.clone()).collect();
150            format!(", {}", specs.join(", "))
151        };
152        match &analyzed.command {
153            QueryCommand::One => {
154                let _ = writeln!(
155                    out,
156                    "@spec {}(Ecto.Repo.t(){}) :: {{:ok, %{}{{}} | nil}} | {{:error, term()}}",
157                    func_name, param_specs, struct_name
158                );
159            }
160            QueryCommand::Many => {
161                let _ = writeln!(
162                    out,
163                    "@spec {}(Ecto.Repo.t(){}) :: {{:ok, [%{}{{}}]}} | {{:error, term()}}",
164                    func_name, param_specs, struct_name
165                );
166            }
167            QueryCommand::Batch => {
168                let batch_fn_name = format!("{}_batch", func_name);
169                let _ = writeln!(
170                    out,
171                    "@spec {}(Ecto.Repo.t(), list()) :: :ok | {{:error, term()}}",
172                    batch_fn_name
173                );
174                let _ = writeln!(out, "def {}(repo, items) do", batch_fn_name);
175                let _ = writeln!(out, "  Ecto.Multi.new()");
176                let _ = writeln!(out, "  |> then(fn multi ->");
177                let _ = writeln!(out, "    items");
178                let _ = writeln!(out, "    |> Enum.with_index()");
179                let _ = writeln!(out, "    |> Enum.reduce(multi, fn {{item, idx}}, acc ->");
180                if params.len() > 1 {
181                    let _ = writeln!(
182                        out,
183                        "      Ecto.Multi.run(acc, {{:batch, idx}}, fn repo, _changes -> Ecto.Adapters.SQL.query(repo, \"{}\", Tuple.to_list(item)) end)",
184                        sql
185                    );
186                } else if params.len() == 1 {
187                    let _ = writeln!(
188                        out,
189                        "      Ecto.Multi.run(acc, {{:batch, idx}}, fn repo, _changes -> Ecto.Adapters.SQL.query(repo, \"{}\", [item]) end)",
190                        sql
191                    );
192                } else {
193                    let _ = writeln!(
194                        out,
195                        "      Ecto.Multi.run(acc, {{:batch, idx}}, fn repo, _changes -> Ecto.Adapters.SQL.query(repo, \"{}\", []) end)",
196                        sql
197                    );
198                }
199                let _ = writeln!(out, "    end)");
200                let _ = writeln!(out, "  end)");
201                let _ = writeln!(out, "  |> repo.transaction()");
202                let _ = writeln!(out, "  |> case do");
203                let _ = writeln!(out, "    {{:ok, _}} -> :ok");
204                let _ = writeln!(out, "    {{:error, _, reason, _}} -> {{:error, reason}}");
205                let _ = writeln!(out, "  end");
206                let _ = write!(out, "end");
207                return Ok(out);
208            }
209            QueryCommand::Exec => {
210                let _ = writeln!(
211                    out,
212                    "@spec {}(Ecto.Repo.t(){}) :: :ok | {{:error, term()}}",
213                    func_name, param_specs
214                );
215            }
216            QueryCommand::ExecResult | QueryCommand::ExecRows => {
217                let _ = writeln!(
218                    out,
219                    "@spec {}(Ecto.Repo.t(){}) :: {{:ok, non_neg_integer()}} | {{:error, term()}}",
220                    func_name, param_specs
221                );
222            }
223            QueryCommand::Grouped => {
224                return Err(ScytheError::new(
225                    ErrorCode::InternalError,
226                    "grouped queries are not yet supported for elixir-ecto".to_string(),
227                ));
228            }
229        }
230        let _ = writeln!(out, "def {}(repo{}{}) do", func_name, sep, param_list);
231
232        match &analyzed.command {
233            QueryCommand::One => {
234                let _ = writeln!(
235                    out,
236                    "  case Ecto.Adapters.SQL.query(repo, \"{}\", {}) do",
237                    sql, param_args
238                );
239                let _ = writeln!(out, "    {{:ok, %{{rows: [row | _]}}}} ->");
240
241                // Destructure row
242                let field_vars = columns
243                    .iter()
244                    .map(|c| c.field_name.clone())
245                    .collect::<Vec<_>>()
246                    .join(", ");
247                let _ = writeln!(out, "      [{}] = row", field_vars);
248
249                // Build struct
250                let struct_fields = columns
251                    .iter()
252                    .map(|c| format!("{}: {}", c.field_name, c.field_name))
253                    .collect::<Vec<_>>()
254                    .join(", ");
255                let _ = writeln!(out, "      {{:ok, %{}{{{}}}}}", struct_name, struct_fields);
256                let _ = writeln!(out, "    {{:ok, %{{rows: []}}}} -> {{:ok, nil}}");
257                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
258                let _ = writeln!(out, "  end");
259            }
260            QueryCommand::Many => {
261                let _ = writeln!(
262                    out,
263                    "  case Ecto.Adapters.SQL.query(repo, \"{}\", {}) do",
264                    sql, param_args
265                );
266                let _ = writeln!(out, "    {{:ok, %{{rows: rows}}}} ->");
267
268                let field_vars = columns
269                    .iter()
270                    .map(|c| c.field_name.clone())
271                    .collect::<Vec<_>>()
272                    .join(", ");
273                let struct_fields = columns
274                    .iter()
275                    .map(|c| format!("{}: {}", c.field_name, c.field_name))
276                    .collect::<Vec<_>>()
277                    .join(", ");
278
279                let _ = writeln!(out, "      results = Enum.map(rows, fn row ->");
280                let _ = writeln!(out, "        [{}] = row", field_vars);
281                let _ = writeln!(out, "        %{}{{{}}}", struct_name, struct_fields);
282                let _ = writeln!(out, "      end)");
283                let _ = writeln!(out, "      {{:ok, results}}");
284                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
285                let _ = writeln!(out, "  end");
286            }
287            QueryCommand::Exec => {
288                let _ = writeln!(
289                    out,
290                    "  case Ecto.Adapters.SQL.query(repo, \"{}\", {}) do",
291                    sql, param_args
292                );
293                let _ = writeln!(out, "    {{:ok, _}} -> :ok");
294                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
295                let _ = writeln!(out, "  end");
296            }
297            QueryCommand::ExecResult | QueryCommand::ExecRows => {
298                let _ = writeln!(
299                    out,
300                    "  case Ecto.Adapters.SQL.query(repo, \"{}\", {}) do",
301                    sql, param_args
302                );
303                let _ = writeln!(out, "    {{:ok, %{{num_rows: n}}}} -> {{:ok, n}}");
304                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
305                let _ = writeln!(out, "  end");
306            }
307            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
308        }
309
310        let _ = write!(out, "end");
311        Ok(out)
312    }
313
314    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
315        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
316        let mut out = String::new();
317        let _ = writeln!(out, "defmodule {} do", type_name);
318        let _ = writeln!(
319            out,
320            "  @moduledoc \"Enum type for {}.\"",
321            enum_info.sql_name
322        );
323        let _ = writeln!(out);
324        let _ = writeln!(out, "  @type t :: String.t()");
325        let _ = writeln!(out);
326        for value in &enum_info.values {
327            let variant = enum_variant_name(value, &self.manifest.naming);
328            let _ = writeln!(out, "  @spec {}() :: String.t()", to_snake_case(&variant));
329            let _ = writeln!(
330                out,
331                "  def {}(), do: \"{}\"",
332                to_snake_case(&variant),
333                value
334            );
335        }
336        // values/0 function
337        let values_list = enum_info
338            .values
339            .iter()
340            .map(|v| format!("\"{}\"", v))
341            .collect::<Vec<_>>()
342            .join(", ");
343        let _ = writeln!(out, "  @spec values() :: [String.t()]");
344        let _ = writeln!(out, "  def values, do: [{}]", values_list);
345        let _ = write!(out, "end");
346        Ok(out)
347    }
348
349    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
350        let name = to_pascal_case(&composite.sql_name);
351        let mut out = String::new();
352        let _ = writeln!(out, "defmodule {} do", name);
353        let _ = writeln!(
354            out,
355            "  @moduledoc \"Composite type for {}.\"",
356            composite.sql_name
357        );
358        let _ = writeln!(out);
359        // Generate @type definition
360        if composite.fields.is_empty() {
361            let _ = writeln!(out, "  @type t :: %__MODULE__{{}}");
362        } else {
363            let _ = writeln!(out, "  @type t :: %__MODULE__{{");
364            for (i, f) in composite.fields.iter().enumerate() {
365                let sep = if i + 1 < composite.fields.len() {
366                    ","
367                } else {
368                    ""
369                };
370                let _ = writeln!(out, "    {}: term(){}", to_snake_case(&f.name), sep);
371            }
372            let _ = writeln!(out, "  }}");
373        }
374        let _ = writeln!(out);
375        if composite.fields.is_empty() {
376            let _ = writeln!(out, "  defstruct []");
377        } else {
378            let fields = composite
379                .fields
380                .iter()
381                .map(|f| format!(":{}", to_snake_case(&f.name)))
382                .collect::<Vec<_>>()
383                .join(", ");
384            let _ = writeln!(out, "  defstruct [{}]", fields);
385        }
386        let _ = write!(out, "end");
387        Ok(out)
388    }
389}