Skip to main content

scythe_codegen/backends/
php_amphp.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8
9use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14
15const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/php-amphp.toml");
16const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/php-amphp.mysql.toml");
17
18pub struct PhpAmphpBackend {
19    manifest: BackendManifest,
20}
21
22impl PhpAmphpBackend {
23    pub fn new(engine: &str) -> Result<Self, ScytheError> {
24        let default_toml = match engine {
25            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
26            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
27            _ => {
28                return Err(ScytheError::new(
29                    ErrorCode::InternalError,
30                    format!("unsupported engine '{}' for php-amphp backend", engine),
31                ));
32            }
33        };
34        let manifest_path = Path::new("backends/php-amphp/manifest.toml");
35        let manifest = if manifest_path.exists() {
36            load_manifest(manifest_path)
37                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
38        } else {
39            toml::from_str(default_toml)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        };
42        Ok(Self { manifest })
43    }
44}
45
46/// Rewrite $1, $2, ... to positional ? placeholders.
47fn rewrite_params_positional(sql: &str) -> String {
48    let mut result = sql.to_string();
49    // Replace from highest number down to avoid $1 matching inside $10
50    for i in (1..=99).rev() {
51        let from = format!("${}", i);
52        result = result.replace(&from, "?");
53    }
54    result
55}
56
57/// Map a neutral type to a PHP cast expression.
58fn php_cast(neutral_type: &str) -> &'static str {
59    match neutral_type {
60        "int16" | "int32" | "int64" => "(int) ",
61        "float32" | "float64" => "(float) ",
62        "bool" => "(bool) ",
63        "string" | "json" | "inet" | "interval" | "uuid" | "decimal" | "bytes" => "(string) ",
64        _ => "",
65    }
66}
67
68impl CodegenBackend for PhpAmphpBackend {
69    fn name(&self) -> &str {
70        "php-amphp"
71    }
72
73    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
74        &self.manifest
75    }
76
77    fn supported_engines(&self) -> &[&str] {
78        &["postgresql", "mysql"]
79    }
80
81    fn file_header(&self) -> String {
82        "<?php\n\ndeclare(strict_types=1);\n\nnamespace App\\Generated;\n\n// Auto-generated by scythe. Do not edit.\n"
83            .to_string()
84    }
85
86    fn query_class_header(&self) -> String {
87        "final class Queries {".to_string()
88    }
89
90    fn file_footer(&self) -> String {
91        "}".to_string()
92    }
93
94    fn generate_row_struct(
95        &self,
96        query_name: &str,
97        columns: &[ResolvedColumn],
98    ) -> Result<String, ScytheError> {
99        let struct_name = row_struct_name(query_name, &self.manifest.naming);
100        let mut out = String::new();
101
102        // Readonly class with constructor
103        let _ = writeln!(out, "readonly class {} {{", struct_name);
104        let _ = writeln!(out, "    public function __construct(");
105        for c in columns.iter() {
106            let sep = ",";
107            let _ = writeln!(
108                out,
109                "        public {} ${}{}",
110                c.full_type, c.field_name, sep
111            );
112        }
113        let _ = writeln!(out, "    ) {{}}");
114        let _ = writeln!(out);
115
116        // fromRow factory method
117        let _ = writeln!(
118            out,
119            "    public static function fromRow(array $row): self {{"
120        );
121        let _ = writeln!(out, "        return new self(");
122        for c in columns.iter() {
123            let sep = ",";
124            let is_enum = c.neutral_type.starts_with("enum::");
125            let is_datetime = matches!(
126                c.neutral_type.as_str(),
127                "date" | "time" | "time_tz" | "datetime" | "datetime_tz"
128            );
129            if is_enum {
130                let enum_type = &c.lang_type;
131                if c.nullable {
132                    let _ = writeln!(
133                        out,
134                        "            {}: $row['{}'] !== null ? {}::from($row['{}']) : null{}",
135                        c.field_name, c.name, enum_type, c.name, sep
136                    );
137                } else {
138                    let _ = writeln!(
139                        out,
140                        "            {}: {}::from($row['{}']){}",
141                        c.field_name, enum_type, c.name, sep
142                    );
143                }
144            } else if is_datetime {
145                if c.nullable {
146                    let _ = writeln!(
147                        out,
148                        "            {}: $row['{}'] !== null ? new \\DateTimeImmutable($row['{}']) : null{}",
149                        c.field_name, c.name, c.name, sep
150                    );
151                } else {
152                    let _ = writeln!(
153                        out,
154                        "            {}: new \\DateTimeImmutable($row['{}']){}",
155                        c.field_name, c.name, sep
156                    );
157                }
158            } else {
159                let cast = php_cast(&c.neutral_type);
160                if c.nullable {
161                    let _ = writeln!(
162                        out,
163                        "            {}: $row['{}'] !== null ? {}{} : null{}",
164                        c.field_name,
165                        c.name,
166                        cast,
167                        format_args!("$row['{}']", c.name),
168                        sep
169                    );
170                } else {
171                    let _ = writeln!(
172                        out,
173                        "            {}: {}$row['{}']{}",
174                        c.field_name, cast, c.name, sep
175                    );
176                }
177            }
178        }
179        let _ = writeln!(out, "        );");
180        let _ = writeln!(out, "    }}");
181        let _ = write!(out, "}}");
182        Ok(out)
183    }
184
185    fn generate_model_struct(
186        &self,
187        table_name: &str,
188        columns: &[ResolvedColumn],
189    ) -> Result<String, ScytheError> {
190        let name = to_pascal_case(table_name);
191        self.generate_row_struct(&name, columns)
192    }
193
194    fn generate_query_fn(
195        &self,
196        analyzed: &AnalyzedQuery,
197        struct_name: &str,
198        _columns: &[ResolvedColumn],
199        params: &[ResolvedParam],
200    ) -> Result<String, ScytheError> {
201        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
202        let sql = rewrite_params_positional(&super::clean_sql_oneline_with_optional(
203            &analyzed.sql,
204            &analyzed.optional_params,
205            &analyzed.params,
206        ));
207        let mut out = String::new();
208
209        // Build PHP parameter list
210        let param_list = params
211            .iter()
212            .map(|p| format!("{} ${}", p.full_type, p.field_name))
213            .collect::<Vec<_>>()
214            .join(", ");
215        let sep = if param_list.is_empty() { "" } else { ", " };
216
217        // Handle :batch separately
218        if matches!(analyzed.command, QueryCommand::Batch) {
219            let batch_fn_name = format!("{}Batch", func_name);
220            let _ = writeln!(
221                out,
222                "    public static function {}(\\Amp\\Sql\\SqlConnectionPool $pool, array $items): void {{",
223                batch_fn_name
224            );
225            let _ = writeln!(out, "        $transaction = $pool->beginTransaction();");
226            let _ = writeln!(out, "        try {{");
227            let _ = writeln!(
228                out,
229                "            $stmt = $transaction->prepare(\"{}\");",
230                sql
231            );
232            let _ = writeln!(out, "            foreach ($items as $item) {{");
233            if params.is_empty() {
234                let _ = writeln!(out, "                $stmt->execute([]);");
235            } else {
236                let _ = writeln!(out, "                $stmt->execute($item);");
237            }
238            let _ = writeln!(out, "            }}");
239            let _ = writeln!(out, "            $transaction->commit();");
240            let _ = writeln!(out, "        }} catch (\\Throwable $e) {{");
241            let _ = writeln!(out, "            $transaction->rollback();");
242            let _ = writeln!(out, "            throw $e;");
243            let _ = writeln!(out, "        }}");
244            let _ = write!(out, "    }}");
245            return Ok(out);
246        }
247
248        // Return type depends on command
249        let return_type = match &analyzed.command {
250            QueryCommand::One => format!("?{}", struct_name),
251            QueryCommand::Many => "\\Generator".to_string(),
252            QueryCommand::Exec => "void".to_string(),
253            QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
254            QueryCommand::Batch => unreachable!(),
255        };
256
257        let _ = writeln!(
258            out,
259            "    public static function {}(\\Amp\\Sql\\SqlConnectionPool $pool{}{}): {} {{",
260            func_name, sep, param_list, return_type
261        );
262
263        // Build execute params
264        if params.is_empty() {
265            let _ = writeln!(
266                out,
267                "        $result = $pool->prepare(\"{}\")->execute([]);",
268                sql
269            );
270        } else {
271            let bindings = params
272                .iter()
273                .map(|p| {
274                    if p.neutral_type.starts_with("enum::") {
275                        format!("${}->value", p.field_name)
276                    } else {
277                        format!("${}", p.field_name)
278                    }
279                })
280                .collect::<Vec<_>>()
281                .join(", ");
282            let _ = writeln!(
283                out,
284                "        $result = $pool->prepare(\"{}\")->execute([{}]);",
285                sql, bindings
286            );
287        }
288
289        match &analyzed.command {
290            QueryCommand::One => {
291                let _ = writeln!(out, "        foreach ($result as $row) {{");
292                let _ = writeln!(out, "            return {}::fromRow($row);", struct_name);
293                let _ = writeln!(out, "        }}");
294                let _ = writeln!(out, "        return null;");
295            }
296            QueryCommand::Many => {
297                let _ = writeln!(out, "        foreach ($result as $row) {{");
298                let _ = writeln!(out, "            yield {}::fromRow($row);", struct_name);
299                let _ = writeln!(out, "        }}");
300            }
301            QueryCommand::Exec => {
302                // nothing else needed
303            }
304            QueryCommand::ExecResult | QueryCommand::ExecRows => {
305                let _ = writeln!(out, "        return $result->getRowCount();");
306            }
307            QueryCommand::Batch => unreachable!(),
308        }
309
310        let _ = write!(out, "    }}");
311        Ok(out)
312    }
313
314    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
315        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
316        let mut out = String::new();
317        let _ = writeln!(out, "enum {}: string {{", type_name);
318        for value in &enum_info.values {
319            let variant = enum_variant_name(value, &self.manifest.naming);
320            let _ = writeln!(out, "    case {} = \"{}\";", variant, value);
321        }
322        let _ = write!(out, "}}");
323        Ok(out)
324    }
325
326    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
327        let name = to_pascal_case(&composite.sql_name);
328        let mut out = String::new();
329        let _ = writeln!(out, "readonly class {} {{", name);
330        let _ = writeln!(out, "    public function __construct(");
331        if composite.fields.is_empty() {
332            // empty constructor
333        } else {
334            for field in &composite.fields {
335                let _ = writeln!(out, "        public mixed ${},", field.name);
336            }
337        }
338        let _ = writeln!(out, "    ) {{}}");
339        let _ = write!(out, "}}");
340        Ok(out)
341    }
342}