Skip to main content

scythe_codegen/backends/
php_amphp.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/php-amphp.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/php-amphp.mysql.toml");
17
18pub struct PhpAmphpBackend {
19    manifest: BackendManifest,
20}
21
22impl PhpAmphpBackend {
23    pub fn new(engine: &str) -> Result<Self, ScytheError> {
24        let default_toml = match engine {
25            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
26            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
27            _ => {
28                return Err(ScytheError::new(
29                    ErrorCode::InternalError,
30                    format!("unsupported engine '{}' for php-amphp backend", engine),
31                ));
32            }
33        };
34        let manifest_path = Path::new("backends/php-amphp/manifest.toml");
35        let manifest = if manifest_path.exists() {
36            load_manifest(manifest_path)
37                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
38        } else {
39            toml::from_str(default_toml)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        };
42        Ok(Self { manifest })
43    }
44}
45
46/// Rewrite $1, $2, ... to positional ? placeholders.
47fn rewrite_params_positional(sql: &str) -> String {
48    let mut result = sql.to_string();
49    // Replace from highest number down to avoid $1 matching inside $10
50    for i in (1..=99).rev() {
51        let from = format!("${}", i);
52        result = result.replace(&from, "?");
53    }
54    result
55}
56
57/// Map a neutral type to a PHP cast expression.
58fn php_cast(neutral_type: &str) -> &'static str {
59    match neutral_type {
60        "int16" | "int32" | "int64" => "(int) ",
61        "float32" | "float64" => "(float) ",
62        "bool" => "(bool) ",
63        "string" | "json" | "inet" | "interval" | "uuid" | "decimal" | "bytes" => "(string) ",
64        _ => "",
65    }
66}
67
68impl CodegenBackend for PhpAmphpBackend {
69    fn name(&self) -> &str {
70        "php-amphp"
71    }
72
73    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
74        &self.manifest
75    }
76
77    fn supported_engines(&self) -> &[&str] {
78        &["postgresql", "mysql"]
79    }
80
81    fn file_header(&self) -> String {
82        "<?php\n\ndeclare(strict_types=1);\n\nnamespace App\\Generated;\n\n// Auto-generated by scythe. Do not edit.\n"
83            .to_string()
84    }
85
86    fn query_class_header(&self) -> String {
87        "final class Queries {".to_string()
88    }
89
90    fn file_footer(&self) -> String {
91        "}".to_string()
92    }
93
94    fn generate_row_struct(
95        &self,
96        query_name: &str,
97        columns: &[ResolvedColumn],
98    ) -> Result<String, ScytheError> {
99        let struct_name = row_struct_name(query_name, &self.manifest.naming);
100        let mut out = String::new();
101
102        // Readonly class with constructor
103        let _ = writeln!(out, "readonly class {} {{", struct_name);
104        let _ = writeln!(out, "    public function __construct(");
105        for c in columns.iter() {
106            let sep = ",";
107            let _ = writeln!(
108                out,
109                "        public {} ${}{}",
110                c.full_type, c.field_name, sep
111            );
112        }
113        let _ = writeln!(out, "    ) {{}}");
114        let _ = writeln!(out);
115
116        // fromRow factory method
117        let _ = writeln!(
118            out,
119            "    public static function fromRow(array $row): self {{"
120        );
121        let _ = writeln!(out, "        return new self(");
122        for c in columns.iter() {
123            let sep = ",";
124            let is_enum = c.neutral_type.starts_with("enum::");
125            let is_datetime = matches!(
126                c.neutral_type.as_str(),
127                "date" | "time" | "time_tz" | "datetime" | "datetime_tz"
128            );
129            if is_enum {
130                let enum_type = &c.lang_type;
131                if c.nullable {
132                    let _ = writeln!(
133                        out,
134                        "            {}: $row['{}'] !== null ? {}::from($row['{}']) : null{}",
135                        c.field_name, c.name, enum_type, c.name, sep
136                    );
137                } else {
138                    let _ = writeln!(
139                        out,
140                        "            {}: {}::from($row['{}']){}",
141                        c.field_name, enum_type, c.name, sep
142                    );
143                }
144            } else if is_datetime {
145                if c.nullable {
146                    let _ = writeln!(
147                        out,
148                        "            {}: $row['{}'] !== null ? new \\DateTimeImmutable($row['{}']) : null{}",
149                        c.field_name, c.name, c.name, sep
150                    );
151                } else {
152                    let _ = writeln!(
153                        out,
154                        "            {}: new \\DateTimeImmutable($row['{}']){}",
155                        c.field_name, c.name, sep
156                    );
157                }
158            } else {
159                let cast = php_cast(&c.neutral_type);
160                if c.nullable {
161                    let _ = writeln!(
162                        out,
163                        "            {}: $row['{}'] !== null ? {}{} : null{}",
164                        c.field_name,
165                        c.name,
166                        cast,
167                        format_args!("$row['{}']", c.name),
168                        sep
169                    );
170                } else {
171                    let _ = writeln!(
172                        out,
173                        "            {}: {}$row['{}']{}",
174                        c.field_name, cast, c.name, sep
175                    );
176                }
177            }
178        }
179        let _ = writeln!(out, "        );");
180        let _ = writeln!(out, "    }}");
181        let _ = write!(out, "}}");
182        Ok(out)
183    }
184
185    fn generate_model_struct(
186        &self,
187        table_name: &str,
188        columns: &[ResolvedColumn],
189    ) -> Result<String, ScytheError> {
190        let name = to_pascal_case(table_name);
191        self.generate_row_struct(&name, columns)
192    }
193
194    fn generate_query_fn(
195        &self,
196        analyzed: &AnalyzedQuery,
197        struct_name: &str,
198        _columns: &[ResolvedColumn],
199        params: &[ResolvedParam],
200    ) -> Result<String, ScytheError> {
201        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
202        let sql = rewrite_params_positional(&super::clean_sql_oneline_with_optional(
203            &analyzed.sql,
204            &analyzed.optional_params,
205            &analyzed.params,
206        ));
207        let mut out = String::new();
208
209        // Build PHP parameter list
210        let param_list = params
211            .iter()
212            .map(|p| format!("{} ${}", p.full_type, p.field_name))
213            .collect::<Vec<_>>()
214            .join(", ");
215        let sep = if param_list.is_empty() { "" } else { ", " };
216
217        // Handle :batch separately
218        if matches!(analyzed.command, QueryCommand::Batch) {
219            let batch_fn_name = format!("{}Batch", func_name);
220            // PHPDoc for batch function
221            let _ = writeln!(out, "    /**");
222            let _ = writeln!(out, "     * @param \\Amp\\Sql\\SqlConnectionPool $pool");
223            let _ = writeln!(out, "     * @param array<int, array<int, mixed>> $items");
224            let _ = writeln!(out, "     * @return void");
225            let _ = writeln!(out, "     */");
226            let _ = writeln!(
227                out,
228                "    public static function {}(\\Amp\\Sql\\SqlConnectionPool $pool, array $items): void {{",
229                batch_fn_name
230            );
231            let _ = writeln!(out, "        $transaction = $pool->beginTransaction();");
232            let _ = writeln!(out, "        try {{");
233            let _ = writeln!(
234                out,
235                "            $stmt = $transaction->prepare(\"{}\");",
236                sql
237            );
238            let _ = writeln!(out, "            foreach ($items as $item) {{");
239            if params.is_empty() {
240                let _ = writeln!(out, "                $stmt->execute([]);");
241            } else {
242                let _ = writeln!(out, "                $stmt->execute($item);");
243            }
244            let _ = writeln!(out, "            }}");
245            let _ = writeln!(out, "            $transaction->commit();");
246            let _ = writeln!(out, "        }} catch (\\Throwable $e) {{");
247            let _ = writeln!(out, "            $transaction->rollback();");
248            let _ = writeln!(out, "            throw $e;");
249            let _ = writeln!(out, "        }}");
250            let _ = write!(out, "    }}");
251            return Ok(out);
252        }
253
254        // Return type depends on command
255        let return_type = match &analyzed.command {
256            QueryCommand::One => format!("?{}", struct_name),
257            QueryCommand::Many => "\\Generator".to_string(),
258            QueryCommand::Exec => "void".to_string(),
259            QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
260            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
261        };
262
263        // PHPDoc block
264        let _ = writeln!(out, "    /**");
265        let _ = writeln!(out, "     * @param \\Amp\\Sql\\SqlConnectionPool $pool");
266        for p in params {
267            let _ = writeln!(out, "     * @param {} ${}", p.full_type, p.field_name);
268        }
269        match &analyzed.command {
270            QueryCommand::One => {
271                let _ = writeln!(out, "     * @return {}|null", struct_name);
272            }
273            QueryCommand::Many => {
274                let _ = writeln!(
275                    out,
276                    "     * @return \\Generator<int, {}, mixed, void>",
277                    struct_name
278                );
279            }
280            QueryCommand::Exec => {
281                let _ = writeln!(out, "     * @return void");
282            }
283            QueryCommand::ExecResult | QueryCommand::ExecRows => {
284                let _ = writeln!(out, "     * @return int");
285            }
286            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
287        }
288        let _ = writeln!(out, "     */");
289
290        let _ = writeln!(
291            out,
292            "    public static function {}(\\Amp\\Sql\\SqlConnectionPool $pool{}{}): {} {{",
293            func_name, sep, param_list, return_type
294        );
295
296        // Build execute params
297        if params.is_empty() {
298            let _ = writeln!(
299                out,
300                "        $result = $pool->prepare(\"{}\")->execute([]);",
301                sql
302            );
303        } else {
304            let bindings = params
305                .iter()
306                .map(|p| {
307                    if p.neutral_type.starts_with("enum::") {
308                        format!("${}->value", p.field_name)
309                    } else {
310                        format!("${}", p.field_name)
311                    }
312                })
313                .collect::<Vec<_>>()
314                .join(", ");
315            let _ = writeln!(
316                out,
317                "        $result = $pool->prepare(\"{}\")->execute([{}]);",
318                sql, bindings
319            );
320        }
321
322        match &analyzed.command {
323            QueryCommand::One => {
324                let _ = writeln!(out, "        foreach ($result as $row) {{");
325                let _ = writeln!(out, "            return {}::fromRow($row);", struct_name);
326                let _ = writeln!(out, "        }}");
327                let _ = writeln!(out, "        return null;");
328            }
329            QueryCommand::Many => {
330                let _ = writeln!(out, "        foreach ($result as $row) {{");
331                let _ = writeln!(out, "            yield {}::fromRow($row);", struct_name);
332                let _ = writeln!(out, "        }}");
333            }
334            QueryCommand::Exec => {
335                // nothing else needed
336            }
337            QueryCommand::ExecResult | QueryCommand::ExecRows => {
338                let _ = writeln!(out, "        return $result->getRowCount();");
339            }
340            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
341        }
342
343        let _ = write!(out, "    }}");
344        Ok(out)
345    }
346
347    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
348        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
349        let mut out = String::new();
350        let _ = writeln!(out, "enum {}: string {{", type_name);
351        for value in &enum_info.values {
352            let variant = enum_variant_name(value, &self.manifest.naming);
353            let _ = writeln!(out, "    case {} = \"{}\";", variant, value);
354        }
355        let _ = write!(out, "}}");
356        Ok(out)
357    }
358
359    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
360        let name = to_pascal_case(&composite.sql_name);
361        let mut out = String::new();
362        let _ = writeln!(out, "readonly class {} {{", name);
363        let _ = writeln!(out, "    public function __construct(");
364        if composite.fields.is_empty() {
365            // empty constructor
366        } else {
367            for field in &composite.fields {
368                let _ = writeln!(out, "        public mixed ${},", field.name);
369            }
370        }
371        let _ = writeln!(out, "    ) {{}}");
372        let _ = write!(out, "}}");
373        Ok(out)
374    }
375}