Skip to main content

scythe_codegen/backends/
ruby_pg.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/ruby-pg.toml");
16
17pub struct RubyPgBackend {
18    manifest: BackendManifest,
19}
20
21impl RubyPgBackend {
22    pub fn new() -> Result<Self, ScytheError> {
23        let manifest_path = Path::new("backends/ruby-pg/manifest.toml");
24        let manifest = if manifest_path.exists() {
25            load_manifest(manifest_path)
26                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
27        } else {
28            toml::from_str(DEFAULT_MANIFEST_TOML)
29                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30        };
31        Ok(Self { manifest })
32    }
33
34    pub fn manifest(&self) -> &BackendManifest {
35        &self.manifest
36    }
37}
38
39/// Map a neutral type to a Ruby type coercion method.
40fn ruby_coercion(neutral_type: &str) -> &'static str {
41    match neutral_type {
42        "int16" | "int32" | "int64" => ".to_i",
43        "float32" | "float64" => ".to_f",
44        "bool" => " == \"t\"",
45        _ => "",
46    }
47}
48
49impl CodegenBackend for RubyPgBackend {
50    fn name(&self) -> &str {
51        "ruby-pg"
52    }
53
54    fn file_header(&self) -> String {
55        "# frozen_string_literal: true\n\n# Auto-generated by scythe. Do not edit.\n".to_string()
56    }
57
58    fn generate_row_struct(
59        &self,
60        query_name: &str,
61        columns: &[ResolvedColumn],
62    ) -> Result<String, ScytheError> {
63        let struct_name = row_struct_name(query_name, &self.manifest.naming);
64        let fields = columns
65            .iter()
66            .map(|c| format!(":{}", c.field_name))
67            .collect::<Vec<_>>()
68            .join(", ");
69        let mut out = String::new();
70        let _ = writeln!(out, "{} = Data.define({})", struct_name, fields);
71        Ok(out)
72    }
73
74    fn generate_model_struct(
75        &self,
76        table_name: &str,
77        columns: &[ResolvedColumn],
78    ) -> Result<String, ScytheError> {
79        let name = to_pascal_case(table_name);
80        self.generate_row_struct(&name, columns)
81    }
82
83    fn generate_query_fn(
84        &self,
85        analyzed: &AnalyzedQuery,
86        struct_name: &str,
87        columns: &[ResolvedColumn],
88        params: &[ResolvedParam],
89    ) -> Result<String, ScytheError> {
90        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
91        let sql = super::clean_sql(&analyzed.sql);
92        let mut out = String::new();
93
94        // Parameter list
95        let param_list = params
96            .iter()
97            .map(|p| p.field_name.clone())
98            .collect::<Vec<_>>()
99            .join(", ");
100        let sep = if param_list.is_empty() { "" } else { ", " };
101
102        let _ = writeln!(out, "def self.{}(conn{}{})", func_name, sep, param_list);
103
104        // Build exec_params call
105        let param_array = if params.is_empty() {
106            "[]".to_string()
107        } else {
108            format!(
109                "[{}]",
110                params
111                    .iter()
112                    .map(|p| p.field_name.clone())
113                    .collect::<Vec<_>>()
114                    .join(", ")
115            )
116        };
117
118        match &analyzed.command {
119            QueryCommand::One => {
120                let _ = writeln!(
121                    out,
122                    "  result = conn.exec_params(\"{}\", {})",
123                    sql, param_array
124                );
125                let _ = writeln!(out, "  return nil if result.ntuples.zero?");
126                let _ = writeln!(out, "  row = result[0]");
127
128                // Build struct constructor
129                let fields = columns
130                    .iter()
131                    .map(|c| {
132                        let coercion = ruby_coercion(&c.neutral_type);
133                        if c.nullable {
134                            format!(
135                                "{}: row[\"{}\"]&.then {{ |v| v{} }}",
136                                c.field_name, c.name, coercion
137                            )
138                        } else {
139                            format!("{}: row[\"{}\"]{}", c.field_name, c.name, coercion)
140                        }
141                    })
142                    .collect::<Vec<_>>()
143                    .join(", ");
144                let _ = writeln!(out, "  {}.new({})", struct_name, fields);
145            }
146            QueryCommand::Many | QueryCommand::Batch => {
147                let _ = writeln!(
148                    out,
149                    "  result = conn.exec_params(\"{}\", {})",
150                    sql, param_array
151                );
152                let _ = writeln!(out, "  result.map do |row|");
153                let fields = columns
154                    .iter()
155                    .map(|c| {
156                        let coercion = ruby_coercion(&c.neutral_type);
157                        if c.nullable {
158                            format!(
159                                "{}: row[\"{}\"]&.then {{ |v| v{} }}",
160                                c.field_name, c.name, coercion
161                            )
162                        } else {
163                            format!("{}: row[\"{}\"]{}", c.field_name, c.name, coercion)
164                        }
165                    })
166                    .collect::<Vec<_>>()
167                    .join(", ");
168                let _ = writeln!(out, "    {}.new({})", struct_name, fields);
169                let _ = writeln!(out, "  end");
170            }
171            QueryCommand::Exec => {
172                let _ = writeln!(out, "  conn.exec_params(\"{}\", {})", sql, param_array);
173                let _ = writeln!(out, "  nil");
174            }
175            QueryCommand::ExecResult | QueryCommand::ExecRows => {
176                let _ = writeln!(
177                    out,
178                    "  result = conn.exec_params(\"{}\", {})",
179                    sql, param_array
180                );
181                let _ = writeln!(out, "  result.cmd_tuples.to_i");
182            }
183        }
184
185        let _ = write!(out, "end");
186        Ok(out)
187    }
188
189    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
190        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
191        let mut out = String::new();
192        let _ = writeln!(out, "module {}", type_name);
193        for value in &enum_info.values {
194            let variant = enum_variant_name(value, &self.manifest.naming);
195            let _ = writeln!(out, "  {} = \"{}\"", variant, value);
196        }
197        // ALL constant
198        let all_values = enum_info
199            .values
200            .iter()
201            .map(|v| enum_variant_name(v, &self.manifest.naming))
202            .collect::<Vec<_>>()
203            .join(", ");
204        let _ = writeln!(out, "  ALL = [{}].freeze", all_values);
205        let _ = write!(out, "end");
206        Ok(out)
207    }
208
209    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
210        let name = to_pascal_case(&composite.sql_name);
211        let mut out = String::new();
212        if composite.fields.is_empty() {
213            let _ = writeln!(out, "{} = Data.define()", name);
214        } else {
215            let fields = composite
216                .fields
217                .iter()
218                .map(|f| format!(":{}", f.name))
219                .collect::<Vec<_>>()
220                .join(", ");
221            let _ = writeln!(out, "{} = Data.define({})", name, fields);
222        }
223        Ok(out)
224    }
225}