Skip to main content

serde_generate/
kotlin.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    common,
6    indent::{IndentConfig, IndentedWriter},
7    CodeGeneratorConfig, Encoding,
8};
9use heck::CamelCase;
10use include_dir::include_dir as include_directory;
11use serde_reflection::{ContainerFormat, Format, FormatHolder, Named, Registry, VariantFormat};
12use std::{
13    collections::{BTreeMap, HashMap},
14    io::{Result, Write},
15    path::PathBuf,
16};
17
18/// Main configuration object for code-generation in Kotlin.
19pub struct CodeGenerator<'a> {
20    /// Language-independent configuration.
21    config: &'a CodeGeneratorConfig,
22    /// Mapping from external type names to fully-qualified class names.
23    /// Derived from `config.external_definitions`.
24    external_qualified_names: HashMap<String, String>,
25}
26
27/// Shared state for the code generation of a Kotlin source file.
28struct KotlinEmitter<'a, T> {
29    /// Writer.
30    out: IndentedWriter<T>,
31    /// Generator.
32    generator: &'a CodeGenerator<'a>,
33    /// Current namespace (e.g. vec!["com", "my_org", "my_package", "MyClass"])
34    current_namespace: Vec<String>,
35}
36
37impl<'a> CodeGenerator<'a> {
38    /// Create a Kotlin code generator for the given config.
39    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
40        if config.enums.c_style {
41            panic!("Kotlin does not support generating c-style enums");
42        }
43        let mut external_qualified_names = HashMap::new();
44        for (namespace, names) in &config.external_definitions {
45            for name in names {
46                external_qualified_names.insert(name.to_string(), format!("{namespace}.{name}"));
47            }
48        }
49        Self {
50            config,
51            external_qualified_names,
52        }
53    }
54
55    /// Output class definitions for `registry` in separate source files.
56    /// Source files will be created in a subdirectory of `install_dir` corresponding to the
57    /// package name (if any, otherwise `install_dir` itself).
58    pub fn write_source_files(
59        &self,
60        install_dir: std::path::PathBuf,
61        registry: &Registry,
62    ) -> Result<()> {
63        let current_namespace = self
64            .config
65            .module_name
66            .split('.')
67            .map(String::from)
68            .collect::<Vec<_>>();
69
70        let mut dir_path = install_dir;
71        for part in &current_namespace {
72            dir_path = dir_path.join(part);
73        }
74        std::fs::create_dir_all(&dir_path)?;
75
76        for (name, format) in registry {
77            self.write_container_class(&dir_path, current_namespace.clone(), name, format)?;
78        }
79        if self.config.serialization {
80            self.write_helper_class(&dir_path, current_namespace, registry)?;
81        }
82        Ok(())
83    }
84
85    fn write_container_class(
86        &self,
87        dir_path: &std::path::Path,
88        current_namespace: Vec<String>,
89        name: &str,
90        format: &ContainerFormat,
91    ) -> Result<()> {
92        let mut file = std::fs::File::create(dir_path.join(name.to_string() + ".kt"))?;
93        let mut emitter = KotlinEmitter {
94            out: IndentedWriter::new(&mut file, IndentConfig::Space(4)),
95            generator: self,
96            current_namespace,
97        };
98
99        emitter.output_preamble()?;
100        emitter.output_container(name, format)
101    }
102
103    fn write_helper_class(
104        &self,
105        dir_path: &std::path::Path,
106        current_namespace: Vec<String>,
107        registry: &Registry,
108    ) -> Result<()> {
109        let mut file = std::fs::File::create(dir_path.join("TraitHelpers.kt"))?;
110        let mut emitter = KotlinEmitter {
111            out: IndentedWriter::new(&mut file, IndentConfig::Space(4)),
112            generator: self,
113            current_namespace,
114        };
115
116        emitter.output_preamble()?;
117        emitter.output_trait_helpers(registry)
118    }
119}
120
121impl<'a, T> KotlinEmitter<'a, T>
122where
123    T: Write,
124{
125    fn output_preamble(&mut self) -> Result<()> {
126        writeln!(self.out, "package {}\n", self.generator.config.module_name)?;
127        Ok(())
128    }
129
130    /// Compute a reference to the registry type `name`.
131    fn quote_qualified_name(&self, name: &str) -> String {
132        self.generator
133            .external_qualified_names
134            .get(name)
135            .cloned()
136            .unwrap_or_else(|| format!("{}.{}", self.generator.config.module_name, name))
137    }
138
139    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
140        let mut path = self.current_namespace.clone();
141        path.push(name.to_string());
142        if let Some(doc) = self.generator.config.comments.get(&path) {
143            let text = textwrap::indent(doc, "// ").replace("\n\n", "\n//\n");
144            write!(self.out, "{text}")?;
145        }
146        Ok(())
147    }
148
149    fn output_custom_code(&mut self) -> std::io::Result<()> {
150        if let Some(code) = self
151            .generator
152            .config
153            .custom_code
154            .get(&self.current_namespace)
155        {
156            writeln!(self.out, "\n{code}")?;
157        }
158        Ok(())
159    }
160
161    fn quote_type(&self, format: &Format) -> String {
162        use Format::*;
163        match format {
164            TypeName(x) => self.quote_qualified_name(x),
165            Unit => "Unit".into(),
166            Bool => "Boolean".into(),
167            I8 => "Byte".into(),
168            I16 => "Short".into(),
169            I32 => "Int".into(),
170            I64 => "Long".into(),
171            I128 => "com.novi.serde.Int128".into(),
172            U8 => "UByte".into(),
173            U16 => "UShort".into(),
174            U32 => "UInt".into(),
175            U64 => "ULong".into(),
176            U128 => "com.novi.serde.UInt128".into(),
177            F32 => "Float".into(),
178            F64 => "Double".into(),
179            Char => "Char".into(),
180            Str => "String".into(),
181            Bytes => "com.novi.serde.Bytes".into(),
182
183            Option(format) => {
184                let inner = self.quote_type(format);
185                if inner.ends_with('?') {
186                    inner
187                } else {
188                    format!("{inner}?")
189                }
190            }
191            Seq(format) => format!("kotlin.collections.List<{}>", self.quote_type(format)),
192            Map { key, value } => format!(
193                "kotlin.collections.Map<{}, {}>",
194                self.quote_type(key),
195                self.quote_type(value)
196            ),
197            Tuple(formats) => match formats.len() {
198                2 => format!(
199                    "Pair<{}, {}>",
200                    self.quote_type(&formats[0]),
201                    self.quote_type(&formats[1])
202                ),
203                3 => format!(
204                    "Triple<{}, {}, {}>",
205                    self.quote_type(&formats[0]),
206                    self.quote_type(&formats[1]),
207                    self.quote_type(&formats[2])
208                ),
209                _ => format!(
210                    "com.novi.serde.Tuple{}<{}>",
211                    formats.len(),
212                    self.quote_types(formats)
213                ),
214            },
215            TupleArray { content, size: _ } => {
216                format!("kotlin.collections.List<{}>", self.quote_type(content))
217            }
218
219            Variable(_) => panic!("unexpected value"),
220        }
221    }
222
223    fn quote_types<'b, I>(&'b self, formats: I) -> String
224    where
225        I: IntoIterator<Item = &'b Format>,
226    {
227        formats
228            .into_iter()
229            .map(|format| self.quote_type(format))
230            .collect::<Vec<_>>()
231            .join(", ")
232    }
233
234    fn enter_class(&mut self, name: &str) {
235        self.out.indent();
236        self.current_namespace.push(name.to_string());
237    }
238
239    fn leave_class(&mut self) {
240        self.out.unindent();
241        self.current_namespace.pop();
242    }
243
244    fn output_trait_helpers(&mut self, registry: &Registry) -> Result<()> {
245        let mut subtypes = BTreeMap::new();
246        for format in registry.values() {
247            format
248                .visit(&mut |f| {
249                    if Self::needs_helper(f) {
250                        subtypes.insert(common::mangle_type(f), f.clone());
251                    }
252                    Ok(())
253                })
254                .unwrap();
255        }
256        writeln!(self.out, "object TraitHelpers {{")?;
257        self.enter_class("TraitHelpers");
258        for (mangled_name, subtype) in &subtypes {
259            self.output_serialization_helper(mangled_name, subtype)?;
260            self.output_deserialization_helper(mangled_name, subtype)?;
261        }
262        self.leave_class();
263        writeln!(self.out, "}}\n")
264    }
265
266    fn needs_helper(format: &Format) -> bool {
267        use Format::*;
268        matches!(
269            format,
270            Option(_) | Seq(_) | Map { .. } | Tuple(_) | TupleArray { .. }
271        )
272    }
273
274    fn quote_serialize_value(&self, value: &str, format: &Format) -> String {
275        use Format::*;
276        match format {
277            TypeName(_) => format!("{value}.serialize(serializer)"),
278            Unit => format!("serializer.serialize_unit({value})"),
279            Bool => format!("serializer.serialize_bool({value})"),
280            I8 => format!("serializer.serialize_i8({value})"),
281            I16 => format!("serializer.serialize_i16({value})"),
282            I32 => format!("serializer.serialize_i32({value})"),
283            I64 => format!("serializer.serialize_i64({value})"),
284            I128 => format!("serializer.serialize_i128({value})"),
285            U8 => format!("serializer.serialize_u8({value})"),
286            U16 => format!("serializer.serialize_u16({value})"),
287            U32 => format!("serializer.serialize_u32({value})"),
288            U64 => format!("serializer.serialize_u64({value})"),
289            U128 => format!("serializer.serialize_u128({value})"),
290            F32 => format!("serializer.serialize_f32({value})"),
291            F64 => format!("serializer.serialize_f64({value})"),
292            Char => format!("serializer.serialize_char({value})"),
293            Str => format!("serializer.serialize_str({value})"),
294            Bytes => format!("serializer.serialize_bytes({value})"),
295            _ => format!(
296                "TraitHelpers.serialize_{}({}, serializer)",
297                common::mangle_type(format),
298                value
299            ),
300        }
301    }
302
303    fn quote_deserialize(&self, format: &Format) -> String {
304        use Format::*;
305        match format {
306            TypeName(name) => format!(
307                "{}.deserialize(deserializer)",
308                self.quote_qualified_name(name)
309            ),
310            Unit => "deserializer.deserialize_unit()".to_string(),
311            Bool => "deserializer.deserialize_bool()".to_string(),
312            I8 => "deserializer.deserialize_i8()".to_string(),
313            I16 => "deserializer.deserialize_i16()".to_string(),
314            I32 => "deserializer.deserialize_i32()".to_string(),
315            I64 => "deserializer.deserialize_i64()".to_string(),
316            I128 => "deserializer.deserialize_i128()".to_string(),
317            U8 => "deserializer.deserialize_u8()".to_string(),
318            U16 => "deserializer.deserialize_u16()".to_string(),
319            U32 => "deserializer.deserialize_u32()".to_string(),
320            U64 => "deserializer.deserialize_u64()".to_string(),
321            U128 => "deserializer.deserialize_u128()".to_string(),
322            F32 => "deserializer.deserialize_f32()".to_string(),
323            F64 => "deserializer.deserialize_f64()".to_string(),
324            Char => "deserializer.deserialize_char()".to_string(),
325            Str => "deserializer.deserialize_str()".to_string(),
326            Bytes => "deserializer.deserialize_bytes()".to_string(),
327            _ => format!(
328                "TraitHelpers.deserialize_{}(deserializer)",
329                common::mangle_type(format),
330            ),
331        }
332    }
333
334    fn output_serialization_helper(&mut self, name: &str, format0: &Format) -> Result<()> {
335        use Format::*;
336
337        write!(
338            self.out,
339            "@Throws(com.novi.serde.SerializationError::class)\nfun serialize_{}(value: {}, serializer: com.novi.serde.Serializer) {{",
340            name,
341            self.quote_type(format0)
342        )?;
343        self.out.indent();
344        match format0 {
345            Option(format) => {
346                write!(
347                    self.out,
348                    r#"
349if (value == null) {{
350    serializer.serialize_option_tag(false)
351}} else {{
352    serializer.serialize_option_tag(true)
353    {}
354}}
355"#,
356                    self.quote_serialize_value("value", format)
357                )?;
358            }
359
360            Seq(format) => {
361                write!(
362                    self.out,
363                    r#"
364serializer.serialize_len(value.size.toLong())
365for (item in value) {{
366    {}
367}}
368"#,
369                    self.quote_serialize_value("item", format)
370                )?;
371            }
372
373            Map { key, value } => {
374                write!(
375                    self.out,
376                    r#"
377serializer.serialize_len(value.size.toLong())
378val offsets = IntArray(value.size)
379var count = 0
380for (entry in value.entries) {{
381    offsets[count++] = serializer.get_buffer_offset()
382    {}
383    {}
384}}
385serializer.sort_map_entries(offsets)
386"#,
387                    self.quote_serialize_value("entry.key", key),
388                    self.quote_serialize_value("entry.value", value)
389                )?;
390            }
391
392            Tuple(formats) => {
393                writeln!(self.out)?;
394                for (index, format) in formats.iter().enumerate() {
395                    let expr = match formats.len() {
396                        2 => match index {
397                            0 => "value.first".to_string(),
398                            1 => "value.second".to_string(),
399                            _ => unreachable!(),
400                        },
401                        3 => match index {
402                            0 => "value.first".to_string(),
403                            1 => "value.second".to_string(),
404                            2 => "value.third".to_string(),
405                            _ => unreachable!(),
406                        },
407                        _ => format!("value.field{index}"),
408                    };
409                    writeln!(self.out, "{}", self.quote_serialize_value(&expr, format))?;
410                }
411            }
412
413            TupleArray { content, size } => {
414                write!(
415                    self.out,
416                    r#"
417if (value.size != {0}) {{
418    throw IllegalArgumentException("Invalid length for fixed-size array: " + value.size + " instead of " + {0})
419}}
420for (item in value) {{
421    {1}
422}}
423"#,
424                    size,
425                    self.quote_serialize_value("item", content),
426                )?;
427            }
428
429            _ => panic!("unexpected case"),
430        }
431        self.out.unindent();
432        writeln!(self.out, "}}\n")
433    }
434
435    fn output_deserialization_helper(&mut self, name: &str, format0: &Format) -> Result<()> {
436        use Format::*;
437
438        write!(
439            self.out,
440            "@Throws(com.novi.serde.DeserializationError::class)\nfun deserialize_{}(deserializer: com.novi.serde.Deserializer): {} {{",
441            name,
442            self.quote_type(format0),
443        )?;
444        self.out.indent();
445        match format0 {
446            Option(format) => {
447                write!(
448                    self.out,
449                    r#"
450val tag = deserializer.deserialize_option_tag()
451return if (!tag) {{
452    null
453}} else {{
454    {}
455}}
456"#,
457                    self.quote_deserialize(format),
458                )?;
459            }
460
461            Seq(format) => {
462                write!(
463                    self.out,
464                    r#"
465val length = deserializer.deserialize_len()
466val obj = ArrayList<{0}>(length.toInt())
467var i = 0L
468while (i < length) {{
469    obj.add({1})
470    i += 1
471}}
472return obj
473"#,
474                    self.quote_type(format),
475                    self.quote_deserialize(format)
476                )?;
477            }
478
479            Map { key, value } => {
480                write!(
481                    self.out,
482                    r#"
483val length = deserializer.deserialize_len()
484val obj = HashMap<{0}, {1}>()
485var previousKeyStart = 0
486var previousKeyEnd = 0
487var i = 0L
488while (i < length) {{
489    val keyStart = deserializer.get_buffer_offset()
490    val key = {2}
491    val keyEnd = deserializer.get_buffer_offset()
492    if (i > 0) {{
493        deserializer.check_that_key_slices_are_increasing(
494            com.novi.serde.Slice(previousKeyStart, previousKeyEnd),
495            com.novi.serde.Slice(keyStart, keyEnd))
496    }}
497    previousKeyStart = keyStart
498    previousKeyEnd = keyEnd
499    val value = {3}
500    obj[key] = value
501    i += 1
502}}
503return obj
504"#,
505                    self.quote_type(key),
506                    self.quote_type(value),
507                    self.quote_deserialize(key),
508                    self.quote_deserialize(value),
509                )?;
510            }
511
512            Tuple(formats) => {
513                let constructor = match formats.len() {
514                    2 => "Pair".to_string(),
515                    3 => "Triple".to_string(),
516                    _ => self.quote_type(format0),
517                };
518                write!(
519                    self.out,
520                    r#"
521return {0}({1}
522)
523"#,
524                    constructor,
525                    formats
526                        .iter()
527                        .map(|f| format!("\n    {}", self.quote_deserialize(f)))
528                        .collect::<Vec<_>>()
529                        .join(",")
530                )?;
531            }
532
533            TupleArray { content, size } => {
534                write!(
535                    self.out,
536                    r#"
537val obj = ArrayList<{0}>({1})
538for (i in 0 until {1}) {{
539    obj.add({2})
540}}
541return obj
542"#,
543                    self.quote_type(content),
544                    size,
545                    self.quote_deserialize(content)
546                )?;
547            }
548
549            _ => panic!("unexpected case"),
550        }
551        self.out.unindent();
552        writeln!(self.out, "}}\n")
553    }
554
555    fn output_variant(
556        &mut self,
557        base: &str,
558        index: u32,
559        name: &str,
560        variant: &VariantFormat,
561    ) -> Result<()> {
562        use VariantFormat::*;
563        let fields = match variant {
564            Unit => Vec::new(),
565            NewType(format) => vec![Named {
566                name: "value".to_string(),
567                value: format.as_ref().clone(),
568            }],
569            Tuple(formats) => formats
570                .iter()
571                .enumerate()
572                .map(|(i, f)| Named {
573                    name: format!("field{i}"),
574                    value: f.clone(),
575                })
576                .collect(),
577            Struct(fields) => fields.clone(),
578            Variable(_) => panic!("incorrect value"),
579        };
580        self.output_struct_or_variant_container(Some(base), Some(index), name, &fields)
581    }
582
583    fn output_variants(
584        &mut self,
585        base: &str,
586        variants: &BTreeMap<u32, Named<VariantFormat>>,
587    ) -> Result<()> {
588        for (index, variant) in variants {
589            self.output_variant(base, *index, &variant.name, &variant.value)?;
590        }
591        Ok(())
592    }
593
594    fn output_fields_in_constructor(
595        &mut self,
596        class_name: &str,
597        fields: &[Named<Format>],
598    ) -> Result<()> {
599        self.out.indent();
600        let mut base_path = self.current_namespace.clone();
601        base_path.push(class_name.to_string());
602        for (index, field) in fields.iter().enumerate() {
603            let mut path = base_path.clone();
604            path.push(field.name.to_string());
605            if let Some(doc) = self.generator.config.comments.get(&path) {
606                let text = textwrap::indent(doc, "// ").replace("\n\n", "\n//\n");
607                write!(self.out, "{text}")?;
608            }
609            let separator = if index + 1 == fields.len() { "" } else { "," };
610            writeln!(
611                self.out,
612                "val {}: {}{}",
613                field.name,
614                self.quote_type(&field.value),
615                separator
616            )?;
617        }
618        self.out.unindent();
619        Ok(())
620    }
621
622    fn output_struct_or_variant_container(
623        &mut self,
624        variant_base: Option<&str>,
625        variant_index: Option<u32>,
626        name: &str,
627        fields: &[Named<Format>],
628    ) -> Result<()> {
629        writeln!(self.out)?;
630        self.output_comment(name)?;
631        match (variant_base, fields.is_empty()) {
632            (Some(base), true) => {
633                writeln!(self.out, "object {name} : {base}() {{")?;
634            }
635            (Some(base), false) => {
636                writeln!(self.out, "data class {name}(")?;
637                self.output_fields_in_constructor(name, fields)?;
638                writeln!(self.out, ") : {base}() {{")?;
639            }
640            (None, true) => {
641                writeln!(self.out, "class {name} {{")?;
642            }
643            (None, false) => {
644                writeln!(self.out, "data class {name}(")?;
645                self.output_fields_in_constructor(name, fields)?;
646                writeln!(self.out, ") {{")?;
647            }
648        }
649        self.enter_class(name);
650
651        if self.generator.config.serialization {
652            let prefix = if variant_index.is_some() {
653                "override "
654            } else {
655                ""
656            };
657            writeln!(
658                self.out,
659                "\n@Throws(com.novi.serde.SerializationError::class)\n{prefix}fun serialize(serializer: com.novi.serde.Serializer) {{"
660            )?;
661            self.out.indent();
662            writeln!(self.out, "serializer.increase_container_depth()")?;
663            if let Some(index) = variant_index {
664                writeln!(self.out, "serializer.serialize_variant_index({index})")?;
665            }
666            for field in fields {
667                writeln!(
668                    self.out,
669                    "{}",
670                    self.quote_serialize_value(&format!("this.{}", field.name), &field.value)
671                )?;
672            }
673            writeln!(self.out, "serializer.decrease_container_depth()")?;
674            self.out.unindent();
675            writeln!(self.out, "}}")?;
676
677            if variant_index.is_none() {
678                for encoding in &self.generator.config.encodings {
679                    self.output_class_serialize_for_encoding(*encoding)?;
680                }
681            }
682        }
683
684        if self.generator.config.serialization {
685            if variant_index.is_some() {
686                if fields.is_empty() {
687                    writeln!(
688                        self.out,
689                        "\n@Throws(com.novi.serde.DeserializationError::class)\nfun load(deserializer: com.novi.serde.Deserializer): {name} {{"
690                    )?;
691                    self.out.indent();
692                    writeln!(self.out, "deserializer.increase_container_depth()")?;
693                    writeln!(self.out, "deserializer.decrease_container_depth()")?;
694                    writeln!(self.out, "return {name}")?;
695                    self.out.unindent();
696                    writeln!(self.out, "}}")?;
697                } else {
698                    writeln!(self.out, "\ncompanion object {{")?;
699                    self.out.indent();
700                    writeln!(
701                        self.out,
702                        "@Throws(com.novi.serde.DeserializationError::class)\nfun load(deserializer: com.novi.serde.Deserializer): {name} {{"
703                    )?;
704                    self.out.indent();
705                    writeln!(self.out, "deserializer.increase_container_depth()")?;
706                    for field in fields {
707                        writeln!(
708                            self.out,
709                            "val {} = {}",
710                            field.name,
711                            self.quote_deserialize(&field.value)
712                        )?;
713                    }
714                    writeln!(self.out, "deserializer.decrease_container_depth()")?;
715                    let result = format!(
716                        "{}({})",
717                        name,
718                        fields
719                            .iter()
720                            .map(|f| f.name.to_string())
721                            .collect::<Vec<_>>()
722                            .join(", ")
723                    );
724                    writeln!(self.out, "return {result}")?;
725                    self.out.unindent();
726                    writeln!(self.out, "}}")?;
727                    self.out.unindent();
728                    writeln!(self.out, "}}")?;
729                }
730            } else {
731                writeln!(self.out, "\ncompanion object {{")?;
732                self.out.indent();
733                writeln!(
734                    self.out,
735                    "@Throws(com.novi.serde.DeserializationError::class)\nfun deserialize(deserializer: com.novi.serde.Deserializer): {name} {{"
736                )?;
737                self.out.indent();
738                writeln!(self.out, "deserializer.increase_container_depth()")?;
739                for field in fields {
740                    writeln!(
741                        self.out,
742                        "val {} = {}",
743                        field.name,
744                        self.quote_deserialize(&field.value)
745                    )?;
746                }
747                writeln!(self.out, "deserializer.decrease_container_depth()")?;
748                let result = if fields.is_empty() {
749                    format!("{name}()")
750                } else {
751                    format!(
752                        "{}({})",
753                        name,
754                        fields
755                            .iter()
756                            .map(|f| f.name.to_string())
757                            .collect::<Vec<_>>()
758                            .join(", ")
759                    )
760                };
761                writeln!(self.out, "return {result}")?;
762                self.out.unindent();
763                writeln!(self.out, "}}")?;
764
765                for encoding in &self.generator.config.encodings {
766                    self.output_class_deserialize_for_encoding(name, *encoding)?;
767                }
768                self.out.unindent();
769                writeln!(self.out, "}}")?;
770            }
771        }
772
773        if variant_base.is_none() && fields.is_empty() {
774            writeln!(
775                self.out,
776                r#"
777override fun equals(other: Any?): Boolean {{
778    return other is {name}
779}}
780
781override fun hashCode(): Int {{
782    return 7
783}}"#
784            )?;
785        }
786
787        self.output_custom_code()?;
788        self.leave_class();
789        writeln!(self.out, "}}")
790    }
791
792    fn output_enum_container(
793        &mut self,
794        name: &str,
795        variants: &BTreeMap<u32, Named<VariantFormat>>,
796    ) -> Result<()> {
797        writeln!(self.out)?;
798        self.output_comment(name)?;
799        writeln!(self.out, "sealed class {name} {{")?;
800        self.enter_class(name);
801        if self.generator.config.serialization {
802            writeln!(
803                self.out,
804                "@Throws(com.novi.serde.SerializationError::class)\nabstract fun serialize(serializer: com.novi.serde.Serializer)"
805            )?;
806            writeln!(self.out, "\ncompanion object {{")?;
807            self.out.indent();
808            writeln!(
809                self.out,
810                "@Throws(com.novi.serde.DeserializationError::class)\nfun deserialize(deserializer: com.novi.serde.Deserializer): {name} {{"
811            )?;
812            self.out.indent();
813            writeln!(
814                self.out,
815                "val index = deserializer.deserialize_variant_index()"
816            )?;
817            writeln!(self.out, "return when (index) {{")?;
818            self.out.indent();
819            for (index, variant) in variants {
820                writeln!(self.out, "{} -> {}.load(deserializer)", index, variant.name,)?;
821            }
822            writeln!(
823                self.out,
824                "else -> throw com.novi.serde.DeserializationError(\"Unknown variant index for {name}: \" + index)"
825            )?;
826            self.out.unindent();
827            writeln!(self.out, "}}")?;
828            self.out.unindent();
829            writeln!(self.out, "}}")?;
830            for encoding in &self.generator.config.encodings {
831                self.output_class_deserialize_for_encoding(name, *encoding)?;
832            }
833            self.out.unindent();
834            writeln!(self.out, "}}")?;
835
836            for encoding in &self.generator.config.encodings {
837                self.output_class_serialize_for_encoding(*encoding)?;
838            }
839        }
840
841        self.output_variants(name, variants)?;
842        self.output_custom_code()?;
843        self.leave_class();
844        writeln!(self.out, "}}\n")
845    }
846
847    fn output_class_serialize_for_encoding(&mut self, encoding: Encoding) -> Result<()> {
848        writeln!(
849            self.out,
850            r#"
851@Throws(com.novi.serde.SerializationError::class)
852fun {0}Serialize(): ByteArray {{
853    val serializer = com.novi.{0}.{1}Serializer()
854    serialize(serializer)
855    return serializer.get_bytes()
856}}"#,
857            encoding.name(),
858            encoding.name().to_camel_case()
859        )
860    }
861
862    fn output_class_deserialize_for_encoding(
863        &mut self,
864        name: &str,
865        encoding: Encoding,
866    ) -> Result<()> {
867        writeln!(
868            self.out,
869            r#"
870@Throws(com.novi.serde.DeserializationError::class)
871fun {1}Deserialize(input: ByteArray): {0} {{
872    val deserializer = com.novi.{1}.{2}Deserializer(input)
873    val value = deserialize(deserializer)
874    if (deserializer.get_buffer_offset() < input.size) {{
875        throw com.novi.serde.DeserializationError("Some input bytes were not read")
876    }}
877    return value
878}}"#,
879            name,
880            encoding.name(),
881            encoding.name().to_camel_case()
882        )
883    }
884
885    fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
886        use ContainerFormat::*;
887        let fields = match format {
888            UnitStruct => Vec::new(),
889            NewTypeStruct(format) => vec![Named {
890                name: "value".to_string(),
891                value: format.as_ref().clone(),
892            }],
893            TupleStruct(formats) => formats
894                .iter()
895                .enumerate()
896                .map(|(i, f)| Named {
897                    name: format!("field{i}"),
898                    value: f.clone(),
899                })
900                .collect::<Vec<_>>(),
901            Struct(fields) => fields.clone(),
902            Enum(variants) => {
903                self.output_enum_container(name, variants)?;
904                return Ok(());
905            }
906        };
907        self.output_struct_or_variant_container(None, None, name, &fields)
908    }
909}
910
911/// Installer for generated source files in Kotlin.
912pub struct Installer {
913    install_dir: PathBuf,
914}
915
916impl Installer {
917    pub fn new(install_dir: PathBuf) -> Self {
918        Installer { install_dir }
919    }
920
921    fn install_runtime(
922        &self,
923        source_dir: include_dir::Dir,
924        path: &str,
925    ) -> std::result::Result<(), Box<dyn std::error::Error>> {
926        let dir_path = self.install_dir.join(path);
927        std::fs::create_dir_all(&dir_path)?;
928        for entry in source_dir.files() {
929            let mut file = std::fs::File::create(dir_path.join(entry.path()))?;
930            file.write_all(entry.contents())?;
931        }
932        Ok(())
933    }
934}
935
936impl crate::SourceInstaller for Installer {
937    type Error = Box<dyn std::error::Error>;
938
939    fn install_module(
940        &self,
941        config: &CodeGeneratorConfig,
942        registry: &Registry,
943    ) -> std::result::Result<(), Self::Error> {
944        let generator = CodeGenerator::new(config);
945        generator.write_source_files(self.install_dir.clone(), registry)?;
946        Ok(())
947    }
948
949    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
950        self.install_runtime(
951            include_directory!("runtime/kotlin/com/novi/serde"),
952            "com/novi/serde",
953        )
954    }
955
956    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
957        self.install_runtime(
958            include_directory!("runtime/kotlin/com/novi/bincode"),
959            "com/novi/bincode",
960        )
961    }
962
963    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
964        self.install_runtime(
965            include_directory!("runtime/kotlin/com/novi/bcs"),
966            "com/novi/bcs",
967        )
968    }
969}