1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-psycopg3.toml");
21
22pub struct PythonPsycopg3Backend {
23 manifest: BackendManifest,
24 row_type: PythonRowType,
25}
26
27impl PythonPsycopg3Backend {
28 pub fn new(engine: &str) -> Result<Self, ScytheError> {
29 match engine {
30 "postgresql" | "postgres" | "pg" => {}
31 _ => {
32 return Err(ScytheError::new(
33 ErrorCode::InternalError,
34 format!(
35 "python-psycopg3 only supports PostgreSQL, got engine '{}'",
36 engine
37 ),
38 ));
39 }
40 }
41 let manifest_path = Path::new("backends/python-psycopg3/manifest.toml");
42 let manifest = if manifest_path.exists() {
43 load_manifest(manifest_path)
44 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45 } else {
46 toml::from_str(DEFAULT_MANIFEST_TOML)
47 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48 };
49 Ok(Self {
50 manifest,
51 row_type: PythonRowType::default(),
52 })
53 }
54}
55
56fn rewrite_params_named(sql: &str, analyzed: &AnalyzedQuery) -> String {
58 let mut result = sql.to_string();
59 let mut params_sorted: Vec<_> = analyzed.params.iter().collect();
61 params_sorted.sort_by(|a, b| b.position.cmp(&a.position));
62 for param in params_sorted {
63 let placeholder = format!("${}", param.position);
64 let named = format!("%({})s", to_snake_case(¶m.name));
65 result = result.replace(&placeholder, &named);
66 }
67 result
68}
69
70impl CodegenBackend for PythonPsycopg3Backend {
71 fn name(&self) -> &str {
72 "python-psycopg3"
73 }
74
75 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
76 &self.manifest
77 }
78
79 fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
80 if let Some(rt) = options.get("row_type") {
81 self.row_type = PythonRowType::from_option(rt)?;
82 }
83 Ok(())
84 }
85
86 fn file_header(&self) -> String {
87 let import_line = self.row_type.import_line();
88 if self.row_type.is_stdlib_import() {
89 format!(
90 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
91 \n\
92 import datetime # noqa: F401\n\
93 import decimal # noqa: F401\n\
94 {import_line}\n\
95 from enum import Enum # noqa: F401\n\
96 \n\
97 from psycopg import AsyncConnection # noqa: F401\n\
98 \n",
99 )
100 } else {
101 let third_party = self
106 .row_type
107 .sorted_third_party_imports("from psycopg import AsyncConnection # noqa: F401");
108 format!(
109 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
110 \n\
111 import datetime # noqa: F401\n\
112 import decimal # noqa: F401\n\
113 from enum import Enum # noqa: F401\n\
114 \n\
115 {third_party}\n\
116 \n",
117 )
118 }
119 }
120
121 fn generate_row_struct(
122 &self,
123 query_name: &str,
124 columns: &[ResolvedColumn],
125 ) -> Result<String, ScytheError> {
126 let struct_name = row_struct_name(query_name, &self.manifest.naming);
127 let mut out = String::new();
128 let _ = write!(out, "{}", self.row_type.decorator());
129 let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
130 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
131 if columns.is_empty() {
132 let _ = writeln!(out, " pass");
133 } else {
134 let _ = writeln!(out);
135 for col in columns {
136 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
137 }
138 }
139 Ok(out)
140 }
141
142 fn generate_model_struct(
143 &self,
144 table_name: &str,
145 columns: &[ResolvedColumn],
146 ) -> Result<String, ScytheError> {
147 let singular = singularize(table_name);
148 let name = to_pascal_case(&singular);
149 self.generate_row_struct(&name, columns)
150 }
151
152 fn generate_query_fn(
153 &self,
154 analyzed: &AnalyzedQuery,
155 struct_name: &str,
156 columns: &[ResolvedColumn],
157 params: &[ResolvedParam],
158 ) -> Result<String, ScytheError> {
159 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
160 let mut out = String::new();
161
162 let param_list = params
164 .iter()
165 .map(|p| format!("{}: {}", p.field_name, p.full_type))
166 .collect::<Vec<_>>()
167 .join(", ");
168 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
169
170 let sql_clean = super::clean_sql_with_optional(
172 &analyzed.sql,
173 &analyzed.optional_params,
174 &analyzed.params,
175 );
176 let sql = rewrite_params_named(&sql_clean, analyzed);
177
178 match &analyzed.command {
179 QueryCommand::One => {
180 let _ = writeln!(
181 out,
182 "async def {}(conn: AsyncConnection{}{}) -> {} | None:",
183 func_name, kw_sep, param_list, struct_name
184 );
185 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
186 if params.is_empty() {
188 let _ = writeln!(out, " cur = await conn.execute(");
189 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
190 let _ = writeln!(out, " )");
191 } else {
192 let dict_entries: Vec<String> = params
193 .iter()
194 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
195 .collect();
196 let _ = writeln!(out, " cur = await conn.execute(");
197 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
198 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
199 let _ = writeln!(out, " )");
200 }
201 let _ = writeln!(out, " row = await cur.fetchone()");
202 let _ = writeln!(out, " if row is None:");
203 let _ = writeln!(out, " return None");
204 let field_assignments: Vec<String> = columns
206 .iter()
207 .enumerate()
208 .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
209 .collect();
210 let oneliner = format!(
211 " return {}({})",
212 struct_name,
213 field_assignments.join(", ")
214 );
215 if oneliner.len() <= 88 {
216 let _ = writeln!(out, "{}", oneliner);
217 } else {
218 let _ = writeln!(out, " return {}(", struct_name);
219 for fa in &field_assignments {
220 let _ = writeln!(out, " {},", fa);
221 }
222 let _ = writeln!(out, " )");
223 }
224 }
225 QueryCommand::Many => {
226 let _ = writeln!(
227 out,
228 "async def {}(conn: AsyncConnection{}{}) -> list[{}]:",
229 func_name, kw_sep, param_list, struct_name
230 );
231 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
232 if params.is_empty() {
233 let _ = writeln!(out, " cur = await conn.execute(");
234 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
235 let _ = writeln!(out, " )");
236 } else {
237 let dict_entries: Vec<String> = params
238 .iter()
239 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
240 .collect();
241 let _ = writeln!(out, " cur = await conn.execute(");
242 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
243 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
244 let _ = writeln!(out, " )");
245 }
246 let _ = writeln!(out, " rows = await cur.fetchall()");
247 let field_assignments: Vec<String> = columns
248 .iter()
249 .enumerate()
250 .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
251 .collect();
252 let oneliner = format!(
253 " return [{}({}) for r in rows]",
254 struct_name,
255 field_assignments.join(", ")
256 );
257 if oneliner.len() <= 88 {
258 let _ = writeln!(out, "{}", oneliner);
259 } else {
260 let _ = writeln!(out, " return [");
261 let _ = writeln!(out, " {}(", struct_name);
262 for fa in &field_assignments {
263 let _ = writeln!(out, " {},", fa);
264 }
265 let _ = writeln!(out, " )");
266 let _ = writeln!(out, " for r in rows");
267 let _ = writeln!(out, " ]");
268 }
269 }
270 QueryCommand::Batch => {
271 let batch_fn_name = format!("{}_batch", func_name);
272 let items_type = if params.len() > 1 {
274 let tuple_types: Vec<String> =
275 params.iter().map(|p| p.full_type.clone()).collect();
276 format!("list[tuple[{}]]", tuple_types.join(", "))
277 } else if params.len() == 1 {
278 format!("list[{}]", params[0].full_type)
279 } else {
280 "int".to_string()
281 };
282 let _ = writeln!(
283 out,
284 "async def {}(conn: AsyncConnection, *, items: {}) -> None:",
285 batch_fn_name, items_type
286 );
287 let _ = writeln!(
288 out,
289 " \"\"\"Execute {} query for each item in the batch.\"\"\"",
290 analyzed.name
291 );
292 if params.is_empty() {
293 let _ = writeln!(out, " for _ in range(items):");
294 let _ = writeln!(out, " await conn.execute(");
295 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
296 let _ = writeln!(out, " )");
297 } else {
298 let dict_entries: Vec<String> = params
300 .iter()
301 .enumerate()
302 .map(|(i, p)| {
303 if params.len() == 1 {
304 format!("\"{}\": item", p.field_name)
305 } else {
306 format!("\"{}\": item[{}]", p.field_name, i)
307 }
308 })
309 .collect();
310 let _ = writeln!(
311 out,
312 " params_list = [{{{dict}}} for item in items]",
313 dict = dict_entries.join(", ")
314 );
315 let _ = writeln!(out, " cur = conn.cursor()");
316 let _ = writeln!(out, " await cur.executemany(");
317 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
318 let _ = writeln!(out, " params_list,");
319 let _ = writeln!(out, " )");
320 }
321 }
322 QueryCommand::Exec => {
323 let _ = writeln!(
324 out,
325 "async def {}(conn: AsyncConnection{}{}) -> None:",
326 func_name, kw_sep, param_list
327 );
328 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
329 if params.is_empty() {
330 let _ = writeln!(out, " await conn.execute(");
331 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
332 let _ = writeln!(out, " )");
333 } else {
334 let dict_entries: Vec<String> = params
335 .iter()
336 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
337 .collect();
338 let _ = writeln!(out, " await conn.execute(");
339 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
340 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
341 let _ = writeln!(out, " )");
342 }
343 }
344 QueryCommand::ExecResult | QueryCommand::ExecRows => {
345 let _ = writeln!(
346 out,
347 "async def {}(conn: AsyncConnection{}{}) -> int:",
348 func_name, kw_sep, param_list
349 );
350 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
351 if params.is_empty() {
352 let _ = writeln!(out, " cur = await conn.execute(");
353 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
354 let _ = writeln!(out, " )");
355 } else {
356 let dict_entries: Vec<String> = params
357 .iter()
358 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
359 .collect();
360 let _ = writeln!(out, " cur = await conn.execute(");
361 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
362 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
363 let _ = writeln!(out, " )");
364 }
365 let _ = writeln!(out, " return cur.rowcount");
366 }
367 QueryCommand::Grouped => {
368 unreachable!("Grouped is rewritten to Many before codegen")
369 }
370 }
371
372 Ok(out)
373 }
374
375 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
376 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
377 let mut out = String::new();
378 let _ = writeln!(out, "class {}(str, Enum):", type_name);
379 let _ = writeln!(
380 out,
381 " \"\"\"Database enum type {}.\"\"\"",
382 enum_info.sql_name
383 );
384 if enum_info.values.is_empty() {
385 let _ = writeln!(out, " pass");
386 } else {
387 let _ = writeln!(out);
388 for value in &enum_info.values {
389 let variant = enum_variant_name(value, &self.manifest.naming);
390 let _ = writeln!(out, " {} = \"{}\"", variant, value);
391 }
392 }
393 Ok(out)
394 }
395
396 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
397 let name = to_pascal_case(&composite.sql_name);
398 let mut out = String::new();
399 let _ = write!(out, "{}", self.row_type.decorator());
400 let _ = writeln!(out, "{}", self.row_type.class_def(&name));
401 let _ = writeln!(
402 out,
403 " \"\"\"Composite type {}.\"\"\"",
404 composite.sql_name
405 );
406 if composite.fields.is_empty() {
407 let _ = writeln!(out, " pass");
408 } else {
409 let _ = writeln!(out);
410 for field in &composite.fields {
411 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
412 .map(|t| t.into_owned())
413 .map_err(|e| {
414 ScytheError::new(
415 ErrorCode::InternalError,
416 format!("composite field type error: {}", e),
417 )
418 })?;
419 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
420 }
421 }
422 Ok(out)
423 }
424}