Skip to main content

tfschema_bindgen/
emit.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//!
5//! Stripped down version of serde reflection's code generator for Rust.
6//! Main changes are around supporting qualified names for Regitry entries as well
7//! as well as customing stubs generation.
8//!
9use crate::config::CodeGeneratorConfig;
10use serde_generate::indent::{IndentConfig, IndentedWriter};
11use serde_reflection::{ContainerFormat, Format, Named, VariantFormat};
12use std::borrow::Cow;
13use std::collections::{BTreeMap, BTreeSet, HashSet};
14use std::io::{Result, Write};
15
16/// A map of container formats indexed by a qualified name
17pub type QualifiedName = (Option<String>, String);
18pub type Registry = BTreeMap<QualifiedName, ContainerFormat>;
19
20/// Main configuration object for code-generation in Rust.
21pub struct CodeGenerator<'a> {
22    /// Language-independent configuration.
23    config: &'a CodeGeneratorConfig,
24    /// Which derive macros should be added (independently from serialization).
25    derive_macros: Vec<String>,
26    /// Additional block of text added before each new container definition.
27    custom_derive_block: Option<String>,
28    /// Whether definitions and fields should be marked as `pub`.
29    track_visibility: bool,
30}
31
32/// Shared state for the code generation of a Rust source file.
33struct RustEmitter<'a, T> {
34    /// Writer.
35    out: IndentedWriter<T>,
36    /// Generator.
37    generator: &'a CodeGenerator<'a>,
38    /// Track which definitions have a known size. (Used to add `Box` types.)
39    known_sizes: Cow<'a, HashSet<&'a str>>,
40    /// Current namespace (e.g. vec!["my_package", "my_module", "MyClass"])
41    current_namespace: Vec<String>,
42}
43
44impl<'a> CodeGenerator<'a> {
45    /// Create a Rust code generator for the given config.
46    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
47        Self {
48            config,
49            derive_macros: vec!["Clone", "Debug", "PartialEq", "PartialOrd"]
50                .into_iter()
51                .map(String::from)
52                .collect(),
53            custom_derive_block: None,
54            track_visibility: true,
55        }
56    }
57
58    /// Which derive macros should be added (independently from serialization).
59    pub fn with_derive_macros(mut self, derive_macros: Vec<String>) -> Self {
60        self.derive_macros = derive_macros;
61        self
62    }
63
64    /// Additional block of text added after `derive_macros` (if any), before each new
65    /// container definition.
66    pub fn with_custom_derive_block(mut self, custom_derive_block: Option<String>) -> Self {
67        self.custom_derive_block = custom_derive_block;
68        self
69    }
70
71    /// Whether definitions and fields should be marked as `pub`.
72    pub fn with_track_visibility(mut self, track_visibility: bool) -> Self {
73        self.track_visibility = track_visibility;
74        self
75    }
76
77    /// Write container definitions in Rust.
78    pub fn output(
79        &self,
80        out: &mut dyn Write,
81        registry: &Registry,
82    ) -> std::result::Result<(), Box<dyn std::error::Error>> {
83        let external_names: BTreeSet<String> = self
84            .config
85            .external_definitions
86            .values()
87            .cloned()
88            .flatten()
89            .collect();
90
91        let known_sizes = external_names
92            .iter()
93            .map(<String as std::ops::Deref>::deref)
94            .collect::<HashSet<_>>();
95
96        let current_namespace = self
97            .config
98            .module_name
99            .split('.')
100            .map(String::from)
101            .collect();
102        let mut emitter = RustEmitter {
103            out: IndentedWriter::new(out, IndentConfig::Space(4)),
104            generator: self,
105            known_sizes: Cow::Owned(known_sizes),
106            current_namespace,
107        };
108
109        emitter.output_preamble()?;
110        for ((ns, name), format) in registry {
111            emitter.output_container(ns, name, format)?;
112            emitter.known_sizes.to_mut().insert(name);
113        }
114        Ok(())
115    }
116}
117
118impl<'a, T> RustEmitter<'a, T>
119where
120    T: std::io::Write,
121{
122    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
123        let mut path = self.current_namespace.clone();
124        path.push(name.to_string());
125        if let Some(doc) = self.generator.config.comments.get(&path) {
126            let text = textwrap::indent(doc, "/// ").replace("\n\n", "\n///\n");
127            write!(self.out, "\n{}", text)?;
128        }
129        Ok(())
130    }
131
132    fn output_preamble(&mut self) -> Result<()> {
133        let external_names = self
134            .generator
135            .config
136            .external_definitions
137            .values()
138            .cloned()
139            .flatten()
140            .collect::<HashSet<_>>();
141        writeln!(self.out, "#![allow(unused_imports, non_snake_case, non_camel_case_types, non_upper_case_globals)]")?;
142        if !external_names.contains("Map") {
143            writeln!(self.out, "use std::collections::BTreeMap as Map;")?;
144        }
145        writeln!(self.out, "use serde::{{Serialize, Deserialize}};")?;
146        if !external_names.contains("Bytes") {
147            writeln!(self.out, "use serde_bytes::ByteBuf as Bytes;")?;
148        }
149        for (module, definitions) in &self.generator.config.external_definitions {
150            // Skip the empty module name.
151            if !module.is_empty() {
152                writeln!(
153                    self.out,
154                    "use {}::{{{}}};",
155                    module,
156                    definitions.to_vec().join(", "),
157                )?;
158            }
159        }
160        writeln!(self.out)?;
161        Ok(())
162    }
163
164    fn output_field_annotation(&mut self, format: &Format) -> std::io::Result<()> {
165        use Format::*;
166        match format {
167            Str => writeln!(
168                self.out,
169                "#[serde(skip_serializing_if = \"String::is_empty\")]"
170            )?,
171            Option(_) => writeln!(
172                self.out,
173                "#[serde(skip_serializing_if = \"Option::is_none\")]"
174            )?,
175            Seq(_) => writeln!(
176                self.out,
177                "#[serde(skip_serializing_if = \"Vec::is_empty\")]"
178            )?,
179            _ => (),
180        }
181
182        Ok(())
183    }
184
185    fn quote_type(format: &Format, known_sizes: Option<&HashSet<&str>>) -> String {
186        use Format::*;
187        match format {
188            TypeName(x) => {
189                if let Some(set) = known_sizes {
190                    if !set.contains(x.as_str()) && !x.as_str().starts_with("Vec") {
191                        return format!("Box<{}>", x);
192                    }
193                }
194                x.to_string()
195            }
196            Unit => "()".into(),
197            Bool => "bool".into(),
198            I8 => "i8".into(),
199            I16 => "i16".into(),
200            I32 => "i32".into(),
201            I64 => "i64".into(),
202            I128 => "i128".into(),
203            U8 => "u8".into(),
204            U16 => "u16".into(),
205            U32 => "u32".into(),
206            U64 => "u64".into(),
207            U128 => "u128".into(),
208            F32 => "f32".into(),
209            F64 => "f64".into(),
210            Char => "char".into(),
211            Str => "String".into(),
212            Bytes => "Bytes".into(),
213
214            Option(format) => format!("Option<{}>", Self::quote_type(format, known_sizes)),
215            Seq(format) => format!("Vec<{}>", Self::quote_type(format, None)),
216            Map { key, value } => format!(
217                "Map<{}, {}>",
218                Self::quote_type(key, None),
219                Self::quote_type(value, None)
220            ),
221            Tuple(formats) => format!("({})", Self::quote_types(formats, known_sizes)),
222            TupleArray { content, size } => {
223                format!("[{}; {}]", Self::quote_type(content, known_sizes), *size)
224            }
225
226            Variable(_) => panic!("unexpected value"),
227        }
228    }
229
230    fn quote_types(formats: &[Format], known_sizes: Option<&HashSet<&str>>) -> String {
231        formats
232            .iter()
233            .map(|x| Self::quote_type(x, known_sizes))
234            .collect::<Vec<_>>()
235            .join(", ")
236    }
237
238    fn output_fields(&mut self, base: &[&str], fields: &[Named<Format>]) -> Result<()> {
239        // Do not add 'pub' within variants.
240        let prefix = if base.len() <= 1 && self.generator.track_visibility {
241            "pub "
242        } else {
243            ""
244        };
245        for field in fields {
246            self.output_comment(&field.name)?;
247            self.output_field_annotation(&field.value)?;
248            writeln!(
249                self.out,
250                "{}{}: {},",
251                prefix,
252                field.name,
253                Self::quote_type(&field.value, Some(&self.known_sizes)),
254            )?;
255        }
256        Ok(())
257    }
258
259    fn output_variant(&mut self, base: &str, name: &str, variant: &VariantFormat) -> Result<()> {
260        self.output_comment(name)?;
261        use VariantFormat::*;
262        match variant {
263            Unit => writeln!(self.out, "{},", name),
264            NewType(format) => writeln!(
265                self.out,
266                "{}({}),",
267                name,
268                Self::quote_type(format, Some(&self.known_sizes))
269            ),
270            Tuple(formats) => writeln!(
271                self.out,
272                "{}({}),",
273                name,
274                Self::quote_types(formats, Some(&self.known_sizes))
275            ),
276            Struct(fields) => {
277                writeln!(self.out, "{} {{", name)?;
278                self.current_namespace.push(name.to_string());
279                self.out.indent();
280                self.output_fields(&[base, name], fields)?;
281                self.out.unindent();
282                self.current_namespace.pop();
283                writeln!(self.out, "}},")
284            }
285            Variable(_) => panic!("incorrect value"),
286        }
287    }
288
289    fn output_variants(
290        &mut self,
291        base: &str,
292        variants: &BTreeMap<u32, Named<VariantFormat>>,
293    ) -> Result<()> {
294        for (expected_index, (index, variant)) in variants.iter().enumerate() {
295            assert_eq!(*index, expected_index as u32);
296            self.output_variant(base, &variant.name, &variant.value)?;
297        }
298        Ok(())
299    }
300
301    fn output_container(
302        &mut self,
303        namespace: &Option<String>,
304        name: &str,
305        format: &ContainerFormat,
306    ) -> Result<()> {
307        self.output_comment(name)?;
308        let mut derive_macros = self.generator.derive_macros.clone();
309        derive_macros.push("Serialize".to_string());
310        derive_macros.push("Deserialize".to_string());
311        let mut prefix = String::new();
312        if !derive_macros.is_empty() {
313            prefix.push_str(&format!("#[derive({})]\n", derive_macros.join(", ")));
314        }
315        if let Some(text) = &self.generator.custom_derive_block {
316            prefix.push_str(text);
317            prefix.push('\n');
318        }
319
320        use ContainerFormat::*;
321        match format {
322            UnitStruct => writeln!(self.out, "{}struct {};\n", prefix, name),
323            NewTypeStruct(format) => writeln!(
324                self.out,
325                "{}struct {}({}{});\n",
326                prefix,
327                name,
328                if self.generator.track_visibility {
329                    "pub "
330                } else {
331                    ""
332                },
333                Self::quote_type(format, Some(&self.known_sizes))
334            ),
335            TupleStruct(formats) => writeln!(
336                self.out,
337                "{}struct {}({});\n",
338                prefix,
339                name,
340                Self::quote_types(formats, Some(&self.known_sizes))
341            ),
342            Struct(fields) => {
343                let mut struct_name = name.to_string();
344                prefix.clear();
345                derive_macros.push("Default".to_string());
346                prefix.push_str(&format!("#[derive({})]\n", derive_macros.join(", ")));
347
348                if let Some(ns) = namespace {
349                    prefix.push_str(&format!("#[serde(rename = \"{}\")]\n", name));
350                    struct_name = format!("{}_{}", ns, name)
351                }
352
353                if self.generator.track_visibility {
354                    prefix.push_str("pub ");
355                }
356
357                writeln!(self.out, "{}struct {} {{", prefix, struct_name)?;
358                self.current_namespace.push(name.to_string());
359                self.out.indent();
360                self.output_fields(&[name], fields)?;
361                self.out.unindent();
362                self.current_namespace.pop();
363                writeln!(self.out, "}}\n")
364            }
365            Enum(variants) => {
366                if self.generator.track_visibility {
367                    prefix.push_str("pub ");
368                }
369
370                writeln!(self.out, "{}enum {} {{", prefix, name)?;
371                self.current_namespace.push(name.to_string());
372                self.out.indent();
373                self.output_variants(name, variants)?;
374                self.out.unindent();
375                self.current_namespace.pop();
376                writeln!(self.out, "}}\n")
377            }
378        }
379    }
380}