1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-asyncpg.toml");
18
19pub struct PythonAsyncpgBackend {
20 manifest: BackendManifest,
21}
22
23impl PythonAsyncpgBackend {
24 pub fn new(engine: &str) -> Result<Self, ScytheError> {
25 match engine {
26 "postgresql" | "postgres" | "pg" => {}
27 _ => {
28 return Err(ScytheError::new(
29 ErrorCode::InternalError,
30 format!(
31 "python-asyncpg only supports PostgreSQL, got engine '{}'",
32 engine
33 ),
34 ));
35 }
36 }
37 let manifest_path = Path::new("backends/python-asyncpg/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(DEFAULT_MANIFEST_TOML)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 Ok(Self { manifest })
46 }
47}
48
49impl CodegenBackend for PythonAsyncpgBackend {
50 fn name(&self) -> &str {
51 "python-asyncpg"
52 }
53
54 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
55 &self.manifest
56 }
57
58 fn file_header(&self) -> String {
59 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
60 \n\
61 import datetime # noqa: F401\n\
62 import decimal # noqa: F401\n\
63 from dataclasses import dataclass\n\
64 from enum import Enum # noqa: F401\n\
65 \n\
66 from asyncpg import Connection # noqa: F401\n\
67 \n"
68 .to_string()
69 }
70
71 fn generate_row_struct(
72 &self,
73 query_name: &str,
74 columns: &[ResolvedColumn],
75 ) -> Result<String, ScytheError> {
76 let struct_name = row_struct_name(query_name, &self.manifest.naming);
77 let mut out = String::new();
78 let _ = writeln!(out, "@dataclass");
79 let _ = writeln!(out, "class {}:", struct_name);
80 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
81 if columns.is_empty() {
82 let _ = writeln!(out, " pass");
83 } else {
84 let _ = writeln!(out);
85 for col in columns {
86 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
87 }
88 }
89 Ok(out)
90 }
91
92 fn generate_model_struct(
93 &self,
94 table_name: &str,
95 columns: &[ResolvedColumn],
96 ) -> Result<String, ScytheError> {
97 let singular = singularize(table_name);
98 let name = to_pascal_case(&singular);
99 self.generate_row_struct(&name, columns)
100 }
101
102 fn generate_query_fn(
103 &self,
104 analyzed: &AnalyzedQuery,
105 struct_name: &str,
106 columns: &[ResolvedColumn],
107 params: &[ResolvedParam],
108 ) -> Result<String, ScytheError> {
109 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
110 let mut out = String::new();
111
112 let param_list = params
114 .iter()
115 .map(|p| format!("{}: {}", p.field_name, p.full_type))
116 .collect::<Vec<_>>()
117 .join(", ");
118 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
119
120 let sql = super::clean_sql(&analyzed.sql);
122
123 match &analyzed.command {
124 QueryCommand::One => {
125 let _ = writeln!(
126 out,
127 "async def {}(conn: Connection{}{}) -> {} | None:",
128 func_name, kw_sep, param_list, struct_name
129 );
130 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
131 let _ = writeln!(out, " row = await conn.fetchrow(");
132 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
133 if !params.is_empty() {
134 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
135 let _ = writeln!(out, " {},", args.join(", "));
136 }
137 let _ = writeln!(out, " )");
138 let _ = writeln!(out, " if row is None:");
139 let _ = writeln!(out, " return None");
140 let field_assignments: Vec<String> = columns
141 .iter()
142 .map(|col| format!("{}=row[\"{}\"]", col.field_name, col.name))
143 .collect();
144 let oneliner = format!(
145 " return {}({})",
146 struct_name,
147 field_assignments.join(", ")
148 );
149 if oneliner.len() <= 88 {
150 let _ = writeln!(out, "{}", oneliner);
151 } else {
152 let _ = writeln!(out, " return {}(", struct_name);
153 for fa in &field_assignments {
154 let _ = writeln!(out, " {},", fa);
155 }
156 let _ = writeln!(out, " )");
157 }
158 }
159 QueryCommand::Many | QueryCommand::Batch => {
160 let _ = writeln!(
161 out,
162 "async def {}(conn: Connection{}{}) -> list[{}]:",
163 func_name, kw_sep, param_list, struct_name
164 );
165 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
166 let _ = writeln!(out, " rows = await conn.fetch(");
167 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
168 if !params.is_empty() {
169 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
170 let _ = writeln!(out, " {},", args.join(", "));
171 }
172 let _ = writeln!(out, " )");
173 let field_assignments: Vec<String> = columns
174 .iter()
175 .map(|col| format!("{}=r[\"{}\"]", col.field_name, col.name))
176 .collect();
177 let oneliner = format!(
178 " return [{}({}) for r in rows]",
179 struct_name,
180 field_assignments.join(", ")
181 );
182 if oneliner.len() <= 88 {
183 let _ = writeln!(out, "{}", oneliner);
184 } else {
185 let _ = writeln!(out, " return [");
186 let _ = writeln!(out, " {}(", struct_name);
187 for fa in &field_assignments {
188 let _ = writeln!(out, " {},", fa);
189 }
190 let _ = writeln!(out, " )");
191 let _ = writeln!(out, " for r in rows");
192 let _ = writeln!(out, " ]");
193 }
194 }
195 QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
196 let _ = writeln!(
197 out,
198 "async def {}(conn: Connection{}{}) -> None:",
199 func_name, kw_sep, param_list
200 );
201 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
202 let _ = writeln!(out, " await conn.execute(");
203 let _ = writeln!(out, " \"\"\"{}\"\"\",", sql);
204 if !params.is_empty() {
205 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
206 let _ = writeln!(out, " {},", args.join(", "));
207 }
208 let _ = writeln!(out, " )");
209 }
210 }
211
212 Ok(out)
213 }
214
215 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
216 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
217 let mut out = String::new();
218 let _ = writeln!(out, "class {}(str, Enum):", type_name);
219 let _ = writeln!(
220 out,
221 " \"\"\"Database enum type {}.\"\"\"",
222 enum_info.sql_name
223 );
224 if enum_info.values.is_empty() {
225 let _ = writeln!(out, " pass");
226 } else {
227 let _ = writeln!(out);
228 for value in &enum_info.values {
229 let variant = enum_variant_name(value, &self.manifest.naming);
230 let _ = writeln!(out, " {} = \"{}\"", variant, value);
231 }
232 }
233 Ok(out)
234 }
235
236 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
237 let name = to_pascal_case(&composite.sql_name);
238 let mut out = String::new();
239 let _ = writeln!(out, "@dataclass");
240 let _ = writeln!(out, "class {}:", name);
241 let _ = writeln!(
242 out,
243 " \"\"\"Composite type {}.\"\"\"",
244 composite.sql_name
245 );
246 if composite.fields.is_empty() {
247 let _ = writeln!(out, " pass");
248 } else {
249 let _ = writeln!(out);
250 for field in &composite.fields {
251 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
252 .map(|t| t.into_owned())
253 .map_err(|e| {
254 ScytheError::new(
255 ErrorCode::InternalError,
256 format!("composite field type error: {}", e),
257 )
258 })?;
259 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
260 }
261 }
262 Ok(out)
263 }
264}