vox_codegen/targets/swift/
client.rs1use heck::{ToLowerCamelCase, ToUpperCamelCase};
10use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape};
11
12use super::decode::generate_decode_stmt_from_with_cursor;
13use super::encode::generate_encode_expr;
14use super::types::{format_doc, is_channel, swift_type_client_arg, swift_type_client_return};
15use crate::code_writer::CodeWriter;
16use crate::cw_writeln;
17use crate::render::hex_u64;
18
19fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
20 match (method.retry.persist, method.retry.idem) {
21 (false, false) => ".volatile",
22 (false, true) => ".idem",
23 (true, false) => ".persist",
24 (true, true) => ".persistIdem",
25 }
26}
27
28pub fn generate_client(service: &ServiceDescriptor) -> String {
33 let mut out = String::new();
34 out.push_str(&generate_caller_protocol(service));
35 out.push_str(&generate_client_impl(service));
36 out
37}
38
39fn generate_caller_protocol(service: &ServiceDescriptor) -> String {
41 let mut out = String::new();
42 let service_name = service.service_name.to_upper_camel_case();
43
44 if let Some(doc) = &service.doc {
45 out.push_str(&format_doc(doc, ""));
46 }
47 out.push_str(&format!("public protocol {service_name}Caller {{\n"));
48
49 for method in service.methods {
50 let method_name = method.method_name.to_lower_camel_case();
51
52 if let Some(doc) = &method.doc {
53 out.push_str(&format_doc(doc, " "));
54 }
55
56 let args: Vec<String> = method
57 .args
58 .iter()
59 .map(|a| {
60 format!(
61 "{}: {}",
62 a.name.to_lower_camel_case(),
63 swift_type_client_arg(a.shape)
64 )
65 })
66 .collect();
67
68 let ret_type = swift_type_client_return(method.return_shape);
69
70 if ret_type == "Void" {
71 out.push_str(&format!(
72 " func {method_name}({}) async throws\n",
73 args.join(", ")
74 ));
75 } else {
76 out.push_str(&format!(
77 " func {method_name}({}) async throws -> {ret_type}\n",
78 args.join(", ")
79 ));
80 }
81 }
82
83 out.push_str("}\n\n");
84 out
85}
86
87fn generate_client_impl(service: &ServiceDescriptor) -> String {
89 let mut out = String::new();
90 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
91 let service_name = service.service_name.to_upper_camel_case();
92
93 w.writeln(&format!(
94 "public final class {service_name}Client: {service_name}Caller, Sendable {{"
95 ))
96 .unwrap();
97 {
98 let _indent = w.indent();
99 w.writeln("private let connection: VoxConnection")
100 .unwrap();
101 w.writeln("private let timeout: TimeInterval?").unwrap();
102 w.blank_line().unwrap();
103 w.writeln("public init(connection: VoxConnection, timeout: TimeInterval? = 30.0) {")
104 .unwrap();
105 {
106 let _indent = w.indent();
107 w.writeln("self.connection = connection").unwrap();
108 w.writeln("self.timeout = timeout").unwrap();
109 }
110 w.writeln("}").unwrap();
111
112 for method in service.methods {
113 w.blank_line().unwrap();
114 generate_client_method(&mut w, method, &service_name);
115 }
116 }
117 w.writeln("}").unwrap();
118 w.blank_line().unwrap();
119
120 out
121}
122
123fn generate_client_method(
130 w: &mut CodeWriter<&mut String>,
131 method: &MethodDescriptor,
132 service_name: &str,
133) {
134 let method_name = method.method_name.to_lower_camel_case();
135 let method_id_name = method.method_name.to_lower_camel_case();
136
137 let args: Vec<String> = method
138 .args
139 .iter()
140 .map(|a| {
141 format!(
142 "{}: {}",
143 a.name.to_lower_camel_case(),
144 swift_type_client_arg(a.shape)
145 )
146 })
147 .collect();
148
149 let ret_type = swift_type_client_return(method.return_shape);
150 let has_streaming = method.args.iter().any(|a| is_channel(a.shape));
151 let retry_policy = swift_retry_policy_literal(method);
152
153 if ret_type == "Void" {
155 cw_writeln!(
156 w,
157 "public func {method_name}({}) async throws {{",
158 args.join(", ")
159 )
160 .unwrap();
161 } else {
162 cw_writeln!(
163 w,
164 "public func {method_name}({}) async throws -> {ret_type} {{",
165 args.join(", ")
166 )
167 .unwrap();
168 }
169
170 {
171 let _indent = w.indent();
172 let cursor_var = unique_decode_cursor_name(method.args);
173
174 let service_name_lower = service_name.to_lower_camel_case();
175 let method_id = crate::method_id(method);
176
177 if has_streaming {
178 generate_streaming_client_body(
179 w,
180 method,
181 service_name,
182 &method_id_name,
183 &cursor_var,
184 retry_policy,
185 );
186 } else {
187 generate_encode_args(w, method.args);
189
190 cw_writeln!(
192 w,
193 "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
194 hex_u64(method_id)
195 )
196 .unwrap();
197
198 cw_writeln!(
200 w,
201 "let response = try await connection.call(methodId: {}, metadata: [], payload: payload, retry: {retry_policy}, timeout: timeout, prepareRetry: nil, finalizeChannels: nil, schemaInfo: schemaInfo)",
202 hex_u64(method_id),
203 )
204 .unwrap();
205 generate_response_decode(w, method, &cursor_var, "response");
206 }
207 }
208 w.writeln("}").unwrap();
209}
210
211fn generate_encode_args(w: &mut CodeWriter<&mut String>, args: &[vox_types::ArgDescriptor]) {
213 if args.is_empty() {
214 w.writeln("let payload = Data()").unwrap();
215 return;
216 }
217
218 w.writeln("var payloadBytes: [UInt8] = []").unwrap();
219 for arg in args {
220 let arg_name = arg.name.to_lower_camel_case();
221 let encode_expr = generate_encode_expr(arg.shape, &arg_name);
222 cw_writeln!(w, "payloadBytes += {encode_expr}").unwrap();
223 }
224 w.writeln("let payload = Data(payloadBytes)").unwrap();
225}
226
227fn generate_streaming_client_body(
234 w: &mut CodeWriter<&mut String>,
235 method: &MethodDescriptor,
236 service_name: &str,
237 method_id_name: &str,
238 cursor_var: &str,
239 retry_policy: &str,
240) {
241 let service_name_lower = service_name.to_lower_camel_case();
242
243 let arg_names: Vec<String> = method
244 .args
245 .iter()
246 .map(|a| a.name.to_lower_camel_case())
247 .collect();
248
249 let method_id = crate::method_id(method);
250
251 cw_writeln!(
253 w,
254 "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
255 hex_u64(method_id)
256 )
257 .unwrap();
258
259 w.writeln("let prepareRetry: @Sendable () async -> PreparedRetryRequest = { [connection] in")
260 .unwrap();
261 {
262 let _indent = w.indent();
263 w.writeln("await bindChannels(").unwrap();
264 {
265 let _indent = w.indent();
266 cw_writeln!(
267 w,
268 "schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args,"
269 )
270 .unwrap();
271 cw_writeln!(w, "args: [{}],", arg_names.join(", ")).unwrap();
272 w.writeln("allocator: connection.channelAllocator,")
273 .unwrap();
274 w.writeln("incomingRegistry: connection.incomingChannelRegistry,")
275 .unwrap();
276 w.writeln("taskSender: connection.taskSender,").unwrap();
277 cw_writeln!(w, "serializers: {service_name}Serializers()").unwrap();
278 }
279 w.writeln(")").unwrap();
280 w.blank_line().unwrap();
281 generate_encode_args(w, method.args);
282 w.writeln("return PreparedRetryRequest(payload: Array(payload))")
283 .unwrap();
284 }
285 w.writeln("}").unwrap();
286 w.writeln("let prepared = await prepareRetry()").unwrap();
287 w.blank_line().unwrap();
288
289 let ret_type = swift_type_client_return(method.return_shape);
291 let _ = ret_type;
292 cw_writeln!(
293 w,
294 "let response = try await connection.call(methodId: {}, metadata: [], payload: Data(prepared.payload), retry: {retry_policy}, timeout: timeout, prepareRetry: prepareRetry, finalizeChannels: {{ finalizeBoundChannels(schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args, args: [{}]) }}, schemaInfo: schemaInfo)",
295 hex_u64(method_id),
296 arg_names.join(", ")
297 )
298 .unwrap();
299 generate_response_decode(w, method, cursor_var, "response");
300}
301
302fn unique_decode_cursor_name(args: &[vox_types::ArgDescriptor]) -> String {
303 let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
304 let mut candidate = String::from("cursor");
305 while arg_names.iter().any(|name| name == &candidate) {
306 candidate.push('_');
307 }
308 candidate
309}
310
311fn generate_response_decode(
314 w: &mut CodeWriter<&mut String>,
315 method: &MethodDescriptor,
316 cursor_var: &str,
317 response_var: &str,
318) {
319 let ret_type = swift_type_client_return(method.return_shape);
320 let result_disc_var = format!("_{cursor_var}_resultDisc");
321 let error_code_var = format!("_{cursor_var}_errorCode");
322 let is_fallible = matches!(
323 classify_shape(method.return_shape),
324 ShapeKind::Result { .. }
325 );
326
327 cw_writeln!(w, "var {cursor_var} = 0").unwrap();
328 cw_writeln!(
329 w,
330 "let {result_disc_var} = try decodeVarint(from: {response_var}, offset: &{cursor_var})"
331 )
332 .unwrap();
333 cw_writeln!(w, "switch {result_disc_var} {{").unwrap();
334
335 w.writeln("case 0:").unwrap();
336 {
337 let _indent = w.indent();
338 if is_fallible {
339 let ShapeKind::Result { ok, .. } = classify_shape(method.return_shape) else {
340 unreachable!()
341 };
342 let decode_ok =
343 generate_decode_stmt_from_with_cursor(ok, "value", "", response_var, cursor_var);
344 for line in decode_ok.lines() {
345 w.writeln(line).unwrap();
346 }
347 w.writeln("return .success(value)").unwrap();
348 } else if ret_type == "Void" {
349 w.writeln("return").unwrap();
350 } else {
351 let decode_stmt = generate_decode_stmt_from_with_cursor(
352 method.return_shape,
353 "result",
354 "",
355 response_var,
356 cursor_var,
357 );
358 for line in decode_stmt.lines() {
359 w.writeln(line).unwrap();
360 }
361 w.writeln("return result").unwrap();
362 }
363 }
364
365 w.writeln("case 1:").unwrap();
366 {
367 let _indent = w.indent();
368 cw_writeln!(
369 w,
370 "let {error_code_var} = try decodeU8(from: {response_var}, offset: &{cursor_var})"
371 )
372 .unwrap();
373 cw_writeln!(w, "switch {error_code_var} {{").unwrap();
374
375 w.writeln("case 0:").unwrap();
376 {
377 let _indent = w.indent();
378 if is_fallible {
379 let ShapeKind::Result { err, .. } = classify_shape(method.return_shape) else {
380 unreachable!()
381 };
382 let decode_err = generate_decode_stmt_from_with_cursor(
383 err,
384 "userError",
385 "",
386 response_var,
387 cursor_var,
388 );
389 for line in decode_err.lines() {
390 w.writeln(line).unwrap();
391 }
392 w.writeln("return .failure(userError)").unwrap();
393 } else {
394 w.writeln(
395 "throw VoxError.decodeError(\"unexpected user error for infallible method\")",
396 )
397 .unwrap();
398 }
399 }
400 w.writeln("case 1:").unwrap();
401 w.writeln(" throw VoxError.unknownMethod").unwrap();
402 w.writeln("case 2:").unwrap();
403 w.writeln(" throw VoxError.decodeError(\"invalid payload\")")
404 .unwrap();
405 w.writeln("case 3:").unwrap();
406 w.writeln(" throw VoxError.cancelled").unwrap();
407 w.writeln("case 4:").unwrap();
408 w.writeln(" throw VoxError.indeterminate").unwrap();
409 w.writeln("default:").unwrap();
410 cw_writeln!(
411 w,
412 " throw VoxError.decodeError(\"invalid VoxError discriminant: \\({error_code_var})\")"
413 )
414 .unwrap();
415 w.writeln("}").unwrap();
416 }
417
418 w.writeln("default:").unwrap();
419 cw_writeln!(
420 w,
421 " throw VoxError.decodeError(\"invalid Result discriminant: \\({result_disc_var})\")"
422 )
423 .unwrap();
424 w.writeln("}").unwrap();
425}