Skip to main content

specta_go/
go.rs

1use std::{borrow::Cow, path::Path};
2
3use specta::{
4    Format, Types,
5    datatype::{DataType, Fields, Reference},
6};
7
8use crate::{
9    Error,
10    primitives::{self, GoContext},
11};
12
13/// Allows configuring the format of the final file.
14#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
15pub enum Layout {
16    /// Flatten all types into a single file. (Idiomatic for Go packages)
17    #[default]
18    FlatFile,
19    /// Produce a dedicated file for each type (Not recommended for Go)
20    Files,
21}
22
23/// Go language exporter.
24#[derive(Debug, Clone)]
25#[non_exhaustive]
26pub struct Go {
27    /// Content written before the generated Go package declaration.
28    pub header: Cow<'static, str>,
29    /// The output file layout.
30    pub layout: Layout,
31    package_name: String,
32}
33
34impl Default for Go {
35    fn default() -> Self {
36        Self {
37            header: Cow::Borrowed(""),
38            layout: Layout::FlatFile,
39            package_name: "bindings".into(),
40        }
41    }
42}
43
44impl Go {
45    /// Creates a Go exporter using the default configuration.
46    pub fn new() -> Self {
47        Default::default()
48    }
49
50    /// Sets the generated Go package name.
51    pub fn package_name(mut self, name: impl Into<String>) -> Self {
52        self.package_name = name.into();
53        self
54    }
55
56    /// Sets content written before the generated Go package declaration.
57    pub fn header(mut self, header: impl Into<Cow<'static, str>>) -> Self {
58        self.header = header.into();
59        self
60    }
61
62    /// Exports the provided types into a Go source file string.
63    pub fn export(&self, types: &Types, format: impl Format) -> Result<String, Error> {
64        let mut ctx = GoContext::default();
65        let mut body = String::new();
66
67        let exporter = self.clone();
68        let formatted_types = format_types(&exporter, types, &format)?;
69        let types = formatted_types.as_ref();
70
71        for ndt in types.into_sorted_iter() {
72            let type_def = primitives::export(&exporter, types, ndt, &mut ctx)?;
73            body.push_str(&type_def);
74            body.push('\n');
75        }
76
77        let mut out = String::new();
78        if !exporter.header.is_empty() {
79            out.push_str(&exporter.header);
80            out.push('\n');
81        }
82
83        out.push_str("package ");
84        out.push_str(&exporter.package_name);
85        out.push_str("\n\n");
86
87        if !ctx.imports.is_empty() {
88            out.push_str("import (\n");
89            let mut sorted: Vec<_> = ctx.imports.iter().collect();
90            sorted.sort();
91            for imp in sorted {
92                out.push_str(&format!("\t\"{}\"\n", imp));
93            }
94            out.push_str(")\n\n");
95        }
96
97        out.push_str(&body);
98        Ok(out)
99    }
100
101    /// Exports the provided types to a Go source file at the given path.
102    pub fn export_to(
103        &self,
104        path: impl AsRef<Path>,
105        types: &Types,
106        format: impl Format,
107    ) -> Result<(), Error> {
108        if self.layout == Layout::Files {
109            return Err(Error::UnableToExport(Layout::Files));
110        }
111
112        let content = self.export(types, format)?;
113        if let Some(parent) = path.as_ref().parent() {
114            std::fs::create_dir_all(parent)?;
115        }
116        std::fs::write(path, content)?;
117        Ok(())
118    }
119}
120
121fn format_types<'a>(
122    exporter: &Go,
123    types: &'a Types,
124    format: &dyn Format,
125) -> Result<Cow<'a, Types>, Error> {
126    let mapped_types = format
127        .map_types(types)
128        .map_err(|err| Error::format("type graph formatter failed", err))?;
129    Ok(Cow::Owned(
130        map_types_for_datatype_format(exporter, mapped_types.as_ref(), Some(format))?.into_owned(),
131    ))
132}
133
134fn map_datatype_format(
135    exporter: &Go,
136    format: Option<&dyn Format>,
137    types: &Types,
138    dt: &DataType,
139) -> Result<DataType, Error> {
140    let Some(format) = format else {
141        return Ok(dt.clone());
142    };
143
144    let mapped = format
145        .map_type(types, dt)
146        .map_err(|err| Error::format("datatype formatter failed", err))?;
147
148    match mapped {
149        Cow::Borrowed(dt) => {
150            map_datatype_format_children(exporter, Some(format), types, dt.clone())
151        }
152        Cow::Owned(dt) => map_datatype_format_children(exporter, Some(format), types, dt),
153    }
154}
155
156fn map_datatype_format_children(
157    exporter: &Go,
158    format: Option<&dyn Format>,
159    types: &Types,
160    mut dt: DataType,
161) -> Result<DataType, Error> {
162    match &mut dt {
163        DataType::Primitive(_) => {}
164        DataType::List(list) => {
165            *list.ty = map_datatype_format(exporter, format, types, &list.ty)?;
166        }
167        DataType::Map(map) => {
168            let key = map_datatype_format(exporter, format, types, map.key_ty())?;
169            let value = map_datatype_format(exporter, format, types, map.value_ty())?;
170            map.set_key_ty(key);
171            map.set_value_ty(value);
172        }
173        DataType::Nullable(inner) => {
174            **inner = map_datatype_format(exporter, format, types, inner)?;
175        }
176        DataType::Struct(strct) => map_datatype_fields(exporter, format, types, &mut strct.fields)?,
177        DataType::Enum(enm) => {
178            for (_, variant) in &mut enm.variants {
179                map_datatype_fields(exporter, format, types, &mut variant.fields)?;
180            }
181        }
182        DataType::Tuple(tuple) => {
183            for element in &mut tuple.elements {
184                *element = map_datatype_format(exporter, format, types, element)?;
185            }
186        }
187        DataType::Intersection(intersection) => {
188            for element in intersection {
189                *element = map_datatype_format(exporter, format, types, element)?;
190            }
191        }
192        DataType::Reference(Reference::Named(reference)) => {
193            if let specta::datatype::NamedReferenceType::Reference { generics, .. } =
194                &mut reference.inner
195            {
196                for (_, generic) in generics {
197                    *generic = map_datatype_format(exporter, format, types, generic)?;
198                }
199            }
200        }
201        DataType::Reference(Reference::Opaque(_)) | DataType::Generic(_) => {}
202    }
203
204    Ok(dt)
205}
206
207fn map_datatype_fields(
208    exporter: &Go,
209    format: Option<&dyn Format>,
210    types: &Types,
211    fields: &mut Fields,
212) -> Result<(), Error> {
213    match fields {
214        Fields::Unit => {}
215        Fields::Unnamed(unnamed) => {
216            for field in &mut unnamed.fields {
217                if let Some(ty) = field.ty.as_mut() {
218                    *ty = map_datatype_format(exporter, format, types, ty)?;
219                }
220            }
221        }
222        Fields::Named(named) => {
223            for (_, field) in &mut named.fields {
224                if let Some(ty) = field.ty.as_mut() {
225                    *ty = map_datatype_format(exporter, format, types, ty)?;
226                }
227            }
228        }
229    }
230
231    Ok(())
232}
233
234fn map_types_for_datatype_format<'a>(
235    exporter: &Go,
236    types: &'a Types,
237    format: Option<&dyn Format>,
238) -> Result<Cow<'a, Types>, Error> {
239    if format.is_none() {
240        return Ok(Cow::Borrowed(types));
241    }
242
243    let mut mapped_types = types.clone();
244    let mut map_err = None;
245    mapped_types.iter_mut(|ndt| {
246        if map_err.is_some() {
247            return;
248        }
249
250        let Some(ty) = &ndt.ty else {
251            return;
252        };
253
254        match map_datatype_format(exporter, format, types, ty) {
255            Ok(mapped) => ndt.ty = Some(mapped),
256            Err(err) => map_err = Some(err),
257        }
258    });
259
260    if let Some(err) = map_err {
261        return Err(err);
262    }
263
264    Ok(Cow::Owned(mapped_types))
265}