serde_generate/
cpp.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    analyzer,
6    indent::{IndentConfig, IndentedWriter},
7    CodeGeneratorConfig, Encoding,
8};
9use heck::CamelCase;
10use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
11use std::{
12    collections::{BTreeMap, HashMap, HashSet},
13    io::{Result, Write},
14    path::PathBuf,
15};
16
17/// Main configuration object for code-generation in C++.
18pub struct CodeGenerator<'a> {
19    /// Language-independent configuration.
20    config: &'a CodeGeneratorConfig,
21    /// Mapping from external type names to suitably qualified names (e.g. "MyClass" -> "name::MyClass").
22    /// Derived from `config.external_definitions`.
23    external_qualified_names: HashMap<String, String>,
24}
25
26/// Shared state for the code generation of a C++ source file.
27struct CppEmitter<'a, T> {
28    /// Writer.
29    out: IndentedWriter<T>,
30    /// Generator.
31    generator: &'a CodeGenerator<'a>,
32    /// Track which type names have been declared so far. (Used to add forward declarations.)
33    known_names: HashSet<&'a str>,
34    /// Track which definitions have a known size. (Used to add shared pointers.)
35    known_sizes: HashSet<&'a str>,
36    /// Current namespace (e.g. vec!["name", "MyClass"])
37    current_namespace: Vec<String>,
38}
39
40impl<'a> CodeGenerator<'a> {
41    /// Create a C++ code generator for the given config.
42    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
43        if config.c_style_enums {
44            panic!("C++ does not support generating c-style enums");
45        }
46        let mut external_qualified_names = HashMap::new();
47        for (namespace, names) in &config.external_definitions {
48            for name in names {
49                external_qualified_names
50                    .insert(name.to_string(), format!("{}::{}", namespace, name));
51            }
52        }
53        Self {
54            config,
55            external_qualified_names,
56        }
57    }
58
59    pub fn output(
60        &self,
61        out: &mut dyn Write,
62        registry: &Registry,
63    ) -> std::result::Result<(), Box<dyn std::error::Error>> {
64        let current_namespace = self
65            .config
66            .module_name
67            .split("::")
68            .map(String::from)
69            .collect();
70        let mut emitter = CppEmitter {
71            out: IndentedWriter::new(out, IndentConfig::Space(4)),
72            generator: self,
73            known_names: HashSet::new(),
74            known_sizes: HashSet::new(),
75            current_namespace,
76        };
77
78        emitter.output_preamble()?;
79        emitter.output_open_namespace()?;
80
81        let dependencies = analyzer::get_dependency_map(registry)?;
82        let entries = analyzer::best_effort_topological_sort(&dependencies);
83
84        for name in entries {
85            for dependency in &dependencies[name] {
86                if !emitter.known_names.contains(dependency) {
87                    emitter.output_container_forward_definition(dependency)?;
88                    emitter.known_names.insert(*dependency);
89                }
90            }
91            let format = &registry[name];
92            emitter.output_container(name, format)?;
93            emitter.known_sizes.insert(name);
94            emitter.known_names.insert(name);
95        }
96
97        emitter.output_close_namespace()?;
98        writeln!(emitter.out)?;
99        for (name, format) in registry {
100            emitter.output_container_traits(name, format)?;
101        }
102        Ok(())
103    }
104}
105
106impl<'a, T> CppEmitter<'a, T>
107where
108    T: std::io::Write,
109{
110    fn output_preamble(&mut self) -> Result<()> {
111        writeln!(
112            self.out,
113            r#"#pragma once
114
115#include "serde.hpp""#
116        )?;
117        if self.generator.config.serialization {
118            for encoding in &self.generator.config.encodings {
119                writeln!(self.out, "#include \"{}.hpp\"", encoding.name())?;
120            }
121        }
122        Ok(())
123    }
124
125    fn output_open_namespace(&mut self) -> Result<()> {
126        writeln!(
127            self.out,
128            "\nnamespace {} {{",
129            self.generator.config.module_name
130        )?;
131        self.out.indent();
132        Ok(())
133    }
134
135    fn output_close_namespace(&mut self) -> Result<()> {
136        self.out.unindent();
137        writeln!(
138            self.out,
139            "\n}} // end of namespace {}",
140            self.generator.config.module_name
141        )?;
142        Ok(())
143    }
144
145    fn enter_class(&mut self, name: &str) {
146        self.out.indent();
147        self.current_namespace.push(name.to_string());
148    }
149
150    fn leave_class(&mut self) {
151        self.out.unindent();
152        self.current_namespace.pop();
153    }
154
155    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
156        let mut path = self.current_namespace.clone();
157        path.push(name.to_string());
158        if let Some(doc) = self.generator.config.comments.get(&path) {
159            let text = textwrap::indent(doc, "/// ").replace("\n\n", "\n///\n");
160            write!(self.out, "{}", text)?;
161        }
162        Ok(())
163    }
164
165    fn output_custom_code(&mut self) -> std::io::Result<()> {
166        if let Some(code) = self
167            .generator
168            .config
169            .custom_code
170            .get(&self.current_namespace)
171        {
172            write!(self.out, "\n{}", code)?;
173        }
174        Ok(())
175    }
176
177    /// Compute a fully qualified reference to the container type `name`.
178    fn quote_qualified_name(&self, name: &str) -> String {
179        self.generator
180            .external_qualified_names
181            .get(name)
182            .cloned()
183            .unwrap_or_else(|| format!("{}::{}", self.generator.config.module_name, name))
184    }
185
186    fn quote_type(&self, format: &Format, require_known_size: bool) -> String {
187        use Format::*;
188        match format {
189            TypeName(x) => {
190                let qname = self.quote_qualified_name(x);
191                if require_known_size && !self.known_sizes.contains(x.as_str()) {
192                    // Cannot use unique_ptr because we need a copy constructor (e.g. for vectors)
193                    // and in-depth equality.
194                    format!("serde::value_ptr<{}>", qname)
195                } else {
196                    qname
197                }
198            }
199            Unit => "std::monostate".into(),
200            Bool => "bool".into(),
201            I8 => "int8_t".into(),
202            I16 => "int16_t".into(),
203            I32 => "int32_t".into(),
204            I64 => "int64_t".into(),
205            I128 => "serde::int128_t".into(),
206            U8 => "uint8_t".into(),
207            U16 => "uint16_t".into(),
208            U32 => "uint32_t".into(),
209            U64 => "uint64_t".into(),
210            U128 => "serde::uint128_t".into(),
211            F32 => "float".into(),
212            F64 => "double".into(),
213            Char => "char32_t".into(),
214            Str => "std::string".into(),
215            Bytes => "std::vector<uint8_t>".into(),
216
217            Option(format) => format!(
218                "std::optional<{}>",
219                self.quote_type(format, require_known_size)
220            ),
221            Seq(format) => format!("std::vector<{}>", self.quote_type(format, false)),
222            Map { key, value } => format!(
223                "std::map<{}, {}>",
224                self.quote_type(key, false),
225                self.quote_type(value, false)
226            ),
227            Tuple(formats) => format!(
228                "std::tuple<{}>",
229                self.quote_types(formats, require_known_size)
230            ),
231            TupleArray { content, size } => format!(
232                "std::array<{}, {}>",
233                self.quote_type(content, require_known_size),
234                *size
235            ),
236
237            Variable(_) => panic!("unexpected value"),
238        }
239    }
240
241    fn quote_types(&self, formats: &[Format], require_known_size: bool) -> String {
242        formats
243            .iter()
244            .map(|x| self.quote_type(x, require_known_size))
245            .collect::<Vec<_>>()
246            .join(", ")
247    }
248
249    fn output_struct_or_variant_container(
250        &mut self,
251        name: &str,
252        fields: &[Named<Format>],
253    ) -> Result<()> {
254        writeln!(self.out)?;
255        self.output_comment(name)?;
256        writeln!(self.out, "struct {} {{", name)?;
257        self.enter_class(name);
258        for field in fields {
259            self.output_comment(&field.name)?;
260            writeln!(
261                self.out,
262                "{} {};",
263                self.quote_type(&field.value, true),
264                field.name
265            )?;
266        }
267        if !fields.is_empty() {
268            writeln!(self.out)?;
269        }
270        self.output_class_method_declarations(name)?;
271        self.output_custom_code()?;
272        self.leave_class();
273        writeln!(self.out, "}};")
274    }
275
276    fn output_variant(&mut self, name: &str, variant: &VariantFormat) -> Result<()> {
277        use VariantFormat::*;
278        let fields = match variant {
279            Unit => Vec::new(),
280            NewType(format) => vec![Named {
281                name: "value".to_string(),
282                value: format.as_ref().clone(),
283            }],
284            Tuple(formats) => vec![Named {
285                name: "value".to_string(),
286                value: Format::Tuple(formats.clone()),
287            }],
288            Struct(fields) => fields.clone(),
289            Variable(_) => panic!("incorrect value"),
290        };
291        self.output_struct_or_variant_container(name, &fields)
292    }
293
294    fn output_container_forward_definition(&mut self, name: &str) -> Result<()> {
295        writeln!(self.out, "\nstruct {};", name)
296    }
297
298    fn output_enum_container(
299        &mut self,
300        name: &str,
301        variants: &BTreeMap<u32, Named<VariantFormat>>,
302    ) -> Result<()> {
303        writeln!(self.out)?;
304        self.output_comment(name)?;
305        writeln!(self.out, "struct {} {{", name)?;
306        self.enter_class(name);
307        for (expected_index, (index, variant)) in variants.iter().enumerate() {
308            assert_eq!(*index, expected_index as u32);
309            self.output_variant(&variant.name, &variant.value)?;
310        }
311        writeln!(
312            self.out,
313            "\nstd::variant<{}> value;",
314            variants
315                .iter()
316                .map(|(_, v)| v.name.clone())
317                .collect::<Vec<_>>()
318                .join(", "),
319        )?;
320        writeln!(self.out)?;
321        self.output_class_method_declarations(name)?;
322        self.output_custom_code()?;
323        self.leave_class();
324        writeln!(self.out, "}};")
325    }
326
327    fn output_class_method_declarations(&mut self, name: &str) -> Result<()> {
328        writeln!(
329            self.out,
330            "friend bool operator==(const {}&, const {}&);",
331            name, name
332        )?;
333        if self.generator.config.serialization {
334            for encoding in &self.generator.config.encodings {
335                writeln!(
336                    self.out,
337                    "std::vector<uint8_t> {}Serialize() const;",
338                    encoding.name()
339                )?;
340                writeln!(
341                    self.out,
342                    "static {} {}Deserialize(std::vector<uint8_t>);",
343                    name,
344                    encoding.name()
345                )?;
346            }
347        }
348        Ok(())
349    }
350
351    fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
352        use ContainerFormat::*;
353        let fields = match format {
354            UnitStruct => Vec::new(),
355            NewTypeStruct(format) => vec![Named {
356                name: "value".to_string(),
357                value: format.as_ref().clone(),
358            }],
359            TupleStruct(formats) => vec![Named {
360                name: "value".to_string(),
361                value: Format::Tuple(formats.clone()),
362            }],
363            Struct(fields) => fields.clone(),
364            Enum(variants) => {
365                self.output_enum_container(name, variants)?;
366                return Ok(());
367            }
368        };
369        self.output_struct_or_variant_container(name, &fields)
370    }
371
372    fn output_struct_equality_test(&mut self, name: &str, fields: &[&str]) -> Result<()> {
373        writeln!(
374            self.out,
375            "\ninline bool operator==(const {0} &lhs, const {0} &rhs) {{",
376            name,
377        )?;
378        self.out.indent();
379        for field in fields {
380            writeln!(
381                self.out,
382                "if (!(lhs.{0} == rhs.{0})) {{ return false; }}",
383                field,
384            )?;
385        }
386        writeln!(self.out, "return true;")?;
387        self.out.unindent();
388        writeln!(self.out, "}}")
389    }
390
391    fn output_struct_serialize_for_encoding(
392        &mut self,
393        name: &str,
394        encoding: Encoding,
395    ) -> Result<()> {
396        writeln!(
397            self.out,
398            r#"
399inline std::vector<uint8_t> {}::{}Serialize() const {{
400    auto serializer = serde::{}Serializer();
401    serde::Serializable<{}>::serialize(*this, serializer);
402    return std::move(serializer).bytes();
403}}"#,
404            name,
405            encoding.name(),
406            encoding.name().to_camel_case(),
407            name
408        )
409    }
410
411    fn output_struct_deserialize_for_encoding(
412        &mut self,
413        name: &str,
414        encoding: Encoding,
415    ) -> Result<()> {
416        writeln!(
417            self.out,
418            r#"
419inline {} {}::{}Deserialize(std::vector<uint8_t> input) {{
420    auto deserializer = serde::{}Deserializer(input);
421    auto value = serde::Deserializable<{}>::deserialize(deserializer);
422    if (deserializer.get_buffer_offset() < input.size()) {{
423        throw serde::deserialization_error("Some input bytes were not read");
424    }}
425    return value;
426}}"#,
427            name,
428            name,
429            encoding.name(),
430            encoding.name().to_camel_case(),
431            name,
432        )
433    }
434
435    fn output_struct_serializable(
436        &mut self,
437        name: &str,
438        fields: &[&str],
439        is_container: bool,
440    ) -> Result<()> {
441        writeln!(
442            self.out,
443            r#"
444template <>
445template <typename Serializer>
446void serde::Serializable<{0}>::serialize(const {0} &obj, Serializer &serializer) {{"#,
447            name,
448        )?;
449        self.out.indent();
450        if is_container {
451            writeln!(self.out, "serializer.increase_container_depth();")?;
452        }
453        for field in fields {
454            writeln!(
455                self.out,
456                "serde::Serializable<decltype(obj.{0})>::serialize(obj.{0}, serializer);",
457                field,
458            )?;
459        }
460        if is_container {
461            writeln!(self.out, "serializer.decrease_container_depth();")?;
462        }
463        self.out.unindent();
464        writeln!(self.out, "}}")
465    }
466
467    fn output_struct_deserializable(
468        &mut self,
469        name: &str,
470        fields: &[&str],
471        is_container: bool,
472    ) -> Result<()> {
473        writeln!(
474            self.out,
475            r#"
476template <>
477template <typename Deserializer>
478{0} serde::Deserializable<{0}>::deserialize(Deserializer &deserializer) {{"#,
479            name,
480        )?;
481        self.out.indent();
482        if is_container {
483            writeln!(self.out, "deserializer.increase_container_depth();")?;
484        }
485        writeln!(self.out, "{} obj;", name)?;
486        for field in fields {
487            writeln!(
488                self.out,
489                "obj.{0} = serde::Deserializable<decltype(obj.{0})>::deserialize(deserializer);",
490                field,
491            )?;
492        }
493        if is_container {
494            writeln!(self.out, "deserializer.decrease_container_depth();")?;
495        }
496        writeln!(self.out, "return obj;")?;
497        self.out.unindent();
498        writeln!(self.out, "}}")
499    }
500
501    fn output_struct_traits(
502        &mut self,
503        name: &str,
504        fields: &[&str],
505        is_container: bool,
506    ) -> Result<()> {
507        self.output_open_namespace()?;
508        self.output_struct_equality_test(name, fields)?;
509        if self.generator.config.serialization {
510            for encoding in &self.generator.config.encodings {
511                self.output_struct_serialize_for_encoding(name, *encoding)?;
512                self.output_struct_deserialize_for_encoding(name, *encoding)?;
513            }
514        }
515        self.output_close_namespace()?;
516        let namespaced_name = self.quote_qualified_name(name);
517        if self.generator.config.serialization {
518            self.output_struct_serializable(&namespaced_name, fields, is_container)?;
519            self.output_struct_deserializable(&namespaced_name, fields, is_container)?;
520        }
521        Ok(())
522    }
523
524    fn get_variant_fields(format: &VariantFormat) -> Vec<&str> {
525        use VariantFormat::*;
526        match format {
527            Unit => Vec::new(),
528            NewType(_format) => vec!["value"],
529            Tuple(_formats) => vec!["value"],
530            Struct(fields) => fields
531                .iter()
532                .map(|field| field.name.as_str())
533                .collect::<Vec<_>>(),
534            Variable(_) => panic!("incorrect value"),
535        }
536    }
537
538    fn output_container_traits(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
539        use ContainerFormat::*;
540        match format {
541            UnitStruct => self.output_struct_traits(name, &[], true),
542            NewTypeStruct(_format) => self.output_struct_traits(name, &["value"], true),
543            TupleStruct(_formats) => self.output_struct_traits(name, &["value"], true),
544            Struct(fields) => self.output_struct_traits(
545                name,
546                &fields
547                    .iter()
548                    .map(|field| field.name.as_str())
549                    .collect::<Vec<_>>(),
550                true,
551            ),
552            Enum(variants) => {
553                self.output_struct_traits(name, &["value"], true)?;
554                for variant in variants.values() {
555                    self.output_struct_traits(
556                        &format!("{}::{}", name, variant.name),
557                        &Self::get_variant_fields(&variant.value),
558                        false,
559                    )?;
560                }
561                Ok(())
562            }
563        }
564    }
565}
566
567/// Installer for generated source files in C++.
568pub struct Installer {
569    install_dir: PathBuf,
570}
571
572impl Installer {
573    pub fn new(install_dir: PathBuf) -> Self {
574        Installer { install_dir }
575    }
576
577    fn create_header_file(&self, name: &str) -> Result<std::fs::File> {
578        let dir_path = &self.install_dir;
579        std::fs::create_dir_all(dir_path)?;
580        std::fs::File::create(dir_path.join(name.to_string() + ".hpp"))
581    }
582}
583
584impl crate::SourceInstaller for Installer {
585    type Error = Box<dyn std::error::Error>;
586
587    fn install_module(
588        &self,
589        config: &crate::CodeGeneratorConfig,
590        registry: &Registry,
591    ) -> std::result::Result<(), Self::Error> {
592        let mut file = self.create_header_file(&config.module_name)?;
593        let generator = CodeGenerator::new(config);
594        generator.output(&mut file, registry)
595    }
596
597    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
598        let mut file = self.create_header_file("serde")?;
599        write!(file, "{}", include_str!("../runtime/cpp/serde.hpp"))?;
600        let mut file = self.create_header_file("binary")?;
601        write!(file, "{}", include_str!("../runtime/cpp/binary.hpp"))?;
602        Ok(())
603    }
604
605    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
606        let mut file = self.create_header_file("bincode")?;
607        write!(file, "{}", include_str!("../runtime/cpp/bincode.hpp"))?;
608        Ok(())
609    }
610
611    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
612        let mut file = self.create_header_file("bcs")?;
613        write!(file, "{}", include_str!("../runtime/cpp/bcs.hpp"))?;
614        Ok(())
615    }
616}