Skip to main content

scythe_codegen/backends/
php_pdo.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/php-pdo.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/php-pdo.mysql.toml");
17const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/php-pdo.sqlite.toml");
18
19pub struct PhpPdoBackend {
20    manifest: BackendManifest,
21}
22
23impl PhpPdoBackend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        let default_toml = match engine {
26            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
27            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
28            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
29            _ => {
30                return Err(ScytheError::new(
31                    ErrorCode::InternalError,
32                    format!("unsupported engine '{}' for php-pdo backend", engine),
33                ));
34            }
35        };
36        let manifest_path = Path::new("backends/php-pdo/manifest.toml");
37        let manifest = if manifest_path.exists() {
38            load_manifest(manifest_path)
39                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40        } else {
41            toml::from_str(default_toml)
42                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43        };
44        Ok(Self { manifest })
45    }
46}
47
48/// Rewrite $1, $2, ... to :p1, :p2, ...
49fn rewrite_params(sql: &str) -> String {
50    let mut result = sql.to_string();
51    // Replace from highest number down to avoid $1 matching inside $10
52    for i in (1..=99).rev() {
53        let from = format!("${}", i);
54        let to = format!(":p{}", i);
55        result = result.replace(&from, &to);
56    }
57    result
58}
59
60/// Map a neutral type to a PHP cast expression.
61fn php_cast(neutral_type: &str) -> &'static str {
62    match neutral_type {
63        "int16" | "int32" | "int64" => "(int) ",
64        "float32" | "float64" => "(float) ",
65        "bool" => "(bool) ",
66        "string" | "json" | "inet" | "interval" | "uuid" | "decimal" | "bytes" => "(string) ",
67        _ => "",
68    }
69}
70
71impl CodegenBackend for PhpPdoBackend {
72    fn name(&self) -> &str {
73        "php-pdo"
74    }
75
76    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
77        &self.manifest
78    }
79
80    fn supported_engines(&self) -> &[&str] {
81        &["postgresql", "mysql", "sqlite"]
82    }
83
84    fn file_header(&self) -> String {
85        "<?php\n\ndeclare(strict_types=1);\n\nnamespace App\\Generated;\n\n// Auto-generated by scythe. Do not edit.\n"
86            .to_string()
87    }
88
89    fn query_class_header(&self) -> String {
90        "final class Queries {".to_string()
91    }
92
93    fn file_footer(&self) -> String {
94        "}".to_string()
95    }
96
97    fn generate_row_struct(
98        &self,
99        query_name: &str,
100        columns: &[ResolvedColumn],
101    ) -> Result<String, ScytheError> {
102        let struct_name = row_struct_name(query_name, &self.manifest.naming);
103        let mut out = String::new();
104
105        // Readonly class with constructor
106        let _ = writeln!(out, "readonly class {} {{", struct_name);
107        let _ = writeln!(out, "    public function __construct(");
108        for c in columns.iter() {
109            let sep = ",";
110            let _ = writeln!(
111                out,
112                "        public {} ${}{}",
113                c.full_type, c.field_name, sep
114            );
115        }
116        let _ = writeln!(out, "    ) {{}}");
117        let _ = writeln!(out);
118
119        // fromRow factory method
120        let _ = writeln!(
121            out,
122            "    public static function fromRow(array $row): self {{"
123        );
124        let _ = writeln!(out, "        return new self(");
125        for c in columns.iter() {
126            let sep = ",";
127            let is_enum = c.neutral_type.starts_with("enum::");
128            let is_datetime = matches!(
129                c.neutral_type.as_str(),
130                "date" | "time" | "time_tz" | "datetime" | "datetime_tz"
131            );
132            if is_enum {
133                // Enum columns: convert DB string to PHP backed enum via ::from()
134                let enum_type = &c.lang_type;
135                if c.nullable {
136                    let _ = writeln!(
137                        out,
138                        "            {}: $row['{}'] !== null ? {}::from($row['{}']) : null{}",
139                        c.field_name, c.name, enum_type, c.name, sep
140                    );
141                } else {
142                    let _ = writeln!(
143                        out,
144                        "            {}: {}::from($row['{}']){}",
145                        c.field_name, enum_type, c.name, sep
146                    );
147                }
148            } else if is_datetime {
149                // DateTime columns: PDO returns strings, wrap in DateTimeImmutable
150                if c.nullable {
151                    let _ = writeln!(
152                        out,
153                        "            {}: $row['{}'] !== null ? new \\DateTimeImmutable($row['{}']) : null{}",
154                        c.field_name, c.name, c.name, sep
155                    );
156                } else {
157                    let _ = writeln!(
158                        out,
159                        "            {}: new \\DateTimeImmutable($row['{}']){}",
160                        c.field_name, c.name, sep
161                    );
162                }
163            } else {
164                let cast = php_cast(&c.neutral_type);
165                if c.nullable {
166                    let _ = writeln!(
167                        out,
168                        "            {}: $row['{}'] !== null ? {}{} : null{}",
169                        c.field_name,
170                        c.name,
171                        cast,
172                        format_args!("$row['{}']", c.name),
173                        sep
174                    );
175                } else {
176                    let _ = writeln!(
177                        out,
178                        "            {}: {}$row['{}']{}",
179                        c.field_name, cast, c.name, sep
180                    );
181                }
182            }
183        }
184        let _ = writeln!(out, "        );");
185        let _ = writeln!(out, "    }}");
186        let _ = write!(out, "}}");
187        Ok(out)
188    }
189
190    fn generate_model_struct(
191        &self,
192        table_name: &str,
193        columns: &[ResolvedColumn],
194    ) -> Result<String, ScytheError> {
195        let name = to_pascal_case(table_name);
196        self.generate_row_struct(&name, columns)
197    }
198
199    fn generate_query_fn(
200        &self,
201        analyzed: &AnalyzedQuery,
202        struct_name: &str,
203        _columns: &[ResolvedColumn],
204        params: &[ResolvedParam],
205    ) -> Result<String, ScytheError> {
206        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
207        let sql = rewrite_params(&super::clean_sql_oneline(&analyzed.sql));
208        let mut out = String::new();
209
210        // Build PHP parameter list
211        let param_list = params
212            .iter()
213            .map(|p| format!("{} ${}", p.full_type, p.field_name))
214            .collect::<Vec<_>>()
215            .join(", ");
216        let sep = if param_list.is_empty() { "" } else { ", " };
217
218        // Return type depends on command
219        let return_type = match &analyzed.command {
220            QueryCommand::One => format!("?{}", struct_name),
221            QueryCommand::Many | QueryCommand::Batch => "\\Generator".to_string(),
222            QueryCommand::Exec => "void".to_string(),
223            QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
224        };
225
226        let _ = writeln!(
227            out,
228            "    public static function {}(\\PDO $pdo{}{}): {} {{",
229            func_name, sep, param_list, return_type
230        );
231
232        // Prepare statement
233        let _ = writeln!(out, "        $stmt = $pdo->prepare(\"{}\");", sql);
234
235        // Build execute params
236        // If the SQL contains `?` placeholders (MySQL/SQLite), use positional array.
237        // If it contains `:pN` placeholders (PostgreSQL), use named array.
238        if params.is_empty() {
239            let _ = writeln!(out, "        $stmt->execute();");
240        } else {
241            let use_positional = sql.contains('?');
242            let bindings = params
243                .iter()
244                .enumerate()
245                .map(|(i, p)| {
246                    let value = if p.neutral_type.starts_with("enum::") {
247                        format!("${}->value", p.field_name)
248                    } else {
249                        format!("${}", p.field_name)
250                    };
251                    if use_positional {
252                        value
253                    } else {
254                        format!("\"p{}\" => {}", i + 1, value)
255                    }
256                })
257                .collect::<Vec<_>>()
258                .join(", ");
259            let _ = writeln!(out, "        $stmt->execute([{}]);", bindings);
260        }
261
262        match &analyzed.command {
263            QueryCommand::One => {
264                let _ = writeln!(out, "        $row = $stmt->fetch(\\PDO::FETCH_ASSOC);");
265                let _ = writeln!(
266                    out,
267                    "        return $row ? {}::fromRow($row) : null;",
268                    struct_name
269                );
270            }
271            QueryCommand::Many | QueryCommand::Batch => {
272                let _ = writeln!(
273                    out,
274                    "        while ($row = $stmt->fetch(\\PDO::FETCH_ASSOC)) {{"
275                );
276                let _ = writeln!(out, "            yield {}::fromRow($row);", struct_name);
277                let _ = writeln!(out, "        }}");
278            }
279            QueryCommand::Exec => {
280                // nothing else needed
281            }
282            QueryCommand::ExecResult | QueryCommand::ExecRows => {
283                let _ = writeln!(out, "        return $stmt->rowCount();");
284            }
285        }
286
287        let _ = write!(out, "    }}");
288        Ok(out)
289    }
290
291    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
292        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
293        let mut out = String::new();
294        let _ = writeln!(out, "enum {}: string {{", type_name);
295        for value in &enum_info.values {
296            let variant = enum_variant_name(value, &self.manifest.naming);
297            let _ = writeln!(out, "    case {} = \"{}\";", variant, value);
298        }
299        let _ = write!(out, "}}");
300        Ok(out)
301    }
302
303    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
304        let name = to_pascal_case(&composite.sql_name);
305        let mut out = String::new();
306        let _ = writeln!(out, "readonly class {} {{", name);
307        let _ = writeln!(out, "    public function __construct(");
308        if composite.fields.is_empty() {
309            // empty constructor
310        } else {
311            for field in &composite.fields {
312                let _ = writeln!(out, "        public mixed ${},", field.name);
313            }
314        }
315        let _ = writeln!(out, "    ) {{}}");
316        let _ = write!(out, "}}");
317        Ok(out)
318    }
319}