serde_generate/
golang.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    common,
6    indent::{IndentConfig, IndentedWriter},
7    CodeGeneratorConfig, Encoding,
8};
9use heck::CamelCase;
10use serde_reflection::{ContainerFormat, Format, FormatHolder, Named, Registry, VariantFormat};
11use std::{
12    collections::{BTreeMap, HashMap},
13    io::{Result, Write},
14    path::PathBuf,
15};
16
17/// Main configuration object for code-generation in Go.
18pub struct CodeGenerator<'a> {
19    /// Language-independent configuration.
20    config: &'a CodeGeneratorConfig,
21    /// Module path where to find the serde runtime packages (serde, bcs, bincode).
22    /// Default: "github.com/novifinancial/serde-reflection/serde-generate/runtime/golang".
23    serde_module_path: String,
24    /// Mapping from external type names to fully-qualified class names (e.g. "MyClass" -> "com.my_org.my_package.MyClass").
25    /// Derived from `config.external_definitions`.
26    external_qualified_names: HashMap<String, String>,
27}
28
29/// Shared state for the code generation of a Go source file.
30struct GoEmitter<'a, T> {
31    /// Writer.
32    out: IndentedWriter<T>,
33    /// Generator.
34    generator: &'a CodeGenerator<'a>,
35    /// Current namespace (e.g. vec!["com", "my_org", "my_package", "MyClass"])
36    current_namespace: Vec<String>,
37}
38
39impl<'a> CodeGenerator<'a> {
40    /// Create a Go code generator for the given config.
41    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
42        if config.c_style_enums {
43            panic!("Go does not support generating c-style enums");
44        }
45        let mut external_qualified_names = HashMap::new();
46        for (namespace, names) in &config.external_definitions {
47            let package_name = {
48                let path = namespace.rsplitn(2, '/').collect::<Vec<_>>();
49                if path.len() <= 1 {
50                    namespace
51                } else {
52                    path[0]
53                }
54            };
55            for name in names {
56                external_qualified_names
57                    .insert(name.to_string(), format!("{}.{}", package_name, name));
58            }
59        }
60        Self {
61            config,
62            serde_module_path:
63                "github.com/novifinancial/serde-reflection/serde-generate/runtime/golang"
64                    .to_string(),
65            external_qualified_names,
66        }
67    }
68
69    /// Whether the package providing Serde definitions is located within a different module.
70    pub fn with_serde_module_path(mut self, serde_module_path: String) -> Self {
71        self.serde_module_path = serde_module_path;
72        self
73    }
74
75    /// Output class definitions for `registry`.
76    pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
77        let current_namespace = self
78            .config
79            .module_name
80            .split('.')
81            .map(String::from)
82            .collect::<Vec<_>>();
83
84        let mut emitter = GoEmitter {
85            // `go fmt` indents using tabs so let's do the same.
86            out: IndentedWriter::new(out, IndentConfig::Tab),
87            generator: self,
88            current_namespace,
89        };
90
91        emitter.output_preamble(registry)?;
92
93        for (name, format) in registry {
94            emitter.output_container(name, format)?;
95        }
96
97        if self.config.serialization {
98            emitter.output_trait_helpers(registry)?;
99        }
100
101        Ok(())
102    }
103}
104
105impl<'a, T> GoEmitter<'a, T>
106where
107    T: Write,
108{
109    fn output_preamble(&mut self, registry: &Registry) -> Result<()> {
110        writeln!(
111            self.out,
112            "package {}\n\n",
113            self.generator.config.module_name
114        )?;
115        // Go does not support disabling warnings on unused imports.
116        if registry.is_empty() {
117            return Ok(());
118        }
119        writeln!(self.out, "import (")?;
120        self.out.indent();
121        if self.generator.config.serialization
122            && (Self::has_enum(registry) || !self.generator.config.encodings.is_empty())
123        {
124            writeln!(self.out, "\"fmt\"")?;
125        }
126        if self.generator.config.serialization || Self::has_int128(registry) {
127            writeln!(self.out, "\"{}/serde\"", self.generator.serde_module_path)?;
128        }
129        if self.generator.config.serialization {
130            for encoding in &self.generator.config.encodings {
131                writeln!(
132                    self.out,
133                    "\"{}/{}\"",
134                    self.generator.serde_module_path,
135                    encoding.name()
136                )?;
137            }
138        }
139        for path in self.generator.config.external_definitions.keys() {
140            writeln!(self.out, "\"{}\"", path)?;
141        }
142        self.out.unindent();
143        writeln!(self.out, ")\n")?;
144        Ok(())
145    }
146
147    fn has_int128(registry: &Registry) -> bool {
148        for format in registry.values() {
149            if format
150                .visit(&mut |f| match f {
151                    Format::I128 | Format::U128 => {
152                        // Interrupt the visit if we find a (u)int128
153                        Err(serde_reflection::Error::Custom(String::new()))
154                    }
155                    _ => Ok(()),
156                })
157                .is_err()
158            {
159                return true;
160            }
161        }
162        false
163    }
164
165    fn has_enum(registry: &Registry) -> bool {
166        for format in registry.values() {
167            if let ContainerFormat::Enum(_) = format {
168                return true;
169            }
170        }
171        false
172    }
173
174    /// Compute a reference to the registry type `name`.
175    fn quote_qualified_name(&self, name: &str) -> String {
176        self.generator
177            .external_qualified_names
178            .get(name)
179            .cloned()
180            .unwrap_or_else(|| name.to_string())
181    }
182
183    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
184        let mut path = self.current_namespace.clone();
185        path.push(name.to_string());
186        if let Some(doc) = self.generator.config.comments.get(&path) {
187            let text = textwrap::indent(doc, "// ").replace("\n\n", "\n//\n");
188            write!(self.out, "{}", text)?;
189        }
190        Ok(())
191    }
192
193    fn output_custom_code(&mut self, name: &str) -> std::io::Result<()> {
194        let mut path = self.current_namespace.clone();
195        path.push(name.to_string());
196        if let Some(code) = self.generator.config.custom_code.get(&path) {
197            write!(self.out, "\n{}", code)?;
198        }
199        Ok(())
200    }
201
202    fn quote_type(&self, format: &Format) -> String {
203        use Format::*;
204        match format {
205            TypeName(x) => self.quote_qualified_name(x),
206            Unit => "struct {}".into(),
207            Bool => "bool".into(),
208            I8 => "int8".into(),
209            I16 => "int16".into(),
210            I32 => "int32".into(),
211            I64 => "int64".into(),
212            I128 => "serde.Int128".into(),
213            U8 => "uint8".into(),
214            U16 => "uint16".into(),
215            U32 => "uint32".into(),
216            U64 => "uint64".into(),
217            U128 => "serde.Uint128".into(),
218            F32 => "float32".into(),
219            F64 => "float64".into(),
220            Char => "rune".into(),
221            Str => "string".into(),
222            Bytes => "[]byte".into(),
223
224            Option(format) => format!("*{}", self.quote_type(format)),
225            Seq(format) => format!("[]{}", self.quote_type(format)),
226            Map { key, value } => {
227                format!("map[{}]{}", self.quote_type(key), self.quote_type(value))
228            }
229            Tuple(formats) => format!(
230                "struct {{{}}}",
231                formats
232                    .iter()
233                    .enumerate()
234                    .map(|(index, format)| format!("Field{} {}", index, self.quote_type(format)))
235                    .collect::<Vec<_>>()
236                    .join("; ")
237            ),
238            TupleArray { content, size } => format!("[{}]{}", size, self.quote_type(content)),
239
240            Variable(_) => panic!("unexpected value"),
241        }
242    }
243
244    fn enter_class(&mut self, name: &str) {
245        self.out.indent();
246        self.current_namespace.push(name.to_string());
247    }
248
249    fn leave_class(&mut self) {
250        self.out.unindent();
251        self.current_namespace.pop();
252    }
253
254    fn output_trait_helpers(&mut self, registry: &Registry) -> Result<()> {
255        let mut subtypes = BTreeMap::new();
256        for format in registry.values() {
257            format
258                .visit(&mut |f| {
259                    if Self::needs_helper(f) {
260                        subtypes.insert(common::mangle_type(f), f.clone());
261                    }
262                    Ok(())
263                })
264                .unwrap();
265        }
266        for (mangled_name, subtype) in &subtypes {
267            self.output_serialization_helper(mangled_name, subtype)?;
268            self.output_deserialization_helper(mangled_name, subtype)?;
269        }
270        Ok(())
271    }
272
273    fn needs_helper(format: &Format) -> bool {
274        use Format::*;
275        matches!(
276            format,
277            Option(_) | Seq(_) | Map { .. } | Tuple(_) | TupleArray { .. }
278        )
279    }
280
281    fn quote_serialize_value(&self, value: &str, format: &Format) -> String {
282        use Format::*;
283        let expr = match format {
284            TypeName(_) => format!("{}.Serialize(serializer)", value),
285            Unit => format!("serializer.SerializeUnit({})", value),
286            Bool => format!("serializer.SerializeBool({})", value),
287            I8 => format!("serializer.SerializeI8({})", value),
288            I16 => format!("serializer.SerializeI16({})", value),
289            I32 => format!("serializer.SerializeI32({})", value),
290            I64 => format!("serializer.SerializeI64({})", value),
291            I128 => format!("serializer.SerializeI128({})", value),
292            U8 => format!("serializer.SerializeU8({})", value),
293            U16 => format!("serializer.SerializeU16({})", value),
294            U32 => format!("serializer.SerializeU32({})", value),
295            U64 => format!("serializer.SerializeU64({})", value),
296            U128 => format!("serializer.SerializeU128({})", value),
297            F32 => format!("serializer.SerializeF32({})", value),
298            F64 => format!("serializer.SerializeF64({})", value),
299            Char => format!("serializer.SerializeChar({})", value),
300            Str => format!("serializer.SerializeStr({})", value),
301            Bytes => format!("serializer.SerializeBytes({})", value),
302            _ => format!(
303                "serialize_{}({}, serializer)",
304                common::mangle_type(format),
305                value
306            ),
307        };
308        format!("if err := {}; err != nil {{ return err }}", expr)
309    }
310
311    fn quote_deserialize(&self, format: &Format, dest: &str, fail: &str) -> String {
312        use Format::*;
313        let expr = match format {
314            TypeName(name) => format!(
315                "Deserialize{}(deserializer)",
316                self.quote_qualified_name(name)
317            ),
318            Unit => "deserializer.DeserializeUnit()".to_string(),
319            Bool => "deserializer.DeserializeBool()".to_string(),
320            I8 => "deserializer.DeserializeI8()".to_string(),
321            I16 => "deserializer.DeserializeI16()".to_string(),
322            I32 => "deserializer.DeserializeI32()".to_string(),
323            I64 => "deserializer.DeserializeI64()".to_string(),
324            I128 => "deserializer.DeserializeI128()".to_string(),
325            U8 => "deserializer.DeserializeU8()".to_string(),
326            U16 => "deserializer.DeserializeU16()".to_string(),
327            U32 => "deserializer.DeserializeU32()".to_string(),
328            U64 => "deserializer.DeserializeU64()".to_string(),
329            U128 => "deserializer.DeserializeU128()".to_string(),
330            F32 => "deserializer.DeserializeF32()".to_string(),
331            F64 => "deserializer.DeserializeF64()".to_string(),
332            Char => "deserializer.DeserializeChar()".to_string(),
333            Str => "deserializer.DeserializeStr()".to_string(),
334            Bytes => "deserializer.DeserializeBytes()".to_string(),
335            _ => format!("deserialize_{}(deserializer)", common::mangle_type(format)),
336        };
337        format!(
338            "if val, err := {}; err == nil {{ {} = val }} else {{ return {}, err }}",
339            expr, dest, fail
340        )
341    }
342
343    fn output_serialization_helper(&mut self, name: &str, format0: &Format) -> Result<()> {
344        use Format::*;
345
346        write!(
347            self.out,
348            "func serialize_{}(value {}, serializer serde.Serializer) error {{",
349            name,
350            self.quote_type(format0)
351        )?;
352        self.out.indent();
353        match format0 {
354            Option(format) => {
355                write!(
356                    self.out,
357                    r#"
358if value != nil {{
359	if err := serializer.SerializeOptionTag(true); err != nil {{ return err }}
360	{}
361}} else {{
362	if err := serializer.SerializeOptionTag(false); err != nil {{ return err }}
363}}
364"#,
365                    self.quote_serialize_value("(*value)", format)
366                )?;
367            }
368
369            Seq(format) => {
370                write!(
371                    self.out,
372                    r#"
373if err := serializer.SerializeLen(uint64(len(value))); err != nil {{ return err }}
374for _, item := range(value) {{
375	{}
376}}
377"#,
378                    self.quote_serialize_value("item", format)
379                )?;
380            }
381
382            Map { key, value } => {
383                write!(
384                    self.out,
385                    r#"
386if err := serializer.SerializeLen(uint64(len(value))); err != nil {{ return err }}
387offsets := make([]uint64, len(value))
388count := 0
389for k, v := range(value) {{
390	offsets[count] = serializer.GetBufferOffset()
391	count += 1
392	{}
393	{}
394}}
395serializer.SortMapEntries(offsets);
396"#,
397                    self.quote_serialize_value("k", key),
398                    self.quote_serialize_value("v", value)
399                )?;
400            }
401
402            Tuple(formats) => {
403                writeln!(self.out)?;
404                for (index, format) in formats.iter().enumerate() {
405                    let expr = format!("value.Field{}", index);
406                    writeln!(self.out, "{}", self.quote_serialize_value(&expr, format))?;
407                }
408            }
409
410            TupleArray { content, size: _ } => {
411                write!(
412                    self.out,
413                    r#"
414for _, item := range(value) {{
415	{}
416}}
417"#,
418                    self.quote_serialize_value("item", content),
419                )?;
420            }
421
422            _ => panic!("unexpected case"),
423        }
424        writeln!(self.out, "return nil")?;
425        self.out.unindent();
426        writeln!(self.out, "}}\n")
427    }
428
429    fn output_deserialization_helper(&mut self, name: &str, format0: &Format) -> Result<()> {
430        use Format::*;
431
432        write!(
433            self.out,
434            "func deserialize_{}(deserializer serde.Deserializer) ({}, error) {{",
435            name,
436            self.quote_type(format0),
437        )?;
438        self.out.indent();
439        match format0 {
440            Option(format) => {
441                write!(
442                    self.out,
443                    r#"
444tag, err := deserializer.DeserializeOptionTag()
445if err != nil {{ return nil, err }}
446if tag {{
447	value := new({})
448	{}
449        return value, nil
450}} else {{
451	return nil, nil
452}}
453"#,
454                    self.quote_type(format),
455                    self.quote_deserialize(format, "*value", "nil"),
456                )?;
457            }
458
459            Seq(format) => {
460                write!(
461                    self.out,
462                    r#"
463length, err := deserializer.DeserializeLen()
464if err != nil {{ return nil, err }}
465obj := make([]{}, length)
466for i := range(obj) {{
467	{}
468}}
469return obj, nil
470"#,
471                    self.quote_type(format),
472                    self.quote_deserialize(format, "obj[i]", "nil")
473                )?;
474            }
475
476            Map { key, value } => {
477                write!(
478                    self.out,
479                    r#"
480length, err := deserializer.DeserializeLen()
481if err != nil {{ return nil, err }}
482obj := make(map[{0}]{1})
483previous_slice := serde.Slice {{ 0, 0 }}
484for i := 0; i < int(length); i++ {{
485	var slice serde.Slice
486	slice.Start = deserializer.GetBufferOffset()
487	var key {0}
488	{2}
489	slice.End = deserializer.GetBufferOffset()
490	if i > 0 {{
491		err := deserializer.CheckThatKeySlicesAreIncreasing(previous_slice, slice)
492		if err != nil {{ return nil, err }}
493	}}
494	previous_slice = slice
495	{3}
496}}
497return obj, nil
498"#,
499                    self.quote_type(key),
500                    self.quote_type(value),
501                    self.quote_deserialize(key, "key", "nil"),
502                    self.quote_deserialize(value, "obj[key]", "nil"),
503                )?;
504            }
505
506            Tuple(formats) => {
507                write!(
508                    self.out,
509                    r#"
510var obj {}
511{}
512return obj, nil
513"#,
514                    self.quote_type(format0),
515                    formats
516                        .iter()
517                        .enumerate()
518                        .map(|(i, f)| self.quote_deserialize(f, &format!("obj.Field{}", i), "obj"))
519                        .collect::<Vec<_>>()
520                        .join("\n")
521                )?;
522            }
523
524            TupleArray { content, size } => {
525                write!(
526                    self.out,
527                    r#"
528var obj [{1}]{0}
529for i := range(obj) {{
530	{2}
531}}
532return obj, nil
533"#,
534                    self.quote_type(content),
535                    size,
536                    self.quote_deserialize(content, "obj[i]", "obj")
537                )?;
538            }
539
540            _ => panic!("unexpected case"),
541        }
542        self.out.unindent();
543        writeln!(self.out, "}}\n")
544    }
545
546    fn output_variant(
547        &mut self,
548        base: &str,
549        index: u32,
550        name: &str,
551        variant: &VariantFormat,
552    ) -> Result<()> {
553        use VariantFormat::*;
554        let fields = match variant {
555            Unit => Vec::new(),
556            NewType(format) => match format.as_ref() {
557                // We cannot define a "new type" (e.g. `type Foo Bar`) out of a typename `Bar` because `Bar`
558                // could point to a Go interface. This would make `Foo` an interface as well. Interfaces can't be used
559                // as structs (e.g. they cannot have methods).
560                //
561                // Similarly, option types are compiled as pointers but `type Foo *Bar` would prevent `Foo` from being a
562                // valid pointer receiver.
563                Format::TypeName(_) | Format::Option(_) => vec![Named {
564                    name: "Value".to_string(),
565                    value: format.as_ref().clone(),
566                }],
567                // Other cases are fine.
568                _ => {
569                    self.output_struct_or_variant_new_type_container(
570                        Some(base),
571                        Some(index),
572                        name,
573                        format,
574                    )?;
575                    return Ok(());
576                }
577            },
578            Tuple(formats) => formats
579                .iter()
580                .enumerate()
581                .map(|(i, f)| Named {
582                    name: format!("Field{}", i),
583                    value: f.clone(),
584                })
585                .collect(),
586            Struct(fields) => fields
587                .iter()
588                .map(|f| Named {
589                    name: f.name.to_camel_case(),
590                    value: f.value.clone(),
591                })
592                .collect(),
593            Variable(_) => panic!("incorrect value"),
594        };
595        self.output_struct_or_variant_container(Some(base), Some(index), name, &fields)
596    }
597
598    fn output_struct_or_variant_container(
599        &mut self,
600        variant_base: Option<&str>,
601        variant_index: Option<u32>,
602        name: &str,
603        fields: &[Named<Format>],
604    ) -> Result<()> {
605        let full_name = match variant_base {
606            None => name.to_string(),
607            Some(base) => format!("{}__{}", base, name),
608        };
609        // Struct
610        writeln!(self.out)?;
611        self.output_comment(name)?;
612        writeln!(self.out, "type {} struct {{", full_name)?;
613        self.enter_class(name);
614        for field in fields {
615            self.output_comment(&field.name)?;
616            writeln!(self.out, "{} {}", field.name, self.quote_type(&field.value))?;
617        }
618        self.leave_class();
619        writeln!(self.out, "}}")?;
620
621        // Link to base interface.
622        if let Some(base) = variant_base {
623            writeln!(self.out, "\nfunc (*{}) is{}() {{}}", full_name, base)?;
624        }
625
626        // Serialize
627        if self.generator.config.serialization {
628            writeln!(
629                self.out,
630                "\nfunc (obj *{}) Serialize(serializer serde.Serializer) error {{",
631                full_name
632            )?;
633            self.out.indent();
634            writeln!(
635                self.out,
636                "if err := serializer.IncreaseContainerDepth(); err != nil {{ return err }}"
637            )?;
638            if let Some(index) = variant_index {
639                writeln!(self.out, "serializer.SerializeVariantIndex({})", index)?;
640            }
641            for field in fields {
642                writeln!(
643                    self.out,
644                    "{}",
645                    self.quote_serialize_value(&format!("obj.{}", &field.name), &field.value)
646                )?;
647            }
648            writeln!(self.out, "serializer.DecreaseContainerDepth()")?;
649            writeln!(self.out, "return nil")?;
650            self.out.unindent();
651            writeln!(self.out, "}}")?;
652
653            for encoding in &self.generator.config.encodings {
654                self.output_struct_serialize_for_encoding(&full_name, *encoding)?;
655            }
656        }
657        // Deserialize (struct) or Load (variant)
658        if self.generator.config.serialization {
659            writeln!(
660                self.out,
661                "\nfunc {0}{1}(deserializer serde.Deserializer) ({1}, error) {{",
662                if variant_base.is_none() {
663                    "Deserialize"
664                } else {
665                    "load_"
666                },
667                full_name,
668            )?;
669            self.out.indent();
670            writeln!(self.out, "var obj {}", full_name)?;
671            writeln!(
672                self.out,
673                "if err := deserializer.IncreaseContainerDepth(); err != nil {{ return obj, err }}"
674            )?;
675            for field in fields {
676                writeln!(
677                    self.out,
678                    "{}",
679                    self.quote_deserialize(&field.value, &format!("obj.{}", field.name), "obj")
680                )?;
681            }
682            writeln!(self.out, "deserializer.DecreaseContainerDepth()")?;
683            writeln!(self.out, "return obj, nil")?;
684            self.out.unindent();
685            writeln!(self.out, "}}")?;
686
687            if variant_base.is_none() {
688                for encoding in &self.generator.config.encodings {
689                    self.output_struct_deserialize_for_encoding(&full_name, *encoding)?;
690                }
691            }
692        }
693        // Custom code
694        self.output_custom_code(name)?;
695        Ok(())
696    }
697
698    // Same as output_struct_or_variant_container but we map the container with a single anonymous field
699    // to a new type in Go.
700    fn output_struct_or_variant_new_type_container(
701        &mut self,
702        variant_base: Option<&str>,
703        variant_index: Option<u32>,
704        name: &str,
705        format: &Format,
706    ) -> Result<()> {
707        let full_name = match variant_base {
708            None => name.to_string(),
709            Some(base) => format!("{}__{}", base, name),
710        };
711        // Struct
712        writeln!(self.out)?;
713        self.output_comment(name)?;
714        writeln!(self.out, "type {} {}", full_name, self.quote_type(format))?;
715
716        // Link to base interface.
717        if let Some(base) = variant_base {
718            writeln!(self.out, "\nfunc (*{}) is{}() {{}}", full_name, base)?;
719        }
720
721        // Serialize
722        if self.generator.config.serialization {
723            writeln!(
724                self.out,
725                "\nfunc (obj *{}) Serialize(serializer serde.Serializer) error {{",
726                full_name
727            )?;
728            self.out.indent();
729            writeln!(
730                self.out,
731                "if err := serializer.IncreaseContainerDepth(); err != nil {{ return err }}"
732            )?;
733            if let Some(index) = variant_index {
734                writeln!(self.out, "serializer.SerializeVariantIndex({})", index)?;
735            }
736            writeln!(
737                self.out,
738                "{}",
739                self.quote_serialize_value(
740                    &format!("(({})(*obj))", self.quote_type(format)),
741                    format
742                )
743            )?;
744            writeln!(self.out, "serializer.DecreaseContainerDepth()")?;
745            writeln!(self.out, "return nil")?;
746            self.out.unindent();
747            writeln!(self.out, "}}")?;
748
749            for encoding in &self.generator.config.encodings {
750                self.output_struct_serialize_for_encoding(&full_name, *encoding)?;
751            }
752        }
753        // Deserialize (struct) or Load (variant)
754        if self.generator.config.serialization {
755            writeln!(
756                self.out,
757                "\nfunc {0}{1}(deserializer serde.Deserializer) ({1}, error) {{",
758                if variant_base.is_none() {
759                    "Deserialize"
760                } else {
761                    "load_"
762                },
763                full_name,
764            )?;
765            self.out.indent();
766            writeln!(self.out, "var obj {}", self.quote_type(format))?;
767            writeln!(self.out, "if err := deserializer.IncreaseContainerDepth(); err != nil {{ return ({})(obj), err }}", full_name)?;
768            writeln!(
769                self.out,
770                "{}",
771                self.quote_deserialize(format, "obj", &format!("(({})(obj))", full_name))
772            )?;
773            writeln!(self.out, "deserializer.DecreaseContainerDepth()")?;
774            writeln!(self.out, "return ({})(obj), nil", full_name)?;
775            self.out.unindent();
776            writeln!(self.out, "}}")?;
777
778            if variant_base.is_none() {
779                for encoding in &self.generator.config.encodings {
780                    self.output_struct_deserialize_for_encoding(&full_name, *encoding)?;
781                }
782            }
783        }
784        // Custom code
785        self.output_custom_code(name)?;
786        Ok(())
787    }
788
789    fn output_struct_serialize_for_encoding(
790        &mut self,
791        name: &str,
792        encoding: Encoding,
793    ) -> Result<()> {
794        writeln!(
795            self.out,
796            r#"
797func (obj *{0}) {2}Serialize() ([]byte, error) {{
798	if obj == nil {{
799		return nil, fmt.Errorf("Cannot serialize null object")
800	}}
801	serializer := {1}.NewSerializer();
802	if err := obj.Serialize(serializer); err != nil {{ return nil, err }}
803	return serializer.GetBytes(), nil
804}}"#,
805            name,
806            encoding.name(),
807            encoding.name().to_camel_case()
808        )
809    }
810
811    fn output_struct_deserialize_for_encoding(
812        &mut self,
813        name: &str,
814        encoding: Encoding,
815    ) -> Result<()> {
816        writeln!(
817            self.out,
818            r#"
819func {2}Deserialize{0}(input []byte) ({0}, error) {{
820	if input == nil {{
821		var obj {0}
822		return obj, fmt.Errorf("Cannot deserialize null array")
823	}}
824	deserializer := {1}.NewDeserializer(input);
825	obj, err := Deserialize{0}(deserializer)
826	if err == nil && deserializer.GetBufferOffset() < uint64(len(input)) {{
827		return obj, fmt.Errorf("Some input bytes were not read")
828	}}
829	return obj, err
830}}"#,
831            name,
832            encoding.name(),
833            encoding.name().to_camel_case(),
834        )
835    }
836
837    fn output_enum_container(
838        &mut self,
839        name: &str,
840        variants: &BTreeMap<u32, Named<VariantFormat>>,
841    ) -> Result<()> {
842        writeln!(self.out)?;
843        self.output_comment(name)?;
844        writeln!(self.out, "type {} interface {{", name)?;
845        self.current_namespace.push(name.to_string());
846        self.out.indent();
847        writeln!(self.out, "is{}()", name)?;
848        if self.generator.config.serialization {
849            writeln!(self.out, "Serialize(serializer serde.Serializer) error")?;
850            for encoding in &self.generator.config.encodings {
851                writeln!(
852                    self.out,
853                    "{}Serialize() ([]byte, error)",
854                    encoding.name().to_camel_case()
855                )?;
856            }
857        }
858        self.out.unindent();
859        writeln!(self.out, "}}")?;
860
861        if self.generator.config.serialization {
862            write!(
863                self.out,
864                "\nfunc Deserialize{0}(deserializer serde.Deserializer) ({0}, error) {{",
865                name
866            )?;
867            self.out.indent();
868            writeln!(
869                self.out,
870                r#"
871index, err := deserializer.DeserializeVariantIndex()
872if err != nil {{ return nil, err }}
873
874switch index {{"#,
875            )?;
876            for (index, variant) in variants {
877                writeln!(
878                    self.out,
879                    r#"case {}:
880	if val, err := load_{}__{}(deserializer); err == nil {{
881		return &val, nil
882	}} else {{
883		return nil, err
884	}}
885"#,
886                    index, name, variant.name
887                )?;
888            }
889            writeln!(
890                self.out,
891                "default:
892	return nil, fmt.Errorf(\"Unknown variant index for {}: %d\", index)",
893                name,
894            )?;
895            writeln!(self.out, "}}")?;
896            self.out.unindent();
897            writeln!(self.out, "}}")?;
898
899            for encoding in &self.generator.config.encodings {
900                self.output_struct_deserialize_for_encoding(name, *encoding)?;
901            }
902        }
903
904        for (index, variant) in variants {
905            self.output_variant(name, *index, &variant.name, &variant.value)?;
906        }
907        self.current_namespace.pop();
908        // Custom code
909        self.output_custom_code(name)?;
910        Ok(())
911    }
912
913    fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
914        use ContainerFormat::*;
915        let fields = match format {
916            UnitStruct => Vec::new(),
917            NewTypeStruct(format) => match format.as_ref() {
918                // See comment in `output_variant`.
919                Format::TypeName(_) | Format::Option(_) => vec![Named {
920                    name: "Value".to_string(),
921                    value: format.as_ref().clone(),
922                }],
923                _ => {
924                    self.output_struct_or_variant_new_type_container(None, None, name, format)?;
925                    return Ok(());
926                }
927            },
928            TupleStruct(formats) => formats
929                .iter()
930                .enumerate()
931                .map(|(i, f)| Named {
932                    name: format!("Field{}", i),
933                    value: f.clone(),
934                })
935                .collect(),
936            Struct(fields) => fields
937                .iter()
938                .map(|f| Named {
939                    name: f.name.to_camel_case(),
940                    value: f.value.clone(),
941                })
942                .collect(),
943            Enum(variants) => {
944                let variants = variants
945                    .iter()
946                    .map(|(i, f)| {
947                        (
948                            *i,
949                            Named {
950                                name: f.name.to_camel_case(),
951                                value: f.value.clone(),
952                            },
953                        )
954                    })
955                    .collect();
956                self.output_enum_container(name, &variants)?;
957                return Ok(());
958            }
959        };
960        self.output_struct_or_variant_container(None, None, name, &fields)
961    }
962}
963
964/// Installer for generated source files in Go.
965pub struct Installer {
966    install_dir: PathBuf,
967    serde_module_path: Option<String>,
968}
969
970impl Installer {
971    pub fn new(install_dir: PathBuf, serde_module_path: Option<String>) -> Self {
972        Installer {
973            install_dir,
974            serde_module_path,
975        }
976    }
977
978    fn runtime_installation_message(&self, name: &str) {
979        eprintln!(
980            "Not installing sources for published package {}{}",
981            match &self.serde_module_path {
982                None => String::new(),
983                Some(path) => format!("{}/", path),
984            },
985            name
986        );
987    }
988}
989
990impl crate::SourceInstaller for Installer {
991    type Error = Box<dyn std::error::Error>;
992
993    fn install_module(
994        &self,
995        config: &CodeGeneratorConfig,
996        registry: &Registry,
997    ) -> std::result::Result<(), Self::Error> {
998        let dir_path = self.install_dir.join(&config.module_name);
999        std::fs::create_dir_all(&dir_path)?;
1000        let source_path = dir_path.join("lib.go");
1001        let mut file = std::fs::File::create(source_path)?;
1002
1003        let mut generator = CodeGenerator::new(config);
1004        if let Some(path) = &self.serde_module_path {
1005            generator = generator.with_serde_module_path(path.clone());
1006        }
1007        generator.output(&mut file, registry)?;
1008        Ok(())
1009    }
1010
1011    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
1012        self.runtime_installation_message("serde");
1013        Ok(())
1014    }
1015
1016    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
1017        self.runtime_installation_message("bincode");
1018        Ok(())
1019    }
1020
1021    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
1022        self.runtime_installation_message("bcs");
1023        Ok(())
1024    }
1025}