1use std::collections::HashSet;
9
10use facet_core::{ScalarType, Shape};
11use heck::ToLowerCamelCase;
12use vox_types::{
13 EnumInfo, ServiceDescriptor, ShapeKind, StructInfo, VariantKind, classify_shape,
14 classify_variant, is_bytes, is_rx, is_tx,
15};
16
17pub fn collect_named_types(service: &ServiceDescriptor) -> Vec<(String, &'static Shape)> {
20 let mut seen: HashSet<String> = HashSet::new();
21 let mut types = Vec::new();
22
23 fn visit(
24 shape: &'static Shape,
25 seen: &mut HashSet<String>,
26 types: &mut Vec<(String, &'static Shape)>,
27 ) {
28 match classify_shape(shape) {
29 ShapeKind::Struct(StructInfo {
30 name: Some(name),
31 fields,
32 ..
33 }) if seen.insert(name.to_string()) => {
34 for field in fields {
36 visit(field.shape(), seen, types);
37 }
38 types.push((name.to_string(), shape));
39 }
40 ShapeKind::Enum(EnumInfo {
41 name: Some(name),
42 variants,
43 }) if seen.insert(name.to_string()) => {
44 for variant in variants {
46 match classify_variant(variant) {
47 VariantKind::Newtype { inner } => visit(inner, seen, types),
48 VariantKind::Struct { fields } | VariantKind::Tuple { fields } => {
49 for field in fields {
50 visit(field.shape(), seen, types);
51 }
52 }
53 VariantKind::Unit => {}
54 }
55 }
56 types.push((name.to_string(), shape));
57 }
58 ShapeKind::List { element }
59 | ShapeKind::Slice { element }
60 | ShapeKind::Option { inner: element }
61 | ShapeKind::Array { element, .. }
62 | ShapeKind::Set { element } => visit(element, seen, types),
63 ShapeKind::Map { key, value } => {
64 visit(key, seen, types);
65 visit(value, seen, types);
66 }
67 ShapeKind::Tuple { elements } => {
68 for param in elements {
69 visit(param.shape, seen, types);
70 }
71 }
72 ShapeKind::Tx { inner } | ShapeKind::Rx { inner } => visit(inner, seen, types),
73 ShapeKind::Pointer { pointee } => visit(pointee, seen, types),
74 ShapeKind::Result { ok, err } => {
75 visit(ok, seen, types);
76 visit(err, seen, types);
77 }
78 _ => {}
79 }
80 }
81
82 for method in service.methods {
83 for arg in method.args {
84 visit(arg.shape, &mut seen, &mut types);
85 }
86 visit(method.return_shape, &mut seen, &mut types);
87 }
88
89 types
90}
91
92pub fn generate_named_types(named_types: &[(String, &'static Shape)]) -> String {
94 let mut out = String::new();
95
96 for (name, shape) in named_types {
97 match classify_shape(shape) {
98 ShapeKind::Struct(StructInfo { fields, .. }) => {
99 out.push_str(&format!("public struct {name}: Codable, Sendable {{\n"));
100 for field in fields {
101 let field_name = swift_field_name(field.name);
102 let field_type = swift_type_base(field.shape());
103 out.push_str(&format!(" public var {field_name}: {field_type}\n"));
104 }
105 out.push('\n');
106 out.push_str(" nonisolated public init(");
108 for (i, field) in fields.iter().enumerate() {
109 if i > 0 {
110 out.push_str(", ");
111 }
112 let field_name = swift_field_name(field.name);
113 let field_type = swift_type_base(field.shape());
114 out.push_str(&format!("{field_name}: {field_type}"));
115 }
116 out.push_str(") {\n");
117 for field in fields {
118 let field_name = swift_field_name(field.name);
119 out.push_str(&format!(" self.{field_name} = {field_name}\n"));
120 }
121 out.push_str(" }\n");
122 out.push_str("}\n\n");
123 }
124 ShapeKind::Enum(EnumInfo { variants, .. }) => {
125 let protocols = if name.ends_with("Error") {
127 "Codable, Sendable, Error"
128 } else {
129 "Codable, Sendable"
130 };
131 out.push_str(&format!("public enum {name}: {protocols} {{\n"));
132 for variant in variants {
133 let variant_name = swift_field_name(variant.name);
134 match classify_variant(variant) {
135 VariantKind::Unit => {
136 out.push_str(&format!(" case {variant_name}\n"));
137 }
138 VariantKind::Newtype { inner } => {
139 let inner_type = swift_type_base(inner);
140 out.push_str(&format!(" case {variant_name}({inner_type})\n"));
141 }
142 VariantKind::Tuple { fields } => {
143 let field_types: Vec<_> =
144 fields.iter().map(|f| swift_type_base(f.shape())).collect();
145 out.push_str(&format!(
146 " case {variant_name}({})\n",
147 field_types.join(", ")
148 ));
149 }
150 VariantKind::Struct { fields } => {
151 let field_decls: Vec<_> = fields
152 .iter()
153 .map(|f| {
154 format!(
155 "{}: {}",
156 swift_field_name(f.name),
157 swift_type_base(f.shape())
158 )
159 })
160 .collect();
161 out.push_str(&format!(
162 " case {variant_name}({})\n",
163 field_decls.join(", ")
164 ));
165 }
166 }
167 }
168 out.push_str("}\n\n");
169 }
170 _ => {}
171 }
172 }
173
174 out
175}
176
177pub fn swift_field_name(name: &str) -> String {
185 if name.chars().next().is_some_and(|c| c.is_ascii_digit()) {
186 return format!("_{name}");
187 }
188 let lower = name.to_lower_camel_case();
189 if SWIFT_RESERVED.binary_search(&lower.as_str()).is_ok() {
190 format!("`{lower}`")
191 } else {
192 lower
193 }
194}
195
196const SWIFT_RESERVED: &[&str] = &[
200 "Any",
201 "Self",
202 "as",
203 "associatedtype",
204 "break",
205 "case",
206 "catch",
207 "class",
208 "continue",
209 "default",
210 "defer",
211 "deinit",
212 "do",
213 "else",
214 "enum",
215 "extension",
216 "fallthrough",
217 "false",
218 "fileprivate",
219 "for",
220 "func",
221 "guard",
222 "if",
223 "import",
224 "in",
225 "init",
226 "inout",
227 "internal",
228 "is",
229 "let",
230 "nil",
231 "open",
232 "operator",
233 "precedencegroup",
234 "private",
235 "protocol",
236 "public",
237 "repeat",
238 "rethrows",
239 "return",
240 "self",
241 "static",
242 "struct",
243 "subscript",
244 "super",
245 "switch",
246 "throw",
247 "throws",
248 "true",
249 "try",
250 "typealias",
251 "var",
252 "where",
253 "while",
254];
255
256pub fn swift_scalar_type(scalar: ScalarType) -> String {
258 match scalar {
259 ScalarType::Bool => "Bool".into(),
260 ScalarType::U8 => "UInt8".into(),
261 ScalarType::U16 => "UInt16".into(),
262 ScalarType::U32 => "UInt32".into(),
263 ScalarType::U64 => "UInt64".into(),
264 ScalarType::U128 => "UInt128".into(),
265 ScalarType::USize => "UInt".into(),
266 ScalarType::I8 => "Int8".into(),
267 ScalarType::I16 => "Int16".into(),
268 ScalarType::I32 => "Int32".into(),
269 ScalarType::I64 => "Int64".into(),
270 ScalarType::I128 => "Int128".into(),
271 ScalarType::ISize => "Int".into(),
272 ScalarType::F32 => "Float".into(),
273 ScalarType::F64 => "Double".into(),
274 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
275 "String".into()
276 }
277 ScalarType::Unit => "Void".into(),
278 _ => "Data".into(),
279 }
280}
281
282pub fn swift_type_base(shape: &'static Shape) -> String {
284 if is_bytes(shape) {
286 return "Data".into();
287 }
288
289 match classify_shape(shape) {
290 ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
291 ShapeKind::List { element } => format!("[{}]", swift_type_base(element)),
292 ShapeKind::Slice { element } => format!("[{}]", swift_type_base(element)),
293 ShapeKind::Option { inner } => format!("{}?", swift_type_base(inner)),
294 ShapeKind::Array { element, .. } => format!("[{}]", swift_type_base(element)),
295 ShapeKind::Map { key, value } => {
296 format!("[{}: {}]", swift_type_base(key), swift_type_base(value))
297 }
298 ShapeKind::Set { element } => format!("Set<{}>", swift_type_base(element)),
299 ShapeKind::Tuple { elements } => {
300 if elements.is_empty() {
301 "Void".into()
302 } else {
303 let types: Vec<_> = elements.iter().map(|p| swift_type_base(p.shape)).collect();
304 format!("({})", types.join(", "))
305 }
306 }
307 ShapeKind::Tx { inner } => format!("Tx<{}>", swift_type_base(inner)),
308 ShapeKind::Rx { inner } => format!("Rx<{}>", swift_type_base(inner)),
309 ShapeKind::Struct(StructInfo {
310 name: Some(name), ..
311 }) => name.to_string(),
312 ShapeKind::Enum(EnumInfo {
313 name: Some(name), ..
314 }) => name.to_string(),
315 ShapeKind::Struct(StructInfo {
316 name: None, fields, ..
317 }) => {
318 let types: Vec<_> = fields.iter().map(|f| swift_type_base(f.shape())).collect();
320 format!("({})", types.join(", "))
321 }
322 ShapeKind::Enum(EnumInfo {
323 name: None,
324 variants,
325 }) => {
326 let _ = variants; "Any".into()
329 }
330 ShapeKind::Pointer { pointee } => swift_type_base(pointee),
331 ShapeKind::Result { ok, err } => {
332 format!("Result<{}, {}>", swift_type_base(ok), swift_type_base(err))
333 }
334 ShapeKind::TupleStruct { fields } => {
335 let types: Vec<_> = fields.iter().map(|f| swift_type_base(f.shape())).collect();
336 format!("({})", types.join(", "))
337 }
338 ShapeKind::Opaque => "Data".into(),
339 }
340}
341
342pub fn swift_type_client_arg(shape: &'static Shape) -> String {
344 match classify_shape(shape) {
345 ShapeKind::Tx { inner } => format!("UnboundTx<{}>", swift_type_base(inner)),
346 ShapeKind::Rx { inner } => format!("UnboundRx<{}>", swift_type_base(inner)),
347 _ => swift_type_base(shape),
348 }
349}
350
351pub fn swift_type_client_return(shape: &'static Shape) -> String {
353 assert_no_channels_in_return_shape(shape);
354 match classify_shape(shape) {
355 ShapeKind::Scalar(ScalarType::Unit) => "Void".into(),
356 ShapeKind::Tuple { elements: [] } => "Void".into(),
357 _ => swift_type_base(shape),
358 }
359}
360
361pub fn swift_type_server_arg(shape: &'static Shape) -> String {
363 match classify_shape(shape) {
364 ShapeKind::Tx { inner } => format!("Tx<{}>", swift_type_base(inner)),
365 ShapeKind::Rx { inner } => format!("Rx<{}>", swift_type_base(inner)),
366 _ => swift_type_base(shape),
367 }
368}
369
370pub fn swift_type_server_return(shape: &'static Shape) -> String {
372 assert_no_channels_in_return_shape(shape);
373 match classify_shape(shape) {
374 ShapeKind::Scalar(ScalarType::Unit) => "Void".into(),
375 ShapeKind::Tuple { elements: [] } => "Void".into(),
376 _ => swift_type_base(shape),
377 }
378}
379
380pub fn is_channel(shape: &'static Shape) -> bool {
382 is_tx(shape) || is_rx(shape)
383}
384
385pub fn format_doc(doc: &str, indent: &str) -> String {
387 doc.lines()
388 .map(|line| format!("{indent}/// {line}\n"))
389 .collect()
390}
391
392pub fn assert_no_channels_in_return_shape(shape: &'static Shape) {
393 fn has_channel(shape: &'static Shape) -> bool {
394 matches!(
395 classify_shape(shape),
396 ShapeKind::Tx { .. } | ShapeKind::Rx { .. }
397 )
398 }
399 assert!(
400 !has_channel(shape),
401 "channels are not allowed in return types"
402 );
403}