1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str =
18 include_str!("../../manifests/python-psycopg3.toml");
19
20pub struct PythonPsycopg3Backend {
21 manifest: BackendManifest,
22}
23
24impl PythonPsycopg3Backend {
25 pub fn new() -> Result<Self, ScytheError> {
26 let manifest_path = Path::new("backends/python-psycopg3/manifest.toml");
27 let manifest = if manifest_path.exists() {
28 load_manifest(manifest_path)
29 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30 } else {
31 toml::from_str(DEFAULT_MANIFEST_TOML)
32 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
33 };
34 Ok(Self { manifest })
35 }
36
37 pub fn manifest(&self) -> &BackendManifest {
38 &self.manifest
39 }
40}
41
42fn rewrite_params_named(sql: &str, analyzed: &AnalyzedQuery) -> String {
44 let mut result = sql.to_string();
45 let mut params_sorted: Vec<_> = analyzed.params.iter().collect();
47 params_sorted.sort_by(|a, b| b.position.cmp(&a.position));
48 for param in params_sorted {
49 let placeholder = format!("${}", param.position);
50 let named = format!("%({})s", to_snake_case(¶m.name));
51 result = result.replace(&placeholder, &named);
52 }
53 result
54}
55
56impl CodegenBackend for PythonPsycopg3Backend {
57 fn name(&self) -> &str {
58 "python-psycopg3"
59 }
60
61 fn file_header(&self) -> String {
62 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
63 \n\
64 import datetime # noqa: F401\n\
65 from dataclasses import dataclass\n\
66 from enum import Enum # noqa: F401\n\
67 \n\
68 from psycopg import AsyncConnection # noqa: F401\n\
69 \n"
70 .to_string()
71 }
72
73 fn generate_row_struct(
74 &self,
75 query_name: &str,
76 columns: &[ResolvedColumn],
77 ) -> Result<String, ScytheError> {
78 let struct_name = row_struct_name(query_name, &self.manifest.naming);
79 let mut out = String::new();
80 let _ = writeln!(out, "@dataclass");
81 let _ = writeln!(out, "class {}:", struct_name);
82 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
83 if columns.is_empty() {
84 let _ = writeln!(out, " pass");
85 } else {
86 let _ = writeln!(out);
87 for col in columns {
88 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
89 }
90 }
91 Ok(out)
92 }
93
94 fn generate_model_struct(
95 &self,
96 table_name: &str,
97 columns: &[ResolvedColumn],
98 ) -> Result<String, ScytheError> {
99 let singular = singularize(table_name);
100 let name = to_pascal_case(&singular);
101 self.generate_row_struct(&name, columns)
102 }
103
104 fn generate_query_fn(
105 &self,
106 analyzed: &AnalyzedQuery,
107 struct_name: &str,
108 columns: &[ResolvedColumn],
109 params: &[ResolvedParam],
110 ) -> Result<String, ScytheError> {
111 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
112 let mut out = String::new();
113
114 let param_list = params
116 .iter()
117 .map(|p| format!("{}: {}", p.field_name, p.full_type))
118 .collect::<Vec<_>>()
119 .join(", ");
120 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
121
122 let sql_clean = super::clean_sql(&analyzed.sql);
124 let sql = rewrite_params_named(&sql_clean, analyzed);
125
126 match &analyzed.command {
127 QueryCommand::One => {
128 let _ = writeln!(
129 out,
130 "async def {}(conn: AsyncConnection{}{}) -> {} | None:",
131 func_name, kw_sep, param_list, struct_name
132 );
133 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
134 if params.is_empty() {
136 let _ = writeln!(out, " cur = await conn.execute(");
137 let _ = writeln!(out, " \"{}\",", sql);
138 let _ = writeln!(out, " )");
139 } else {
140 let dict_entries: Vec<String> = params
141 .iter()
142 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
143 .collect();
144 let _ = writeln!(out, " cur = await conn.execute(");
145 let _ = writeln!(out, " \"{}\",", sql);
146 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
147 let _ = writeln!(out, " )");
148 }
149 let _ = writeln!(out, " row = await cur.fetchone()");
150 let _ = writeln!(out, " if row is None:");
151 let _ = writeln!(out, " return None");
152 let field_assignments: Vec<String> = columns
154 .iter()
155 .enumerate()
156 .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
157 .collect();
158 let oneliner = format!(
159 " return {}({})",
160 struct_name,
161 field_assignments.join(", ")
162 );
163 if oneliner.len() <= 88 {
164 let _ = writeln!(out, "{}", oneliner);
165 } else {
166 let _ = writeln!(out, " return {}(", struct_name);
167 for fa in &field_assignments {
168 let _ = writeln!(out, " {},", fa);
169 }
170 let _ = writeln!(out, " )");
171 }
172 }
173 QueryCommand::Many | QueryCommand::Batch => {
174 let _ = writeln!(
175 out,
176 "async def {}(conn: AsyncConnection{}{}) -> list[{}]:",
177 func_name, kw_sep, param_list, struct_name
178 );
179 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
180 if params.is_empty() {
181 let _ = writeln!(out, " cur = await conn.execute(");
182 let _ = writeln!(out, " \"{}\",", sql);
183 let _ = writeln!(out, " )");
184 } else {
185 let dict_entries: Vec<String> = params
186 .iter()
187 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
188 .collect();
189 let _ = writeln!(out, " cur = await conn.execute(");
190 let _ = writeln!(out, " \"{}\",", sql);
191 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
192 let _ = writeln!(out, " )");
193 }
194 let _ = writeln!(out, " rows = await cur.fetchall()");
195 let field_assignments: Vec<String> = columns
196 .iter()
197 .enumerate()
198 .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
199 .collect();
200 let oneliner = format!(
201 " return [{}({}) for r in rows]",
202 struct_name,
203 field_assignments.join(", ")
204 );
205 if oneliner.len() <= 88 {
206 let _ = writeln!(out, "{}", oneliner);
207 } else {
208 let _ = writeln!(out, " return [");
209 let _ = writeln!(out, " {}(", struct_name);
210 for fa in &field_assignments {
211 let _ = writeln!(out, " {},", fa);
212 }
213 let _ = writeln!(out, " )");
214 let _ = writeln!(out, " for r in rows");
215 let _ = writeln!(out, " ]");
216 }
217 }
218 QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
219 let _ = writeln!(
220 out,
221 "async def {}(conn: AsyncConnection{}{}) -> None:",
222 func_name, kw_sep, param_list
223 );
224 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
225 if params.is_empty() {
226 let _ = writeln!(out, " await conn.execute(");
227 let _ = writeln!(out, " \"{}\",", sql);
228 let _ = writeln!(out, " )");
229 } else {
230 let dict_entries: Vec<String> = params
231 .iter()
232 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
233 .collect();
234 let _ = writeln!(out, " await conn.execute(");
235 let _ = writeln!(out, " \"{}\",", sql);
236 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
237 let _ = writeln!(out, " )");
238 }
239 }
240 }
241
242 Ok(out)
243 }
244
245 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
246 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
247 let mut out = String::new();
248 let _ = writeln!(out, "class {}(str, Enum):", type_name);
249 let _ = writeln!(
250 out,
251 " \"\"\"Database enum type {}.\"\"\"",
252 enum_info.sql_name
253 );
254 if enum_info.values.is_empty() {
255 let _ = writeln!(out, " pass");
256 } else {
257 let _ = writeln!(out);
258 for value in &enum_info.values {
259 let variant = enum_variant_name(value, &self.manifest.naming);
260 let _ = writeln!(out, " {} = \"{}\"", variant, value);
261 }
262 }
263 Ok(out)
264 }
265
266 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
267 let name = to_pascal_case(&composite.sql_name);
268 let mut out = String::new();
269 let _ = writeln!(out, "@dataclass");
270 let _ = writeln!(out, "class {}:", name);
271 let _ = writeln!(
272 out,
273 " \"\"\"Composite type {}.\"\"\"",
274 composite.sql_name
275 );
276 if composite.fields.is_empty() {
277 let _ = writeln!(out, " pass");
278 } else {
279 let _ = writeln!(out);
280 for field in &composite.fields {
281 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
282 .map(|t| t.into_owned())
283 .map_err(|e| {
284 ScytheError::new(
285 ErrorCode::InternalError,
286 format!("composite field type error: {}", e),
287 )
288 })?;
289 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
290 }
291 }
292 Ok(out)
293 }
294}