Skip to main content

serde_generate/
swift.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4#![allow(dead_code)]
5
6use crate::{
7    common,
8    indent::{IndentConfig, IndentedWriter},
9    CodeGeneratorConfig, Encoding,
10};
11use heck::{CamelCase, MixedCase};
12use include_dir::include_dir as include_directory;
13use serde_reflection::{ContainerFormat, Format, FormatHolder, Named, Registry, VariantFormat};
14use std::{
15    collections::{BTreeMap, HashMap},
16    io::{Result, Write},
17    path::PathBuf,
18};
19
20/// Main configuration object for code-generation in Swift.
21pub struct CodeGenerator<'a> {
22    /// Language-independent configuration.
23    config: &'a CodeGeneratorConfig,
24    /// Mapping from external type names to fully-qualified class names (e.g. "MyClass" -> "com.my_org.my_package.MyClass").
25    /// Derived from `config.external_definitions`.
26    external_qualified_names: HashMap<String, String>,
27}
28
29/// Shared state for the code generation of a Swift source file.
30struct SwiftEmitter<'a, T> {
31    /// Writer.
32    out: IndentedWriter<T>,
33    /// Generator.
34    generator: &'a CodeGenerator<'a>,
35    /// Current namespace (e.g. vec!["Package", "MyClass"])
36    current_namespace: Vec<String>,
37}
38
39impl<'a> CodeGenerator<'a> {
40    /// Create a Swift code generator for the given config.
41    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
42        if config.enums.c_style {
43            panic!("Swift does not support generating c-style enums");
44        }
45        let mut external_qualified_names = HashMap::new();
46        for (namespace, names) in &config.external_definitions {
47            let package_name = {
48                let path = namespace.rsplitn(2, '/').collect::<Vec<_>>();
49                if path.len() <= 1 {
50                    namespace
51                } else {
52                    path[0]
53                }
54            };
55            for name in names {
56                external_qualified_names.insert(name.to_string(), format!("{package_name}.{name}"));
57            }
58        }
59        Self {
60            config,
61            external_qualified_names,
62        }
63    }
64
65    /// Output class definitions for `registry`.
66    pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
67        let current_namespace = self
68            .config
69            .module_name
70            .split('.')
71            .map(String::from)
72            .collect::<Vec<_>>();
73
74        let mut emitter = SwiftEmitter {
75            out: IndentedWriter::new(out, IndentConfig::Space(4)),
76            generator: self,
77            current_namespace,
78        };
79
80        emitter.output_preamble()?;
81
82        for (name, format) in registry {
83            emitter.output_container(name, format)?;
84        }
85
86        if self.config.serialization {
87            writeln!(emitter.out)?;
88            emitter.output_trait_helpers(registry)?;
89        }
90
91        Ok(())
92    }
93}
94
95impl<'a, T> SwiftEmitter<'a, T>
96where
97    T: Write,
98{
99    fn output_preamble(&mut self) -> Result<()> {
100        writeln!(self.out, "import Serde\n")?;
101        Ok(())
102    }
103
104    /// Compute a reference to the registry type `name`.
105    fn quote_qualified_name(&self, name: &str) -> String {
106        self.generator
107            .external_qualified_names
108            .get(name)
109            .cloned()
110            .unwrap_or_else(|| format!("{}.{}", self.generator.config.module_name, name))
111    }
112
113    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
114        let mut path = self.current_namespace.clone();
115        path.push(name.to_string());
116        if let Some(doc) = self.generator.config.comments.get(&path) {
117            let text = textwrap::indent(doc, "// ").replace("\n\n", "\n//\n");
118            write!(self.out, "{text}")?;
119        }
120        Ok(())
121    }
122
123    fn output_custom_code(&mut self, name: &str) -> std::io::Result<()> {
124        let mut path = self.current_namespace.clone();
125        path.push(name.to_string());
126        if let Some(code) = self.generator.config.custom_code.get(&path) {
127            writeln!(self.out, "\n{code}")?;
128        }
129        Ok(())
130    }
131
132    fn quote_type(&self, format: &Format) -> String {
133        use Format::*;
134        match format {
135            TypeName(x) => self.quote_qualified_name(x),
136            Unit => "Unit".into(),
137            Bool => "Bool".into(),
138            I8 => "Int8".into(),
139            I16 => "Int16".into(),
140            I32 => "Int32".into(),
141            I64 => "Int64".into(),
142            I128 => "Int128".into(),
143            U8 => "UInt8".into(),
144            U16 => "UInt16".into(),
145            U32 => "UInt32".into(),
146            U64 => "UInt64".into(),
147            U128 => "UInt128".into(),
148            F32 => "Float".into(),
149            F64 => "Double".into(),
150            Char => "Character".into(),
151            Str => "String".into(),
152            Bytes => "[UInt8]".into(),
153
154            Option(format) => format!("{}?", self.quote_type(format)),
155            Seq(format) => format!("[{}]", self.quote_type(format)),
156            Map { key, value } => {
157                format!("[{}: {}]", self.quote_type(key), self.quote_type(value))
158            }
159            // Sadly, Swift tuples are not hashable.
160            Tuple(formats) => format!("Tuple{}<{}>", formats.len(), self.quote_types(formats)),
161            TupleArray { content, size: _ } => {
162                // Sadly, there are no fixed-size arrays in Swift.
163                format!("[{}]", self.quote_type(content))
164            }
165
166            Variable(_) => panic!("unexpected value"),
167        }
168    }
169
170    fn quote_types<'b, I>(&'b self, formats: I) -> String
171    where
172        I: IntoIterator<Item = &'b Format>,
173    {
174        formats
175            .into_iter()
176            .map(|format| self.quote_type(format))
177            .collect::<Vec<_>>()
178            .join(", ")
179    }
180
181    fn enter_class(&mut self, name: &str) {
182        self.out.indent();
183        self.current_namespace.push(name.to_string());
184    }
185
186    fn leave_class(&mut self) {
187        self.out.unindent();
188        self.current_namespace.pop();
189    }
190
191    fn output_trait_helpers(&mut self, registry: &Registry) -> Result<()> {
192        let mut subtypes = BTreeMap::new();
193        for format in registry.values() {
194            format
195                .visit(&mut |f| {
196                    if Self::needs_helper(f) {
197                        subtypes.insert(common::mangle_type(f), f.clone());
198                    }
199                    Ok(())
200                })
201                .unwrap();
202        }
203        for (mangled_name, subtype) in &subtypes {
204            self.output_serialization_helper(mangled_name, subtype)?;
205            self.output_deserialization_helper(mangled_name, subtype)?;
206        }
207        Ok(())
208    }
209
210    fn needs_helper(format: &Format) -> bool {
211        use Format::*;
212        matches!(
213            format,
214            Option(_) | Seq(_) | Map { .. } | Tuple(_) | TupleArray { .. }
215        )
216    }
217
218    fn quote_serialize_value(&self, value: &str, format: &Format) -> String {
219        use Format::*;
220        match format {
221            TypeName(_) => format!("try {value}.serialize(serializer: serializer)"),
222            Unit => format!("try serializer.serialize_unit(value: {value})"),
223            Bool => format!("try serializer.serialize_bool(value: {value})"),
224            I8 => format!("try serializer.serialize_i8(value: {value})"),
225            I16 => format!("try serializer.serialize_i16(value: {value})"),
226            I32 => format!("try serializer.serialize_i32(value: {value})"),
227            I64 => format!("try serializer.serialize_i64(value: {value})"),
228            I128 => format!("try serializer.serialize_i128(value: {value})"),
229            U8 => format!("try serializer.serialize_u8(value: {value})"),
230            U16 => format!("try serializer.serialize_u16(value: {value})"),
231            U32 => format!("try serializer.serialize_u32(value: {value})"),
232            U64 => format!("try serializer.serialize_u64(value: {value})"),
233            U128 => format!("try serializer.serialize_u128(value: {value})"),
234            F32 => format!("try serializer.serialize_f32(value: {value})"),
235            F64 => format!("try serializer.serialize_f64(value: {value})"),
236            Char => format!("try serializer.serialize_char(value: {value})"),
237            Str => format!("try serializer.serialize_str(value: {value})"),
238            Bytes => format!("try serializer.serialize_bytes(value: {value})"),
239            _ => format!(
240                "try serialize_{}(value: {}, serializer: serializer)",
241                common::mangle_type(format),
242                value
243            ),
244        }
245    }
246
247    fn quote_deserialize(&self, format: &Format) -> String {
248        use Format::*;
249        match format {
250            TypeName(name) => format!(
251                "try {}.deserialize(deserializer: deserializer)",
252                self.quote_qualified_name(name)
253            ),
254            Unit => "try deserializer.deserialize_unit()".to_string(),
255            Bool => "try deserializer.deserialize_bool()".to_string(),
256            I8 => "try deserializer.deserialize_i8()".to_string(),
257            I16 => "try deserializer.deserialize_i16()".to_string(),
258            I32 => "try deserializer.deserialize_i32()".to_string(),
259            I64 => "try deserializer.deserialize_i64()".to_string(),
260            I128 => "try deserializer.deserialize_i128()".to_string(),
261            U8 => "try deserializer.deserialize_u8()".to_string(),
262            U16 => "try deserializer.deserialize_u16()".to_string(),
263            U32 => "try deserializer.deserialize_u32()".to_string(),
264            U64 => "try deserializer.deserialize_u64()".to_string(),
265            U128 => "try deserializer.deserialize_u128()".to_string(),
266            F32 => "try deserializer.deserialize_f32()".to_string(),
267            F64 => "try deserializer.deserialize_f64()".to_string(),
268            Char => "try deserializer.deserialize_char()".to_string(),
269            Str => "try deserializer.deserialize_str()".to_string(),
270            Bytes => "try deserializer.deserialize_bytes()".to_string(),
271            _ => format!(
272                "try deserialize_{}(deserializer: deserializer)",
273                common::mangle_type(format)
274            ),
275        }
276    }
277
278    // TODO: Should this be an extension for Serializer?
279    fn output_serialization_helper(&mut self, name: &str, format0: &Format) -> Result<()> {
280        use Format::*;
281
282        write!(
283            self.out,
284            "func serialize_{}<S: Serializer>(value: {}, serializer: S) throws {{",
285            name,
286            self.quote_type(format0)
287        )?;
288        self.out.indent();
289        match format0 {
290            Option(format) => {
291                write!(
292                    self.out,
293                    r#"
294if let value = value {{
295    try serializer.serialize_option_tag(value: true)
296    {}
297}} else {{
298    try serializer.serialize_option_tag(value: false)
299}}
300"#,
301                    self.quote_serialize_value("value", format)
302                )?;
303            }
304
305            Seq(format) => {
306                write!(
307                    self.out,
308                    r#"
309try serializer.serialize_len(value: value.count)
310for item in value {{
311    {}
312}}
313"#,
314                    self.quote_serialize_value("item", format)
315                )?;
316            }
317
318            Map { key, value } => {
319                write!(
320                    self.out,
321                    r#"
322try serializer.serialize_len(value: value.count)
323var offsets : [Int]  = []
324for (key, value) in value {{
325    offsets.append(serializer.get_buffer_offset())
326    {}
327    {}
328}}
329serializer.sort_map_entries(offsets: offsets)
330"#,
331                    self.quote_serialize_value("key", key),
332                    self.quote_serialize_value("value", value)
333                )?;
334            }
335
336            Tuple(formats) => {
337                writeln!(self.out)?;
338                for (index, format) in formats.iter().enumerate() {
339                    let expr = format!("value.field{index}");
340                    writeln!(self.out, "{}", self.quote_serialize_value(&expr, format))?;
341                }
342            }
343
344            TupleArray { content, size: _ } => {
345                write!(
346                    self.out,
347                    r#"
348for item in value {{
349    {}
350}}
351"#,
352                    self.quote_serialize_value("item", content),
353                )?;
354            }
355
356            _ => panic!("unexpected case"),
357        }
358        self.out.unindent();
359        writeln!(self.out, "}}\n")
360    }
361
362    fn output_deserialization_helper(&mut self, name: &str, format0: &Format) -> Result<()> {
363        use Format::*;
364
365        write!(
366            self.out,
367            "func deserialize_{}<D: Deserializer>(deserializer: D) throws -> {} {{",
368            name,
369            self.quote_type(format0),
370        )?;
371        self.out.indent();
372        match format0 {
373            Option(format) => {
374                write!(
375                    self.out,
376                    r#"
377let tag = try deserializer.deserialize_option_tag()
378if tag {{
379    return {}
380}} else {{
381    return nil
382}}
383"#,
384                    self.quote_deserialize(format),
385                )?;
386            }
387
388            Seq(format) => {
389                write!(
390                    self.out,
391                    r#"
392let length = try deserializer.deserialize_len()
393var obj : [{}] = []
394for _ in 0..<length {{
395    obj.append({})
396}}
397return obj
398"#,
399                    self.quote_type(format),
400                    self.quote_deserialize(format)
401                )?;
402            }
403
404            Map { key, value } => {
405                write!(
406                    self.out,
407                    r#"
408let length = try deserializer.deserialize_len()
409var obj : [{0}: {1}] = [:]
410var previous_slice = Slice(start: 0, end: 0)
411for i in 0..<length {{
412    var slice = Slice(start: 0, end: 0)
413    slice.start = deserializer.get_buffer_offset()
414    let key = {2}
415    slice.end = deserializer.get_buffer_offset()
416    if i > 0 {{
417        try deserializer.check_that_key_slices_are_increasing(key1: previous_slice, key2: slice)
418    }}
419    previous_slice = slice
420    obj[key] = {3}
421}}
422return obj
423"#,
424                    self.quote_type(key),
425                    self.quote_type(value),
426                    self.quote_deserialize(key),
427                    self.quote_deserialize(value),
428                )?;
429            }
430
431            Tuple(formats) => {
432                write!(
433                    self.out,
434                    r#"
435return Tuple{}.init({})
436"#,
437                    formats.len(),
438                    formats
439                        .iter()
440                        .map(|f| self.quote_deserialize(f))
441                        .collect::<Vec<_>>()
442                        .join(", ")
443                )?;
444            }
445
446            TupleArray { content, size } => {
447                write!(
448                    self.out,
449                    r#"
450var obj : [{}] = []
451for _ in 0..<{} {{
452    obj.append({})
453}}
454return obj
455"#,
456                    self.quote_type(content),
457                    size,
458                    self.quote_deserialize(content)
459                )?;
460            }
461
462            _ => panic!("unexpected case"),
463        }
464        self.out.unindent();
465        writeln!(self.out, "}}\n")
466    }
467
468    fn output_variant(&mut self, name: &str, variant: &VariantFormat) -> Result<()> {
469        use VariantFormat::*;
470        self.output_comment(name)?;
471        let name = common::lowercase_first_letter(name).to_mixed_case();
472        match variant {
473            Unit => {
474                writeln!(self.out, "case {name}")?;
475            }
476            NewType(format) => {
477                writeln!(self.out, "case {}({})", name, self.quote_type(format))?;
478            }
479            Tuple(formats) => {
480                if formats.is_empty() {
481                    writeln!(self.out, "case {name}")?;
482                } else {
483                    writeln!(self.out, "case {}({})", name, self.quote_types(formats))?;
484                }
485            }
486            Struct(fields) => {
487                if fields.is_empty() {
488                    writeln!(self.out, "case {name}")?;
489                } else {
490                    writeln!(
491                        self.out,
492                        "case {}({})",
493                        name,
494                        fields
495                            .iter()
496                            .map(|f| format!("{}: {}", f.name, self.quote_type(&f.value)))
497                            .collect::<Vec<_>>()
498                            .join(", ")
499                    )?;
500                }
501            }
502            Variable(_) => panic!("incorrect value"),
503        }
504        Ok(())
505    }
506
507    fn variant_fields(variant: &VariantFormat) -> Vec<Named<Format>> {
508        use VariantFormat::*;
509        match variant {
510            Unit => Vec::new(),
511            NewType(format) => vec![Named {
512                name: "x".to_string(),
513                value: format.as_ref().clone(),
514            }],
515            Tuple(formats) => formats
516                .clone()
517                .into_iter()
518                .enumerate()
519                .map(|(i, f)| Named {
520                    name: format!("x{i}"),
521                    value: f,
522                })
523                .collect(),
524            Struct(fields) => fields.clone(),
525            Variable(_) => panic!("incorrect value"),
526        }
527    }
528
529    fn output_struct_container(&mut self, name: &str, fields: &[Named<Format>]) -> Result<()> {
530        // Struct
531        writeln!(self.out)?;
532        self.output_comment(name)?;
533        writeln!(self.out, "public struct {name}: Hashable {{")?;
534        self.enter_class(name);
535        for field in fields {
536            self.output_comment(&field.name)?;
537            writeln!(
538                self.out,
539                "@Indirect public var {}: {}",
540                field.name,
541                self.quote_type(&field.value)
542            )?;
543        }
544        // Public constructor
545        writeln!(
546            self.out,
547            "\npublic init({}) {{",
548            fields
549                .iter()
550                .map(|f| format!("{}: {}", &f.name, self.quote_type(&f.value)))
551                .collect::<Vec<_>>()
552                .join(", ")
553        )?;
554        self.out.indent();
555        for field in fields {
556            writeln!(self.out, "self.{0} = {0}", &field.name)?;
557        }
558        self.out.unindent();
559        writeln!(self.out, "}}")?;
560        // Serialize
561        if self.generator.config.serialization {
562            writeln!(
563                self.out,
564                "\npublic func serialize<S: Serializer>(serializer: S) throws {{",
565            )?;
566            self.out.indent();
567            writeln!(self.out, "try serializer.increase_container_depth()")?;
568            for field in fields {
569                writeln!(
570                    self.out,
571                    "{}",
572                    self.quote_serialize_value(&format!("self.{}", &field.name), &field.value)
573                )?;
574            }
575            writeln!(self.out, "try serializer.decrease_container_depth()")?;
576            self.out.unindent();
577            writeln!(self.out, "}}")?;
578
579            for encoding in &self.generator.config.encodings {
580                self.output_struct_serialize_for_encoding(*encoding)?;
581            }
582        }
583        // Deserialize
584        if self.generator.config.serialization {
585            writeln!(
586                self.out,
587                "\npublic static func deserialize<D: Deserializer>(deserializer: D) throws -> {name} {{",
588            )?;
589            self.out.indent();
590            writeln!(self.out, "try deserializer.increase_container_depth()")?;
591            for field in fields {
592                writeln!(
593                    self.out,
594                    "let {} = {}",
595                    field.name,
596                    self.quote_deserialize(&field.value)
597                )?;
598            }
599            writeln!(self.out, "try deserializer.decrease_container_depth()")?;
600            writeln!(
601                self.out,
602                "return {}.init({})",
603                name,
604                fields
605                    .iter()
606                    .map(|f| format!("{0}: {0}", &f.name))
607                    .collect::<Vec<_>>()
608                    .join(", ")
609            )?;
610            self.out.unindent();
611            writeln!(self.out, "}}")?;
612
613            for encoding in &self.generator.config.encodings {
614                self.output_struct_deserialize_for_encoding(name, *encoding)?;
615            }
616        }
617        // Custom code
618        self.output_custom_code(name)?;
619        self.leave_class();
620        writeln!(self.out, "}}")?;
621        Ok(())
622    }
623
624    fn output_struct_serialize_for_encoding(&mut self, encoding: Encoding) -> Result<()> {
625        writeln!(
626            self.out,
627            r#"
628public func {0}Serialize() throws -> [UInt8] {{
629    let serializer = {1}Serializer.init();
630    try self.serialize(serializer: serializer)
631    return serializer.get_bytes()
632}}"#,
633            encoding.name(),
634            encoding.name().to_camel_case()
635        )
636    }
637
638    fn output_struct_deserialize_for_encoding(
639        &mut self,
640        name: &str,
641        encoding: Encoding,
642    ) -> Result<()> {
643        writeln!(
644            self.out,
645            r#"
646public static func {1}Deserialize(input: [UInt8]) throws -> {0} {{
647    let deserializer = {2}Deserializer.init(input: input);
648    let obj = try deserialize(deserializer: deserializer)
649    if deserializer.get_buffer_offset() < input.count {{
650        throw DeserializationError.invalidInput(issue: "Some input bytes were not read")
651    }}
652    return obj
653}}"#,
654            name,
655            encoding.name(),
656            encoding.name().to_camel_case(),
657        )
658    }
659
660    fn output_enum_container(
661        &mut self,
662        name: &str,
663        variants: &BTreeMap<u32, Named<VariantFormat>>,
664    ) -> Result<()> {
665        writeln!(self.out)?;
666        self.output_comment(name)?;
667        writeln!(self.out, "indirect public enum {name}: Hashable {{")?;
668        self.current_namespace.push(name.to_string());
669        self.out.indent();
670        for variant in variants.values() {
671            self.output_variant(&variant.name, &variant.value)?;
672        }
673
674        // Serialize
675        if self.generator.config.serialization {
676            writeln!(
677                self.out,
678                "\npublic func serialize<S: Serializer>(serializer: S) throws {{",
679            )?;
680            self.out.indent();
681            writeln!(self.out, "try serializer.increase_container_depth()")?;
682            writeln!(self.out, "switch self {{")?;
683            for (index, variant) in variants {
684                let fields = Self::variant_fields(&variant.value);
685                let formatted_variant_name =
686                    common::lowercase_first_letter(&variant.name).to_mixed_case();
687                if fields.is_empty() {
688                    writeln!(self.out, "case .{formatted_variant_name}:")?;
689                } else {
690                    writeln!(
691                        self.out,
692                        "case .{}({}):",
693                        formatted_variant_name,
694                        fields
695                            .iter()
696                            .map(|f| format!("let {}", f.name))
697                            .collect::<Vec<_>>()
698                            .join(", ")
699                    )?;
700                }
701                self.out.indent();
702                writeln!(
703                    self.out,
704                    "try serializer.serialize_variant_index(value: {index})"
705                )?;
706                for field in fields {
707                    writeln!(
708                        self.out,
709                        "{}",
710                        self.quote_serialize_value(&field.name, &field.value)
711                    )?;
712                }
713                self.out.unindent();
714            }
715            writeln!(self.out, "}}")?;
716            writeln!(self.out, "try serializer.decrease_container_depth()")?;
717            self.out.unindent();
718            writeln!(self.out, "}}")?;
719
720            for encoding in &self.generator.config.encodings {
721                self.output_struct_serialize_for_encoding(*encoding)?;
722            }
723        }
724        // Deserialize
725        if self.generator.config.serialization {
726            write!(
727                self.out,
728                "\npublic static func deserialize<D: Deserializer>(deserializer: D) throws -> {name} {{"
729            )?;
730            self.out.indent();
731            writeln!(
732                self.out,
733                r#"
734let index = try deserializer.deserialize_variant_index()
735try deserializer.increase_container_depth()
736switch index {{"#,
737            )?;
738            for (index, variant) in variants {
739                writeln!(self.out, "case {index}:")?;
740                self.out.indent();
741                let formatted_variant_name =
742                    common::lowercase_first_letter(&variant.name).to_mixed_case();
743                let fields = Self::variant_fields(&variant.value);
744                if fields.is_empty() {
745                    writeln!(self.out, "try deserializer.decrease_container_depth()")?;
746                    writeln!(self.out, "return .{formatted_variant_name}")?;
747                    self.out.unindent();
748                    continue;
749                }
750                for field in &fields {
751                    writeln!(
752                        self.out,
753                        "let {} = {}",
754                        field.name,
755                        self.quote_deserialize(&field.value)
756                    )?;
757                }
758                writeln!(self.out, "try deserializer.decrease_container_depth()")?;
759                let init_values = match &variant.value {
760                    VariantFormat::Struct(_) => fields
761                        .iter()
762                        .map(|f| format!("{0}: {0}", f.name))
763                        .collect::<Vec<_>>()
764                        .join(", "),
765                    _ => fields
766                        .iter()
767                        .map(|f| f.name.to_string())
768                        .collect::<Vec<_>>()
769                        .join(", "),
770                };
771                writeln!(self.out, "return .{formatted_variant_name}({init_values})")?;
772                self.out.unindent();
773            }
774            writeln!(
775                self.out,
776                "default: throw DeserializationError.invalidInput(issue: \"Unknown variant index for {name}: \\(index)\")",
777            )?;
778            writeln!(self.out, "}}")?;
779            self.out.unindent();
780            writeln!(self.out, "}}")?;
781
782            for encoding in &self.generator.config.encodings {
783                self.output_struct_deserialize_for_encoding(name, *encoding)?;
784            }
785        }
786
787        self.current_namespace.pop();
788        // Custom code
789        self.output_custom_code(name)?;
790        self.out.unindent();
791        writeln!(self.out, "}}")?;
792        Ok(())
793    }
794
795    fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
796        use ContainerFormat::*;
797        let fields = match format {
798            UnitStruct => Vec::new(),
799            NewTypeStruct(format) => vec![Named {
800                name: "value".to_string(),
801                value: format.as_ref().clone(),
802            }],
803            TupleStruct(formats) => formats
804                .iter()
805                .enumerate()
806                .map(|(i, f)| Named {
807                    name: format!("field{i}"),
808                    value: f.clone(),
809                })
810                .collect(),
811            Struct(fields) => fields
812                .iter()
813                .map(|f| Named {
814                    name: f.name.to_mixed_case(),
815                    value: f.value.clone(),
816                })
817                .collect(),
818            Enum(variants) => {
819                self.output_enum_container(name, variants)?;
820                return Ok(());
821            }
822        };
823        self.output_struct_container(name, &fields)
824    }
825}
826
827/// Installer for generated source files in Swift.
828pub struct Installer {
829    install_dir: PathBuf,
830}
831
832impl Installer {
833    pub fn new(install_dir: PathBuf) -> Self {
834        Installer { install_dir }
835    }
836
837    fn install_runtime(
838        &self,
839        source_dir: include_dir::Dir,
840        path: &str,
841    ) -> std::result::Result<(), Box<dyn std::error::Error>> {
842        let dir_path = self.install_dir.join(path);
843        std::fs::create_dir_all(&dir_path)?;
844        for entry in source_dir.files() {
845            let mut file = std::fs::File::create(dir_path.join(entry.path()))?;
846            file.write_all(entry.contents())?;
847        }
848        Ok(())
849    }
850}
851
852impl crate::SourceInstaller for Installer {
853    type Error = Box<dyn std::error::Error>;
854
855    fn install_module(
856        &self,
857        config: &CodeGeneratorConfig,
858        registry: &Registry,
859    ) -> std::result::Result<(), Self::Error> {
860        let dir_path = self.install_dir.join("Sources").join(&config.module_name);
861        std::fs::create_dir_all(&dir_path)?;
862        let source_path = dir_path.join(format!("{}.swift", config.module_name.to_camel_case()));
863        let mut file = std::fs::File::create(source_path)?;
864        let generator = CodeGenerator::new(config);
865        generator.output(&mut file, registry)?;
866        Ok(())
867    }
868
869    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
870        self.install_runtime(
871            include_directory!("runtime/swift/Sources/Serde"),
872            "Sources/Serde",
873        )
874    }
875
876    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
877        // Ignored. Currently always installed with Serde.
878        Ok(())
879    }
880
881    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
882        // Ignored. Currently always installed with Serde.
883        Ok(())
884    }
885}