1use std::{borrow::Cow, path::Path};
2
3use specta::{
4 Format, Types,
5 datatype::{DataType, Fields, Reference},
6};
7
8use crate::{
9 Error,
10 primitives::{self, GoContext},
11};
12
13#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
15pub enum Layout {
16 #[default]
18 FlatFile,
19 Files,
21}
22
23#[derive(Debug, Clone)]
25#[non_exhaustive]
26pub struct Go {
27 pub header: Cow<'static, str>,
29 pub layout: Layout,
31 package_name: String,
32}
33
34impl Default for Go {
35 fn default() -> Self {
36 Self {
37 header: Cow::Borrowed(""),
38 layout: Layout::FlatFile,
39 package_name: "bindings".into(),
40 }
41 }
42}
43
44impl Go {
45 pub fn new() -> Self {
47 Default::default()
48 }
49
50 pub fn package_name(mut self, name: impl Into<String>) -> Self {
52 self.package_name = name.into();
53 self
54 }
55
56 pub fn header(mut self, header: impl Into<Cow<'static, str>>) -> Self {
58 self.header = header.into();
59 self
60 }
61
62 pub fn export(&self, types: &Types, format: impl Format) -> Result<String, Error> {
64 let mut ctx = GoContext::default();
65 let mut body = String::new();
66
67 let exporter = self.clone();
68 let formatted_types = format_types(&exporter, types, &format)?;
69 let types = formatted_types.as_ref();
70
71 for ndt in types.into_sorted_iter() {
72 let type_def = primitives::export(&exporter, types, ndt, &mut ctx)?;
73 body.push_str(&type_def);
74 body.push('\n');
75 }
76
77 let mut out = String::new();
78 if !exporter.header.is_empty() {
79 out.push_str(&exporter.header);
80 out.push('\n');
81 }
82
83 out.push_str("package ");
84 out.push_str(&exporter.package_name);
85 out.push_str("\n\n");
86
87 if !ctx.imports.is_empty() {
88 out.push_str("import (\n");
89 let mut sorted: Vec<_> = ctx.imports.iter().collect();
90 sorted.sort();
91 for imp in sorted {
92 out.push_str(&format!("\t\"{}\"\n", imp));
93 }
94 out.push_str(")\n\n");
95 }
96
97 out.push_str(&body);
98 Ok(out)
99 }
100
101 pub fn export_to(
103 &self,
104 path: impl AsRef<Path>,
105 types: &Types,
106 format: impl Format,
107 ) -> Result<(), Error> {
108 if self.layout == Layout::Files {
109 return Err(Error::UnableToExport(Layout::Files));
110 }
111
112 let content = self.export(types, format)?;
113 if let Some(parent) = path.as_ref().parent() {
114 std::fs::create_dir_all(parent)?;
115 }
116 std::fs::write(path, content)?;
117 Ok(())
118 }
119}
120
121fn format_types<'a>(
122 exporter: &Go,
123 types: &'a Types,
124 format: &dyn Format,
125) -> Result<Cow<'a, Types>, Error> {
126 let mapped_types = format
127 .map_types(types)
128 .map_err(|err| Error::format("type graph formatter failed", err))?;
129 Ok(Cow::Owned(
130 map_types_for_datatype_format(exporter, mapped_types.as_ref(), Some(format))?.into_owned(),
131 ))
132}
133
134fn map_datatype_format(
135 exporter: &Go,
136 format: Option<&dyn Format>,
137 types: &Types,
138 dt: &DataType,
139) -> Result<DataType, Error> {
140 let Some(format) = format else {
141 return Ok(dt.clone());
142 };
143
144 let mapped = format
145 .map_type(types, dt)
146 .map_err(|err| Error::format("datatype formatter failed", err))?;
147
148 match mapped {
149 Cow::Borrowed(dt) => {
150 map_datatype_format_children(exporter, Some(format), types, dt.clone())
151 }
152 Cow::Owned(dt) => map_datatype_format_children(exporter, Some(format), types, dt),
153 }
154}
155
156fn map_datatype_format_children(
157 exporter: &Go,
158 format: Option<&dyn Format>,
159 types: &Types,
160 mut dt: DataType,
161) -> Result<DataType, Error> {
162 match &mut dt {
163 DataType::Primitive(_) => {}
164 DataType::List(list) => {
165 *list.ty = map_datatype_format(exporter, format, types, &list.ty)?;
166 }
167 DataType::Map(map) => {
168 let key = map_datatype_format(exporter, format, types, map.key_ty())?;
169 let value = map_datatype_format(exporter, format, types, map.value_ty())?;
170 map.set_key_ty(key);
171 map.set_value_ty(value);
172 }
173 DataType::Nullable(inner) => {
174 **inner = map_datatype_format(exporter, format, types, inner)?;
175 }
176 DataType::Struct(strct) => map_datatype_fields(exporter, format, types, &mut strct.fields)?,
177 DataType::Enum(enm) => {
178 for (_, variant) in &mut enm.variants {
179 map_datatype_fields(exporter, format, types, &mut variant.fields)?;
180 }
181 }
182 DataType::Tuple(tuple) => {
183 for element in &mut tuple.elements {
184 *element = map_datatype_format(exporter, format, types, element)?;
185 }
186 }
187 DataType::Intersection(intersection) => {
188 for element in intersection {
189 *element = map_datatype_format(exporter, format, types, element)?;
190 }
191 }
192 DataType::Reference(Reference::Named(reference)) => {
193 if let specta::datatype::NamedReferenceType::Reference { generics, .. } =
194 &mut reference.inner
195 {
196 for (_, generic) in generics {
197 *generic = map_datatype_format(exporter, format, types, generic)?;
198 }
199 }
200 }
201 DataType::Reference(Reference::Opaque(_)) | DataType::Generic(_) => {}
202 }
203
204 Ok(dt)
205}
206
207fn map_datatype_fields(
208 exporter: &Go,
209 format: Option<&dyn Format>,
210 types: &Types,
211 fields: &mut Fields,
212) -> Result<(), Error> {
213 match fields {
214 Fields::Unit => {}
215 Fields::Unnamed(unnamed) => {
216 for field in &mut unnamed.fields {
217 if let Some(ty) = field.ty.as_mut() {
218 *ty = map_datatype_format(exporter, format, types, ty)?;
219 }
220 }
221 }
222 Fields::Named(named) => {
223 for (_, field) in &mut named.fields {
224 if let Some(ty) = field.ty.as_mut() {
225 *ty = map_datatype_format(exporter, format, types, ty)?;
226 }
227 }
228 }
229 }
230
231 Ok(())
232}
233
234fn map_types_for_datatype_format<'a>(
235 exporter: &Go,
236 types: &'a Types,
237 format: Option<&dyn Format>,
238) -> Result<Cow<'a, Types>, Error> {
239 if format.is_none() {
240 return Ok(Cow::Borrowed(types));
241 }
242
243 let mut mapped_types = types.clone();
244 let mut map_err = None;
245 mapped_types.iter_mut(|ndt| {
246 if map_err.is_some() {
247 return;
248 }
249
250 let Some(ty) = &ndt.ty else {
251 return;
252 };
253
254 match map_datatype_format(exporter, format, types, ty) {
255 Ok(mapped) => ndt.ty = Some(mapped),
256 Err(err) => map_err = Some(err),
257 }
258 });
259
260 if let Some(err) = map_err {
261 return Err(err);
262 }
263
264 Ok(Cow::Owned(mapped_types))
265}