Skip to main content

scythe_codegen/backends/
elixir_postgrex.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/elixir-postgrex.toml");
16
17pub struct ElixirPostgrexBackend {
18    manifest: BackendManifest,
19}
20
21impl ElixirPostgrexBackend {
22    pub fn new() -> Result<Self, ScytheError> {
23        let manifest_path = Path::new("backends/elixir-postgrex/manifest.toml");
24        let manifest = if manifest_path.exists() {
25            load_manifest(manifest_path)
26                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
27        } else {
28            toml::from_str(DEFAULT_MANIFEST_TOML)
29                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30        };
31        Ok(Self { manifest })
32    }
33
34    pub fn manifest(&self) -> &BackendManifest {
35        &self.manifest
36    }
37}
38
39impl CodegenBackend for ElixirPostgrexBackend {
40    fn name(&self) -> &str {
41        "elixir-postgrex"
42    }
43
44    fn generate_row_struct(
45        &self,
46        query_name: &str,
47        columns: &[ResolvedColumn],
48    ) -> Result<String, ScytheError> {
49        let struct_name = row_struct_name(query_name, &self.manifest.naming);
50        let mut out = String::new();
51        let _ = writeln!(out, "defmodule {} do", struct_name);
52        let _ = writeln!(out, "  @moduledoc \"Row type for {} queries.\"", query_name);
53        let _ = writeln!(out);
54
55        // Generate typespec
56        let _ = writeln!(out, "  @type t :: %__MODULE__{{");
57        for (i, c) in columns.iter().enumerate() {
58            let sep = if i + 1 < columns.len() { "," } else { "" };
59            let _ = writeln!(out, "    {}: {}{}", c.field_name, c.full_type, sep);
60        }
61        let _ = writeln!(out, "  }}");
62
63        // Generate defstruct
64        let fields = columns
65            .iter()
66            .map(|c| format!(":{}", c.field_name))
67            .collect::<Vec<_>>()
68            .join(", ");
69        let _ = writeln!(out, "  defstruct [{}]", fields);
70        let _ = write!(out, "end");
71        Ok(out)
72    }
73
74    fn generate_model_struct(
75        &self,
76        table_name: &str,
77        columns: &[ResolvedColumn],
78    ) -> Result<String, ScytheError> {
79        let name = to_pascal_case(table_name);
80        self.generate_row_struct(&name, columns)
81    }
82
83    fn generate_query_fn(
84        &self,
85        analyzed: &AnalyzedQuery,
86        struct_name: &str,
87        columns: &[ResolvedColumn],
88        params: &[ResolvedParam],
89    ) -> Result<String, ScytheError> {
90        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
91        let sql = super::clean_sql(&analyzed.sql);
92        let mut out = String::new();
93
94        // Parameter list
95        let param_list = params
96            .iter()
97            .map(|p| p.field_name.clone())
98            .collect::<Vec<_>>()
99            .join(", ");
100        let sep = if param_list.is_empty() { "" } else { ", " };
101
102        // Build the params array for Postgrex.query
103        let param_args = if params.is_empty() {
104            "[]".to_string()
105        } else {
106            format!(
107                "[{}]",
108                params
109                    .iter()
110                    .map(|p| p.field_name.clone())
111                    .collect::<Vec<_>>()
112                    .join(", ")
113            )
114        };
115
116        // Build @spec
117        let param_specs = if params.is_empty() {
118            String::new()
119        } else {
120            let specs: Vec<String> = params.iter().map(|p| p.full_type.clone()).collect();
121            format!(", {}", specs.join(", "))
122        };
123        match &analyzed.command {
124            QueryCommand::One => {
125                let _ = writeln!(
126                    out,
127                    "@spec {}(pid(){}) :: {{:ok, %{}{{}}}} | {{:error, term()}}",
128                    func_name, param_specs, struct_name
129                );
130            }
131            QueryCommand::Many | QueryCommand::Batch => {
132                let _ = writeln!(
133                    out,
134                    "@spec {}(pid(){}) :: {{:ok, [%{}{{}}]}} | {{:error, term()}}",
135                    func_name, param_specs, struct_name
136                );
137            }
138            QueryCommand::Exec => {
139                let _ = writeln!(
140                    out,
141                    "@spec {}(pid(){}) :: :ok | {{:error, term()}}",
142                    func_name, param_specs
143                );
144            }
145            QueryCommand::ExecResult | QueryCommand::ExecRows => {
146                let _ = writeln!(
147                    out,
148                    "@spec {}(pid(){}) :: {{:ok, non_neg_integer()}} | {{:error, term()}}",
149                    func_name, param_specs
150                );
151            }
152        }
153        let _ = writeln!(out, "def {}(conn{}{}) do", func_name, sep, param_list);
154
155        match &analyzed.command {
156            QueryCommand::One => {
157                let _ = writeln!(
158                    out,
159                    "  case Postgrex.query(conn, \"{}\", {}) do",
160                    sql, param_args
161                );
162                let _ = writeln!(out, "    {{:ok, %{{rows: [row]}}}} ->");
163
164                // Destructure row
165                let field_vars = columns
166                    .iter()
167                    .map(|c| c.field_name.clone())
168                    .collect::<Vec<_>>()
169                    .join(", ");
170                let _ = writeln!(out, "      [{}] = row", field_vars);
171
172                // Build struct
173                let struct_fields = columns
174                    .iter()
175                    .map(|c| format!("{}: {}", c.field_name, c.field_name))
176                    .collect::<Vec<_>>()
177                    .join(", ");
178                let _ = writeln!(out, "      {{:ok, %{}{{{}}}}}", struct_name, struct_fields);
179                let _ = writeln!(out, "    {{:ok, %{{rows: []}}}} -> {{:error, :not_found}}");
180                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
181                let _ = writeln!(out, "  end");
182            }
183            QueryCommand::Many | QueryCommand::Batch => {
184                let _ = writeln!(
185                    out,
186                    "  case Postgrex.query(conn, \"{}\", {}) do",
187                    sql, param_args
188                );
189                let _ = writeln!(out, "    {{:ok, %{{rows: rows}}}} ->");
190
191                let field_vars = columns
192                    .iter()
193                    .map(|c| c.field_name.clone())
194                    .collect::<Vec<_>>()
195                    .join(", ");
196                let struct_fields = columns
197                    .iter()
198                    .map(|c| format!("{}: {}", c.field_name, c.field_name))
199                    .collect::<Vec<_>>()
200                    .join(", ");
201
202                let _ = writeln!(out, "      results = Enum.map(rows, fn row ->");
203                let _ = writeln!(out, "        [{}] = row", field_vars);
204                let _ = writeln!(out, "        %{}{{{}}}", struct_name, struct_fields);
205                let _ = writeln!(out, "      end)");
206                let _ = writeln!(out, "      {{:ok, results}}");
207                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
208                let _ = writeln!(out, "  end");
209            }
210            QueryCommand::Exec => {
211                let _ = writeln!(
212                    out,
213                    "  case Postgrex.query(conn, \"{}\", {}) do",
214                    sql, param_args
215                );
216                let _ = writeln!(out, "    {{:ok, _}} -> :ok");
217                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
218                let _ = writeln!(out, "  end");
219            }
220            QueryCommand::ExecResult | QueryCommand::ExecRows => {
221                let _ = writeln!(
222                    out,
223                    "  case Postgrex.query(conn, \"{}\", {}) do",
224                    sql, param_args
225                );
226                let _ = writeln!(out, "    {{:ok, %{{num_rows: n}}}} -> {{:ok, n}}");
227                let _ = writeln!(out, "    {{:error, err}} -> {{:error, err}}");
228                let _ = writeln!(out, "  end");
229            }
230        }
231
232        let _ = write!(out, "end");
233        Ok(out)
234    }
235
236    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
237        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
238        let mut out = String::new();
239        let _ = writeln!(out, "defmodule {} do", type_name);
240        let _ = writeln!(
241            out,
242            "  @moduledoc \"Enum type for {}.\"",
243            enum_info.sql_name
244        );
245        let _ = writeln!(out);
246        let _ = writeln!(out, "  @type t :: String.t()");
247        for value in &enum_info.values {
248            let variant = enum_variant_name(value, &self.manifest.naming);
249            let _ = writeln!(
250                out,
251                "  def {}(), do: \"{}\"",
252                to_snake_case(&variant),
253                value
254            );
255        }
256        // values/0 function
257        let values_list = enum_info
258            .values
259            .iter()
260            .map(|v| format!("\"{}\"", v))
261            .collect::<Vec<_>>()
262            .join(", ");
263        let _ = writeln!(out, "  def values, do: [{}]", values_list);
264        let _ = write!(out, "end");
265        Ok(out)
266    }
267
268    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
269        let name = to_pascal_case(&composite.sql_name);
270        let mut out = String::new();
271        let _ = writeln!(out, "defmodule {} do", name);
272        let _ = writeln!(
273            out,
274            "  @moduledoc \"Composite type for {}.\"",
275            composite.sql_name
276        );
277        let _ = writeln!(out);
278        if composite.fields.is_empty() {
279            let _ = writeln!(out, "  defstruct []");
280        } else {
281            let fields = composite
282                .fields
283                .iter()
284                .map(|f| format!(":{}", to_snake_case(&f.name)))
285                .collect::<Vec<_>>()
286                .join(", ");
287            let _ = writeln!(out, "  defstruct [{}]", fields);
288        }
289        let _ = write!(out, "end");
290        Ok(out)
291    }
292}