serde_generate/
python3.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    indent::{IndentConfig, IndentedWriter},
6    CodeGeneratorConfig, Encoding,
7};
8use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
9use std::{
10    collections::{BTreeMap, HashMap},
11    io::{Result, Write},
12    path::PathBuf,
13};
14
15/// Main configuration object for code-generation in Python.
16pub struct CodeGenerator<'a> {
17    /// Language-independent configuration.
18    config: &'a CodeGeneratorConfig,
19    /// Whether the module providing Serde definitions is located within package.
20    serde_package_name: Option<String>,
21    /// Mapping from external type names to suitably qualified names (e.g. "MyClass" -> "my_module.MyClass").
22    /// Assumes suitable imports (e.g. "from my_package import my_module").
23    /// Derived from `config.external_definitions`.
24    external_qualified_names: HashMap<String, String>,
25}
26
27/// Shared state for the code generation of a Python source file.
28struct PythonEmitter<'a, T> {
29    /// Writer.
30    out: IndentedWriter<T>,
31    /// Generator.
32    generator: &'a CodeGenerator<'a>,
33    /// Current namespace (e.g. vec!["my_package", "my_module", "MyClass"])
34    current_namespace: Vec<String>,
35}
36
37impl<'a> CodeGenerator<'a> {
38    /// Create a Python code generator for the given config.
39    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
40        if config.c_style_enums {
41            panic!("Python 3 does not support generating c-style enums");
42        }
43        let mut external_qualified_names = HashMap::new();
44        for (module_path, names) in &config.external_definitions {
45            let module = {
46                let mut path = module_path.split('.').collect::<Vec<_>>();
47                if path.len() < 2 {
48                    module_path
49                } else {
50                    path.pop().unwrap()
51                }
52            };
53            for name in names {
54                external_qualified_names.insert(name.to_string(), format!("{}.{}", module, name));
55            }
56        }
57        Self {
58            config,
59            serde_package_name: None,
60            external_qualified_names,
61        }
62    }
63
64    /// Whether the module providing Serde definitions is located within a package.
65    pub fn with_serde_package_name(mut self, serde_package_name: Option<String>) -> Self {
66        self.serde_package_name = serde_package_name;
67        self
68    }
69
70    /// Write container definitions in Python.
71    pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
72        let current_namespace = self
73            .config
74            .module_name
75            .split('.')
76            .map(String::from)
77            .collect();
78        let mut emitter = PythonEmitter {
79            out: IndentedWriter::new(out, IndentConfig::Space(4)),
80            generator: self,
81            current_namespace,
82        };
83        emitter.output_preamble()?;
84        for (name, format) in registry {
85            emitter.output_container(name, format)?;
86        }
87        Ok(())
88    }
89}
90
91impl<'a, T> PythonEmitter<'a, T>
92where
93    T: Write,
94{
95    fn quote_import(&self, module: &str) -> String {
96        let mut parts = module.split('.').collect::<Vec<_>>();
97        if parts.len() <= 1 {
98            format!("import {}", module)
99        } else {
100            let module_name = parts.pop().unwrap();
101            format!("from {} import {}", parts.join("."), module_name)
102        }
103    }
104
105    fn output_preamble(&mut self) -> Result<()> {
106        let from_serde_package = match &self.generator.serde_package_name {
107            None => "".to_string(),
108            Some(name) => format!("from {} ", name),
109        };
110        writeln!(
111            self.out,
112            r#"# pyre-strict
113from dataclasses import dataclass
114import typing
115{}import serde_types as st"#,
116            from_serde_package,
117        )?;
118        for encoding in &self.generator.config.encodings {
119            writeln!(self.out, "{}import {}", from_serde_package, encoding.name())?;
120        }
121        for module in self.generator.config.external_definitions.keys() {
122            writeln!(self.out, "{}\n", self.quote_import(module))?;
123        }
124        Ok(())
125    }
126
127    /// Compute a reference to the registry type `name`.
128    /// Use a qualified name in case of external definitions.
129    fn quote_qualified_name(&self, name: &str) -> String {
130        self.generator
131            .external_qualified_names
132            .get(name)
133            .cloned()
134            .unwrap_or_else(|| {
135                // Need quotes because of circular dependencies.
136                format!("\"{}\"", name)
137            })
138    }
139
140    fn quote_type(&self, format: &Format) -> String {
141        use Format::*;
142        match format {
143            TypeName(x) => self.quote_qualified_name(x),
144            Unit => "st.unit".into(),
145            Bool => "bool".into(),
146            I8 => "st.int8".into(),
147            I16 => "st.int16".into(),
148            I32 => "st.int32".into(),
149            I64 => "st.int64".into(),
150            I128 => "st.int128".into(),
151            U8 => "st.uint8".into(),
152            U16 => "st.uint16".into(),
153            U32 => "st.uint32".into(),
154            U64 => "st.uint64".into(),
155            U128 => "st.uint128".into(),
156            F32 => "st.float32".into(),
157            F64 => "st.float64".into(),
158            Char => "st.char".into(),
159            Str => "str".into(),
160            Bytes => "bytes".into(),
161
162            Option(format) => format!("typing.Optional[{}]", self.quote_type(format)),
163            Seq(format) => format!("typing.Sequence[{}]", self.quote_type(format)),
164            Map { key, value } => format!(
165                "typing.Dict[{}, {}]",
166                self.quote_type(key),
167                self.quote_type(value)
168            ),
169            Tuple(formats) => {
170                if formats.is_empty() {
171                    "typing.Tuple[()]".into()
172                } else {
173                    format!("typing.Tuple[{}]", self.quote_types(formats))
174                }
175            }
176            TupleArray { content, size } => format!(
177                "typing.Tuple[{}]",
178                self.quote_types(&vec![content.as_ref().clone(); *size])
179            ), // Sadly, there are no fixed-size arrays in python.
180
181            Variable(_) => panic!("unexpected value"),
182        }
183    }
184
185    fn quote_types(&self, formats: &[Format]) -> String {
186        formats
187            .iter()
188            .map(|x| self.quote_type(x))
189            .collect::<Vec<_>>()
190            .join(", ")
191    }
192
193    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
194        let mut path = self.current_namespace.clone();
195        path.push(name.to_string());
196        if let Some(doc) = self.generator.config.comments.get(&path) {
197            writeln!(self.out, "\"\"\"{}\"\"\"", doc)?;
198        }
199        Ok(())
200    }
201
202    fn output_custom_code(&mut self) -> std::io::Result<bool> {
203        match self
204            .generator
205            .config
206            .custom_code
207            .get(&self.current_namespace)
208        {
209            Some(code) => {
210                writeln!(self.out, "\n{}", code)?;
211                Ok(true)
212            }
213            None => Ok(false),
214        }
215    }
216
217    fn output_fields(&mut self, fields: &[Named<Format>]) -> Result<()> {
218        if fields.is_empty() {
219            writeln!(self.out, "pass")?;
220            return Ok(());
221        }
222        for field in fields {
223            writeln!(
224                self.out,
225                "{}: {}",
226                field.name,
227                self.quote_type(&field.value)
228            )?;
229        }
230        Ok(())
231    }
232
233    fn output_variant(
234        &mut self,
235        base: &str,
236        name: &str,
237        index: u32,
238        variant: &VariantFormat,
239    ) -> Result<()> {
240        use VariantFormat::*;
241        let fields = match variant {
242            Unit => Vec::new(),
243            NewType(format) => vec![Named {
244                name: "value".to_string(),
245                value: format.as_ref().clone(),
246            }],
247            Tuple(formats) => vec![Named {
248                name: "value".to_string(),
249                value: Format::Tuple(formats.clone()),
250            }],
251            Struct(fields) => fields.clone(),
252            Variable(_) => panic!("incorrect value"),
253        };
254
255        // Regarding comments, we pretend the namespace is `[module, base, name]`.
256        writeln!(
257            self.out,
258            "\n@dataclass(frozen=True)\nclass {0}__{1}({0}):",
259            base, name
260        )?;
261        self.out.indent();
262        self.output_comment(name)?;
263        if self.generator.config.serialization {
264            writeln!(self.out, "INDEX = {}  # type: int", index)?;
265        }
266        self.current_namespace.push(name.to_string());
267        self.output_fields(&fields)?;
268        self.output_custom_code()?;
269        self.current_namespace.pop();
270        self.out.unindent();
271        writeln!(self.out)
272    }
273
274    fn output_enum_container(
275        &mut self,
276        name: &str,
277        variants: &BTreeMap<u32, Named<VariantFormat>>,
278    ) -> Result<()> {
279        writeln!(self.out, "\nclass {}:", name)?;
280        self.out.indent();
281        self.output_comment(name)?;
282        self.current_namespace.push(name.to_string());
283        if self.generator.config.serialization {
284            writeln!(
285                self.out,
286                "VARIANTS = []  # type: typing.Sequence[typing.Type[{}]]",
287                name
288            )?;
289            for encoding in &self.generator.config.encodings {
290                self.output_serialize_method_for_encoding(name, *encoding)?;
291                self.output_deserialize_method_for_encoding(name, *encoding)?;
292            }
293        }
294        let wrote_custom_code = self.output_custom_code()?;
295        if !self.generator.config.serialization && !wrote_custom_code {
296            writeln!(self.out, "pass")?;
297        }
298        writeln!(self.out)?;
299        self.out.unindent();
300
301        for (index, variant) in variants {
302            self.output_variant(name, &variant.name, *index, &variant.value)?;
303        }
304        self.current_namespace.pop();
305
306        if self.generator.config.serialization {
307            writeln!(
308                self.out,
309                "{}.VARIANTS = [\n{}]\n",
310                name,
311                variants
312                    .iter()
313                    .map(|(_, v)| format!("    {}__{},\n", name, v.name))
314                    .collect::<Vec<_>>()
315                    .join("")
316            )?;
317        }
318        Ok(())
319    }
320
321    fn output_serialize_method_for_encoding(
322        &mut self,
323        name: &str,
324        encoding: Encoding,
325    ) -> Result<()> {
326        writeln!(
327            self.out,
328            r#"
329def {0}_serialize(self) -> bytes:
330    return {0}.serialize(self, {1})"#,
331            encoding.name(),
332            name
333        )
334    }
335
336    fn output_deserialize_method_for_encoding(
337        &mut self,
338        name: &str,
339        encoding: Encoding,
340    ) -> Result<()> {
341        writeln!(
342            self.out,
343            r#"
344@staticmethod
345def {0}_deserialize(input: bytes) -> '{1}':
346    v, buffer = {0}.deserialize(input, {1})
347    if buffer:
348        raise st.DeserializationError("Some input bytes were not read");
349    return v"#,
350            encoding.name(),
351            name
352        )
353    }
354
355    fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
356        use ContainerFormat::*;
357        let fields = match format {
358            UnitStruct => Vec::new(),
359            NewTypeStruct(format) => vec![Named {
360                name: "value".to_string(),
361                value: format.as_ref().clone(),
362            }],
363            TupleStruct(formats) => vec![Named {
364                name: "value".to_string(),
365                value: Format::Tuple(formats.clone()),
366            }],
367            Struct(fields) => fields.clone(),
368            Enum(variants) => {
369                // Enum case.
370                self.output_enum_container(name, variants)?;
371                return Ok(());
372            }
373        };
374        // Struct case.
375        writeln!(self.out, "\n@dataclass(frozen=True)\nclass {}:", name)?;
376        self.out.indent();
377        self.output_comment(name)?;
378        self.current_namespace.push(name.to_string());
379        self.output_fields(&fields)?;
380        for encoding in &self.generator.config.encodings {
381            self.output_serialize_method_for_encoding(name, *encoding)?;
382            self.output_deserialize_method_for_encoding(name, *encoding)?;
383        }
384        self.output_custom_code()?;
385        self.current_namespace.pop();
386        self.out.unindent();
387        writeln!(self.out)
388    }
389}
390
391/// Installer for generated source files in Python.
392pub struct Installer {
393    install_dir: PathBuf,
394    serde_package_name: Option<String>,
395}
396
397impl Installer {
398    pub fn new(install_dir: PathBuf, serde_package_name: Option<String>) -> Self {
399        Installer {
400            install_dir,
401            serde_package_name,
402        }
403    }
404
405    fn create_module_init_file(&self, name: &str) -> Result<std::fs::File> {
406        let dir_path = self.install_dir.join(name);
407        std::fs::create_dir_all(&dir_path)?;
408        std::fs::File::create(dir_path.join("__init__.py"))
409    }
410
411    fn fix_serde_package(&self, content: &str) -> String {
412        match &self.serde_package_name {
413            None => content.into(),
414            Some(name) => content
415                .replace(
416                    "import serde_types",
417                    &format!("from {} import serde_types", name),
418                )
419                .replace(
420                    "import serde_binary",
421                    &format!("from {} import serde_binary", name),
422                ),
423        }
424    }
425}
426
427impl crate::SourceInstaller for Installer {
428    type Error = Box<dyn std::error::Error>;
429
430    fn install_module(
431        &self,
432        config: &crate::CodeGeneratorConfig,
433        registry: &Registry,
434    ) -> std::result::Result<(), Self::Error> {
435        let mut file = self.create_module_init_file(&config.module_name)?;
436        let generator =
437            CodeGenerator::new(config).with_serde_package_name(self.serde_package_name.clone());
438        generator.output(&mut file, registry)?;
439        Ok(())
440    }
441
442    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
443        let mut file = self.create_module_init_file("serde_types")?;
444        write!(
445            file,
446            "{}",
447            self.fix_serde_package(include_str!("../runtime/python/serde_types/__init__.py"))
448        )?;
449        let mut file = self.create_module_init_file("serde_binary")?;
450        write!(
451            file,
452            "{}",
453            self.fix_serde_package(include_str!("../runtime/python/serde_binary/__init__.py"))
454        )?;
455        Ok(())
456    }
457
458    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
459        let mut file = self.create_module_init_file("bincode")?;
460        write!(
461            file,
462            "{}",
463            self.fix_serde_package(include_str!("../runtime/python/bincode/__init__.py"))
464        )?;
465        Ok(())
466    }
467
468    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
469        let mut file = self.create_module_init_file("bcs")?;
470        write!(
471            file,
472            "{}",
473            self.fix_serde_package(include_str!("../runtime/python/bcs/__init__.py"))
474        )?;
475        Ok(())
476    }
477}