vox_codegen/targets/swift/
client.rs1use heck::{ToLowerCamelCase, ToUpperCamelCase};
10use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape};
11
12use super::decode::generate_decode_stmt_from_with_cursor;
13use super::encode::generate_encode_expr;
14use super::types::{format_doc, is_channel, swift_type_client_arg, swift_type_client_return};
15use crate::code_writer::CodeWriter;
16use crate::cw_writeln;
17use crate::render::hex_u64;
18
19fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
20 match (method.retry.persist, method.retry.idem) {
21 (false, false) => ".volatile",
22 (false, true) => ".idem",
23 (true, false) => ".persist",
24 (true, true) => ".persistIdem",
25 }
26}
27
28pub fn generate_client(service: &ServiceDescriptor) -> String {
33 let mut out = String::new();
34 out.push_str(&generate_caller_protocol(service));
35 out.push_str(&generate_client_impl(service));
36 out
37}
38
39fn generate_caller_protocol(service: &ServiceDescriptor) -> String {
41 let mut out = String::new();
42 let service_name = service.service_name.to_upper_camel_case();
43
44 if let Some(doc) = &service.doc {
45 out.push_str(&format_doc(doc, ""));
46 }
47 out.push_str(&format!("public protocol {service_name}Caller {{\n"));
48
49 for method in service.methods {
50 let method_name = method.method_name.to_lower_camel_case();
51
52 if let Some(doc) = &method.doc {
53 out.push_str(&format_doc(doc, " "));
54 }
55
56 let args: Vec<String> = method
57 .args
58 .iter()
59 .map(|a| {
60 format!(
61 "{}: {}",
62 a.name.to_lower_camel_case(),
63 swift_type_client_arg(a.shape)
64 )
65 })
66 .collect();
67
68 let ret_type = swift_type_client_return(method.return_shape);
69
70 if ret_type == "Void" {
71 out.push_str(&format!(
72 " func {method_name}({}) async throws\n",
73 args.join(", ")
74 ));
75 } else {
76 out.push_str(&format!(
77 " func {method_name}({}) async throws -> {ret_type}\n",
78 args.join(", ")
79 ));
80 }
81 }
82
83 out.push_str("}\n\n");
84 out
85}
86
87fn generate_client_impl(service: &ServiceDescriptor) -> String {
89 let mut out = String::new();
90 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
91 let service_name = service.service_name.to_upper_camel_case();
92
93 w.writeln(&format!(
94 "public final class {service_name}Client: {service_name}Caller, Sendable {{"
95 ))
96 .unwrap();
97 {
98 let _indent = w.indent();
99 w.writeln("private let connection: VoxConnection").unwrap();
100 w.writeln("private let timeout: TimeInterval?").unwrap();
101 w.blank_line().unwrap();
102 w.writeln("public init(connection: VoxConnection, timeout: TimeInterval? = 30.0) {")
103 .unwrap();
104 {
105 let _indent = w.indent();
106 w.writeln("self.connection = connection").unwrap();
107 w.writeln("self.timeout = timeout").unwrap();
108 }
109 w.writeln("}").unwrap();
110
111 for method in service.methods {
112 w.blank_line().unwrap();
113 generate_client_method(&mut w, method, &service_name);
114 }
115 }
116 w.writeln("}").unwrap();
117 w.blank_line().unwrap();
118
119 out
120}
121
122fn generate_client_method(
129 w: &mut CodeWriter<&mut String>,
130 method: &MethodDescriptor,
131 service_name: &str,
132) {
133 let method_name = method.method_name.to_lower_camel_case();
134 let method_id_name = method.method_name.to_lower_camel_case();
135
136 let args: Vec<String> = method
137 .args
138 .iter()
139 .map(|a| {
140 format!(
141 "{}: {}",
142 a.name.to_lower_camel_case(),
143 swift_type_client_arg(a.shape)
144 )
145 })
146 .collect();
147
148 let ret_type = swift_type_client_return(method.return_shape);
149 let has_streaming = method.args.iter().any(|a| is_channel(a.shape));
150 let retry_policy = swift_retry_policy_literal(method);
151
152 if ret_type == "Void" {
154 cw_writeln!(
155 w,
156 "public func {method_name}({}) async throws {{",
157 args.join(", ")
158 )
159 .unwrap();
160 } else {
161 cw_writeln!(
162 w,
163 "public func {method_name}({}) async throws -> {ret_type} {{",
164 args.join(", ")
165 )
166 .unwrap();
167 }
168
169 {
170 let _indent = w.indent();
171 let cursor_var = unique_decode_cursor_name(method.args);
172
173 let service_name_lower = service_name.to_lower_camel_case();
174 let method_id = crate::method_id(method);
175
176 if has_streaming {
177 generate_streaming_client_body(
178 w,
179 method,
180 service_name,
181 &method_id_name,
182 &cursor_var,
183 retry_policy,
184 );
185 } else {
186 generate_encode_args(w, method.args);
188
189 cw_writeln!(
191 w,
192 "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
193 hex_u64(method_id)
194 )
195 .unwrap();
196
197 cw_writeln!(
199 w,
200 "let response = try await connection.call(methodId: {}, metadata: [], payload: payload, retry: {retry_policy}, timeout: timeout, prepareRetry: nil, finalizeChannels: nil, schemaInfo: schemaInfo)",
201 hex_u64(method_id),
202 )
203 .unwrap();
204 generate_response_decode(w, method, &cursor_var, "response");
205 }
206 }
207 w.writeln("}").unwrap();
208}
209
210fn generate_encode_args(w: &mut CodeWriter<&mut String>, args: &[vox_types::ArgDescriptor]) {
212 if args.is_empty() {
213 w.writeln("let payload = Data()").unwrap();
214 return;
215 }
216
217 w.writeln("var payloadBytes: [UInt8] = []").unwrap();
218 for arg in args {
219 let arg_name = arg.name.to_lower_camel_case();
220 let encode_expr = generate_encode_expr(arg.shape, &arg_name);
221 cw_writeln!(w, "payloadBytes += {encode_expr}").unwrap();
222 }
223 w.writeln("let payload = Data(payloadBytes)").unwrap();
224}
225
226fn generate_streaming_client_body(
233 w: &mut CodeWriter<&mut String>,
234 method: &MethodDescriptor,
235 service_name: &str,
236 method_id_name: &str,
237 cursor_var: &str,
238 retry_policy: &str,
239) {
240 let service_name_lower = service_name.to_lower_camel_case();
241
242 let arg_names: Vec<String> = method
243 .args
244 .iter()
245 .map(|a| a.name.to_lower_camel_case())
246 .collect();
247
248 let method_id = crate::method_id(method);
249
250 cw_writeln!(
252 w,
253 "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
254 hex_u64(method_id)
255 )
256 .unwrap();
257
258 w.writeln("let prepareRetry: @Sendable () async -> PreparedRetryRequest = { [connection] in")
259 .unwrap();
260 {
261 let _indent = w.indent();
262 w.writeln("await bindChannels(").unwrap();
263 {
264 let _indent = w.indent();
265 cw_writeln!(
266 w,
267 "schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args,"
268 )
269 .unwrap();
270 cw_writeln!(w, "args: [{}],", arg_names.join(", ")).unwrap();
271 w.writeln("allocator: connection.channelAllocator,")
272 .unwrap();
273 w.writeln("incomingRegistry: connection.incomingChannelRegistry,")
274 .unwrap();
275 w.writeln("taskSender: connection.taskSender,").unwrap();
276 cw_writeln!(w, "serializers: {service_name}Serializers()").unwrap();
277 }
278 w.writeln(")").unwrap();
279 w.blank_line().unwrap();
280 generate_encode_args(w, method.args);
281 w.writeln("return PreparedRetryRequest(payload: Array(payload))")
282 .unwrap();
283 }
284 w.writeln("}").unwrap();
285 w.writeln("let prepared = await prepareRetry()").unwrap();
286 w.blank_line().unwrap();
287
288 let ret_type = swift_type_client_return(method.return_shape);
290 let _ = ret_type;
291 cw_writeln!(
292 w,
293 "let response = try await connection.call(methodId: {}, metadata: [], payload: Data(prepared.payload), retry: {retry_policy}, timeout: timeout, prepareRetry: prepareRetry, finalizeChannels: {{ finalizeBoundChannels(schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args, args: [{}]) }}, schemaInfo: schemaInfo)",
294 hex_u64(method_id),
295 arg_names.join(", ")
296 )
297 .unwrap();
298 generate_response_decode(w, method, cursor_var, "response");
299}
300
301fn unique_decode_cursor_name(args: &[vox_types::ArgDescriptor]) -> String {
302 let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
303 let mut candidate = String::from("cursor");
304 while arg_names.iter().any(|name| name == &candidate) {
305 candidate.push('_');
306 }
307 candidate
308}
309
310fn generate_response_decode(
313 w: &mut CodeWriter<&mut String>,
314 method: &MethodDescriptor,
315 cursor_var: &str,
316 response_var: &str,
317) {
318 let ret_type = swift_type_client_return(method.return_shape);
319 let result_disc_var = format!("_{cursor_var}_resultDisc");
320 let error_code_var = format!("_{cursor_var}_errorCode");
321 let is_fallible = matches!(
322 classify_shape(method.return_shape),
323 ShapeKind::Result { .. }
324 );
325
326 cw_writeln!(w, "var {cursor_var} = 0").unwrap();
327 cw_writeln!(
328 w,
329 "let {result_disc_var} = try decodeVarint(from: {response_var}, offset: &{cursor_var})"
330 )
331 .unwrap();
332 cw_writeln!(w, "switch {result_disc_var} {{").unwrap();
333
334 w.writeln("case 0:").unwrap();
335 {
336 let _indent = w.indent();
337 if is_fallible {
338 let ShapeKind::Result { ok, .. } = classify_shape(method.return_shape) else {
339 unreachable!()
340 };
341 let decode_ok =
342 generate_decode_stmt_from_with_cursor(ok, "value", "", response_var, cursor_var);
343 for line in decode_ok.lines() {
344 w.writeln(line).unwrap();
345 }
346 w.writeln("return .success(value)").unwrap();
347 } else if ret_type == "Void" {
348 w.writeln("return").unwrap();
349 } else {
350 let decode_stmt = generate_decode_stmt_from_with_cursor(
351 method.return_shape,
352 "result",
353 "",
354 response_var,
355 cursor_var,
356 );
357 for line in decode_stmt.lines() {
358 w.writeln(line).unwrap();
359 }
360 w.writeln("return result").unwrap();
361 }
362 }
363
364 w.writeln("case 1:").unwrap();
365 {
366 let _indent = w.indent();
367 cw_writeln!(
368 w,
369 "let {error_code_var} = try decodeU8(from: {response_var}, offset: &{cursor_var})"
370 )
371 .unwrap();
372 cw_writeln!(w, "switch {error_code_var} {{").unwrap();
373
374 w.writeln("case 0:").unwrap();
375 {
376 let _indent = w.indent();
377 if is_fallible {
378 let ShapeKind::Result { err, .. } = classify_shape(method.return_shape) else {
379 unreachable!()
380 };
381 let decode_err = generate_decode_stmt_from_with_cursor(
382 err,
383 "userError",
384 "",
385 response_var,
386 cursor_var,
387 );
388 for line in decode_err.lines() {
389 w.writeln(line).unwrap();
390 }
391 w.writeln("return .failure(userError)").unwrap();
392 } else {
393 w.writeln(
394 "throw VoxError.decodeError(\"unexpected user error for infallible method\")",
395 )
396 .unwrap();
397 }
398 }
399 w.writeln("case 1:").unwrap();
400 w.writeln(" throw VoxError.unknownMethod").unwrap();
401 w.writeln("case 2:").unwrap();
402 w.writeln(" throw VoxError.decodeError(\"invalid payload\")")
403 .unwrap();
404 w.writeln("case 3:").unwrap();
405 w.writeln(" throw VoxError.cancelled").unwrap();
406 w.writeln("case 4:").unwrap();
407 w.writeln(" throw VoxError.indeterminate").unwrap();
408 w.writeln("default:").unwrap();
409 cw_writeln!(
410 w,
411 " throw VoxError.decodeError(\"invalid VoxError discriminant: \\({error_code_var})\")"
412 )
413 .unwrap();
414 w.writeln("}").unwrap();
415 }
416
417 w.writeln("default:").unwrap();
418 cw_writeln!(
419 w,
420 " throw VoxError.decodeError(\"invalid Result discriminant: \\({result_disc_var})\")"
421 )
422 .unwrap();
423 w.writeln("}").unwrap();
424}