Skip to main content

scythe_codegen/backends/
ruby_pg.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/ruby-pg.toml");
16
17pub struct RubyPgBackend {
18    manifest: BackendManifest,
19}
20
21impl RubyPgBackend {
22    pub fn new(engine: &str) -> Result<Self, ScytheError> {
23        match engine {
24            "postgresql" | "postgres" | "pg" => {}
25            _ => {
26                return Err(ScytheError::new(
27                    ErrorCode::InternalError,
28                    format!("ruby-pg only supports PostgreSQL, got engine '{}'", engine),
29                ));
30            }
31        }
32        let manifest_path = Path::new("backends/ruby-pg/manifest.toml");
33        let manifest = if manifest_path.exists() {
34            load_manifest(manifest_path)
35                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
36        } else {
37            toml::from_str(DEFAULT_MANIFEST_TOML)
38                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
39        };
40        Ok(Self { manifest })
41    }
42}
43
44/// Map a neutral type to a Ruby type coercion method.
45fn ruby_coercion(neutral_type: &str) -> &'static str {
46    match neutral_type {
47        "int16" | "int32" | "int64" => ".to_i",
48        "float32" | "float64" => ".to_f",
49        "bool" => " == \"t\"",
50        _ => "",
51    }
52}
53
54impl CodegenBackend for RubyPgBackend {
55    fn name(&self) -> &str {
56        "ruby-pg"
57    }
58
59    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
60        &self.manifest
61    }
62
63    fn file_header(&self) -> String {
64        "# frozen_string_literal: true\n\n# Auto-generated by scythe. Do not edit.\n".to_string()
65    }
66
67    fn generate_row_struct(
68        &self,
69        query_name: &str,
70        columns: &[ResolvedColumn],
71    ) -> Result<String, ScytheError> {
72        let struct_name = row_struct_name(query_name, &self.manifest.naming);
73        let fields = columns
74            .iter()
75            .map(|c| format!(":{}", c.field_name))
76            .collect::<Vec<_>>()
77            .join(", ");
78        let mut out = String::new();
79        let _ = writeln!(out, "{} = Data.define({})", struct_name, fields);
80        Ok(out)
81    }
82
83    fn generate_model_struct(
84        &self,
85        table_name: &str,
86        columns: &[ResolvedColumn],
87    ) -> Result<String, ScytheError> {
88        let name = to_pascal_case(table_name);
89        self.generate_row_struct(&name, columns)
90    }
91
92    fn generate_query_fn(
93        &self,
94        analyzed: &AnalyzedQuery,
95        struct_name: &str,
96        columns: &[ResolvedColumn],
97        params: &[ResolvedParam],
98    ) -> Result<String, ScytheError> {
99        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
100        let sql = super::clean_sql(&analyzed.sql);
101        let mut out = String::new();
102
103        // Parameter list
104        let param_list = params
105            .iter()
106            .map(|p| p.field_name.clone())
107            .collect::<Vec<_>>()
108            .join(", ");
109        let sep = if param_list.is_empty() { "" } else { ", " };
110
111        let _ = writeln!(out, "def self.{}(conn{}{})", func_name, sep, param_list);
112
113        // Build exec_params call
114        let param_array = if params.is_empty() {
115            "[]".to_string()
116        } else {
117            format!(
118                "[{}]",
119                params
120                    .iter()
121                    .map(|p| p.field_name.clone())
122                    .collect::<Vec<_>>()
123                    .join(", ")
124            )
125        };
126
127        match &analyzed.command {
128            QueryCommand::One => {
129                let _ = writeln!(
130                    out,
131                    "  result = conn.exec_params(\"{}\", {})",
132                    sql, param_array
133                );
134                let _ = writeln!(out, "  return nil if result.ntuples.zero?");
135                let _ = writeln!(out, "  row = result[0]");
136
137                // Build struct constructor
138                let fields = columns
139                    .iter()
140                    .map(|c| {
141                        let coercion = ruby_coercion(&c.neutral_type);
142                        if c.nullable {
143                            format!(
144                                "{}: row[\"{}\"]&.then {{ |v| v{} }}",
145                                c.field_name, c.name, coercion
146                            )
147                        } else {
148                            format!("{}: row[\"{}\"]{}", c.field_name, c.name, coercion)
149                        }
150                    })
151                    .collect::<Vec<_>>()
152                    .join(", ");
153                let _ = writeln!(out, "  {}.new({})", struct_name, fields);
154            }
155            QueryCommand::Many | QueryCommand::Batch => {
156                let _ = writeln!(
157                    out,
158                    "  result = conn.exec_params(\"{}\", {})",
159                    sql, param_array
160                );
161                let _ = writeln!(out, "  result.map do |row|");
162                let fields = columns
163                    .iter()
164                    .map(|c| {
165                        let coercion = ruby_coercion(&c.neutral_type);
166                        if c.nullable {
167                            format!(
168                                "{}: row[\"{}\"]&.then {{ |v| v{} }}",
169                                c.field_name, c.name, coercion
170                            )
171                        } else {
172                            format!("{}: row[\"{}\"]{}", c.field_name, c.name, coercion)
173                        }
174                    })
175                    .collect::<Vec<_>>()
176                    .join(", ");
177                let _ = writeln!(out, "    {}.new({})", struct_name, fields);
178                let _ = writeln!(out, "  end");
179            }
180            QueryCommand::Exec => {
181                let _ = writeln!(out, "  conn.exec_params(\"{}\", {})", sql, param_array);
182                let _ = writeln!(out, "  nil");
183            }
184            QueryCommand::ExecResult | QueryCommand::ExecRows => {
185                let _ = writeln!(
186                    out,
187                    "  result = conn.exec_params(\"{}\", {})",
188                    sql, param_array
189                );
190                let _ = writeln!(out, "  result.cmd_tuples.to_i");
191            }
192        }
193
194        let _ = write!(out, "end");
195        Ok(out)
196    }
197
198    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
199        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
200        let mut out = String::new();
201        let _ = writeln!(out, "module {}", type_name);
202        for value in &enum_info.values {
203            let variant = enum_variant_name(value, &self.manifest.naming);
204            let _ = writeln!(out, "  {} = \"{}\"", variant, value);
205        }
206        // ALL constant
207        let all_values = enum_info
208            .values
209            .iter()
210            .map(|v| enum_variant_name(v, &self.manifest.naming))
211            .collect::<Vec<_>>()
212            .join(", ");
213        let _ = writeln!(out, "  ALL = [{}].freeze", all_values);
214        let _ = write!(out, "end");
215        Ok(out)
216    }
217
218    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
219        let name = to_pascal_case(&composite.sql_name);
220        let mut out = String::new();
221        if composite.fields.is_empty() {
222            let _ = writeln!(out, "{} = Data.define()", name);
223        } else {
224            let fields = composite
225                .fields
226                .iter()
227                .map(|f| format!(":{}", f.name))
228                .collect::<Vec<_>>()
229                .join(", ");
230            let _ = writeln!(out, "{} = Data.define({})", name, fields);
231        }
232        Ok(out)
233    }
234}