Skip to main content

scythe_codegen/backends/
php_pdo.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/php-pdo.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/php-pdo.mysql.toml");
17const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/php-pdo.sqlite.toml");
18
19pub struct PhpPdoBackend {
20    manifest: BackendManifest,
21}
22
23impl PhpPdoBackend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        let default_toml = match engine {
26            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
27            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
28            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
29            _ => {
30                return Err(ScytheError::new(
31                    ErrorCode::InternalError,
32                    format!("unsupported engine '{}' for php-pdo backend", engine),
33                ));
34            }
35        };
36        let manifest_path = Path::new("backends/php-pdo/manifest.toml");
37        let manifest = if manifest_path.exists() {
38            load_manifest(manifest_path)
39                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40        } else {
41            toml::from_str(default_toml)
42                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43        };
44        Ok(Self { manifest })
45    }
46}
47
48/// Rewrite $1, $2, ... to :p1, :p2, ...
49fn rewrite_params(sql: &str) -> String {
50    let mut result = sql.to_string();
51    // Replace from highest number down to avoid $1 matching inside $10
52    for i in (1..=99).rev() {
53        let from = format!("${}", i);
54        let to = format!(":p{}", i);
55        result = result.replace(&from, &to);
56    }
57    result
58}
59
60/// Map a neutral type to a PHP cast expression.
61fn php_cast(neutral_type: &str) -> &'static str {
62    match neutral_type {
63        "int16" | "int32" | "int64" => "(int) ",
64        "float32" | "float64" => "(float) ",
65        "bool" => "(bool) ",
66        "string" | "json" | "inet" | "interval" | "uuid" | "decimal" | "bytes" => "(string) ",
67        _ => "",
68    }
69}
70
71impl CodegenBackend for PhpPdoBackend {
72    fn name(&self) -> &str {
73        "php-pdo"
74    }
75
76    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
77        &self.manifest
78    }
79
80    fn supported_engines(&self) -> &[&str] {
81        &["postgresql", "mysql", "sqlite"]
82    }
83
84    fn file_header(&self) -> String {
85        "<?php\n\ndeclare(strict_types=1);\n\nnamespace App\\Generated;\n\n// Auto-generated by scythe. Do not edit.\n"
86            .to_string()
87    }
88
89    fn query_class_header(&self) -> String {
90        "final class Queries {".to_string()
91    }
92
93    fn file_footer(&self) -> String {
94        "}".to_string()
95    }
96
97    fn generate_row_struct(
98        &self,
99        query_name: &str,
100        columns: &[ResolvedColumn],
101    ) -> Result<String, ScytheError> {
102        let struct_name = row_struct_name(query_name, &self.manifest.naming);
103        let mut out = String::new();
104
105        // Readonly class with constructor
106        let _ = writeln!(out, "readonly class {} {{", struct_name);
107        let _ = writeln!(out, "    public function __construct(");
108        for c in columns.iter() {
109            let sep = ",";
110            let _ = writeln!(
111                out,
112                "        public {} ${}{}",
113                c.full_type, c.field_name, sep
114            );
115        }
116        let _ = writeln!(out, "    ) {{}}");
117        let _ = writeln!(out);
118
119        // fromRow factory method
120        let _ = writeln!(
121            out,
122            "    public static function fromRow(array $row): self {{"
123        );
124        let _ = writeln!(out, "        return new self(");
125        for c in columns.iter() {
126            let sep = ",";
127            let is_enum = c.neutral_type.starts_with("enum::");
128            let is_datetime = matches!(
129                c.neutral_type.as_str(),
130                "date" | "time" | "time_tz" | "datetime" | "datetime_tz"
131            );
132            if is_enum {
133                // Enum columns: convert DB string to PHP backed enum via ::from()
134                let enum_type = &c.lang_type;
135                if c.nullable {
136                    let _ = writeln!(
137                        out,
138                        "            {}: $row['{}'] !== null ? {}::from($row['{}']) : null{}",
139                        c.field_name, c.name, enum_type, c.name, sep
140                    );
141                } else {
142                    let _ = writeln!(
143                        out,
144                        "            {}: {}::from($row['{}']){}",
145                        c.field_name, enum_type, c.name, sep
146                    );
147                }
148            } else if is_datetime {
149                // DateTime columns: PDO returns strings, wrap in DateTimeImmutable
150                if c.nullable {
151                    let _ = writeln!(
152                        out,
153                        "            {}: $row['{}'] !== null ? new \\DateTimeImmutable($row['{}']) : null{}",
154                        c.field_name, c.name, c.name, sep
155                    );
156                } else {
157                    let _ = writeln!(
158                        out,
159                        "            {}: new \\DateTimeImmutable($row['{}']){}",
160                        c.field_name, c.name, sep
161                    );
162                }
163            } else {
164                let cast = php_cast(&c.neutral_type);
165                if c.nullable {
166                    let _ = writeln!(
167                        out,
168                        "            {}: $row['{}'] !== null ? {}{} : null{}",
169                        c.field_name,
170                        c.name,
171                        cast,
172                        format_args!("$row['{}']", c.name),
173                        sep
174                    );
175                } else {
176                    let _ = writeln!(
177                        out,
178                        "            {}: {}$row['{}']{}",
179                        c.field_name, cast, c.name, sep
180                    );
181                }
182            }
183        }
184        let _ = writeln!(out, "        );");
185        let _ = writeln!(out, "    }}");
186        let _ = write!(out, "}}");
187        Ok(out)
188    }
189
190    fn generate_model_struct(
191        &self,
192        table_name: &str,
193        columns: &[ResolvedColumn],
194    ) -> Result<String, ScytheError> {
195        let name = to_pascal_case(table_name);
196        self.generate_row_struct(&name, columns)
197    }
198
199    fn generate_query_fn(
200        &self,
201        analyzed: &AnalyzedQuery,
202        struct_name: &str,
203        _columns: &[ResolvedColumn],
204        params: &[ResolvedParam],
205    ) -> Result<String, ScytheError> {
206        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
207        let sql = rewrite_params(&super::clean_sql_oneline_with_optional(
208            &analyzed.sql,
209            &analyzed.optional_params,
210            &analyzed.params,
211        ));
212        let mut out = String::new();
213
214        // Handle :batch separately
215        if matches!(analyzed.command, QueryCommand::Batch) {
216            let batch_fn_name = format!("{}Batch", func_name);
217            let _ = writeln!(
218                out,
219                "    public static function {}(\\PDO $pdo, array $items): void {{",
220                batch_fn_name
221            );
222            let _ = writeln!(out, "        $stmt = $pdo->prepare(\"{}\");", sql);
223            let _ = writeln!(out, "        $pdo->beginTransaction();");
224            let _ = writeln!(out, "        try {{");
225            let _ = writeln!(out, "            foreach ($items as $item) {{");
226            if params.is_empty() {
227                let _ = writeln!(out, "                $stmt->execute();");
228            } else {
229                let use_positional = sql.contains('?');
230                if use_positional {
231                    let _ = writeln!(out, "                $stmt->execute($item);");
232                } else {
233                    // Named params — build mapping from item array
234                    let bindings = params
235                        .iter()
236                        .enumerate()
237                        .map(|(i, _p)| format!("\"p{}\" => $item[{}]", i + 1, i))
238                        .collect::<Vec<_>>()
239                        .join(", ");
240                    let _ = writeln!(out, "                $stmt->execute([{}]);", bindings);
241                }
242            }
243            let _ = writeln!(out, "            }}");
244            let _ = writeln!(out, "            $pdo->commit();");
245            let _ = writeln!(out, "        }} catch (\\Throwable $e) {{");
246            let _ = writeln!(out, "            $pdo->rollBack();");
247            let _ = writeln!(out, "            throw $e;");
248            let _ = writeln!(out, "        }}");
249            let _ = write!(out, "    }}");
250            return Ok(out);
251        }
252
253        // Build PHP parameter list
254        let param_list = params
255            .iter()
256            .map(|p| format!("{} ${}", p.full_type, p.field_name))
257            .collect::<Vec<_>>()
258            .join(", ");
259        let sep = if param_list.is_empty() { "" } else { ", " };
260
261        // Return type depends on command
262        let return_type = match &analyzed.command {
263            QueryCommand::One => format!("?{}", struct_name),
264            QueryCommand::Many => "\\Generator".to_string(),
265            QueryCommand::Exec => "void".to_string(),
266            QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
267            QueryCommand::Batch => unreachable!(),
268        };
269
270        let _ = writeln!(
271            out,
272            "    public static function {}(\\PDO $pdo{}{}): {} {{",
273            func_name, sep, param_list, return_type
274        );
275
276        // Prepare statement
277        let _ = writeln!(out, "        $stmt = $pdo->prepare(\"{}\");", sql);
278
279        // Build execute params
280        // If the SQL contains `?` placeholders (MySQL/SQLite), use positional array.
281        // If it contains `:pN` placeholders (PostgreSQL), use named array.
282        if params.is_empty() {
283            let _ = writeln!(out, "        $stmt->execute();");
284        } else {
285            let use_positional = sql.contains('?');
286            let bindings = params
287                .iter()
288                .enumerate()
289                .map(|(i, p)| {
290                    let value = if p.neutral_type.starts_with("enum::") {
291                        format!("${}->value", p.field_name)
292                    } else {
293                        format!("${}", p.field_name)
294                    };
295                    if use_positional {
296                        value
297                    } else {
298                        format!("\"p{}\" => {}", i + 1, value)
299                    }
300                })
301                .collect::<Vec<_>>()
302                .join(", ");
303            let _ = writeln!(out, "        $stmt->execute([{}]);", bindings);
304        }
305
306        match &analyzed.command {
307            QueryCommand::One => {
308                let _ = writeln!(out, "        $row = $stmt->fetch(\\PDO::FETCH_ASSOC);");
309                let _ = writeln!(
310                    out,
311                    "        return $row ? {}::fromRow($row) : null;",
312                    struct_name
313                );
314            }
315            QueryCommand::Many => {
316                let _ = writeln!(
317                    out,
318                    "        while ($row = $stmt->fetch(\\PDO::FETCH_ASSOC)) {{"
319                );
320                let _ = writeln!(out, "            yield {}::fromRow($row);", struct_name);
321                let _ = writeln!(out, "        }}");
322            }
323            QueryCommand::Exec => {
324                // nothing else needed
325            }
326            QueryCommand::ExecResult | QueryCommand::ExecRows => {
327                let _ = writeln!(out, "        return $stmt->rowCount();");
328            }
329            QueryCommand::Batch => unreachable!(),
330        }
331
332        let _ = write!(out, "    }}");
333        Ok(out)
334    }
335
336    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
337        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
338        let mut out = String::new();
339        let _ = writeln!(out, "enum {}: string {{", type_name);
340        for value in &enum_info.values {
341            let variant = enum_variant_name(value, &self.manifest.naming);
342            let _ = writeln!(out, "    case {} = \"{}\";", variant, value);
343        }
344        let _ = write!(out, "}}");
345        Ok(out)
346    }
347
348    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
349        let name = to_pascal_case(&composite.sql_name);
350        let mut out = String::new();
351        let _ = writeln!(out, "readonly class {} {{", name);
352        let _ = writeln!(out, "    public function __construct(");
353        if composite.fields.is_empty() {
354            // empty constructor
355        } else {
356            for field in &composite.fields {
357                let _ = writeln!(out, "        public mixed ${},", field.name);
358            }
359        }
360        let _ = writeln!(out, "    ) {{}}");
361        let _ = write!(out, "}}");
362        Ok(out)
363    }
364}