1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-asyncpg.toml");
21
22pub struct PythonAsyncpgBackend {
23 manifest: BackendManifest,
24 row_type: PythonRowType,
25}
26
27impl PythonAsyncpgBackend {
28 pub fn new(engine: &str) -> Result<Self, ScytheError> {
29 match engine {
30 "postgresql" | "postgres" | "pg" => {}
31 _ => {
32 return Err(ScytheError::new(
33 ErrorCode::InternalError,
34 format!(
35 "python-asyncpg only supports PostgreSQL, got engine '{}'",
36 engine
37 ),
38 ));
39 }
40 }
41 let manifest_path = Path::new("backends/python-asyncpg/manifest.toml");
42 let manifest = if manifest_path.exists() {
43 load_manifest(manifest_path)
44 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45 } else {
46 toml::from_str(DEFAULT_MANIFEST_TOML)
47 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48 };
49 Ok(Self {
50 manifest,
51 row_type: PythonRowType::default(),
52 })
53 }
54}
55
56impl CodegenBackend for PythonAsyncpgBackend {
57 fn name(&self) -> &str {
58 "python-asyncpg"
59 }
60
61 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
62 &self.manifest
63 }
64
65 fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
66 if let Some(rt) = options.get("row_type") {
67 self.row_type = PythonRowType::from_option(rt)?;
68 }
69 Ok(())
70 }
71
72 fn file_header(&self) -> String {
73 let import_line = self.row_type.import_line();
74 if self.row_type.is_stdlib_import() {
75 format!(
76 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
77 \n\
78 import datetime # noqa: F401\n\
79 import decimal # noqa: F401\n\
80 {import_line}\n\
81 from enum import Enum # noqa: F401\n\
82 \n\
83 from asyncpg import Connection # noqa: F401\n\
84 \n",
85 )
86 } else {
87 let third_party = self
88 .row_type
89 .sorted_third_party_imports("from asyncpg import Connection # noqa: F401");
90 format!(
91 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
92 \n\
93 import datetime # noqa: F401\n\
94 import decimal # noqa: F401\n\
95 from enum import Enum # noqa: F401\n\
96 \n\
97 {third_party}\n\
98 \n",
99 )
100 }
101 }
102
103 fn generate_row_struct(
104 &self,
105 query_name: &str,
106 columns: &[ResolvedColumn],
107 ) -> Result<String, ScytheError> {
108 let struct_name = row_struct_name(query_name, &self.manifest.naming);
109 let mut out = String::new();
110 let _ = write!(out, "{}", self.row_type.decorator());
111 let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
112 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
113 if columns.is_empty() {
114 let _ = writeln!(out, " pass");
115 } else {
116 let _ = writeln!(out);
117 for col in columns {
118 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
119 }
120 }
121 Ok(out)
122 }
123
124 fn generate_model_struct(
125 &self,
126 table_name: &str,
127 columns: &[ResolvedColumn],
128 ) -> Result<String, ScytheError> {
129 let singular = singularize(table_name);
130 let name = to_pascal_case(&singular);
131 self.generate_row_struct(&name, columns)
132 }
133
134 fn generate_query_fn(
135 &self,
136 analyzed: &AnalyzedQuery,
137 struct_name: &str,
138 columns: &[ResolvedColumn],
139 params: &[ResolvedParam],
140 ) -> Result<String, ScytheError> {
141 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
142 let mut out = String::new();
143
144 let param_list = params
146 .iter()
147 .map(|p| format!("{}: {}", p.field_name, p.full_type))
148 .collect::<Vec<_>>()
149 .join(", ");
150 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
151
152 let sql = super::clean_sql_with_optional(
154 &analyzed.sql,
155 &analyzed.optional_params,
156 &analyzed.params,
157 );
158
159 match &analyzed.command {
160 QueryCommand::One => {
161 let _ = writeln!(
162 out,
163 "async def {}(conn: Connection{}{}) -> {} | None:",
164 func_name, kw_sep, param_list, struct_name
165 );
166 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
167 let _ = writeln!(out, " row = await conn.fetchrow(");
168 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
169 if !params.is_empty() {
170 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
171 let _ = writeln!(out, " {},", args.join(", "));
172 }
173 let _ = writeln!(out, " )");
174 let _ = writeln!(out, " if row is None:");
175 let _ = writeln!(out, " return None");
176 let field_assignments: Vec<String> = columns
177 .iter()
178 .map(|col| format!("{}=row[\"{}\"]", col.field_name, col.name))
179 .collect();
180 let oneliner = format!(
181 " return {}({})",
182 struct_name,
183 field_assignments.join(", ")
184 );
185 if oneliner.len() <= 88 {
186 let _ = writeln!(out, "{}", oneliner);
187 } else {
188 let _ = writeln!(out, " return {}(", struct_name);
189 for fa in &field_assignments {
190 let _ = writeln!(out, " {},", fa);
191 }
192 let _ = writeln!(out, " )");
193 }
194 }
195 QueryCommand::Many => {
196 let _ = writeln!(
197 out,
198 "async def {}(conn: Connection{}{}) -> list[{}]:",
199 func_name, kw_sep, param_list, struct_name
200 );
201 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
202 let _ = writeln!(out, " rows = await conn.fetch(");
203 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
204 if !params.is_empty() {
205 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
206 let _ = writeln!(out, " {},", args.join(", "));
207 }
208 let _ = writeln!(out, " )");
209 let field_assignments: Vec<String> = columns
210 .iter()
211 .map(|col| format!("{}=r[\"{}\"]", col.field_name, col.name))
212 .collect();
213 let oneliner = format!(
214 " return [{}({}) for r in rows]",
215 struct_name,
216 field_assignments.join(", ")
217 );
218 if oneliner.len() <= 88 {
219 let _ = writeln!(out, "{}", oneliner);
220 } else {
221 let _ = writeln!(out, " return [");
222 let _ = writeln!(out, " {}(", struct_name);
223 for fa in &field_assignments {
224 let _ = writeln!(out, " {},", fa);
225 }
226 let _ = writeln!(out, " )");
227 let _ = writeln!(out, " for r in rows");
228 let _ = writeln!(out, " ]");
229 }
230 }
231 QueryCommand::Batch => {
232 let batch_fn_name = format!("{}_batch", func_name);
233 let items_type = if params.len() > 1 {
234 let tuple_types: Vec<String> =
235 params.iter().map(|p| p.full_type.clone()).collect();
236 format!("list[tuple[{}]]", tuple_types.join(", "))
237 } else if params.len() == 1 {
238 format!("list[{}]", params[0].full_type)
239 } else {
240 "int".to_string()
241 };
242 let _ = writeln!(
243 out,
244 "async def {}(conn: Connection, *, items: {}) -> None:",
245 batch_fn_name, items_type
246 );
247 let _ = writeln!(
248 out,
249 " \"\"\"Execute {} query for each item in the batch.\"\"\"",
250 analyzed.name
251 );
252 if params.is_empty() {
253 let _ = writeln!(out, " for _ in range(items):");
254 let _ = writeln!(out, " await conn.execute(");
255 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
256 let _ = writeln!(out, " )");
257 } else {
258 if params.len() == 1 {
260 let _ = writeln!(out, " args = [(item,) for item in items]");
261 } else {
262 let _ = writeln!(out, " args = items");
263 }
264 let _ = writeln!(out, " await conn.executemany(");
265 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
266 let _ = writeln!(out, " args,");
267 let _ = writeln!(out, " )");
268 }
269 }
270 QueryCommand::Exec => {
271 let _ = writeln!(
272 out,
273 "async def {}(conn: Connection{}{}) -> None:",
274 func_name, kw_sep, param_list
275 );
276 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
277 let _ = writeln!(out, " await conn.execute(");
278 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
279 if !params.is_empty() {
280 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
281 let _ = writeln!(out, " {},", args.join(", "));
282 }
283 let _ = writeln!(out, " )");
284 }
285 QueryCommand::ExecResult | QueryCommand::ExecRows => {
286 let _ = writeln!(
287 out,
288 "async def {}(conn: Connection{}{}) -> int:",
289 func_name, kw_sep, param_list
290 );
291 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
292 let _ = writeln!(out, " result = await conn.execute(");
293 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
294 if !params.is_empty() {
295 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
296 let _ = writeln!(out, " {},", args.join(", "));
297 }
298 let _ = writeln!(out, " )");
299 let _ = writeln!(out, " return int(result.split()[-1])");
300 }
301 QueryCommand::Grouped => {
302 unreachable!("Grouped is rewritten to Many before codegen")
303 }
304 }
305
306 Ok(out)
307 }
308
309 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
310 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
311 let mut out = String::new();
312 let _ = writeln!(out, "class {}(str, Enum):", type_name);
313 let _ = writeln!(
314 out,
315 " \"\"\"Database enum type {}.\"\"\"",
316 enum_info.sql_name
317 );
318 if enum_info.values.is_empty() {
319 let _ = writeln!(out, " pass");
320 } else {
321 let _ = writeln!(out);
322 for value in &enum_info.values {
323 let variant = enum_variant_name(value, &self.manifest.naming);
324 let _ = writeln!(out, " {} = \"{}\"", variant, value);
325 }
326 }
327 Ok(out)
328 }
329
330 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
331 let name = to_pascal_case(&composite.sql_name);
332 let mut out = String::new();
333 let _ = write!(out, "{}", self.row_type.decorator());
334 let _ = writeln!(out, "{}", self.row_type.class_def(&name));
335 let _ = writeln!(
336 out,
337 " \"\"\"Composite type {}.\"\"\"",
338 composite.sql_name
339 );
340 if composite.fields.is_empty() {
341 let _ = writeln!(out, " pass");
342 } else {
343 let _ = writeln!(out);
344 for field in &composite.fields {
345 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
346 .map(|t| t.into_owned())
347 .map_err(|e| {
348 ScytheError::new(
349 ErrorCode::InternalError,
350 format!("composite field type error: {}", e),
351 )
352 })?;
353 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
354 }
355 }
356 Ok(out)
357 }
358}