1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str =
18 include_str!("../../manifests/python-asyncpg.toml");
19
20pub struct PythonAsyncpgBackend {
21 manifest: BackendManifest,
22}
23
24impl PythonAsyncpgBackend {
25 pub fn new() -> Result<Self, ScytheError> {
26 let manifest_path = Path::new("backends/python-asyncpg/manifest.toml");
27 let manifest = if manifest_path.exists() {
28 load_manifest(manifest_path)
29 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30 } else {
31 toml::from_str(DEFAULT_MANIFEST_TOML)
32 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
33 };
34 Ok(Self { manifest })
35 }
36
37 pub fn manifest(&self) -> &BackendManifest {
38 &self.manifest
39 }
40}
41
42impl CodegenBackend for PythonAsyncpgBackend {
43 fn name(&self) -> &str {
44 "python-asyncpg"
45 }
46
47 fn file_header(&self) -> String {
48 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
49 \n\
50 import datetime # noqa: F401\n\
51 from dataclasses import dataclass\n\
52 from enum import Enum # noqa: F401\n\
53 \n\
54 from asyncpg import Connection # noqa: F401\n\
55 \n"
56 .to_string()
57 }
58
59 fn generate_row_struct(
60 &self,
61 query_name: &str,
62 columns: &[ResolvedColumn],
63 ) -> Result<String, ScytheError> {
64 let struct_name = row_struct_name(query_name, &self.manifest.naming);
65 let mut out = String::new();
66 let _ = writeln!(out, "@dataclass");
67 let _ = writeln!(out, "class {}:", struct_name);
68 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
69 if columns.is_empty() {
70 let _ = writeln!(out, " pass");
71 } else {
72 let _ = writeln!(out);
73 for col in columns {
74 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
75 }
76 }
77 Ok(out)
78 }
79
80 fn generate_model_struct(
81 &self,
82 table_name: &str,
83 columns: &[ResolvedColumn],
84 ) -> Result<String, ScytheError> {
85 let singular = singularize(table_name);
86 let name = to_pascal_case(&singular);
87 self.generate_row_struct(&name, columns)
88 }
89
90 fn generate_query_fn(
91 &self,
92 analyzed: &AnalyzedQuery,
93 struct_name: &str,
94 columns: &[ResolvedColumn],
95 params: &[ResolvedParam],
96 ) -> Result<String, ScytheError> {
97 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
98 let mut out = String::new();
99
100 let param_list = params
102 .iter()
103 .map(|p| format!("{}: {}", p.field_name, p.full_type))
104 .collect::<Vec<_>>()
105 .join(", ");
106 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
107
108 let sql = super::clean_sql(&analyzed.sql);
110
111 match &analyzed.command {
112 QueryCommand::One => {
113 let _ = writeln!(
114 out,
115 "async def {}(conn: Connection{}{}) -> {} | None:",
116 func_name, kw_sep, param_list, struct_name
117 );
118 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
119 let _ = writeln!(out, " row = await conn.fetchrow(");
120 let _ = writeln!(out, " \"{}\",", sql);
121 if !params.is_empty() {
122 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
123 let _ = writeln!(out, " {},", args.join(", "));
124 }
125 let _ = writeln!(out, " )");
126 let _ = writeln!(out, " if row is None:");
127 let _ = writeln!(out, " return None");
128 let field_assignments: Vec<String> = columns
129 .iter()
130 .map(|col| format!("{}=row[\"{}\"]", col.field_name, col.name))
131 .collect();
132 let oneliner = format!(
133 " return {}({})",
134 struct_name,
135 field_assignments.join(", ")
136 );
137 if oneliner.len() <= 88 {
138 let _ = writeln!(out, "{}", oneliner);
139 } else {
140 let _ = writeln!(out, " return {}(", struct_name);
141 for fa in &field_assignments {
142 let _ = writeln!(out, " {},", fa);
143 }
144 let _ = writeln!(out, " )");
145 }
146 }
147 QueryCommand::Many | QueryCommand::Batch => {
148 let _ = writeln!(
149 out,
150 "async def {}(conn: Connection{}{}) -> list[{}]:",
151 func_name, kw_sep, param_list, struct_name
152 );
153 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
154 let _ = writeln!(out, " rows = await conn.fetch(");
155 let _ = writeln!(out, " \"{}\",", sql);
156 if !params.is_empty() {
157 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
158 let _ = writeln!(out, " {},", args.join(", "));
159 }
160 let _ = writeln!(out, " )");
161 let field_assignments: Vec<String> = columns
162 .iter()
163 .map(|col| format!("{}=r[\"{}\"]", col.field_name, col.name))
164 .collect();
165 let oneliner = format!(
166 " return [{}({}) for r in rows]",
167 struct_name,
168 field_assignments.join(", ")
169 );
170 if oneliner.len() <= 88 {
171 let _ = writeln!(out, "{}", oneliner);
172 } else {
173 let _ = writeln!(out, " return [");
174 let _ = writeln!(out, " {}(", struct_name);
175 for fa in &field_assignments {
176 let _ = writeln!(out, " {},", fa);
177 }
178 let _ = writeln!(out, " )");
179 let _ = writeln!(out, " for r in rows");
180 let _ = writeln!(out, " ]");
181 }
182 }
183 QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
184 let _ = writeln!(
185 out,
186 "async def {}(conn: Connection{}{}) -> None:",
187 func_name, kw_sep, param_list
188 );
189 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
190 let _ = writeln!(out, " await conn.execute(");
191 let _ = writeln!(out, " \"{}\",", sql);
192 if !params.is_empty() {
193 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
194 let _ = writeln!(out, " {},", args.join(", "));
195 }
196 let _ = writeln!(out, " )");
197 }
198 }
199
200 Ok(out)
201 }
202
203 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
204 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
205 let mut out = String::new();
206 let _ = writeln!(out, "class {}(str, Enum):", type_name);
207 let _ = writeln!(
208 out,
209 " \"\"\"Database enum type {}.\"\"\"",
210 enum_info.sql_name
211 );
212 if enum_info.values.is_empty() {
213 let _ = writeln!(out, " pass");
214 } else {
215 let _ = writeln!(out);
216 for value in &enum_info.values {
217 let variant = enum_variant_name(value, &self.manifest.naming);
218 let _ = writeln!(out, " {} = \"{}\"", variant, value);
219 }
220 }
221 Ok(out)
222 }
223
224 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
225 let name = to_pascal_case(&composite.sql_name);
226 let mut out = String::new();
227 let _ = writeln!(out, "@dataclass");
228 let _ = writeln!(out, "class {}:", name);
229 let _ = writeln!(
230 out,
231 " \"\"\"Composite type {}.\"\"\"",
232 composite.sql_name
233 );
234 if composite.fields.is_empty() {
235 let _ = writeln!(out, " pass");
236 } else {
237 let _ = writeln!(out);
238 for field in &composite.fields {
239 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
240 .map(|t| t.into_owned())
241 .map_err(|e| {
242 ScytheError::new(
243 ErrorCode::InternalError,
244 format!("composite field type error: {}", e),
245 )
246 })?;
247 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
248 }
249 }
250 Ok(out)
251 }
252}