Skip to main content

serde_generate/
cpp.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    analyzer,
6    indent::{IndentConfig, IndentedWriter},
7    CodeGeneratorConfig, Encoding,
8};
9use heck::CamelCase;
10use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
11use std::{
12    collections::{BTreeMap, HashMap, HashSet},
13    io::{Result, Write},
14    path::PathBuf,
15};
16
17/// Main configuration object for code-generation in C++.
18pub struct CodeGenerator<'a> {
19    /// Language-independent configuration.
20    config: &'a CodeGeneratorConfig,
21    /// Mapping from external type names to suitably qualified names (e.g. "MyClass" -> "name::MyClass").
22    /// Derived from `config.external_definitions`.
23    external_qualified_names: HashMap<String, String>,
24}
25
26/// Shared state for the code generation of a C++ source file.
27struct CppEmitter<'a, T> {
28    /// Writer.
29    out: IndentedWriter<T>,
30    /// Generator.
31    generator: &'a CodeGenerator<'a>,
32    /// Track which type names have been declared so far. (Used to add forward declarations.)
33    known_names: HashSet<&'a str>,
34    /// Track which definitions have a known size. (Used to add shared pointers.)
35    known_sizes: HashSet<&'a str>,
36    /// Current namespace (e.g. vec!["name", "MyClass"])
37    current_namespace: Vec<String>,
38}
39
40impl<'a> CodeGenerator<'a> {
41    /// Create a C++ code generator for the given config.
42    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
43        if config.enums.c_style {
44            panic!("C++ does not support generating c-style enums");
45        }
46        let mut external_qualified_names = HashMap::new();
47        for (namespace, names) in &config.external_definitions {
48            for name in names {
49                external_qualified_names.insert(name.to_string(), format!("{namespace}::{name}"));
50            }
51        }
52        Self {
53            config,
54            external_qualified_names,
55        }
56    }
57
58    pub fn output(
59        &self,
60        out: &mut dyn Write,
61        registry: &Registry,
62    ) -> std::result::Result<(), Box<dyn std::error::Error>> {
63        let current_namespace = self
64            .config
65            .module_name
66            .split("::")
67            .map(String::from)
68            .collect();
69        let mut emitter = CppEmitter {
70            out: IndentedWriter::new(out, IndentConfig::Space(4)),
71            generator: self,
72            known_names: HashSet::new(),
73            known_sizes: HashSet::new(),
74            current_namespace,
75        };
76
77        emitter.output_preamble()?;
78        emitter.output_open_namespace()?;
79
80        let dependencies = analyzer::get_dependency_map(registry)?;
81        let entries = analyzer::best_effort_topological_sort(&dependencies);
82
83        for name in entries {
84            for dependency in &dependencies[name] {
85                if !emitter.known_names.contains(dependency) {
86                    emitter.output_container_forward_definition(dependency)?;
87                    emitter.known_names.insert(*dependency);
88                }
89            }
90            let format = &registry[name];
91            emitter.output_container(name, format)?;
92            emitter.known_sizes.insert(name);
93            emitter.known_names.insert(name);
94        }
95
96        emitter.output_close_namespace()?;
97        writeln!(emitter.out)?;
98        for (name, format) in registry {
99            emitter.output_container_traits(name, format)?;
100        }
101        Ok(())
102    }
103}
104
105impl<'a, T> CppEmitter<'a, T>
106where
107    T: std::io::Write,
108{
109    fn output_preamble(&mut self) -> Result<()> {
110        writeln!(
111            self.out,
112            r#"#pragma once
113
114#include "serde.hpp""#
115        )?;
116        if self.generator.config.serialization {
117            for encoding in &self.generator.config.encodings {
118                writeln!(self.out, "#include \"{}.hpp\"", encoding.name())?;
119            }
120        }
121        Ok(())
122    }
123
124    fn output_open_namespace(&mut self) -> Result<()> {
125        writeln!(
126            self.out,
127            "\nnamespace {} {{",
128            self.generator.config.module_name
129        )?;
130        self.out.indent();
131        Ok(())
132    }
133
134    fn output_close_namespace(&mut self) -> Result<()> {
135        self.out.unindent();
136        writeln!(
137            self.out,
138            "\n}} // end of namespace {}",
139            self.generator.config.module_name
140        )?;
141        Ok(())
142    }
143
144    fn enter_class(&mut self, name: &str) {
145        self.out.indent();
146        self.current_namespace.push(name.to_string());
147    }
148
149    fn leave_class(&mut self) {
150        self.out.unindent();
151        self.current_namespace.pop();
152    }
153
154    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
155        let mut path = self.current_namespace.clone();
156        path.push(name.to_string());
157        if let Some(doc) = self.generator.config.comments.get(&path) {
158            let text = textwrap::indent(doc, "/// ").replace("\n\n", "\n///\n");
159            write!(self.out, "{text}")?;
160        }
161        Ok(())
162    }
163
164    fn output_custom_code(&mut self) -> std::io::Result<()> {
165        if let Some(code) = self
166            .generator
167            .config
168            .custom_code
169            .get(&self.current_namespace)
170        {
171            write!(self.out, "\n{code}")?;
172        }
173        Ok(())
174    }
175
176    /// Compute a fully qualified reference to the container type `name`.
177    fn quote_qualified_name(&self, name: &str) -> String {
178        self.generator
179            .external_qualified_names
180            .get(name)
181            .cloned()
182            .unwrap_or_else(|| format!("{}::{}", self.generator.config.module_name, name))
183    }
184
185    fn quote_type(&self, format: &Format, require_known_size: bool) -> String {
186        use Format::*;
187        match format {
188            TypeName(x) => {
189                let qname = self.quote_qualified_name(x);
190                if require_known_size && !self.known_sizes.contains(x.as_str()) {
191                    // Cannot use unique_ptr because we need a copy constructor (e.g. for vectors)
192                    // and in-depth equality.
193                    format!("serde::value_ptr<{qname}>")
194                } else {
195                    qname
196                }
197            }
198            Unit => "std::monostate".into(),
199            Bool => "bool".into(),
200            I8 => "int8_t".into(),
201            I16 => "int16_t".into(),
202            I32 => "int32_t".into(),
203            I64 => "int64_t".into(),
204            I128 => "serde::int128_t".into(),
205            U8 => "uint8_t".into(),
206            U16 => "uint16_t".into(),
207            U32 => "uint32_t".into(),
208            U64 => "uint64_t".into(),
209            U128 => "serde::uint128_t".into(),
210            F32 => "float".into(),
211            F64 => "double".into(),
212            Char => "char32_t".into(),
213            Str => "std::string".into(),
214            Bytes => "std::vector<uint8_t>".into(),
215
216            Option(format) => format!(
217                "std::optional<{}>",
218                self.quote_type(format, require_known_size)
219            ),
220            Seq(format) => format!("std::vector<{}>", self.quote_type(format, false)),
221            Map { key, value } => format!(
222                "std::map<{}, {}>",
223                self.quote_type(key, false),
224                self.quote_type(value, false)
225            ),
226            Tuple(formats) => format!(
227                "std::tuple<{}>",
228                self.quote_types(formats, require_known_size)
229            ),
230            TupleArray { content, size } => format!(
231                "std::array<{}, {}>",
232                self.quote_type(content, require_known_size),
233                *size
234            ),
235
236            Variable(_) => panic!("unexpected value"),
237        }
238    }
239
240    fn quote_types(&self, formats: &[Format], require_known_size: bool) -> String {
241        formats
242            .iter()
243            .map(|x| self.quote_type(x, require_known_size))
244            .collect::<Vec<_>>()
245            .join(", ")
246    }
247
248    fn output_struct_or_variant_container(
249        &mut self,
250        name: &str,
251        fields: &[Named<Format>],
252    ) -> Result<()> {
253        writeln!(self.out)?;
254        self.output_comment(name)?;
255        writeln!(self.out, "struct {name} {{")?;
256        self.enter_class(name);
257        for field in fields {
258            self.output_comment(&field.name)?;
259            writeln!(
260                self.out,
261                "{} {};",
262                self.quote_type(&field.value, true),
263                field.name
264            )?;
265        }
266        if !fields.is_empty() {
267            writeln!(self.out)?;
268        }
269        self.output_class_method_declarations(name)?;
270        self.output_custom_code()?;
271        self.leave_class();
272        writeln!(self.out, "}};")
273    }
274
275    fn output_variant(&mut self, name: &str, variant: &VariantFormat) -> Result<()> {
276        use VariantFormat::*;
277        let fields = match variant {
278            Unit => Vec::new(),
279            NewType(format) => vec![Named {
280                name: "value".to_string(),
281                value: format.as_ref().clone(),
282            }],
283            Tuple(formats) => vec![Named {
284                name: "value".to_string(),
285                value: Format::Tuple(formats.clone()),
286            }],
287            Struct(fields) => fields.clone(),
288            Variable(_) => panic!("incorrect value"),
289        };
290        self.output_struct_or_variant_container(name, &fields)
291    }
292
293    fn output_container_forward_definition(&mut self, name: &str) -> Result<()> {
294        writeln!(self.out, "\nstruct {name};")
295    }
296
297    fn output_enum_container(
298        &mut self,
299        name: &str,
300        variants: &BTreeMap<u32, Named<VariantFormat>>,
301    ) -> Result<()> {
302        writeln!(self.out)?;
303        self.output_comment(name)?;
304        writeln!(self.out, "struct {name} {{")?;
305        self.enter_class(name);
306        for (expected_index, (index, variant)) in variants.iter().enumerate() {
307            assert_eq!(*index, expected_index as u32);
308            self.output_variant(&variant.name, &variant.value)?;
309        }
310        writeln!(
311            self.out,
312            "\nstd::variant<{}> value;",
313            variants
314                .values()
315                .map(|v| v.name.clone())
316                .collect::<Vec<_>>()
317                .join(", "),
318        )?;
319        writeln!(self.out)?;
320        self.output_class_method_declarations(name)?;
321        self.output_custom_code()?;
322        self.leave_class();
323        writeln!(self.out, "}};")
324    }
325
326    fn output_class_method_declarations(&mut self, name: &str) -> Result<()> {
327        writeln!(
328            self.out,
329            "friend bool operator==(const {name}&, const {name}&);"
330        )?;
331        if self.generator.config.serialization {
332            for encoding in &self.generator.config.encodings {
333                writeln!(
334                    self.out,
335                    "std::vector<uint8_t> {}Serialize() const;",
336                    encoding.name()
337                )?;
338                writeln!(
339                    self.out,
340                    "static {} {}Deserialize(std::vector<uint8_t>);",
341                    name,
342                    encoding.name()
343                )?;
344            }
345        }
346        Ok(())
347    }
348
349    fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
350        use ContainerFormat::*;
351        let fields = match format {
352            UnitStruct => Vec::new(),
353            NewTypeStruct(format) => vec![Named {
354                name: "value".to_string(),
355                value: format.as_ref().clone(),
356            }],
357            TupleStruct(formats) => vec![Named {
358                name: "value".to_string(),
359                value: Format::Tuple(formats.clone()),
360            }],
361            Struct(fields) => fields.clone(),
362            Enum(variants) => {
363                self.output_enum_container(name, variants)?;
364                return Ok(());
365            }
366        };
367        self.output_struct_or_variant_container(name, &fields)
368    }
369
370    fn output_struct_equality_test(&mut self, name: &str, fields: &[&str]) -> Result<()> {
371        writeln!(
372            self.out,
373            "\ninline bool operator==(const {name} &lhs, const {name} &rhs) {{",
374        )?;
375        self.out.indent();
376        for field in fields {
377            writeln!(
378                self.out,
379                "if (!(lhs.{field} == rhs.{field})) {{ return false; }}",
380            )?;
381        }
382        writeln!(self.out, "return true;")?;
383        self.out.unindent();
384        writeln!(self.out, "}}")
385    }
386
387    fn output_struct_serialize_for_encoding(
388        &mut self,
389        name: &str,
390        encoding: Encoding,
391    ) -> Result<()> {
392        writeln!(
393            self.out,
394            r#"
395inline std::vector<uint8_t> {}::{}Serialize() const {{
396    auto serializer = serde::{}Serializer();
397    serde::Serializable<{}>::serialize(*this, serializer);
398    return std::move(serializer).bytes();
399}}"#,
400            name,
401            encoding.name(),
402            encoding.name().to_camel_case(),
403            name
404        )
405    }
406
407    fn output_struct_deserialize_for_encoding(
408        &mut self,
409        name: &str,
410        encoding: Encoding,
411    ) -> Result<()> {
412        writeln!(
413            self.out,
414            r#"
415inline {} {}::{}Deserialize(std::vector<uint8_t> input) {{
416    auto deserializer = serde::{}Deserializer(input);
417    auto value = serde::Deserializable<{}>::deserialize(deserializer);
418    if (deserializer.get_buffer_offset() < input.size()) {{
419        throw serde::deserialization_error("Some input bytes were not read");
420    }}
421    return value;
422}}"#,
423            name,
424            name,
425            encoding.name(),
426            encoding.name().to_camel_case(),
427            name,
428        )
429    }
430
431    fn output_struct_serializable(
432        &mut self,
433        name: &str,
434        fields: &[&str],
435        is_container: bool,
436    ) -> Result<()> {
437        writeln!(
438            self.out,
439            r#"
440template <>
441template <typename Serializer>
442void serde::Serializable<{name}>::serialize(const {name} &obj, Serializer &serializer) {{"#,
443        )?;
444        self.out.indent();
445        if is_container {
446            writeln!(self.out, "serializer.increase_container_depth();")?;
447        }
448        for field in fields {
449            writeln!(
450                self.out,
451                "serde::Serializable<decltype(obj.{field})>::serialize(obj.{field}, serializer);",
452            )?;
453        }
454        if is_container {
455            writeln!(self.out, "serializer.decrease_container_depth();")?;
456        }
457        self.out.unindent();
458        writeln!(self.out, "}}")
459    }
460
461    fn output_struct_deserializable(
462        &mut self,
463        name: &str,
464        fields: &[&str],
465        is_container: bool,
466    ) -> Result<()> {
467        writeln!(
468            self.out,
469            r#"
470template <>
471template <typename Deserializer>
472{name} serde::Deserializable<{name}>::deserialize(Deserializer &deserializer) {{"#,
473        )?;
474        self.out.indent();
475        if is_container {
476            writeln!(self.out, "deserializer.increase_container_depth();")?;
477        }
478        writeln!(self.out, "{name} obj;")?;
479        for field in fields {
480            writeln!(
481                self.out,
482                "obj.{field} = serde::Deserializable<decltype(obj.{field})>::deserialize(deserializer);",
483            )?;
484        }
485        if is_container {
486            writeln!(self.out, "deserializer.decrease_container_depth();")?;
487        }
488        writeln!(self.out, "return obj;")?;
489        self.out.unindent();
490        writeln!(self.out, "}}")
491    }
492
493    fn output_struct_traits(
494        &mut self,
495        name: &str,
496        fields: &[&str],
497        is_container: bool,
498    ) -> Result<()> {
499        self.output_open_namespace()?;
500        self.output_struct_equality_test(name, fields)?;
501        if self.generator.config.serialization {
502            for encoding in &self.generator.config.encodings {
503                self.output_struct_serialize_for_encoding(name, *encoding)?;
504                self.output_struct_deserialize_for_encoding(name, *encoding)?;
505            }
506        }
507        self.output_close_namespace()?;
508        let namespaced_name = self.quote_qualified_name(name);
509        if self.generator.config.serialization {
510            self.output_struct_serializable(&namespaced_name, fields, is_container)?;
511            self.output_struct_deserializable(&namespaced_name, fields, is_container)?;
512        }
513        Ok(())
514    }
515
516    fn get_variant_fields(format: &VariantFormat) -> Vec<&str> {
517        use VariantFormat::*;
518        match format {
519            Unit => Vec::new(),
520            NewType(_format) => vec!["value"],
521            Tuple(_formats) => vec!["value"],
522            Struct(fields) => fields
523                .iter()
524                .map(|field| field.name.as_str())
525                .collect::<Vec<_>>(),
526            Variable(_) => panic!("incorrect value"),
527        }
528    }
529
530    fn output_container_traits(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
531        use ContainerFormat::*;
532        match format {
533            UnitStruct => self.output_struct_traits(name, &[], true),
534            NewTypeStruct(_format) => self.output_struct_traits(name, &["value"], true),
535            TupleStruct(_formats) => self.output_struct_traits(name, &["value"], true),
536            Struct(fields) => self.output_struct_traits(
537                name,
538                &fields
539                    .iter()
540                    .map(|field| field.name.as_str())
541                    .collect::<Vec<_>>(),
542                true,
543            ),
544            Enum(variants) => {
545                self.output_struct_traits(name, &["value"], true)?;
546                for variant in variants.values() {
547                    self.output_struct_traits(
548                        &format!("{}::{}", name, variant.name),
549                        &Self::get_variant_fields(&variant.value),
550                        false,
551                    )?;
552                }
553                Ok(())
554            }
555        }
556    }
557}
558
559/// Installer for generated source files in C++.
560pub struct Installer {
561    install_dir: PathBuf,
562}
563
564impl Installer {
565    pub fn new(install_dir: PathBuf) -> Self {
566        Installer { install_dir }
567    }
568
569    fn create_header_file(&self, name: &str) -> Result<std::fs::File> {
570        let dir_path = &self.install_dir;
571        std::fs::create_dir_all(dir_path)?;
572        std::fs::File::create(dir_path.join(name.to_string() + ".hpp"))
573    }
574}
575
576impl crate::SourceInstaller for Installer {
577    type Error = Box<dyn std::error::Error>;
578
579    fn install_module(
580        &self,
581        config: &crate::CodeGeneratorConfig,
582        registry: &Registry,
583    ) -> std::result::Result<(), Self::Error> {
584        let mut file = self.create_header_file(&config.module_name)?;
585        let generator = CodeGenerator::new(config);
586        generator.output(&mut file, registry)
587    }
588
589    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
590        let mut file = self.create_header_file("serde")?;
591        write!(file, "{}", include_str!("../runtime/cpp/serde.hpp"))?;
592        let mut file = self.create_header_file("binary")?;
593        write!(file, "{}", include_str!("../runtime/cpp/binary.hpp"))?;
594        Ok(())
595    }
596
597    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
598        let mut file = self.create_header_file("bincode")?;
599        write!(file, "{}", include_str!("../runtime/cpp/bincode.hpp"))?;
600        Ok(())
601    }
602
603    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
604        let mut file = self.create_header_file("bcs")?;
605        write!(file, "{}", include_str!("../runtime/cpp/bcs.hpp"))?;
606        Ok(())
607    }
608}