1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-asyncpg.toml");
18
19pub struct PythonAsyncpgBackend {
20 manifest: BackendManifest,
21}
22
23impl PythonAsyncpgBackend {
24 pub fn new() -> Result<Self, ScytheError> {
25 let manifest_path = Path::new("backends/python-asyncpg/manifest.toml");
26 let manifest = if manifest_path.exists() {
27 load_manifest(manifest_path)
28 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
29 } else {
30 toml::from_str(DEFAULT_MANIFEST_TOML)
31 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
32 };
33 Ok(Self { manifest })
34 }
35
36 pub fn manifest(&self) -> &BackendManifest {
37 &self.manifest
38 }
39}
40
41impl CodegenBackend for PythonAsyncpgBackend {
42 fn name(&self) -> &str {
43 "python-asyncpg"
44 }
45
46 fn file_header(&self) -> String {
47 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
48 \n\
49 import datetime # noqa: F401\n\
50 from dataclasses import dataclass\n\
51 from enum import Enum # noqa: F401\n\
52 \n\
53 from asyncpg import Connection # noqa: F401\n\
54 \n"
55 .to_string()
56 }
57
58 fn generate_row_struct(
59 &self,
60 query_name: &str,
61 columns: &[ResolvedColumn],
62 ) -> Result<String, ScytheError> {
63 let struct_name = row_struct_name(query_name, &self.manifest.naming);
64 let mut out = String::new();
65 let _ = writeln!(out, "@dataclass");
66 let _ = writeln!(out, "class {}:", struct_name);
67 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
68 if columns.is_empty() {
69 let _ = writeln!(out, " pass");
70 } else {
71 let _ = writeln!(out);
72 for col in columns {
73 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
74 }
75 }
76 Ok(out)
77 }
78
79 fn generate_model_struct(
80 &self,
81 table_name: &str,
82 columns: &[ResolvedColumn],
83 ) -> Result<String, ScytheError> {
84 let singular = singularize(table_name);
85 let name = to_pascal_case(&singular);
86 self.generate_row_struct(&name, columns)
87 }
88
89 fn generate_query_fn(
90 &self,
91 analyzed: &AnalyzedQuery,
92 struct_name: &str,
93 columns: &[ResolvedColumn],
94 params: &[ResolvedParam],
95 ) -> Result<String, ScytheError> {
96 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
97 let mut out = String::new();
98
99 let param_list = params
101 .iter()
102 .map(|p| format!("{}: {}", p.field_name, p.full_type))
103 .collect::<Vec<_>>()
104 .join(", ");
105 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
106
107 let sql = super::clean_sql(&analyzed.sql);
109
110 match &analyzed.command {
111 QueryCommand::One => {
112 let _ = writeln!(
113 out,
114 "async def {}(conn: Connection{}{}) -> {} | None:",
115 func_name, kw_sep, param_list, struct_name
116 );
117 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
118 let _ = writeln!(out, " row = await conn.fetchrow(");
119 let _ = writeln!(out, " \"{}\",", sql);
120 if !params.is_empty() {
121 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
122 let _ = writeln!(out, " {},", args.join(", "));
123 }
124 let _ = writeln!(out, " )");
125 let _ = writeln!(out, " if row is None:");
126 let _ = writeln!(out, " return None");
127 let field_assignments: Vec<String> = columns
128 .iter()
129 .map(|col| format!("{}=row[\"{}\"]", col.field_name, col.name))
130 .collect();
131 let oneliner = format!(
132 " return {}({})",
133 struct_name,
134 field_assignments.join(", ")
135 );
136 if oneliner.len() <= 88 {
137 let _ = writeln!(out, "{}", oneliner);
138 } else {
139 let _ = writeln!(out, " return {}(", struct_name);
140 for fa in &field_assignments {
141 let _ = writeln!(out, " {},", fa);
142 }
143 let _ = writeln!(out, " )");
144 }
145 }
146 QueryCommand::Many | QueryCommand::Batch => {
147 let _ = writeln!(
148 out,
149 "async def {}(conn: Connection{}{}) -> list[{}]:",
150 func_name, kw_sep, param_list, struct_name
151 );
152 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
153 let _ = writeln!(out, " rows = await conn.fetch(");
154 let _ = writeln!(out, " \"{}\",", sql);
155 if !params.is_empty() {
156 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
157 let _ = writeln!(out, " {},", args.join(", "));
158 }
159 let _ = writeln!(out, " )");
160 let field_assignments: Vec<String> = columns
161 .iter()
162 .map(|col| format!("{}=r[\"{}\"]", col.field_name, col.name))
163 .collect();
164 let oneliner = format!(
165 " return [{}({}) for r in rows]",
166 struct_name,
167 field_assignments.join(", ")
168 );
169 if oneliner.len() <= 88 {
170 let _ = writeln!(out, "{}", oneliner);
171 } else {
172 let _ = writeln!(out, " return [");
173 let _ = writeln!(out, " {}(", struct_name);
174 for fa in &field_assignments {
175 let _ = writeln!(out, " {},", fa);
176 }
177 let _ = writeln!(out, " )");
178 let _ = writeln!(out, " for r in rows");
179 let _ = writeln!(out, " ]");
180 }
181 }
182 QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
183 let _ = writeln!(
184 out,
185 "async def {}(conn: Connection{}{}) -> None:",
186 func_name, kw_sep, param_list
187 );
188 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
189 let _ = writeln!(out, " await conn.execute(");
190 let _ = writeln!(out, " \"{}\",", sql);
191 if !params.is_empty() {
192 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
193 let _ = writeln!(out, " {},", args.join(", "));
194 }
195 let _ = writeln!(out, " )");
196 }
197 }
198
199 Ok(out)
200 }
201
202 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
203 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
204 let mut out = String::new();
205 let _ = writeln!(out, "class {}(str, Enum):", type_name);
206 let _ = writeln!(
207 out,
208 " \"\"\"Database enum type {}.\"\"\"",
209 enum_info.sql_name
210 );
211 if enum_info.values.is_empty() {
212 let _ = writeln!(out, " pass");
213 } else {
214 let _ = writeln!(out);
215 for value in &enum_info.values {
216 let variant = enum_variant_name(value, &self.manifest.naming);
217 let _ = writeln!(out, " {} = \"{}\"", variant, value);
218 }
219 }
220 Ok(out)
221 }
222
223 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
224 let name = to_pascal_case(&composite.sql_name);
225 let mut out = String::new();
226 let _ = writeln!(out, "@dataclass");
227 let _ = writeln!(out, "class {}:", name);
228 let _ = writeln!(
229 out,
230 " \"\"\"Composite type {}.\"\"\"",
231 composite.sql_name
232 );
233 if composite.fields.is_empty() {
234 let _ = writeln!(out, " pass");
235 } else {
236 let _ = writeln!(out);
237 for field in &composite.fields {
238 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
239 .map(|t| t.into_owned())
240 .map_err(|e| {
241 ScytheError::new(
242 ErrorCode::InternalError,
243 format!("composite field type error: {}", e),
244 )
245 })?;
246 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
247 }
248 }
249 Ok(out)
250 }
251}