1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-psycopg3.toml");
18
19pub struct PythonPsycopg3Backend {
20 manifest: BackendManifest,
21}
22
23impl PythonPsycopg3Backend {
24 pub fn new(engine: &str) -> Result<Self, ScytheError> {
25 match engine {
26 "postgresql" | "postgres" | "pg" => {}
27 _ => {
28 return Err(ScytheError::new(
29 ErrorCode::InternalError,
30 format!(
31 "python-psycopg3 only supports PostgreSQL, got engine '{}'",
32 engine
33 ),
34 ));
35 }
36 }
37 let manifest_path = Path::new("backends/python-psycopg3/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(DEFAULT_MANIFEST_TOML)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 Ok(Self { manifest })
46 }
47}
48
49fn rewrite_params_named(sql: &str, analyzed: &AnalyzedQuery) -> String {
51 let mut result = sql.to_string();
52 let mut params_sorted: Vec<_> = analyzed.params.iter().collect();
54 params_sorted.sort_by(|a, b| b.position.cmp(&a.position));
55 for param in params_sorted {
56 let placeholder = format!("${}", param.position);
57 let named = format!("%({})s", to_snake_case(¶m.name));
58 result = result.replace(&placeholder, &named);
59 }
60 result
61}
62
63impl CodegenBackend for PythonPsycopg3Backend {
64 fn name(&self) -> &str {
65 "python-psycopg3"
66 }
67
68 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
69 &self.manifest
70 }
71
72 fn file_header(&self) -> String {
73 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
74 \n\
75 import datetime # noqa: F401\n\
76 import decimal # noqa: F401\n\
77 from dataclasses import dataclass\n\
78 from enum import Enum # noqa: F401\n\
79 \n\
80 from psycopg import AsyncConnection # noqa: F401\n\
81 \n"
82 .to_string()
83 }
84
85 fn generate_row_struct(
86 &self,
87 query_name: &str,
88 columns: &[ResolvedColumn],
89 ) -> Result<String, ScytheError> {
90 let struct_name = row_struct_name(query_name, &self.manifest.naming);
91 let mut out = String::new();
92 let _ = writeln!(out, "@dataclass");
93 let _ = writeln!(out, "class {}:", struct_name);
94 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
95 if columns.is_empty() {
96 let _ = writeln!(out, " pass");
97 } else {
98 let _ = writeln!(out);
99 for col in columns {
100 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
101 }
102 }
103 Ok(out)
104 }
105
106 fn generate_model_struct(
107 &self,
108 table_name: &str,
109 columns: &[ResolvedColumn],
110 ) -> Result<String, ScytheError> {
111 let singular = singularize(table_name);
112 let name = to_pascal_case(&singular);
113 self.generate_row_struct(&name, columns)
114 }
115
116 fn generate_query_fn(
117 &self,
118 analyzed: &AnalyzedQuery,
119 struct_name: &str,
120 columns: &[ResolvedColumn],
121 params: &[ResolvedParam],
122 ) -> Result<String, ScytheError> {
123 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
124 let mut out = String::new();
125
126 let param_list = params
128 .iter()
129 .map(|p| format!("{}: {}", p.field_name, p.full_type))
130 .collect::<Vec<_>>()
131 .join(", ");
132 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
133
134 let sql_clean = super::clean_sql(&analyzed.sql);
136 let sql = rewrite_params_named(&sql_clean, analyzed);
137
138 match &analyzed.command {
139 QueryCommand::One => {
140 let _ = writeln!(
141 out,
142 "async def {}(conn: AsyncConnection{}{}) -> {} | None:",
143 func_name, kw_sep, param_list, struct_name
144 );
145 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
146 if params.is_empty() {
148 let _ = writeln!(out, " cur = await conn.execute(");
149 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
150 let _ = writeln!(out, " )");
151 } else {
152 let dict_entries: Vec<String> = params
153 .iter()
154 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
155 .collect();
156 let _ = writeln!(out, " cur = await conn.execute(");
157 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
158 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
159 let _ = writeln!(out, " )");
160 }
161 let _ = writeln!(out, " row = await cur.fetchone()");
162 let _ = writeln!(out, " if row is None:");
163 let _ = writeln!(out, " return None");
164 let field_assignments: Vec<String> = columns
166 .iter()
167 .enumerate()
168 .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
169 .collect();
170 let oneliner = format!(
171 " return {}({})",
172 struct_name,
173 field_assignments.join(", ")
174 );
175 if oneliner.len() <= 88 {
176 let _ = writeln!(out, "{}", oneliner);
177 } else {
178 let _ = writeln!(out, " return {}(", struct_name);
179 for fa in &field_assignments {
180 let _ = writeln!(out, " {},", fa);
181 }
182 let _ = writeln!(out, " )");
183 }
184 }
185 QueryCommand::Many | QueryCommand::Batch => {
186 let _ = writeln!(
187 out,
188 "async def {}(conn: AsyncConnection{}{}) -> list[{}]:",
189 func_name, kw_sep, param_list, struct_name
190 );
191 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
192 if params.is_empty() {
193 let _ = writeln!(out, " cur = await conn.execute(");
194 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
195 let _ = writeln!(out, " )");
196 } else {
197 let dict_entries: Vec<String> = params
198 .iter()
199 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
200 .collect();
201 let _ = writeln!(out, " cur = await conn.execute(");
202 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
203 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
204 let _ = writeln!(out, " )");
205 }
206 let _ = writeln!(out, " rows = await cur.fetchall()");
207 let field_assignments: Vec<String> = columns
208 .iter()
209 .enumerate()
210 .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
211 .collect();
212 let oneliner = format!(
213 " return [{}({}) for r in rows]",
214 struct_name,
215 field_assignments.join(", ")
216 );
217 if oneliner.len() <= 88 {
218 let _ = writeln!(out, "{}", oneliner);
219 } else {
220 let _ = writeln!(out, " return [");
221 let _ = writeln!(out, " {}(", struct_name);
222 for fa in &field_assignments {
223 let _ = writeln!(out, " {},", fa);
224 }
225 let _ = writeln!(out, " )");
226 let _ = writeln!(out, " for r in rows");
227 let _ = writeln!(out, " ]");
228 }
229 }
230 QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
231 let _ = writeln!(
232 out,
233 "async def {}(conn: AsyncConnection{}{}) -> None:",
234 func_name, kw_sep, param_list
235 );
236 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
237 if params.is_empty() {
238 let _ = writeln!(out, " await conn.execute(");
239 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
240 let _ = writeln!(out, " )");
241 } else {
242 let dict_entries: Vec<String> = params
243 .iter()
244 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
245 .collect();
246 let _ = writeln!(out, " await conn.execute(");
247 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
248 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
249 let _ = writeln!(out, " )");
250 }
251 }
252 }
253
254 Ok(out)
255 }
256
257 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
258 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
259 let mut out = String::new();
260 let _ = writeln!(out, "class {}(str, Enum):", type_name);
261 let _ = writeln!(
262 out,
263 " \"\"\"Database enum type {}.\"\"\"",
264 enum_info.sql_name
265 );
266 if enum_info.values.is_empty() {
267 let _ = writeln!(out, " pass");
268 } else {
269 let _ = writeln!(out);
270 for value in &enum_info.values {
271 let variant = enum_variant_name(value, &self.manifest.naming);
272 let _ = writeln!(out, " {} = \"{}\"", variant, value);
273 }
274 }
275 Ok(out)
276 }
277
278 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
279 let name = to_pascal_case(&composite.sql_name);
280 let mut out = String::new();
281 let _ = writeln!(out, "@dataclass");
282 let _ = writeln!(out, "class {}:", name);
283 let _ = writeln!(
284 out,
285 " \"\"\"Composite type {}.\"\"\"",
286 composite.sql_name
287 );
288 if composite.fields.is_empty() {
289 let _ = writeln!(out, " pass");
290 } else {
291 let _ = writeln!(out);
292 for field in &composite.fields {
293 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
294 .map(|t| t.into_owned())
295 .map_err(|e| {
296 ScytheError::new(
297 ErrorCode::InternalError,
298 format!("composite field type error: {}", e),
299 )
300 })?;
301 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
302 }
303 }
304 Ok(out)
305 }
306}