Skip to main content

scythe_codegen/backends/
elixir_ecto.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/elixir-ecto.toml");
16
17pub struct ElixirEctoBackend {
18    manifest: BackendManifest,
19}
20
21impl ElixirEctoBackend {
22    pub fn new(engine: &str) -> Result<Self, ScytheError> {
23        match engine {
24            "postgresql" | "postgres" | "pg" => {}
25            _ => {
26                return Err(ScytheError::new(
27                    ErrorCode::InternalError,
28                    format!(
29                        "elixir-ecto only supports PostgreSQL, got engine '{}'",
30                        engine
31                    ),
32                ));
33            }
34        }
35        let manifest_path = Path::new("backends/elixir-ecto/manifest.toml");
36        let manifest = if manifest_path.exists() {
37            load_manifest(manifest_path)
38                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
39        } else {
40            toml::from_str(DEFAULT_MANIFEST_TOML)
41                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
42        };
43        Ok(Self { manifest })
44    }
45}
46
47impl CodegenBackend for ElixirEctoBackend {
48    fn name(&self) -> &str {
49        "elixir-ecto"
50    }
51
52    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
53        &self.manifest
54    }
55
56    fn file_header(&self) -> String {
57        "# Auto-generated by scythe. Do not edit.\n\ndefmodule Scythe.Queries do".to_string()
58    }
59
60    fn file_footer(&self) -> String {
61        "end".to_string()
62    }
63
64    fn generate_row_struct(
65        &self,
66        query_name: &str,
67        columns: &[ResolvedColumn],
68    ) -> Result<String, ScytheError> {
69        let struct_name = row_struct_name(query_name, &self.manifest.naming);
70        let mut out = String::new();
71        let _ = writeln!(out, "defmodule {} do", struct_name);
72        let _ = writeln!(out, "  @moduledoc \"Row type for {} queries.\"", query_name);
73        let _ = writeln!(out);
74
75        // Generate typespec
76        let _ = writeln!(out, "  @type t :: %__MODULE__{{");
77        for (i, c) in columns.iter().enumerate() {
78            let sep = if i + 1 < columns.len() { "," } else { "" };
79            let _ = writeln!(out, "    {}: {}{}", c.field_name, c.full_type, sep);
80        }
81        let _ = writeln!(out, "  }}");
82
83        // Generate defstruct
84        let fields = columns
85            .iter()
86            .map(|c| format!(":{}", c.field_name))
87            .collect::<Vec<_>>()
88            .join(", ");
89        let _ = writeln!(out, "  defstruct [{}]", fields);
90        let _ = write!(out, "end");
91        Ok(out)
92    }
93
94    fn generate_model_struct(
95        &self,
96        table_name: &str,
97        columns: &[ResolvedColumn],
98    ) -> Result<String, ScytheError> {
99        let name = to_pascal_case(table_name);
100        self.generate_row_struct(&name, columns)
101    }
102
103    fn generate_query_fn(
104        &self,
105        analyzed: &AnalyzedQuery,
106        struct_name: &str,
107        columns: &[ResolvedColumn],
108        params: &[ResolvedParam],
109    ) -> Result<String, ScytheError> {
110        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
111        let sql = super::clean_sql_with_optional(
112            &analyzed.sql,
113            &analyzed.optional_params,
114            &analyzed.params,
115        );
116        let mut out = String::new();
117
118        // Parameter list
119        let param_list = params
120            .iter()
121            .map(|p| p.field_name.clone())
122            .collect::<Vec<_>>()
123            .join(", ");
124        let sep = if param_list.is_empty() { "" } else { ", " };
125
126        // Build the params array for Ecto.Adapters.SQL.query
127        let param_args = if params.is_empty() {
128            "[]".to_string()
129        } else {
130            format!(
131                "[{}]",
132                params
133                    .iter()
134                    .map(|p| p.field_name.clone())
135                    .collect::<Vec<_>>()
136                    .join(", ")
137            )
138        };
139
140        // Build @spec
141        let param_specs = if params.is_empty() {
142            String::new()
143        } else {
144            let specs: Vec<String> = params.iter().map(|p| p.full_type.clone()).collect();
145            format!(", {}", specs.join(", "))
146        };
147        match &analyzed.command {
148            QueryCommand::One => {
149                let _ = writeln!(
150                    out,
151                    "@spec {}(Ecto.Repo.t(){}) :: {{:ok, %{}{{}} | nil}} | {{:error, term()}}",
152                    func_name, param_specs, struct_name
153                );
154            }
155            QueryCommand::Many => {
156                let _ = writeln!(
157                    out,
158                    "@spec {}(Ecto.Repo.t(){}) :: {{:ok, [%{}{{}}]}} | {{:error, term()}}",
159                    func_name, param_specs, struct_name
160                );
161            }
162            QueryCommand::Batch => {
163                let batch_fn_name = format!("{}_batch", func_name);
164                let _ = writeln!(
165                    out,
166                    "@spec {}(Ecto.Repo.t(), list()) :: :ok | {{:error, term()}}",
167                    batch_fn_name
168                );
169                let _ = writeln!(out, "def {}(repo, items) do", batch_fn_name);
170                let _ = writeln!(out, "  Ecto.Multi.new()");
171                let _ = writeln!(out, "  |> then(fn multi ->");
172                let _ = writeln!(out, "    items");
173                let _ = writeln!(out, "    |> Enum.with_index()");
174                let _ = writeln!(out, "    |> Enum.reduce(multi, fn {{item, idx}}, acc ->");
175                if params.len() > 1 {
176                    let _ = writeln!(
177                        out,
178                        "      Ecto.Multi.run(acc, {{:batch, idx}}, fn repo, _changes -> Ecto.Adapters.SQL.query(repo, \"{}\", Tuple.to_list(item)) end)",
179                        sql
180                    );
181                } else if params.len() == 1 {
182                    let _ = writeln!(
183                        out,
184                        "      Ecto.Multi.run(acc, {{:batch, idx}}, fn repo, _changes -> Ecto.Adapters.SQL.query(repo, \"{}\", [item]) end)",
185                        sql
186                    );
187                } else {
188                    let _ = writeln!(
189                        out,
190                        "      Ecto.Multi.run(acc, {{:batch, idx}}, fn repo, _changes -> Ecto.Adapters.SQL.query(repo, \"{}\", []) end)",
191                        sql
192                    );
193                }
194                let _ = writeln!(out, "    end)");
195                let _ = writeln!(out, "  end)");
196                let _ = writeln!(out, "  |> repo.transaction()");
197                let _ = writeln!(out, "  |> case do");
198                let _ = writeln!(out, "    {{:ok, _}} -> :ok");
199                let _ = writeln!(out, "    {{:error, _, reason, _}} -> {{:error, reason}}");
200                let _ = writeln!(out, "  end");
201                let _ = write!(out, "end");
202                return Ok(out);
203            }
204            QueryCommand::Exec => {
205                let _ = writeln!(
206                    out,
207                    "@spec {}(Ecto.Repo.t(){}) :: :ok | {{:error, term()}}",
208                    func_name, param_specs
209                );
210            }
211            QueryCommand::ExecResult | QueryCommand::ExecRows => {
212                let _ = writeln!(
213                    out,
214                    "@spec {}(Ecto.Repo.t(){}) :: {{:ok, non_neg_integer()}} | {{:error, term()}}",
215                    func_name, param_specs
216                );
217            }
218        }
219        let _ = writeln!(out, "def {}(repo{}{}) do", func_name, sep, param_list);
220
221        match &analyzed.command {
222            QueryCommand::One => {
223                let _ = writeln!(
224                    out,
225                    "  case Ecto.Adapters.SQL.query(repo, \"{}\", {}) do",
226                    sql, param_args
227                );
228                let _ = writeln!(out, "    {{:ok, %{{rows: [row | _]}}}} ->");
229
230                // Destructure row
231                let field_vars = columns
232                    .iter()
233                    .map(|c| c.field_name.clone())
234                    .collect::<Vec<_>>()
235                    .join(", ");
236                let _ = writeln!(out, "      [{}] = row", field_vars);
237
238                // Build struct
239                let struct_fields = columns
240                    .iter()
241                    .map(|c| format!("{}: {}", c.field_name, c.field_name))
242                    .collect::<Vec<_>>()
243                    .join(", ");
244                let _ = writeln!(out, "      {{:ok, %{}{{{}}}}}", struct_name, struct_fields);
245                let _ = writeln!(out, "    {{:ok, %{{rows: []}}}} -> {{:ok, nil}}");
246                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
247                let _ = writeln!(out, "  end");
248            }
249            QueryCommand::Many => {
250                let _ = writeln!(
251                    out,
252                    "  case Ecto.Adapters.SQL.query(repo, \"{}\", {}) do",
253                    sql, param_args
254                );
255                let _ = writeln!(out, "    {{:ok, %{{rows: rows}}}} ->");
256
257                let field_vars = columns
258                    .iter()
259                    .map(|c| c.field_name.clone())
260                    .collect::<Vec<_>>()
261                    .join(", ");
262                let struct_fields = columns
263                    .iter()
264                    .map(|c| format!("{}: {}", c.field_name, c.field_name))
265                    .collect::<Vec<_>>()
266                    .join(", ");
267
268                let _ = writeln!(out, "      results = Enum.map(rows, fn row ->");
269                let _ = writeln!(out, "        [{}] = row", field_vars);
270                let _ = writeln!(out, "        %{}{{{}}}", struct_name, struct_fields);
271                let _ = writeln!(out, "      end)");
272                let _ = writeln!(out, "      {{:ok, results}}");
273                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
274                let _ = writeln!(out, "  end");
275            }
276            QueryCommand::Exec => {
277                let _ = writeln!(
278                    out,
279                    "  case Ecto.Adapters.SQL.query(repo, \"{}\", {}) do",
280                    sql, param_args
281                );
282                let _ = writeln!(out, "    {{:ok, _}} -> :ok");
283                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
284                let _ = writeln!(out, "  end");
285            }
286            QueryCommand::ExecResult | QueryCommand::ExecRows => {
287                let _ = writeln!(
288                    out,
289                    "  case Ecto.Adapters.SQL.query(repo, \"{}\", {}) do",
290                    sql, param_args
291                );
292                let _ = writeln!(out, "    {{:ok, %{{num_rows: n}}}} -> {{:ok, n}}");
293                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
294                let _ = writeln!(out, "  end");
295            }
296            QueryCommand::Batch => unreachable!(),
297        }
298
299        let _ = write!(out, "end");
300        Ok(out)
301    }
302
303    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
304        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
305        let mut out = String::new();
306        let _ = writeln!(out, "defmodule {} do", type_name);
307        let _ = writeln!(
308            out,
309            "  @moduledoc \"Enum type for {}.\"",
310            enum_info.sql_name
311        );
312        let _ = writeln!(out);
313        let _ = writeln!(out, "  @type t :: String.t()");
314        for value in &enum_info.values {
315            let variant = enum_variant_name(value, &self.manifest.naming);
316            let _ = writeln!(
317                out,
318                "  def {}(), do: \"{}\"",
319                to_snake_case(&variant),
320                value
321            );
322        }
323        // values/0 function
324        let values_list = enum_info
325            .values
326            .iter()
327            .map(|v| format!("\"{}\"", v))
328            .collect::<Vec<_>>()
329            .join(", ");
330        let _ = writeln!(out, "  def values, do: [{}]", values_list);
331        let _ = write!(out, "end");
332        Ok(out)
333    }
334
335    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
336        let name = to_pascal_case(&composite.sql_name);
337        let mut out = String::new();
338        let _ = writeln!(out, "defmodule {} do", name);
339        let _ = writeln!(
340            out,
341            "  @moduledoc \"Composite type for {}.\"",
342            composite.sql_name
343        );
344        let _ = writeln!(out);
345        if composite.fields.is_empty() {
346            let _ = writeln!(out, "  defstruct []");
347        } else {
348            let fields = composite
349                .fields
350                .iter()
351                .map(|f| format!(":{}", to_snake_case(&f.name)))
352                .collect::<Vec<_>>()
353                .join(", ");
354            let _ = writeln!(out, "  defstruct [{}]", fields);
355        }
356        let _ = write!(out, "end");
357        Ok(out)
358    }
359}