1use crate::{
5 indent::{IndentConfig, IndentedWriter},
6 CodeGeneratorConfig, Encoding,
7};
8use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
9use std::{
10 collections::{BTreeMap, HashMap},
11 io::{Result, Write},
12 path::PathBuf,
13};
14
15pub struct CodeGenerator<'a> {
17 config: &'a CodeGeneratorConfig,
19 serde_package_name: Option<String>,
21 external_qualified_names: HashMap<String, String>,
25}
26
27struct PythonEmitter<'a, T> {
29 out: IndentedWriter<T>,
31 generator: &'a CodeGenerator<'a>,
33 current_namespace: Vec<String>,
35}
36
37impl<'a> CodeGenerator<'a> {
38 pub fn new(config: &'a CodeGeneratorConfig) -> Self {
40 if config.enums.c_style {
41 panic!("Python 3 does not support generating c-style enums");
42 }
43 let mut external_qualified_names = HashMap::new();
44 for (module_path, names) in &config.external_definitions {
45 let module = {
46 let mut path = module_path.split('.').collect::<Vec<_>>();
47 if path.len() < 2 {
48 module_path
49 } else {
50 path.pop().unwrap()
51 }
52 };
53 for name in names {
54 external_qualified_names.insert(name.to_string(), format!("{module}.{name}"));
55 }
56 }
57 Self {
58 config,
59 serde_package_name: None,
60 external_qualified_names,
61 }
62 }
63
64 pub fn with_serde_package_name(mut self, serde_package_name: Option<String>) -> Self {
66 self.serde_package_name = serde_package_name;
67 self
68 }
69
70 pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
72 let current_namespace = self
73 .config
74 .module_name
75 .split('.')
76 .map(String::from)
77 .collect();
78 let mut emitter = PythonEmitter {
79 out: IndentedWriter::new(out, IndentConfig::Space(4)),
80 generator: self,
81 current_namespace,
82 };
83 emitter.output_preamble()?;
84 for (name, format) in registry {
85 emitter.output_container(name, format)?;
86 }
87 Ok(())
88 }
89}
90
91impl<'a, T> PythonEmitter<'a, T>
92where
93 T: Write,
94{
95 fn quote_import(&self, module: &str) -> String {
96 let mut parts = module.split('.').collect::<Vec<_>>();
97 if parts.len() <= 1 {
98 format!("import {module}")
99 } else {
100 let module_name = parts.pop().unwrap();
101 format!("from {} import {}", parts.join("."), module_name)
102 }
103 }
104
105 fn output_preamble(&mut self) -> Result<()> {
106 let from_serde_package = match &self.generator.serde_package_name {
107 None => "".to_string(),
108 Some(name) => format!("from {name} "),
109 };
110 writeln!(
111 self.out,
112 r#"# pyre-strict
113from dataclasses import dataclass
114import typing
115{from_serde_package}import serde_types as st"#,
116 )?;
117 for encoding in &self.generator.config.encodings {
118 writeln!(self.out, "{}import {}", from_serde_package, encoding.name())?;
119 }
120 for module in self.generator.config.external_definitions.keys() {
121 writeln!(self.out, "{}\n", self.quote_import(module))?;
122 }
123 Ok(())
124 }
125
126 fn quote_qualified_name(&self, name: &str) -> String {
129 self.generator
130 .external_qualified_names
131 .get(name)
132 .cloned()
133 .unwrap_or_else(|| {
134 format!("\"{name}\"")
136 })
137 }
138
139 fn quote_type(&self, format: &Format) -> String {
140 use Format::*;
141 match format {
142 TypeName(x) => self.quote_qualified_name(x),
143 Unit => "st.unit".into(),
144 Bool => "bool".into(),
145 I8 => "st.int8".into(),
146 I16 => "st.int16".into(),
147 I32 => "st.int32".into(),
148 I64 => "st.int64".into(),
149 I128 => "st.int128".into(),
150 U8 => "st.uint8".into(),
151 U16 => "st.uint16".into(),
152 U32 => "st.uint32".into(),
153 U64 => "st.uint64".into(),
154 U128 => "st.uint128".into(),
155 F32 => "st.float32".into(),
156 F64 => "st.float64".into(),
157 Char => "st.char".into(),
158 Str => "str".into(),
159 Bytes => "bytes".into(),
160
161 Option(format) => format!("typing.Optional[{}]", self.quote_type(format)),
162 Seq(format) => format!("typing.Sequence[{}]", self.quote_type(format)),
163 Map { key, value } => format!(
164 "typing.Dict[{}, {}]",
165 self.quote_type(key),
166 self.quote_type(value)
167 ),
168 Tuple(formats) => {
169 if formats.is_empty() {
170 "typing.Tuple[()]".into()
171 } else {
172 format!("typing.Tuple[{}]", self.quote_types(formats))
173 }
174 }
175 TupleArray { content, size } => format!(
176 "typing.Tuple[{}]",
177 self.quote_types(&vec![content.as_ref().clone(); *size])
178 ), Variable(_) => panic!("unexpected value"),
181 }
182 }
183
184 fn quote_types(&self, formats: &[Format]) -> String {
185 formats
186 .iter()
187 .map(|x| self.quote_type(x))
188 .collect::<Vec<_>>()
189 .join(", ")
190 }
191
192 fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
193 let mut path = self.current_namespace.clone();
194 path.push(name.to_string());
195 if let Some(doc) = self.generator.config.comments.get(&path) {
196 writeln!(self.out, "\"\"\"{doc}\"\"\"")?;
197 }
198 Ok(())
199 }
200
201 fn output_custom_code(&mut self) -> std::io::Result<bool> {
202 match self
203 .generator
204 .config
205 .custom_code
206 .get(&self.current_namespace)
207 {
208 Some(code) => {
209 writeln!(self.out, "\n{code}")?;
210 Ok(true)
211 }
212 None => Ok(false),
213 }
214 }
215
216 fn output_fields(&mut self, fields: &[Named<Format>]) -> Result<()> {
217 if fields.is_empty() {
218 writeln!(self.out, "pass")?;
219 return Ok(());
220 }
221 for field in fields {
222 writeln!(
223 self.out,
224 "{}: {}",
225 field.name,
226 self.quote_type(&field.value)
227 )?;
228 }
229 Ok(())
230 }
231
232 fn output_variant(
233 &mut self,
234 base: &str,
235 name: &str,
236 index: u32,
237 variant: &VariantFormat,
238 ) -> Result<()> {
239 use VariantFormat::*;
240 let fields = match variant {
241 Unit => Vec::new(),
242 NewType(format) => vec![Named {
243 name: "value".to_string(),
244 value: format.as_ref().clone(),
245 }],
246 Tuple(formats) => vec![Named {
247 name: "value".to_string(),
248 value: Format::Tuple(formats.clone()),
249 }],
250 Struct(fields) => fields.clone(),
251 Variable(_) => panic!("incorrect value"),
252 };
253
254 writeln!(
256 self.out,
257 "\n@dataclass(frozen=True)\nclass {base}__{name}({base}):"
258 )?;
259 self.out.indent();
260 self.output_comment(name)?;
261 if self.generator.config.serialization {
262 writeln!(self.out, "INDEX = {index} # type: int")?;
263 }
264 self.current_namespace.push(name.to_string());
265 self.output_fields(&fields)?;
266 self.output_custom_code()?;
267 self.current_namespace.pop();
268 self.out.unindent();
269 writeln!(self.out)
270 }
271
272 fn output_enum_container(
273 &mut self,
274 name: &str,
275 variants: &BTreeMap<u32, Named<VariantFormat>>,
276 ) -> Result<()> {
277 writeln!(self.out, "\nclass {name}:")?;
278 self.out.indent();
279 self.output_comment(name)?;
280 self.current_namespace.push(name.to_string());
281 if self.generator.config.serialization {
282 writeln!(
283 self.out,
284 "VARIANTS = [] # type: typing.Sequence[typing.Type[{name}]]"
285 )?;
286 for encoding in &self.generator.config.encodings {
287 self.output_serialize_method_for_encoding(name, *encoding)?;
288 self.output_deserialize_method_for_encoding(name, *encoding)?;
289 }
290 }
291 let wrote_custom_code = self.output_custom_code()?;
292 if !self.generator.config.serialization && !wrote_custom_code {
293 writeln!(self.out, "pass")?;
294 }
295 writeln!(self.out)?;
296 self.out.unindent();
297
298 for (index, variant) in variants {
299 self.output_variant(name, &variant.name, *index, &variant.value)?;
300 }
301 self.current_namespace.pop();
302
303 if self.generator.config.serialization {
304 writeln!(
305 self.out,
306 "{}.VARIANTS = [\n{}]\n",
307 name,
308 variants
309 .values()
310 .map(|v| format!(" {name}__{},\n", v.name))
311 .collect::<Vec<_>>()
312 .join("")
313 )?;
314 }
315 Ok(())
316 }
317
318 fn output_serialize_method_for_encoding(
319 &mut self,
320 name: &str,
321 encoding: Encoding,
322 ) -> Result<()> {
323 writeln!(
324 self.out,
325 r#"
326def {0}_serialize(self) -> bytes:
327 return {0}.serialize(self, {1})"#,
328 encoding.name(),
329 name
330 )
331 }
332
333 fn output_deserialize_method_for_encoding(
334 &mut self,
335 name: &str,
336 encoding: Encoding,
337 ) -> Result<()> {
338 writeln!(
339 self.out,
340 r#"
341@staticmethod
342def {0}_deserialize(input: bytes) -> '{1}':
343 v, buffer = {0}.deserialize(input, {1})
344 if buffer:
345 raise st.DeserializationError("Some input bytes were not read");
346 return v"#,
347 encoding.name(),
348 name
349 )
350 }
351
352 fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
353 use ContainerFormat::*;
354 let fields = match format {
355 UnitStruct => Vec::new(),
356 NewTypeStruct(format) => vec![Named {
357 name: "value".to_string(),
358 value: format.as_ref().clone(),
359 }],
360 TupleStruct(formats) => vec![Named {
361 name: "value".to_string(),
362 value: Format::Tuple(formats.clone()),
363 }],
364 Struct(fields) => fields.clone(),
365 Enum(variants) => {
366 self.output_enum_container(name, variants)?;
368 return Ok(());
369 }
370 };
371 writeln!(self.out, "\n@dataclass(frozen=True)\nclass {name}:")?;
373 self.out.indent();
374 self.output_comment(name)?;
375 self.current_namespace.push(name.to_string());
376 self.output_fields(&fields)?;
377 for encoding in &self.generator.config.encodings {
378 self.output_serialize_method_for_encoding(name, *encoding)?;
379 self.output_deserialize_method_for_encoding(name, *encoding)?;
380 }
381 self.output_custom_code()?;
382 self.current_namespace.pop();
383 self.out.unindent();
384 writeln!(self.out)
385 }
386}
387
388pub struct Installer {
390 install_dir: PathBuf,
391 serde_package_name: Option<String>,
392}
393
394impl Installer {
395 pub fn new(install_dir: PathBuf, serde_package_name: Option<String>) -> Self {
396 Installer {
397 install_dir,
398 serde_package_name,
399 }
400 }
401
402 fn create_module_init_file(&self, name: &str) -> Result<std::fs::File> {
403 let dir_path = self.install_dir.join(name);
404 std::fs::create_dir_all(&dir_path)?;
405 std::fs::File::create(dir_path.join("__init__.py"))
406 }
407
408 fn fix_serde_package(&self, content: &str) -> String {
409 match &self.serde_package_name {
410 None => content.into(),
411 Some(name) => content
412 .replace(
413 "import serde_types",
414 &format!("from {name} import serde_types"),
415 )
416 .replace(
417 "import serde_binary",
418 &format!("from {name} import serde_binary"),
419 ),
420 }
421 }
422}
423
424impl crate::SourceInstaller for Installer {
425 type Error = Box<dyn std::error::Error>;
426
427 fn install_module(
428 &self,
429 config: &crate::CodeGeneratorConfig,
430 registry: &Registry,
431 ) -> std::result::Result<(), Self::Error> {
432 let mut file = self.create_module_init_file(&config.module_name)?;
433 let generator =
434 CodeGenerator::new(config).with_serde_package_name(self.serde_package_name.clone());
435 generator.output(&mut file, registry)?;
436 Ok(())
437 }
438
439 fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
440 let mut file = self.create_module_init_file("serde_types")?;
441 write!(
442 file,
443 "{}",
444 self.fix_serde_package(include_str!("../runtime/python/serde_types/__init__.py"))
445 )?;
446 let mut file = self.create_module_init_file("serde_binary")?;
447 write!(
448 file,
449 "{}",
450 self.fix_serde_package(include_str!("../runtime/python/serde_binary/__init__.py"))
451 )?;
452 Ok(())
453 }
454
455 fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
456 let mut file = self.create_module_init_file("bincode")?;
457 write!(
458 file,
459 "{}",
460 self.fix_serde_package(include_str!("../runtime/python/bincode/__init__.py"))
461 )?;
462 Ok(())
463 }
464
465 fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
466 let mut file = self.create_module_init_file("bcs")?;
467 write!(
468 file,
469 "{}",
470 self.fix_serde_package(include_str!("../runtime/python/bcs/__init__.py"))
471 )?;
472 Ok(())
473 }
474}