1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-psycopg3.toml");
18
19pub struct PythonPsycopg3Backend {
20 manifest: BackendManifest,
21}
22
23impl PythonPsycopg3Backend {
24 pub fn new() -> Result<Self, ScytheError> {
25 let manifest_path = Path::new("backends/python-psycopg3/manifest.toml");
26 let manifest = if manifest_path.exists() {
27 load_manifest(manifest_path)
28 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
29 } else {
30 toml::from_str(DEFAULT_MANIFEST_TOML)
31 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
32 };
33 Ok(Self { manifest })
34 }
35
36 pub fn manifest(&self) -> &BackendManifest {
37 &self.manifest
38 }
39}
40
41fn rewrite_params_named(sql: &str, analyzed: &AnalyzedQuery) -> String {
43 let mut result = sql.to_string();
44 let mut params_sorted: Vec<_> = analyzed.params.iter().collect();
46 params_sorted.sort_by(|a, b| b.position.cmp(&a.position));
47 for param in params_sorted {
48 let placeholder = format!("${}", param.position);
49 let named = format!("%({})s", to_snake_case(¶m.name));
50 result = result.replace(&placeholder, &named);
51 }
52 result
53}
54
55impl CodegenBackend for PythonPsycopg3Backend {
56 fn name(&self) -> &str {
57 "python-psycopg3"
58 }
59
60 fn file_header(&self) -> String {
61 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
62 \n\
63 import datetime # noqa: F401\n\
64 from dataclasses import dataclass\n\
65 from enum import Enum # noqa: F401\n\
66 \n\
67 from psycopg import AsyncConnection # noqa: F401\n\
68 \n"
69 .to_string()
70 }
71
72 fn generate_row_struct(
73 &self,
74 query_name: &str,
75 columns: &[ResolvedColumn],
76 ) -> Result<String, ScytheError> {
77 let struct_name = row_struct_name(query_name, &self.manifest.naming);
78 let mut out = String::new();
79 let _ = writeln!(out, "@dataclass");
80 let _ = writeln!(out, "class {}:", struct_name);
81 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
82 if columns.is_empty() {
83 let _ = writeln!(out, " pass");
84 } else {
85 let _ = writeln!(out);
86 for col in columns {
87 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
88 }
89 }
90 Ok(out)
91 }
92
93 fn generate_model_struct(
94 &self,
95 table_name: &str,
96 columns: &[ResolvedColumn],
97 ) -> Result<String, ScytheError> {
98 let singular = singularize(table_name);
99 let name = to_pascal_case(&singular);
100 self.generate_row_struct(&name, columns)
101 }
102
103 fn generate_query_fn(
104 &self,
105 analyzed: &AnalyzedQuery,
106 struct_name: &str,
107 columns: &[ResolvedColumn],
108 params: &[ResolvedParam],
109 ) -> Result<String, ScytheError> {
110 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
111 let mut out = String::new();
112
113 let param_list = params
115 .iter()
116 .map(|p| format!("{}: {}", p.field_name, p.full_type))
117 .collect::<Vec<_>>()
118 .join(", ");
119 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
120
121 let sql_clean = super::clean_sql(&analyzed.sql);
123 let sql = rewrite_params_named(&sql_clean, analyzed);
124
125 match &analyzed.command {
126 QueryCommand::One => {
127 let _ = writeln!(
128 out,
129 "async def {}(conn: AsyncConnection{}{}) -> {} | None:",
130 func_name, kw_sep, param_list, struct_name
131 );
132 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
133 if params.is_empty() {
135 let _ = writeln!(out, " cur = await conn.execute(");
136 let _ = writeln!(out, " \"{}\",", sql);
137 let _ = writeln!(out, " )");
138 } else {
139 let dict_entries: Vec<String> = params
140 .iter()
141 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
142 .collect();
143 let _ = writeln!(out, " cur = await conn.execute(");
144 let _ = writeln!(out, " \"{}\",", sql);
145 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
146 let _ = writeln!(out, " )");
147 }
148 let _ = writeln!(out, " row = await cur.fetchone()");
149 let _ = writeln!(out, " if row is None:");
150 let _ = writeln!(out, " return None");
151 let field_assignments: Vec<String> = columns
153 .iter()
154 .enumerate()
155 .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
156 .collect();
157 let oneliner = format!(
158 " return {}({})",
159 struct_name,
160 field_assignments.join(", ")
161 );
162 if oneliner.len() <= 88 {
163 let _ = writeln!(out, "{}", oneliner);
164 } else {
165 let _ = writeln!(out, " return {}(", struct_name);
166 for fa in &field_assignments {
167 let _ = writeln!(out, " {},", fa);
168 }
169 let _ = writeln!(out, " )");
170 }
171 }
172 QueryCommand::Many | QueryCommand::Batch => {
173 let _ = writeln!(
174 out,
175 "async def {}(conn: AsyncConnection{}{}) -> list[{}]:",
176 func_name, kw_sep, param_list, struct_name
177 );
178 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
179 if params.is_empty() {
180 let _ = writeln!(out, " cur = await conn.execute(");
181 let _ = writeln!(out, " \"{}\",", sql);
182 let _ = writeln!(out, " )");
183 } else {
184 let dict_entries: Vec<String> = params
185 .iter()
186 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
187 .collect();
188 let _ = writeln!(out, " cur = await conn.execute(");
189 let _ = writeln!(out, " \"{}\",", sql);
190 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
191 let _ = writeln!(out, " )");
192 }
193 let _ = writeln!(out, " rows = await cur.fetchall()");
194 let field_assignments: Vec<String> = columns
195 .iter()
196 .enumerate()
197 .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
198 .collect();
199 let oneliner = format!(
200 " return [{}({}) for r in rows]",
201 struct_name,
202 field_assignments.join(", ")
203 );
204 if oneliner.len() <= 88 {
205 let _ = writeln!(out, "{}", oneliner);
206 } else {
207 let _ = writeln!(out, " return [");
208 let _ = writeln!(out, " {}(", struct_name);
209 for fa in &field_assignments {
210 let _ = writeln!(out, " {},", fa);
211 }
212 let _ = writeln!(out, " )");
213 let _ = writeln!(out, " for r in rows");
214 let _ = writeln!(out, " ]");
215 }
216 }
217 QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
218 let _ = writeln!(
219 out,
220 "async def {}(conn: AsyncConnection{}{}) -> None:",
221 func_name, kw_sep, param_list
222 );
223 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
224 if params.is_empty() {
225 let _ = writeln!(out, " await conn.execute(");
226 let _ = writeln!(out, " \"{}\",", sql);
227 let _ = writeln!(out, " )");
228 } else {
229 let dict_entries: Vec<String> = params
230 .iter()
231 .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
232 .collect();
233 let _ = writeln!(out, " await conn.execute(");
234 let _ = writeln!(out, " \"{}\",", sql);
235 let _ = writeln!(out, " {{{}}},", dict_entries.join(", "));
236 let _ = writeln!(out, " )");
237 }
238 }
239 }
240
241 Ok(out)
242 }
243
244 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
245 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
246 let mut out = String::new();
247 let _ = writeln!(out, "class {}(str, Enum):", type_name);
248 let _ = writeln!(
249 out,
250 " \"\"\"Database enum type {}.\"\"\"",
251 enum_info.sql_name
252 );
253 if enum_info.values.is_empty() {
254 let _ = writeln!(out, " pass");
255 } else {
256 let _ = writeln!(out);
257 for value in &enum_info.values {
258 let variant = enum_variant_name(value, &self.manifest.naming);
259 let _ = writeln!(out, " {} = \"{}\"", variant, value);
260 }
261 }
262 Ok(out)
263 }
264
265 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
266 let name = to_pascal_case(&composite.sql_name);
267 let mut out = String::new();
268 let _ = writeln!(out, "@dataclass");
269 let _ = writeln!(out, "class {}:", name);
270 let _ = writeln!(
271 out,
272 " \"\"\"Composite type {}.\"\"\"",
273 composite.sql_name
274 );
275 if composite.fields.is_empty() {
276 let _ = writeln!(out, " pass");
277 } else {
278 let _ = writeln!(out);
279 for field in &composite.fields {
280 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
281 .map(|t| t.into_owned())
282 .map_err(|e| {
283 ScytheError::new(
284 ErrorCode::InternalError,
285 format!("composite field type error: {}", e),
286 )
287 })?;
288 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
289 }
290 }
291 Ok(out)
292 }
293}