1use std::{borrow::Cow, fmt, path::Path};
4
5use specta::{
6 Format, Types,
7 datatype::{DataType, Fields, Reference},
8};
9
10use crate::Error;
11use crate::primitives::{export_type, is_duration_struct};
12
13#[derive(Clone)]
15pub struct Swift {
16 pub header: Cow<'static, str>,
18 pub indent: IndentStyle,
20 pub naming: NamingConvention,
22 pub generics: GenericStyle,
24 pub optionals: OptionalStyle,
26 pub protocols: Vec<Cow<'static, str>>,
28}
29
30impl fmt::Debug for Swift {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 f.debug_struct("Swift")
33 .field("header", &self.header)
34 .field("indent", &self.indent)
35 .field("naming", &self.naming)
36 .field("generics", &self.generics)
37 .field("optionals", &self.optionals)
38 .field("protocols", &self.protocols)
39 .finish()
40 }
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum IndentStyle {
46 Spaces(usize),
48 Tabs,
50}
51
52impl Default for IndentStyle {
53 fn default() -> Self {
54 Self::Spaces(4)
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum NamingConvention {
61 #[default]
63 PascalCase,
64 CamelCase,
66 SnakeCase,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
72pub enum GenericStyle {
73 #[default]
75 Protocol,
76 Typealias,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
82pub enum OptionalStyle {
83 #[default]
85 QuestionMark,
86 Optional,
88}
89
90impl Default for Swift {
91 fn default() -> Self {
92 Self {
93 header: "// This file has been generated by Specta. DO NOT EDIT.".into(),
94 indent: IndentStyle::default(),
95 naming: NamingConvention::default(),
96 generics: GenericStyle::default(),
97 optionals: OptionalStyle::default(),
98 protocols: vec![],
99 }
100 }
101}
102
103impl Swift {
104 pub fn new() -> Self {
106 Self::default()
107 }
108
109 pub fn header(mut self, header: impl Into<Cow<'static, str>>) -> Self {
111 self.header = header.into();
112 self
113 }
114
115 pub fn indent(mut self, style: IndentStyle) -> Self {
117 self.indent = style;
118 self
119 }
120
121 pub fn naming(mut self, convention: NamingConvention) -> Self {
123 self.naming = convention;
124 self
125 }
126
127 pub fn generics(mut self, style: GenericStyle) -> Self {
129 self.generics = style;
130 self
131 }
132
133 pub fn optionals(mut self, style: OptionalStyle) -> Self {
135 self.optionals = style;
136 self
137 }
138
139 pub fn add_protocol(mut self, protocol: impl Into<Cow<'static, str>>) -> Self {
141 self.protocols.push(protocol.into());
142 self
143 }
144
145 pub fn export(&self, types: &Types, format: impl Format) -> Result<String, Error> {
147 let exporter = self.clone();
148 let formatted_types = format_types(types, &format)?.into_owned();
149 let raw_types = &formatted_types;
150
151 let mut result = String::new();
152
153 if !exporter.header.is_empty() {
155 result.push_str(&exporter.header);
156 result.push('\n');
157 }
158
159 result.push_str("import Foundation\n");
161 for protocol in &exporter.protocols {
162 result.push_str(&format!("import {}\n", protocol));
163 }
164 result.push('\n');
165
166 if needs_duration_helper(raw_types) {
168 result.push_str(&generate_duration_helper());
169 }
170
171 for ndt in raw_types.into_sorted_iter() {
173 let exported = export_type(&exporter, Some(&format), raw_types, ndt)?;
174 if !exported.is_empty() {
175 result.push_str(&exported);
176 result.push_str("\n\n");
177 }
178 }
179
180 Ok(result)
181 }
182
183 pub fn export_to(
185 &self,
186 path: impl AsRef<Path>,
187 types: &Types,
188 format: impl Format,
189 ) -> Result<(), Error> {
190 let content = self.export(types, format)?;
191 std::fs::write(path, content)?;
192 Ok(())
193 }
194}
195
196fn format_types<'a>(types: &'a Types, format: &'a dyn Format) -> Result<Cow<'a, Types>, Error> {
197 format
198 .map_types(types)
199 .map_err(|err| Error::format("type graph formatter failed", err))
200}
201
202impl NamingConvention {
203 pub fn convert(&self, name: &str) -> String {
205 match self {
206 Self::PascalCase => self.to_pascal_case(name),
207 Self::CamelCase => self.to_camel_case(name),
208 Self::SnakeCase => self.to_snake_case(name),
209 }
210 }
211
212 pub fn convert_to_camel_case(&self, name: &str) -> String {
214 self.to_camel_case(name)
215 }
216
217 pub fn convert_field(&self, name: &str) -> String {
219 match self {
220 Self::PascalCase => self.to_camel_case(name), Self::CamelCase => self.to_camel_case(name),
222 Self::SnakeCase => self.to_snake_case(name),
223 }
224 }
225
226 pub fn convert_enum_case(&self, name: &str) -> String {
228 match self {
229 Self::PascalCase => self.to_camel_case(name), Self::CamelCase => self.to_camel_case(name),
231 Self::SnakeCase => self.to_snake_case(name),
232 }
233 }
234
235 #[allow(clippy::wrong_self_convention)]
236 fn to_camel_case(&self, name: &str) -> String {
237 if name.contains('_') {
239 let parts: Vec<&str> = name.split('_').collect();
241 if parts.is_empty() {
242 return name.to_string();
243 }
244
245 let mut result = String::new();
246 for (i, part) in parts.iter().enumerate() {
247 if i == 0 {
248 result.push_str(&part.to_lowercase());
249 } else {
250 let mut chars = part.chars();
251 match chars.next() {
252 None => continue,
253 Some(first) => {
254 result.push(first.to_uppercase().next().unwrap_or(first));
255 for c in chars {
256 result.extend(c.to_lowercase());
257 }
258 }
259 }
260 }
261 }
262 result
263 } else {
264 if name.chars().any(|c| c.is_ascii_alphabetic())
265 && name
266 .chars()
267 .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase())
268 {
269 return name.to_ascii_lowercase();
270 }
271
272 let mut chars = name.chars();
274 match chars.next() {
275 None => name.to_string(),
276 Some(first) => {
277 let mut result = String::new();
278 result.push(first.to_lowercase().next().unwrap_or(first));
279 for c in chars {
280 result.push(c); }
282 result
283 }
284 }
285 }
286 }
287
288 #[allow(clippy::wrong_self_convention)]
289 fn to_pascal_case(&self, name: &str) -> String {
290 name.split('_')
292 .map(|part| {
293 let mut chars = part.chars();
294 match chars.next() {
295 None => String::new(),
296 Some(first) => first.to_uppercase().chain(chars).collect(),
297 }
298 })
299 .collect()
300 }
301
302 #[allow(clippy::wrong_self_convention)]
303 fn to_snake_case(&self, name: &str) -> String {
304 let mut result = String::new();
306 let chars = name.chars();
307
308 for c in chars {
309 if c.is_uppercase() && !result.is_empty() {
310 result.push('_');
311 }
312 result.push(c.to_lowercase().next().unwrap_or(c));
313 }
314
315 result
316 }
317}
318
319fn needs_duration_helper(types: &Types) -> bool {
321 for ndt in types.into_sorted_iter() {
322 if ndt.name == "Duration" {
323 return true;
324 }
325 if let Some(DataType::Struct(s)) = &ndt.ty
327 && let Fields::Named(fields) = &s.fields
328 {
329 for (_, field) in &fields.fields {
330 if let Some(ty) = field.ty.as_ref() {
331 if let DataType::Reference(Reference::Named(r)) = ty
332 && let Some(referenced_ndt) = types.get(r)
333 && referenced_ndt.name == "Duration"
334 {
335 return true;
336 }
337 if let DataType::Struct(struct_ty) = ty
339 && is_duration_struct(struct_ty)
340 {
341 return true;
342 }
343 }
344 }
345 }
346 }
347 false
348}
349
350fn generate_duration_helper() -> String {
352 "// MARK: - Duration Helper\n".to_string()
353 + "/// Helper struct to decode Rust Duration format {\"secs\": u64, \"nanos\": u32}\n"
354 + "public struct RustDuration: Codable {\n"
355 + " public let secs: UInt64\n"
356 + " public let nanos: UInt32\n"
357 + " \n"
358 + " public var timeInterval: TimeInterval {\n"
359 + " return Double(secs) + Double(nanos) / 1_000_000_000.0\n"
360 + " }\n"
361 + "}\n\n"
362 + "// MARK: - Generated Types\n\n"
363}