Skip to main content

serde_generate/
python3.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    indent::{IndentConfig, IndentedWriter},
6    CodeGeneratorConfig, Encoding,
7};
8use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
9use std::{
10    collections::{BTreeMap, HashMap},
11    io::{Result, Write},
12    path::PathBuf,
13};
14
15/// Main configuration object for code-generation in Python.
16pub struct CodeGenerator<'a> {
17    /// Language-independent configuration.
18    config: &'a CodeGeneratorConfig,
19    /// Whether the module providing Serde definitions is located within package.
20    serde_package_name: Option<String>,
21    /// Mapping from external type names to suitably qualified names (e.g. "MyClass" -> "my_module.MyClass").
22    /// Assumes suitable imports (e.g. "from my_package import my_module").
23    /// Derived from `config.external_definitions`.
24    external_qualified_names: HashMap<String, String>,
25}
26
27/// Shared state for the code generation of a Python source file.
28struct PythonEmitter<'a, T> {
29    /// Writer.
30    out: IndentedWriter<T>,
31    /// Generator.
32    generator: &'a CodeGenerator<'a>,
33    /// Current namespace (e.g. vec!["my_package", "my_module", "MyClass"])
34    current_namespace: Vec<String>,
35}
36
37impl<'a> CodeGenerator<'a> {
38    /// Create a Python code generator for the given config.
39    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
40        if config.enums.c_style {
41            panic!("Python 3 does not support generating c-style enums");
42        }
43        let mut external_qualified_names = HashMap::new();
44        for (module_path, names) in &config.external_definitions {
45            let module = {
46                let mut path = module_path.split('.').collect::<Vec<_>>();
47                if path.len() < 2 {
48                    module_path
49                } else {
50                    path.pop().unwrap()
51                }
52            };
53            for name in names {
54                external_qualified_names.insert(name.to_string(), format!("{module}.{name}"));
55            }
56        }
57        Self {
58            config,
59            serde_package_name: None,
60            external_qualified_names,
61        }
62    }
63
64    /// Whether the module providing Serde definitions is located within a package.
65    pub fn with_serde_package_name(mut self, serde_package_name: Option<String>) -> Self {
66        self.serde_package_name = serde_package_name;
67        self
68    }
69
70    /// Write container definitions in Python.
71    pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
72        let current_namespace = self
73            .config
74            .module_name
75            .split('.')
76            .map(String::from)
77            .collect();
78        let mut emitter = PythonEmitter {
79            out: IndentedWriter::new(out, IndentConfig::Space(4)),
80            generator: self,
81            current_namespace,
82        };
83        emitter.output_preamble()?;
84        for (name, format) in registry {
85            emitter.output_container(name, format)?;
86        }
87        Ok(())
88    }
89}
90
91impl<'a, T> PythonEmitter<'a, T>
92where
93    T: Write,
94{
95    fn quote_import(&self, module: &str) -> String {
96        let mut parts = module.split('.').collect::<Vec<_>>();
97        if parts.len() <= 1 {
98            format!("import {module}")
99        } else {
100            let module_name = parts.pop().unwrap();
101            format!("from {} import {}", parts.join("."), module_name)
102        }
103    }
104
105    fn output_preamble(&mut self) -> Result<()> {
106        let from_serde_package = match &self.generator.serde_package_name {
107            None => "".to_string(),
108            Some(name) => format!("from {name} "),
109        };
110        writeln!(
111            self.out,
112            r#"# pyre-strict
113from dataclasses import dataclass
114import typing
115{from_serde_package}import serde_types as st"#,
116        )?;
117        for encoding in &self.generator.config.encodings {
118            writeln!(self.out, "{}import {}", from_serde_package, encoding.name())?;
119        }
120        for module in self.generator.config.external_definitions.keys() {
121            writeln!(self.out, "{}\n", self.quote_import(module))?;
122        }
123        Ok(())
124    }
125
126    /// Compute a reference to the registry type `name`.
127    /// Use a qualified name in case of external definitions.
128    fn quote_qualified_name(&self, name: &str) -> String {
129        self.generator
130            .external_qualified_names
131            .get(name)
132            .cloned()
133            .unwrap_or_else(|| {
134                // Need quotes because of circular dependencies.
135                format!("\"{name}\"")
136            })
137    }
138
139    fn quote_type(&self, format: &Format) -> String {
140        use Format::*;
141        match format {
142            TypeName(x) => self.quote_qualified_name(x),
143            Unit => "st.unit".into(),
144            Bool => "bool".into(),
145            I8 => "st.int8".into(),
146            I16 => "st.int16".into(),
147            I32 => "st.int32".into(),
148            I64 => "st.int64".into(),
149            I128 => "st.int128".into(),
150            U8 => "st.uint8".into(),
151            U16 => "st.uint16".into(),
152            U32 => "st.uint32".into(),
153            U64 => "st.uint64".into(),
154            U128 => "st.uint128".into(),
155            F32 => "st.float32".into(),
156            F64 => "st.float64".into(),
157            Char => "st.char".into(),
158            Str => "str".into(),
159            Bytes => "bytes".into(),
160
161            Option(format) => format!("typing.Optional[{}]", self.quote_type(format)),
162            Seq(format) => format!("typing.Sequence[{}]", self.quote_type(format)),
163            Map { key, value } => format!(
164                "typing.Dict[{}, {}]",
165                self.quote_type(key),
166                self.quote_type(value)
167            ),
168            Tuple(formats) => {
169                if formats.is_empty() {
170                    "typing.Tuple[()]".into()
171                } else {
172                    format!("typing.Tuple[{}]", self.quote_types(formats))
173                }
174            }
175            TupleArray { content, size } => format!(
176                "typing.Tuple[{}]",
177                self.quote_types(&vec![content.as_ref().clone(); *size])
178            ), // Sadly, there are no fixed-size arrays in python.
179
180            Variable(_) => panic!("unexpected value"),
181        }
182    }
183
184    fn quote_types(&self, formats: &[Format]) -> String {
185        formats
186            .iter()
187            .map(|x| self.quote_type(x))
188            .collect::<Vec<_>>()
189            .join(", ")
190    }
191
192    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
193        let mut path = self.current_namespace.clone();
194        path.push(name.to_string());
195        if let Some(doc) = self.generator.config.comments.get(&path) {
196            writeln!(self.out, "\"\"\"{doc}\"\"\"")?;
197        }
198        Ok(())
199    }
200
201    fn output_custom_code(&mut self) -> std::io::Result<bool> {
202        match self
203            .generator
204            .config
205            .custom_code
206            .get(&self.current_namespace)
207        {
208            Some(code) => {
209                writeln!(self.out, "\n{code}")?;
210                Ok(true)
211            }
212            None => Ok(false),
213        }
214    }
215
216    fn output_fields(&mut self, fields: &[Named<Format>]) -> Result<()> {
217        if fields.is_empty() {
218            writeln!(self.out, "pass")?;
219            return Ok(());
220        }
221        for field in fields {
222            writeln!(
223                self.out,
224                "{}: {}",
225                field.name,
226                self.quote_type(&field.value)
227            )?;
228        }
229        Ok(())
230    }
231
232    fn output_variant(
233        &mut self,
234        base: &str,
235        name: &str,
236        index: u32,
237        variant: &VariantFormat,
238    ) -> Result<()> {
239        use VariantFormat::*;
240        let fields = match variant {
241            Unit => Vec::new(),
242            NewType(format) => vec![Named {
243                name: "value".to_string(),
244                value: format.as_ref().clone(),
245            }],
246            Tuple(formats) => vec![Named {
247                name: "value".to_string(),
248                value: Format::Tuple(formats.clone()),
249            }],
250            Struct(fields) => fields.clone(),
251            Variable(_) => panic!("incorrect value"),
252        };
253
254        // Regarding comments, we pretend the namespace is `[module, base, name]`.
255        writeln!(
256            self.out,
257            "\n@dataclass(frozen=True)\nclass {base}__{name}({base}):"
258        )?;
259        self.out.indent();
260        self.output_comment(name)?;
261        if self.generator.config.serialization {
262            writeln!(self.out, "INDEX = {index}  # type: int")?;
263        }
264        self.current_namespace.push(name.to_string());
265        self.output_fields(&fields)?;
266        self.output_custom_code()?;
267        self.current_namespace.pop();
268        self.out.unindent();
269        writeln!(self.out)
270    }
271
272    fn output_enum_container(
273        &mut self,
274        name: &str,
275        variants: &BTreeMap<u32, Named<VariantFormat>>,
276    ) -> Result<()> {
277        writeln!(self.out, "\nclass {name}:")?;
278        self.out.indent();
279        self.output_comment(name)?;
280        self.current_namespace.push(name.to_string());
281        if self.generator.config.serialization {
282            writeln!(
283                self.out,
284                "VARIANTS = []  # type: typing.Sequence[typing.Type[{name}]]"
285            )?;
286            for encoding in &self.generator.config.encodings {
287                self.output_serialize_method_for_encoding(name, *encoding)?;
288                self.output_deserialize_method_for_encoding(name, *encoding)?;
289            }
290        }
291        let wrote_custom_code = self.output_custom_code()?;
292        if !self.generator.config.serialization && !wrote_custom_code {
293            writeln!(self.out, "pass")?;
294        }
295        writeln!(self.out)?;
296        self.out.unindent();
297
298        for (index, variant) in variants {
299            self.output_variant(name, &variant.name, *index, &variant.value)?;
300        }
301        self.current_namespace.pop();
302
303        if self.generator.config.serialization {
304            writeln!(
305                self.out,
306                "{}.VARIANTS = [\n{}]\n",
307                name,
308                variants
309                    .values()
310                    .map(|v| format!("    {name}__{},\n", v.name))
311                    .collect::<Vec<_>>()
312                    .join("")
313            )?;
314        }
315        Ok(())
316    }
317
318    fn output_serialize_method_for_encoding(
319        &mut self,
320        name: &str,
321        encoding: Encoding,
322    ) -> Result<()> {
323        writeln!(
324            self.out,
325            r#"
326def {0}_serialize(self) -> bytes:
327    return {0}.serialize(self, {1})"#,
328            encoding.name(),
329            name
330        )
331    }
332
333    fn output_deserialize_method_for_encoding(
334        &mut self,
335        name: &str,
336        encoding: Encoding,
337    ) -> Result<()> {
338        writeln!(
339            self.out,
340            r#"
341@staticmethod
342def {0}_deserialize(input: bytes) -> '{1}':
343    v, buffer = {0}.deserialize(input, {1})
344    if buffer:
345        raise st.DeserializationError("Some input bytes were not read");
346    return v"#,
347            encoding.name(),
348            name
349        )
350    }
351
352    fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
353        use ContainerFormat::*;
354        let fields = match format {
355            UnitStruct => Vec::new(),
356            NewTypeStruct(format) => vec![Named {
357                name: "value".to_string(),
358                value: format.as_ref().clone(),
359            }],
360            TupleStruct(formats) => vec![Named {
361                name: "value".to_string(),
362                value: Format::Tuple(formats.clone()),
363            }],
364            Struct(fields) => fields.clone(),
365            Enum(variants) => {
366                // Enum case.
367                self.output_enum_container(name, variants)?;
368                return Ok(());
369            }
370        };
371        // Struct case.
372        writeln!(self.out, "\n@dataclass(frozen=True)\nclass {name}:")?;
373        self.out.indent();
374        self.output_comment(name)?;
375        self.current_namespace.push(name.to_string());
376        self.output_fields(&fields)?;
377        for encoding in &self.generator.config.encodings {
378            self.output_serialize_method_for_encoding(name, *encoding)?;
379            self.output_deserialize_method_for_encoding(name, *encoding)?;
380        }
381        self.output_custom_code()?;
382        self.current_namespace.pop();
383        self.out.unindent();
384        writeln!(self.out)
385    }
386}
387
388/// Installer for generated source files in Python.
389pub struct Installer {
390    install_dir: PathBuf,
391    serde_package_name: Option<String>,
392}
393
394impl Installer {
395    pub fn new(install_dir: PathBuf, serde_package_name: Option<String>) -> Self {
396        Installer {
397            install_dir,
398            serde_package_name,
399        }
400    }
401
402    fn create_module_init_file(&self, name: &str) -> Result<std::fs::File> {
403        let dir_path = self.install_dir.join(name);
404        std::fs::create_dir_all(&dir_path)?;
405        std::fs::File::create(dir_path.join("__init__.py"))
406    }
407
408    fn fix_serde_package(&self, content: &str) -> String {
409        match &self.serde_package_name {
410            None => content.into(),
411            Some(name) => content
412                .replace(
413                    "import serde_types",
414                    &format!("from {name} import serde_types"),
415                )
416                .replace(
417                    "import serde_binary",
418                    &format!("from {name} import serde_binary"),
419                ),
420        }
421    }
422}
423
424impl crate::SourceInstaller for Installer {
425    type Error = Box<dyn std::error::Error>;
426
427    fn install_module(
428        &self,
429        config: &crate::CodeGeneratorConfig,
430        registry: &Registry,
431    ) -> std::result::Result<(), Self::Error> {
432        let mut file = self.create_module_init_file(&config.module_name)?;
433        let generator =
434            CodeGenerator::new(config).with_serde_package_name(self.serde_package_name.clone());
435        generator.output(&mut file, registry)?;
436        Ok(())
437    }
438
439    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
440        let mut file = self.create_module_init_file("serde_types")?;
441        write!(
442            file,
443            "{}",
444            self.fix_serde_package(include_str!("../runtime/python/serde_types/__init__.py"))
445        )?;
446        let mut file = self.create_module_init_file("serde_binary")?;
447        write!(
448            file,
449            "{}",
450            self.fix_serde_package(include_str!("../runtime/python/serde_binary/__init__.py"))
451        )?;
452        Ok(())
453    }
454
455    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
456        let mut file = self.create_module_init_file("bincode")?;
457        write!(
458            file,
459            "{}",
460            self.fix_serde_package(include_str!("../runtime/python/bincode/__init__.py"))
461        )?;
462        Ok(())
463    }
464
465    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
466        let mut file = self.create_module_init_file("bcs")?;
467        write!(
468            file,
469            "{}",
470            self.fix_serde_package(include_str!("../runtime/python/bcs/__init__.py"))
471        )?;
472        Ok(())
473    }
474}