1use crate::{
5 indent::{IndentConfig, IndentedWriter},
6 CodeGeneratorConfig, Encoding,
7};
8use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
9use std::{
10 collections::{BTreeMap, HashMap},
11 io::{Result, Write},
12 path::PathBuf,
13};
14
15pub struct CodeGenerator<'a> {
17 config: &'a CodeGeneratorConfig,
19 serde_package_name: Option<String>,
21 external_qualified_names: HashMap<String, String>,
25}
26
27struct PythonEmitter<'a, T> {
29 out: IndentedWriter<T>,
31 generator: &'a CodeGenerator<'a>,
33 current_namespace: Vec<String>,
35}
36
37impl<'a> CodeGenerator<'a> {
38 pub fn new(config: &'a CodeGeneratorConfig) -> Self {
40 if config.c_style_enums {
41 panic!("Python 3 does not support generating c-style enums");
42 }
43 let mut external_qualified_names = HashMap::new();
44 for (module_path, names) in &config.external_definitions {
45 let module = {
46 let mut path = module_path.split('.').collect::<Vec<_>>();
47 if path.len() < 2 {
48 module_path
49 } else {
50 path.pop().unwrap()
51 }
52 };
53 for name in names {
54 external_qualified_names.insert(name.to_string(), format!("{}.{}", module, name));
55 }
56 }
57 Self {
58 config,
59 serde_package_name: None,
60 external_qualified_names,
61 }
62 }
63
64 pub fn with_serde_package_name(mut self, serde_package_name: Option<String>) -> Self {
66 self.serde_package_name = serde_package_name;
67 self
68 }
69
70 pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
72 let current_namespace = self
73 .config
74 .module_name
75 .split('.')
76 .map(String::from)
77 .collect();
78 let mut emitter = PythonEmitter {
79 out: IndentedWriter::new(out, IndentConfig::Space(4)),
80 generator: self,
81 current_namespace,
82 };
83 emitter.output_preamble()?;
84 for (name, format) in registry {
85 emitter.output_container(name, format)?;
86 }
87 Ok(())
88 }
89}
90
91impl<'a, T> PythonEmitter<'a, T>
92where
93 T: Write,
94{
95 fn quote_import(&self, module: &str) -> String {
96 let mut parts = module.split('.').collect::<Vec<_>>();
97 if parts.len() <= 1 {
98 format!("import {}", module)
99 } else {
100 let module_name = parts.pop().unwrap();
101 format!("from {} import {}", parts.join("."), module_name)
102 }
103 }
104
105 fn output_preamble(&mut self) -> Result<()> {
106 let from_serde_package = match &self.generator.serde_package_name {
107 None => "".to_string(),
108 Some(name) => format!("from {} ", name),
109 };
110 writeln!(
111 self.out,
112 r#"# pyre-strict
113from dataclasses import dataclass
114import typing
115{}import serde_types as st"#,
116 from_serde_package,
117 )?;
118 for encoding in &self.generator.config.encodings {
119 writeln!(self.out, "{}import {}", from_serde_package, encoding.name())?;
120 }
121 for module in self.generator.config.external_definitions.keys() {
122 writeln!(self.out, "{}\n", self.quote_import(module))?;
123 }
124 Ok(())
125 }
126
127 fn quote_qualified_name(&self, name: &str) -> String {
130 self.generator
131 .external_qualified_names
132 .get(name)
133 .cloned()
134 .unwrap_or_else(|| {
135 format!("\"{}\"", name)
137 })
138 }
139
140 fn quote_type(&self, format: &Format) -> String {
141 use Format::*;
142 match format {
143 TypeName(x) => self.quote_qualified_name(x),
144 Unit => "st.unit".into(),
145 Bool => "bool".into(),
146 I8 => "st.int8".into(),
147 I16 => "st.int16".into(),
148 I32 => "st.int32".into(),
149 I64 => "st.int64".into(),
150 I128 => "st.int128".into(),
151 U8 => "st.uint8".into(),
152 U16 => "st.uint16".into(),
153 U32 => "st.uint32".into(),
154 U64 => "st.uint64".into(),
155 U128 => "st.uint128".into(),
156 F32 => "st.float32".into(),
157 F64 => "st.float64".into(),
158 Char => "st.char".into(),
159 Str => "str".into(),
160 Bytes => "bytes".into(),
161
162 Option(format) => format!("typing.Optional[{}]", self.quote_type(format)),
163 Seq(format) => format!("typing.Sequence[{}]", self.quote_type(format)),
164 Map { key, value } => format!(
165 "typing.Dict[{}, {}]",
166 self.quote_type(key),
167 self.quote_type(value)
168 ),
169 Tuple(formats) => {
170 if formats.is_empty() {
171 "typing.Tuple[()]".into()
172 } else {
173 format!("typing.Tuple[{}]", self.quote_types(formats))
174 }
175 }
176 TupleArray { content, size } => format!(
177 "typing.Tuple[{}]",
178 self.quote_types(&vec![content.as_ref().clone(); *size])
179 ), Variable(_) => panic!("unexpected value"),
182 }
183 }
184
185 fn quote_types(&self, formats: &[Format]) -> String {
186 formats
187 .iter()
188 .map(|x| self.quote_type(x))
189 .collect::<Vec<_>>()
190 .join(", ")
191 }
192
193 fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
194 let mut path = self.current_namespace.clone();
195 path.push(name.to_string());
196 if let Some(doc) = self.generator.config.comments.get(&path) {
197 writeln!(self.out, "\"\"\"{}\"\"\"", doc)?;
198 }
199 Ok(())
200 }
201
202 fn output_custom_code(&mut self) -> std::io::Result<bool> {
203 match self
204 .generator
205 .config
206 .custom_code
207 .get(&self.current_namespace)
208 {
209 Some(code) => {
210 writeln!(self.out, "\n{}", code)?;
211 Ok(true)
212 }
213 None => Ok(false),
214 }
215 }
216
217 fn output_fields(&mut self, fields: &[Named<Format>]) -> Result<()> {
218 if fields.is_empty() {
219 writeln!(self.out, "pass")?;
220 return Ok(());
221 }
222 for field in fields {
223 writeln!(
224 self.out,
225 "{}: {}",
226 field.name,
227 self.quote_type(&field.value)
228 )?;
229 }
230 Ok(())
231 }
232
233 fn output_variant(
234 &mut self,
235 base: &str,
236 name: &str,
237 index: u32,
238 variant: &VariantFormat,
239 ) -> Result<()> {
240 use VariantFormat::*;
241 let fields = match variant {
242 Unit => Vec::new(),
243 NewType(format) => vec![Named {
244 name: "value".to_string(),
245 value: format.as_ref().clone(),
246 }],
247 Tuple(formats) => vec![Named {
248 name: "value".to_string(),
249 value: Format::Tuple(formats.clone()),
250 }],
251 Struct(fields) => fields.clone(),
252 Variable(_) => panic!("incorrect value"),
253 };
254
255 writeln!(
257 self.out,
258 "\n@dataclass(frozen=True)\nclass {0}__{1}({0}):",
259 base, name
260 )?;
261 self.out.indent();
262 self.output_comment(name)?;
263 if self.generator.config.serialization {
264 writeln!(self.out, "INDEX = {} # type: int", index)?;
265 }
266 self.current_namespace.push(name.to_string());
267 self.output_fields(&fields)?;
268 self.output_custom_code()?;
269 self.current_namespace.pop();
270 self.out.unindent();
271 writeln!(self.out)
272 }
273
274 fn output_enum_container(
275 &mut self,
276 name: &str,
277 variants: &BTreeMap<u32, Named<VariantFormat>>,
278 ) -> Result<()> {
279 writeln!(self.out, "\nclass {}:", name)?;
280 self.out.indent();
281 self.output_comment(name)?;
282 self.current_namespace.push(name.to_string());
283 if self.generator.config.serialization {
284 writeln!(
285 self.out,
286 "VARIANTS = [] # type: typing.Sequence[typing.Type[{}]]",
287 name
288 )?;
289 for encoding in &self.generator.config.encodings {
290 self.output_serialize_method_for_encoding(name, *encoding)?;
291 self.output_deserialize_method_for_encoding(name, *encoding)?;
292 }
293 }
294 let wrote_custom_code = self.output_custom_code()?;
295 if !self.generator.config.serialization && !wrote_custom_code {
296 writeln!(self.out, "pass")?;
297 }
298 writeln!(self.out)?;
299 self.out.unindent();
300
301 for (index, variant) in variants {
302 self.output_variant(name, &variant.name, *index, &variant.value)?;
303 }
304 self.current_namespace.pop();
305
306 if self.generator.config.serialization {
307 writeln!(
308 self.out,
309 "{}.VARIANTS = [\n{}]\n",
310 name,
311 variants
312 .iter()
313 .map(|(_, v)| format!(" {}__{},\n", name, v.name))
314 .collect::<Vec<_>>()
315 .join("")
316 )?;
317 }
318 Ok(())
319 }
320
321 fn output_serialize_method_for_encoding(
322 &mut self,
323 name: &str,
324 encoding: Encoding,
325 ) -> Result<()> {
326 writeln!(
327 self.out,
328 r#"
329def {0}_serialize(self) -> bytes:
330 return {0}.serialize(self, {1})"#,
331 encoding.name(),
332 name
333 )
334 }
335
336 fn output_deserialize_method_for_encoding(
337 &mut self,
338 name: &str,
339 encoding: Encoding,
340 ) -> Result<()> {
341 writeln!(
342 self.out,
343 r#"
344@staticmethod
345def {0}_deserialize(input: bytes) -> '{1}':
346 v, buffer = {0}.deserialize(input, {1})
347 if buffer:
348 raise st.DeserializationError("Some input bytes were not read");
349 return v"#,
350 encoding.name(),
351 name
352 )
353 }
354
355 fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
356 use ContainerFormat::*;
357 let fields = match format {
358 UnitStruct => Vec::new(),
359 NewTypeStruct(format) => vec![Named {
360 name: "value".to_string(),
361 value: format.as_ref().clone(),
362 }],
363 TupleStruct(formats) => vec![Named {
364 name: "value".to_string(),
365 value: Format::Tuple(formats.clone()),
366 }],
367 Struct(fields) => fields.clone(),
368 Enum(variants) => {
369 self.output_enum_container(name, variants)?;
371 return Ok(());
372 }
373 };
374 writeln!(self.out, "\n@dataclass(frozen=True)\nclass {}:", name)?;
376 self.out.indent();
377 self.output_comment(name)?;
378 self.current_namespace.push(name.to_string());
379 self.output_fields(&fields)?;
380 for encoding in &self.generator.config.encodings {
381 self.output_serialize_method_for_encoding(name, *encoding)?;
382 self.output_deserialize_method_for_encoding(name, *encoding)?;
383 }
384 self.output_custom_code()?;
385 self.current_namespace.pop();
386 self.out.unindent();
387 writeln!(self.out)
388 }
389}
390
391pub struct Installer {
393 install_dir: PathBuf,
394 serde_package_name: Option<String>,
395}
396
397impl Installer {
398 pub fn new(install_dir: PathBuf, serde_package_name: Option<String>) -> Self {
399 Installer {
400 install_dir,
401 serde_package_name,
402 }
403 }
404
405 fn create_module_init_file(&self, name: &str) -> Result<std::fs::File> {
406 let dir_path = self.install_dir.join(name);
407 std::fs::create_dir_all(&dir_path)?;
408 std::fs::File::create(dir_path.join("__init__.py"))
409 }
410
411 fn fix_serde_package(&self, content: &str) -> String {
412 match &self.serde_package_name {
413 None => content.into(),
414 Some(name) => content
415 .replace(
416 "import serde_types",
417 &format!("from {} import serde_types", name),
418 )
419 .replace(
420 "import serde_binary",
421 &format!("from {} import serde_binary", name),
422 ),
423 }
424 }
425}
426
427impl crate::SourceInstaller for Installer {
428 type Error = Box<dyn std::error::Error>;
429
430 fn install_module(
431 &self,
432 config: &crate::CodeGeneratorConfig,
433 registry: &Registry,
434 ) -> std::result::Result<(), Self::Error> {
435 let mut file = self.create_module_init_file(&config.module_name)?;
436 let generator =
437 CodeGenerator::new(config).with_serde_package_name(self.serde_package_name.clone());
438 generator.output(&mut file, registry)?;
439 Ok(())
440 }
441
442 fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
443 let mut file = self.create_module_init_file("serde_types")?;
444 write!(
445 file,
446 "{}",
447 self.fix_serde_package(include_str!("../runtime/python/serde_types/__init__.py"))
448 )?;
449 let mut file = self.create_module_init_file("serde_binary")?;
450 write!(
451 file,
452 "{}",
453 self.fix_serde_package(include_str!("../runtime/python/serde_binary/__init__.py"))
454 )?;
455 Ok(())
456 }
457
458 fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
459 let mut file = self.create_module_init_file("bincode")?;
460 write!(
461 file,
462 "{}",
463 self.fix_serde_package(include_str!("../runtime/python/bincode/__init__.py"))
464 )?;
465 Ok(())
466 }
467
468 fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
469 let mut file = self.create_module_init_file("bcs")?;
470 write!(
471 file,
472 "{}",
473 self.fix_serde_package(include_str!("../runtime/python/bcs/__init__.py"))
474 )?;
475 Ok(())
476 }
477}