Skip to main content

serde_generate/
golang.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    common,
6    indent::{IndentConfig, IndentedWriter},
7    CodeGeneratorConfig, Encoding,
8};
9use heck::CamelCase;
10use serde_reflection::{ContainerFormat, Format, FormatHolder, Named, Registry, VariantFormat};
11use std::{
12    collections::{BTreeMap, HashMap},
13    io::{Result, Write},
14    path::PathBuf,
15};
16
17/// Main configuration object for code-generation in Go.
18pub struct CodeGenerator<'a> {
19    /// Language-independent configuration.
20    config: &'a CodeGeneratorConfig,
21    /// Module path where to find the serde runtime packages (serde, bcs, bincode).
22    /// Default: "github.com/novifinancial/serde-reflection/serde-generate/runtime/golang".
23    serde_module_path: String,
24    /// Mapping from external type names to fully-qualified class names (e.g. "MyClass" -> "com.my_org.my_package.MyClass").
25    /// Derived from `config.external_definitions`.
26    external_qualified_names: HashMap<String, String>,
27}
28
29/// Shared state for the code generation of a Go source file.
30struct GoEmitter<'a, T> {
31    /// Writer.
32    out: IndentedWriter<T>,
33    /// Generator.
34    generator: &'a CodeGenerator<'a>,
35    /// Current namespace (e.g. vec!["com", "my_org", "my_package", "MyClass"])
36    current_namespace: Vec<String>,
37}
38
39impl<'a> CodeGenerator<'a> {
40    /// Create a Go code generator for the given config.
41    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
42        if config.enums.c_style {
43            panic!("Go does not support generating c-style enums");
44        }
45        let mut external_qualified_names = HashMap::new();
46        for (namespace, names) in &config.external_definitions {
47            let package_name = {
48                let path = namespace.rsplitn(2, '/').collect::<Vec<_>>();
49                if path.len() <= 1 {
50                    namespace
51                } else {
52                    path[0]
53                }
54            };
55            for name in names {
56                external_qualified_names.insert(name.to_string(), format!("{package_name}.{name}"));
57            }
58        }
59        Self {
60            config,
61            serde_module_path:
62                "github.com/novifinancial/serde-reflection/serde-generate/runtime/golang"
63                    .to_string(),
64            external_qualified_names,
65        }
66    }
67
68    /// Whether the package providing Serde definitions is located within a different module.
69    pub fn with_serde_module_path(mut self, serde_module_path: String) -> Self {
70        self.serde_module_path = serde_module_path;
71        self
72    }
73
74    /// Output class definitions for `registry`.
75    pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
76        let current_namespace = self
77            .config
78            .module_name
79            .split('.')
80            .map(String::from)
81            .collect::<Vec<_>>();
82
83        let mut emitter = GoEmitter {
84            // `go fmt` indents using tabs so let's do the same.
85            out: IndentedWriter::new(out, IndentConfig::Tab),
86            generator: self,
87            current_namespace,
88        };
89
90        emitter.output_preamble(registry)?;
91
92        for (name, format) in registry {
93            emitter.output_container(name, format)?;
94        }
95
96        if self.config.serialization {
97            emitter.output_trait_helpers(registry)?;
98        }
99
100        Ok(())
101    }
102}
103
104impl<'a, T> GoEmitter<'a, T>
105where
106    T: Write,
107{
108    fn output_preamble(&mut self, registry: &Registry) -> Result<()> {
109        writeln!(
110            self.out,
111            "package {}\n\n",
112            self.generator.config.module_name
113        )?;
114        // Go does not support disabling warnings on unused imports.
115        if registry.is_empty() {
116            return Ok(());
117        }
118        writeln!(self.out, "import (")?;
119        self.out.indent();
120        if self.generator.config.serialization
121            && (Self::has_enum(registry) || !self.generator.config.encodings.is_empty())
122        {
123            writeln!(self.out, "\"fmt\"")?;
124        }
125        if self.generator.config.serialization || Self::has_int128(registry) {
126            writeln!(self.out, "\"{}/serde\"", self.generator.serde_module_path)?;
127        }
128        if self.generator.config.serialization {
129            for encoding in &self.generator.config.encodings {
130                writeln!(
131                    self.out,
132                    "\"{}/{}\"",
133                    self.generator.serde_module_path,
134                    encoding.name()
135                )?;
136            }
137        }
138        for path in self.generator.config.external_definitions.keys() {
139            writeln!(self.out, "\"{path}\"")?;
140        }
141        self.out.unindent();
142        writeln!(self.out, ")\n")?;
143        Ok(())
144    }
145
146    fn has_int128(registry: &Registry) -> bool {
147        for format in registry.values() {
148            if format
149                .visit(&mut |f| match f {
150                    Format::I128 | Format::U128 => {
151                        // Interrupt the visit if we find a (u)int128
152                        Err(serde_reflection::Error::Custom(String::new()))
153                    }
154                    _ => Ok(()),
155                })
156                .is_err()
157            {
158                return true;
159            }
160        }
161        false
162    }
163
164    fn has_enum(registry: &Registry) -> bool {
165        for format in registry.values() {
166            if let ContainerFormat::Enum(_) = format {
167                return true;
168            }
169        }
170        false
171    }
172
173    /// Compute a reference to the registry type `name`.
174    fn quote_qualified_name(&self, name: &str) -> String {
175        self.generator
176            .external_qualified_names
177            .get(name)
178            .cloned()
179            .unwrap_or_else(|| name.to_string())
180    }
181
182    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
183        let mut path = self.current_namespace.clone();
184        path.push(name.to_string());
185        if let Some(doc) = self.generator.config.comments.get(&path) {
186            let text = textwrap::indent(doc, "// ").replace("\n\n", "\n//\n");
187            write!(self.out, "{text}")?;
188        }
189        Ok(())
190    }
191
192    fn output_custom_code(&mut self, name: &str) -> std::io::Result<()> {
193        let mut path = self.current_namespace.clone();
194        path.push(name.to_string());
195        if let Some(code) = self.generator.config.custom_code.get(&path) {
196            write!(self.out, "\n{code}")?;
197        }
198        Ok(())
199    }
200
201    fn quote_type(&self, format: &Format) -> String {
202        use Format::*;
203        match format {
204            TypeName(x) => self.quote_qualified_name(x),
205            Unit => "struct {}".into(),
206            Bool => "bool".into(),
207            I8 => "int8".into(),
208            I16 => "int16".into(),
209            I32 => "int32".into(),
210            I64 => "int64".into(),
211            I128 => "serde.Int128".into(),
212            U8 => "uint8".into(),
213            U16 => "uint16".into(),
214            U32 => "uint32".into(),
215            U64 => "uint64".into(),
216            U128 => "serde.Uint128".into(),
217            F32 => "float32".into(),
218            F64 => "float64".into(),
219            Char => "rune".into(),
220            Str => "string".into(),
221            Bytes => "[]byte".into(),
222
223            Option(format) => format!("*{}", self.quote_type(format)),
224            Seq(format) => format!("[]{}", self.quote_type(format)),
225            Map { key, value } => {
226                format!("map[{}]{}", self.quote_type(key), self.quote_type(value))
227            }
228            Tuple(formats) => format!(
229                "struct {{{}}}",
230                formats
231                    .iter()
232                    .enumerate()
233                    .map(|(index, format)| format!("Field{} {}", index, self.quote_type(format)))
234                    .collect::<Vec<_>>()
235                    .join("; ")
236            ),
237            TupleArray { content, size } => format!("[{}]{}", size, self.quote_type(content)),
238
239            Variable(_) => panic!("unexpected value"),
240        }
241    }
242
243    fn enter_class(&mut self, name: &str) {
244        self.out.indent();
245        self.current_namespace.push(name.to_string());
246    }
247
248    fn leave_class(&mut self) {
249        self.out.unindent();
250        self.current_namespace.pop();
251    }
252
253    fn output_trait_helpers(&mut self, registry: &Registry) -> Result<()> {
254        let mut subtypes = BTreeMap::new();
255        for format in registry.values() {
256            format
257                .visit(&mut |f| {
258                    if Self::needs_helper(f) {
259                        subtypes.insert(common::mangle_type(f), f.clone());
260                    }
261                    Ok(())
262                })
263                .unwrap();
264        }
265        for (mangled_name, subtype) in &subtypes {
266            self.output_serialization_helper(mangled_name, subtype)?;
267            self.output_deserialization_helper(mangled_name, subtype)?;
268        }
269        Ok(())
270    }
271
272    fn needs_helper(format: &Format) -> bool {
273        use Format::*;
274        matches!(
275            format,
276            Option(_) | Seq(_) | Map { .. } | Tuple(_) | TupleArray { .. }
277        )
278    }
279
280    fn quote_serialize_value(&self, value: &str, format: &Format) -> String {
281        use Format::*;
282        let expr = match format {
283            TypeName(_) => format!("{value}.Serialize(serializer)"),
284            Unit => format!("serializer.SerializeUnit({value})"),
285            Bool => format!("serializer.SerializeBool({value})"),
286            I8 => format!("serializer.SerializeI8({value})"),
287            I16 => format!("serializer.SerializeI16({value})"),
288            I32 => format!("serializer.SerializeI32({value})"),
289            I64 => format!("serializer.SerializeI64({value})"),
290            I128 => format!("serializer.SerializeI128({value})"),
291            U8 => format!("serializer.SerializeU8({value})"),
292            U16 => format!("serializer.SerializeU16({value})"),
293            U32 => format!("serializer.SerializeU32({value})"),
294            U64 => format!("serializer.SerializeU64({value})"),
295            U128 => format!("serializer.SerializeU128({value})"),
296            F32 => format!("serializer.SerializeF32({value})"),
297            F64 => format!("serializer.SerializeF64({value})"),
298            Char => format!("serializer.SerializeChar({value})"),
299            Str => format!("serializer.SerializeStr({value})"),
300            Bytes => format!("serializer.SerializeBytes({value})"),
301            _ => format!(
302                "serialize_{}({}, serializer)",
303                common::mangle_type(format),
304                value
305            ),
306        };
307        format!("if err := {expr}; err != nil {{ return err }}")
308    }
309
310    fn quote_deserialize(&self, format: &Format, dest: &str, fail: &str) -> String {
311        use Format::*;
312        let expr = match format {
313            TypeName(name) => format!(
314                "Deserialize{}(deserializer)",
315                self.quote_qualified_name(name)
316            ),
317            Unit => "deserializer.DeserializeUnit()".to_string(),
318            Bool => "deserializer.DeserializeBool()".to_string(),
319            I8 => "deserializer.DeserializeI8()".to_string(),
320            I16 => "deserializer.DeserializeI16()".to_string(),
321            I32 => "deserializer.DeserializeI32()".to_string(),
322            I64 => "deserializer.DeserializeI64()".to_string(),
323            I128 => "deserializer.DeserializeI128()".to_string(),
324            U8 => "deserializer.DeserializeU8()".to_string(),
325            U16 => "deserializer.DeserializeU16()".to_string(),
326            U32 => "deserializer.DeserializeU32()".to_string(),
327            U64 => "deserializer.DeserializeU64()".to_string(),
328            U128 => "deserializer.DeserializeU128()".to_string(),
329            F32 => "deserializer.DeserializeF32()".to_string(),
330            F64 => "deserializer.DeserializeF64()".to_string(),
331            Char => "deserializer.DeserializeChar()".to_string(),
332            Str => "deserializer.DeserializeStr()".to_string(),
333            Bytes => "deserializer.DeserializeBytes()".to_string(),
334            _ => format!("deserialize_{}(deserializer)", common::mangle_type(format)),
335        };
336        format!(
337            "if val, err := {expr}; err == nil {{ {dest} = val }} else {{ return {fail}, err }}"
338        )
339    }
340
341    fn output_serialization_helper(&mut self, name: &str, format0: &Format) -> Result<()> {
342        use Format::*;
343
344        write!(
345            self.out,
346            "func serialize_{}(value {}, serializer serde.Serializer) error {{",
347            name,
348            self.quote_type(format0)
349        )?;
350        self.out.indent();
351        match format0 {
352            Option(format) => {
353                write!(
354                    self.out,
355                    r#"
356if value != nil {{
357	if err := serializer.SerializeOptionTag(true); err != nil {{ return err }}
358	{}
359}} else {{
360	if err := serializer.SerializeOptionTag(false); err != nil {{ return err }}
361}}
362"#,
363                    self.quote_serialize_value("(*value)", format)
364                )?;
365            }
366
367            Seq(format) => {
368                write!(
369                    self.out,
370                    r#"
371if err := serializer.SerializeLen(uint64(len(value))); err != nil {{ return err }}
372for _, item := range(value) {{
373	{}
374}}
375"#,
376                    self.quote_serialize_value("item", format)
377                )?;
378            }
379
380            Map { key, value } => {
381                write!(
382                    self.out,
383                    r#"
384if err := serializer.SerializeLen(uint64(len(value))); err != nil {{ return err }}
385offsets := make([]uint64, len(value))
386count := 0
387for k, v := range(value) {{
388	offsets[count] = serializer.GetBufferOffset()
389	count += 1
390	{}
391	{}
392}}
393serializer.SortMapEntries(offsets);
394"#,
395                    self.quote_serialize_value("k", key),
396                    self.quote_serialize_value("v", value)
397                )?;
398            }
399
400            Tuple(formats) => {
401                writeln!(self.out)?;
402                for (index, format) in formats.iter().enumerate() {
403                    let expr = format!("value.Field{index}");
404                    writeln!(self.out, "{}", self.quote_serialize_value(&expr, format))?;
405                }
406            }
407
408            TupleArray { content, size: _ } => {
409                write!(
410                    self.out,
411                    r#"
412for _, item := range(value) {{
413	{}
414}}
415"#,
416                    self.quote_serialize_value("item", content),
417                )?;
418            }
419
420            _ => panic!("unexpected case"),
421        }
422        writeln!(self.out, "return nil")?;
423        self.out.unindent();
424        writeln!(self.out, "}}\n")
425    }
426
427    fn output_deserialization_helper(&mut self, name: &str, format0: &Format) -> Result<()> {
428        use Format::*;
429
430        write!(
431            self.out,
432            "func deserialize_{}(deserializer serde.Deserializer) ({}, error) {{",
433            name,
434            self.quote_type(format0),
435        )?;
436        self.out.indent();
437        match format0 {
438            Option(format) => {
439                write!(
440                    self.out,
441                    r#"
442tag, err := deserializer.DeserializeOptionTag()
443if err != nil {{ return nil, err }}
444if tag {{
445	value := new({})
446	{}
447        return value, nil
448}} else {{
449	return nil, nil
450}}
451"#,
452                    self.quote_type(format),
453                    self.quote_deserialize(format, "*value", "nil"),
454                )?;
455            }
456
457            Seq(format) => {
458                write!(
459                    self.out,
460                    r#"
461length, err := deserializer.DeserializeLen()
462if err != nil {{ return nil, err }}
463obj := make([]{}, length)
464for i := range(obj) {{
465	{}
466}}
467return obj, nil
468"#,
469                    self.quote_type(format),
470                    self.quote_deserialize(format, "obj[i]", "nil")
471                )?;
472            }
473
474            Map { key, value } => {
475                write!(
476                    self.out,
477                    r#"
478length, err := deserializer.DeserializeLen()
479if err != nil {{ return nil, err }}
480obj := make(map[{0}]{1})
481previous_slice := serde.Slice {{ 0, 0 }}
482for i := 0; i < int(length); i++ {{
483	var slice serde.Slice
484	slice.Start = deserializer.GetBufferOffset()
485	var key {0}
486	{2}
487	slice.End = deserializer.GetBufferOffset()
488	if i > 0 {{
489		err := deserializer.CheckThatKeySlicesAreIncreasing(previous_slice, slice)
490		if err != nil {{ return nil, err }}
491	}}
492	previous_slice = slice
493	{3}
494}}
495return obj, nil
496"#,
497                    self.quote_type(key),
498                    self.quote_type(value),
499                    self.quote_deserialize(key, "key", "nil"),
500                    self.quote_deserialize(value, "obj[key]", "nil"),
501                )?;
502            }
503
504            Tuple(formats) => {
505                write!(
506                    self.out,
507                    r#"
508var obj {}
509{}
510return obj, nil
511"#,
512                    self.quote_type(format0),
513                    formats
514                        .iter()
515                        .enumerate()
516                        .map(|(i, f)| self.quote_deserialize(f, &format!("obj.Field{i}"), "obj"))
517                        .collect::<Vec<_>>()
518                        .join("\n")
519                )?;
520            }
521
522            TupleArray { content, size } => {
523                write!(
524                    self.out,
525                    r#"
526var obj [{1}]{0}
527for i := range(obj) {{
528	{2}
529}}
530return obj, nil
531"#,
532                    self.quote_type(content),
533                    size,
534                    self.quote_deserialize(content, "obj[i]", "obj")
535                )?;
536            }
537
538            _ => panic!("unexpected case"),
539        }
540        self.out.unindent();
541        writeln!(self.out, "}}\n")
542    }
543
544    fn output_variant(
545        &mut self,
546        base: &str,
547        index: u32,
548        name: &str,
549        variant: &VariantFormat,
550    ) -> Result<()> {
551        use VariantFormat::*;
552        let fields = match variant {
553            Unit => Vec::new(),
554            NewType(format) => match format.as_ref() {
555                // We cannot define a "new type" (e.g. `type Foo Bar`) out of a typename `Bar` because `Bar`
556                // could point to a Go interface. This would make `Foo` an interface as well. Interfaces can't be used
557                // as structs (e.g. they cannot have methods).
558                //
559                // Similarly, option types are compiled as pointers but `type Foo *Bar` would prevent `Foo` from being a
560                // valid pointer receiver.
561                Format::TypeName(_) | Format::Option(_) => vec![Named {
562                    name: "Value".to_string(),
563                    value: format.as_ref().clone(),
564                }],
565                // Other cases are fine.
566                _ => {
567                    self.output_struct_or_variant_new_type_container(
568                        Some(base),
569                        Some(index),
570                        name,
571                        format,
572                    )?;
573                    return Ok(());
574                }
575            },
576            Tuple(formats) => formats
577                .iter()
578                .enumerate()
579                .map(|(i, f)| Named {
580                    name: format!("Field{i}"),
581                    value: f.clone(),
582                })
583                .collect(),
584            Struct(fields) => fields
585                .iter()
586                .map(|f| Named {
587                    name: f.name.to_camel_case(),
588                    value: f.value.clone(),
589                })
590                .collect(),
591            Variable(_) => panic!("incorrect value"),
592        };
593        self.output_struct_or_variant_container(Some(base), Some(index), name, &fields)
594    }
595
596    fn output_struct_or_variant_container(
597        &mut self,
598        variant_base: Option<&str>,
599        variant_index: Option<u32>,
600        name: &str,
601        fields: &[Named<Format>],
602    ) -> Result<()> {
603        let full_name = match variant_base {
604            None => name.to_string(),
605            Some(base) => format!("{base}__{name}"),
606        };
607        // Struct
608        writeln!(self.out)?;
609        self.output_comment(name)?;
610        writeln!(self.out, "type {full_name} struct {{")?;
611        self.enter_class(name);
612        for field in fields {
613            self.output_comment(&field.name)?;
614            writeln!(self.out, "{} {}", field.name, self.quote_type(&field.value))?;
615        }
616        self.leave_class();
617        writeln!(self.out, "}}")?;
618
619        // Link to base interface.
620        if let Some(base) = variant_base {
621            writeln!(self.out, "\nfunc (*{full_name}) is{base}() {{}}")?;
622        }
623
624        // Serialize
625        if self.generator.config.serialization {
626            writeln!(
627                self.out,
628                "\nfunc (obj *{full_name}) Serialize(serializer serde.Serializer) error {{"
629            )?;
630            self.out.indent();
631            writeln!(
632                self.out,
633                "if err := serializer.IncreaseContainerDepth(); err != nil {{ return err }}"
634            )?;
635            if let Some(index) = variant_index {
636                writeln!(self.out, "serializer.SerializeVariantIndex({index})")?;
637            }
638            for field in fields {
639                writeln!(
640                    self.out,
641                    "{}",
642                    self.quote_serialize_value(&format!("obj.{}", &field.name), &field.value)
643                )?;
644            }
645            writeln!(self.out, "serializer.DecreaseContainerDepth()")?;
646            writeln!(self.out, "return nil")?;
647            self.out.unindent();
648            writeln!(self.out, "}}")?;
649
650            for encoding in &self.generator.config.encodings {
651                self.output_struct_serialize_for_encoding(&full_name, *encoding)?;
652            }
653        }
654        // Deserialize (struct) or Load (variant)
655        if self.generator.config.serialization {
656            writeln!(
657                self.out,
658                "\nfunc {0}{1}(deserializer serde.Deserializer) ({1}, error) {{",
659                if variant_base.is_none() {
660                    "Deserialize"
661                } else {
662                    "load_"
663                },
664                full_name,
665            )?;
666            self.out.indent();
667            writeln!(self.out, "var obj {full_name}")?;
668            writeln!(
669                self.out,
670                "if err := deserializer.IncreaseContainerDepth(); err != nil {{ return obj, err }}"
671            )?;
672            for field in fields {
673                writeln!(
674                    self.out,
675                    "{}",
676                    self.quote_deserialize(&field.value, &format!("obj.{}", field.name), "obj")
677                )?;
678            }
679            writeln!(self.out, "deserializer.DecreaseContainerDepth()")?;
680            writeln!(self.out, "return obj, nil")?;
681            self.out.unindent();
682            writeln!(self.out, "}}")?;
683
684            if variant_base.is_none() {
685                for encoding in &self.generator.config.encodings {
686                    self.output_struct_deserialize_for_encoding(&full_name, *encoding)?;
687                }
688            }
689        }
690        // Custom code
691        self.output_custom_code(name)?;
692        Ok(())
693    }
694
695    // Same as output_struct_or_variant_container but we map the container with a single anonymous field
696    // to a new type in Go.
697    fn output_struct_or_variant_new_type_container(
698        &mut self,
699        variant_base: Option<&str>,
700        variant_index: Option<u32>,
701        name: &str,
702        format: &Format,
703    ) -> Result<()> {
704        let full_name = match variant_base {
705            None => name.to_string(),
706            Some(base) => format!("{base}__{name}"),
707        };
708        // Struct
709        writeln!(self.out)?;
710        self.output_comment(name)?;
711        writeln!(self.out, "type {} {}", full_name, self.quote_type(format))?;
712
713        // Link to base interface.
714        if let Some(base) = variant_base {
715            writeln!(self.out, "\nfunc (*{full_name}) is{base}() {{}}")?;
716        }
717
718        // Serialize
719        if self.generator.config.serialization {
720            writeln!(
721                self.out,
722                "\nfunc (obj *{full_name}) Serialize(serializer serde.Serializer) error {{"
723            )?;
724            self.out.indent();
725            writeln!(
726                self.out,
727                "if err := serializer.IncreaseContainerDepth(); err != nil {{ return err }}"
728            )?;
729            if let Some(index) = variant_index {
730                writeln!(self.out, "serializer.SerializeVariantIndex({index})")?;
731            }
732            writeln!(
733                self.out,
734                "{}",
735                self.quote_serialize_value(
736                    &format!("(({})(*obj))", self.quote_type(format)),
737                    format
738                )
739            )?;
740            writeln!(self.out, "serializer.DecreaseContainerDepth()")?;
741            writeln!(self.out, "return nil")?;
742            self.out.unindent();
743            writeln!(self.out, "}}")?;
744
745            for encoding in &self.generator.config.encodings {
746                self.output_struct_serialize_for_encoding(&full_name, *encoding)?;
747            }
748        }
749        // Deserialize (struct) or Load (variant)
750        if self.generator.config.serialization {
751            writeln!(
752                self.out,
753                "\nfunc {0}{1}(deserializer serde.Deserializer) ({1}, error) {{",
754                if variant_base.is_none() {
755                    "Deserialize"
756                } else {
757                    "load_"
758                },
759                full_name,
760            )?;
761            self.out.indent();
762            writeln!(self.out, "var obj {}", self.quote_type(format))?;
763            writeln!(self.out, "if err := deserializer.IncreaseContainerDepth(); err != nil {{ return ({full_name})(obj), err }}")?;
764            writeln!(
765                self.out,
766                "{}",
767                self.quote_deserialize(format, "obj", &format!("(({full_name})(obj))"))
768            )?;
769            writeln!(self.out, "deserializer.DecreaseContainerDepth()")?;
770            writeln!(self.out, "return ({full_name})(obj), nil")?;
771            self.out.unindent();
772            writeln!(self.out, "}}")?;
773
774            if variant_base.is_none() {
775                for encoding in &self.generator.config.encodings {
776                    self.output_struct_deserialize_for_encoding(&full_name, *encoding)?;
777                }
778            }
779        }
780        // Custom code
781        self.output_custom_code(name)?;
782        Ok(())
783    }
784
785    fn output_struct_serialize_for_encoding(
786        &mut self,
787        name: &str,
788        encoding: Encoding,
789    ) -> Result<()> {
790        writeln!(
791            self.out,
792            r#"
793func (obj *{0}) {2}Serialize() ([]byte, error) {{
794	if obj == nil {{
795		return nil, fmt.Errorf("Cannot serialize null object")
796	}}
797	serializer := {1}.NewSerializer();
798	if err := obj.Serialize(serializer); err != nil {{ return nil, err }}
799	return serializer.GetBytes(), nil
800}}"#,
801            name,
802            encoding.name(),
803            encoding.name().to_camel_case()
804        )
805    }
806
807    fn output_struct_deserialize_for_encoding(
808        &mut self,
809        name: &str,
810        encoding: Encoding,
811    ) -> Result<()> {
812        writeln!(
813            self.out,
814            r#"
815func {2}Deserialize{0}(input []byte) ({0}, error) {{
816	if input == nil {{
817		var obj {0}
818		return obj, fmt.Errorf("Cannot deserialize null array")
819	}}
820	deserializer := {1}.NewDeserializer(input);
821	obj, err := Deserialize{0}(deserializer)
822	if err == nil && deserializer.GetBufferOffset() < uint64(len(input)) {{
823		return obj, fmt.Errorf("Some input bytes were not read")
824	}}
825	return obj, err
826}}"#,
827            name,
828            encoding.name(),
829            encoding.name().to_camel_case(),
830        )
831    }
832
833    fn output_enum_container(
834        &mut self,
835        name: &str,
836        variants: &BTreeMap<u32, Named<VariantFormat>>,
837    ) -> Result<()> {
838        writeln!(self.out)?;
839        self.output_comment(name)?;
840        writeln!(self.out, "type {name} interface {{")?;
841        self.current_namespace.push(name.to_string());
842        self.out.indent();
843        writeln!(self.out, "is{name}()")?;
844        if self.generator.config.serialization {
845            writeln!(self.out, "Serialize(serializer serde.Serializer) error")?;
846            for encoding in &self.generator.config.encodings {
847                writeln!(
848                    self.out,
849                    "{}Serialize() ([]byte, error)",
850                    encoding.name().to_camel_case()
851                )?;
852            }
853        }
854        self.out.unindent();
855        writeln!(self.out, "}}")?;
856
857        if self.generator.config.serialization {
858            write!(
859                self.out,
860                "\nfunc Deserialize{name}(deserializer serde.Deserializer) ({name}, error) {{"
861            )?;
862            self.out.indent();
863            writeln!(
864                self.out,
865                r#"
866index, err := deserializer.DeserializeVariantIndex()
867if err != nil {{ return nil, err }}
868
869switch index {{"#,
870            )?;
871            for (index, variant) in variants {
872                writeln!(
873                    self.out,
874                    r#"case {}:
875	if val, err := load_{}__{}(deserializer); err == nil {{
876		return &val, nil
877	}} else {{
878		return nil, err
879	}}
880"#,
881                    index, name, variant.name
882                )?;
883            }
884            writeln!(
885                self.out,
886                "default:
887	return nil, fmt.Errorf(\"Unknown variant index for {name}: %d\", index)",
888            )?;
889            writeln!(self.out, "}}")?;
890            self.out.unindent();
891            writeln!(self.out, "}}")?;
892
893            for encoding in &self.generator.config.encodings {
894                self.output_struct_deserialize_for_encoding(name, *encoding)?;
895            }
896        }
897
898        for (index, variant) in variants {
899            self.output_variant(name, *index, &variant.name, &variant.value)?;
900        }
901        self.current_namespace.pop();
902        // Custom code
903        self.output_custom_code(name)?;
904        Ok(())
905    }
906
907    fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
908        use ContainerFormat::*;
909        let fields = match format {
910            UnitStruct => Vec::new(),
911            NewTypeStruct(format) => match format.as_ref() {
912                // See comment in `output_variant`.
913                Format::TypeName(_) | Format::Option(_) => vec![Named {
914                    name: "Value".to_string(),
915                    value: format.as_ref().clone(),
916                }],
917                _ => {
918                    self.output_struct_or_variant_new_type_container(None, None, name, format)?;
919                    return Ok(());
920                }
921            },
922            TupleStruct(formats) => formats
923                .iter()
924                .enumerate()
925                .map(|(i, f)| Named {
926                    name: format!("Field{i}"),
927                    value: f.clone(),
928                })
929                .collect(),
930            Struct(fields) => fields
931                .iter()
932                .map(|f| Named {
933                    name: f.name.to_camel_case(),
934                    value: f.value.clone(),
935                })
936                .collect(),
937            Enum(variants) => {
938                let variants = variants
939                    .iter()
940                    .map(|(i, f)| {
941                        (
942                            *i,
943                            Named {
944                                name: f.name.to_camel_case(),
945                                value: f.value.clone(),
946                            },
947                        )
948                    })
949                    .collect();
950                self.output_enum_container(name, &variants)?;
951                return Ok(());
952            }
953        };
954        self.output_struct_or_variant_container(None, None, name, &fields)
955    }
956}
957
958/// Installer for generated source files in Go.
959pub struct Installer {
960    install_dir: PathBuf,
961    serde_module_path: Option<String>,
962}
963
964impl Installer {
965    pub fn new(install_dir: PathBuf, serde_module_path: Option<String>) -> Self {
966        Installer {
967            install_dir,
968            serde_module_path,
969        }
970    }
971
972    fn runtime_installation_message(&self, name: &str) {
973        eprintln!(
974            "Not installing sources for published package {}{}",
975            match &self.serde_module_path {
976                None => String::new(),
977                Some(path) => format!("{path}/"),
978            },
979            name
980        );
981    }
982}
983
984impl crate::SourceInstaller for Installer {
985    type Error = Box<dyn std::error::Error>;
986
987    fn install_module(
988        &self,
989        config: &CodeGeneratorConfig,
990        registry: &Registry,
991    ) -> std::result::Result<(), Self::Error> {
992        let dir_path = self.install_dir.join(&config.module_name);
993        std::fs::create_dir_all(&dir_path)?;
994        let source_path = dir_path.join("lib.go");
995        let mut file = std::fs::File::create(source_path)?;
996
997        let mut generator = CodeGenerator::new(config);
998        if let Some(path) = &self.serde_module_path {
999            generator = generator.with_serde_module_path(path.clone());
1000        }
1001        generator.output(&mut file, registry)?;
1002        Ok(())
1003    }
1004
1005    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
1006        self.runtime_installation_message("serde");
1007        Ok(())
1008    }
1009
1010    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
1011        self.runtime_installation_message("bincode");
1012        Ok(())
1013    }
1014
1015    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
1016        self.runtime_installation_message("bcs");
1017        Ok(())
1018    }
1019}