Skip to main content

vox_codegen/targets/swift/
client.rs

1//! Swift client generation.
2//!
3//! Generates the caller protocol and client implementation for making
4//! application-level RPC calls. Each generated client method represents one
5//! logical call; the runtime may realize that call with one request attempt or
6//! multiple request attempts if retry/session recovery creates later attempts
7//! for the same operation.
8
9use heck::{ToLowerCamelCase, ToUpperCamelCase};
10use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape};
11
12use super::decode::generate_decode_stmt_from_with_cursor;
13use super::encode::generate_encode_expr;
14use super::types::{format_doc, is_channel, swift_type_client_arg, swift_type_client_return};
15use crate::code_writer::CodeWriter;
16use crate::cw_writeln;
17use crate::render::hex_u64;
18
19fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
20    match (method.retry.persist, method.retry.idem) {
21        (false, false) => ".volatile",
22        (false, true) => ".idem",
23        (true, false) => ".persist",
24        (true, true) => ".persistIdem",
25    }
26}
27
28/// Generate complete client code (caller protocol + client implementation).
29///
30/// The generated API speaks in terms of application-level calls, while the
31/// runtime beneath it sends wire-level request attempts and receives responses.
32pub fn generate_client(service: &ServiceDescriptor) -> String {
33    let mut out = String::new();
34    out.push_str(&generate_caller_protocol(service));
35    out.push_str(&generate_client_impl(service));
36    out
37}
38
39/// Generate caller protocol for making application-level service calls.
40fn generate_caller_protocol(service: &ServiceDescriptor) -> String {
41    let mut out = String::new();
42    let service_name = service.service_name.to_upper_camel_case();
43
44    if let Some(doc) = &service.doc {
45        out.push_str(&format_doc(doc, ""));
46    }
47    out.push_str(&format!("public protocol {service_name}Caller {{\n"));
48
49    for method in service.methods {
50        let method_name = method.method_name.to_lower_camel_case();
51
52        if let Some(doc) = &method.doc {
53            out.push_str(&format_doc(doc, "    "));
54        }
55
56        let args: Vec<String> = method
57            .args
58            .iter()
59            .map(|a| {
60                format!(
61                    "{}: {}",
62                    a.name.to_lower_camel_case(),
63                    swift_type_client_arg(a.shape)
64                )
65            })
66            .collect();
67
68        let ret_type = swift_type_client_return(method.return_shape);
69
70        if ret_type == "Void" {
71            out.push_str(&format!(
72                "    func {method_name}({}) async throws\n",
73                args.join(", ")
74            ));
75        } else {
76            out.push_str(&format!(
77                "    func {method_name}({}) async throws -> {ret_type}\n",
78                args.join(", ")
79            ));
80        }
81    }
82
83    out.push_str("}\n\n");
84    out
85}
86
87/// Generate client implementation for making application-level service calls.
88fn generate_client_impl(service: &ServiceDescriptor) -> String {
89    let mut out = String::new();
90    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
91    let service_name = service.service_name.to_upper_camel_case();
92
93    w.writeln(&format!(
94        "public final class {service_name}Client: {service_name}Caller, Sendable {{"
95    ))
96    .unwrap();
97    {
98        let _indent = w.indent();
99        w.writeln("private let connection: VoxConnection")
100            .unwrap();
101        w.writeln("private let timeout: TimeInterval?").unwrap();
102        w.blank_line().unwrap();
103        w.writeln("public init(connection: VoxConnection, timeout: TimeInterval? = 30.0) {")
104            .unwrap();
105        {
106            let _indent = w.indent();
107            w.writeln("self.connection = connection").unwrap();
108            w.writeln("self.timeout = timeout").unwrap();
109        }
110        w.writeln("}").unwrap();
111
112        for method in service.methods {
113            w.blank_line().unwrap();
114            generate_client_method(&mut w, method, &service_name);
115        }
116    }
117    w.writeln("}").unwrap();
118    w.blank_line().unwrap();
119
120    out
121}
122
123/// Generate a single client method implementation.
124///
125/// One generated method corresponds to one logical call. At runtime, the
126/// underlying connection sends one request attempt immediately, and may later
127/// send additional request attempts for the same logical operation if retry or
128/// session recovery requires it.
129fn generate_client_method(
130    w: &mut CodeWriter<&mut String>,
131    method: &MethodDescriptor,
132    service_name: &str,
133) {
134    let method_name = method.method_name.to_lower_camel_case();
135    let method_id_name = method.method_name.to_lower_camel_case();
136
137    let args: Vec<String> = method
138        .args
139        .iter()
140        .map(|a| {
141            format!(
142                "{}: {}",
143                a.name.to_lower_camel_case(),
144                swift_type_client_arg(a.shape)
145            )
146        })
147        .collect();
148
149    let ret_type = swift_type_client_return(method.return_shape);
150    let has_streaming = method.args.iter().any(|a| is_channel(a.shape));
151    let retry_policy = swift_retry_policy_literal(method);
152
153    // Method signature
154    if ret_type == "Void" {
155        cw_writeln!(
156            w,
157            "public func {method_name}({}) async throws {{",
158            args.join(", ")
159        )
160        .unwrap();
161    } else {
162        cw_writeln!(
163            w,
164            "public func {method_name}({}) async throws -> {ret_type} {{",
165            args.join(", ")
166        )
167        .unwrap();
168    }
169
170    {
171        let _indent = w.indent();
172        let cursor_var = unique_decode_cursor_name(method.args);
173
174        let service_name_lower = service_name.to_lower_camel_case();
175        let method_id = crate::method_id(method);
176
177        if has_streaming {
178            generate_streaming_client_body(
179                w,
180                method,
181                service_name,
182                &method_id_name,
183                &cursor_var,
184                retry_policy,
185            );
186        } else {
187            // Encode arguments
188            generate_encode_args(w, method.args);
189
190            // Build schema info for the request
191            cw_writeln!(
192                w,
193                "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
194                hex_u64(method_id)
195            )
196            .unwrap();
197
198            // Start the first request attempt for this logical call.
199            cw_writeln!(
200                w,
201                "let response = try await connection.call(methodId: {}, metadata: [], payload: payload, retry: {retry_policy}, timeout: timeout, prepareRetry: nil, finalizeChannels: nil, schemaInfo: schemaInfo)",
202                hex_u64(method_id),
203            )
204            .unwrap();
205            generate_response_decode(w, method, &cursor_var, "response");
206        }
207    }
208    w.writeln("}").unwrap();
209}
210
211/// Generate code to encode method arguments (for client).
212fn generate_encode_args(w: &mut CodeWriter<&mut String>, args: &[vox_types::ArgDescriptor]) {
213    if args.is_empty() {
214        w.writeln("let payload = Data()").unwrap();
215        return;
216    }
217
218    w.writeln("var payloadBytes: [UInt8] = []").unwrap();
219    for arg in args {
220        let arg_name = arg.name.to_lower_camel_case();
221        let encode_expr = generate_encode_expr(arg.shape, &arg_name);
222        cw_writeln!(w, "payloadBytes += {encode_expr}").unwrap();
223    }
224    w.writeln("let payload = Data(payloadBytes)").unwrap();
225}
226
227/// Generate client body for channel-bearing methods.
228///
229/// These methods still represent one logical call at the API level, but the
230/// request payload and channel bindings may need to be rebuilt for later
231/// request attempts if retry/session recovery triggers another attempt for the
232/// same operation.
233fn generate_streaming_client_body(
234    w: &mut CodeWriter<&mut String>,
235    method: &MethodDescriptor,
236    service_name: &str,
237    method_id_name: &str,
238    cursor_var: &str,
239    retry_policy: &str,
240) {
241    let service_name_lower = service_name.to_lower_camel_case();
242
243    let arg_names: Vec<String> = method
244        .args
245        .iter()
246        .map(|a| a.name.to_lower_camel_case())
247        .collect();
248
249    let method_id = crate::method_id(method);
250
251    // Build schema info for the request
252    cw_writeln!(
253        w,
254        "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
255        hex_u64(method_id)
256    )
257    .unwrap();
258
259    w.writeln("let prepareRetry: @Sendable () async -> PreparedRetryRequest = { [connection] in")
260        .unwrap();
261    {
262        let _indent = w.indent();
263        w.writeln("await bindChannels(").unwrap();
264        {
265            let _indent = w.indent();
266            cw_writeln!(
267                w,
268                "schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args,"
269            )
270            .unwrap();
271            cw_writeln!(w, "args: [{}],", arg_names.join(", ")).unwrap();
272            w.writeln("allocator: connection.channelAllocator,")
273                .unwrap();
274            w.writeln("incomingRegistry: connection.incomingChannelRegistry,")
275                .unwrap();
276            w.writeln("taskSender: connection.taskSender,").unwrap();
277            cw_writeln!(w, "serializers: {service_name}Serializers()").unwrap();
278        }
279        w.writeln(")").unwrap();
280        w.blank_line().unwrap();
281        generate_encode_args(w, method.args);
282        w.writeln("return PreparedRetryRequest(payload: Array(payload))")
283            .unwrap();
284    }
285    w.writeln("}").unwrap();
286    w.writeln("let prepared = await prepareRetry()").unwrap();
287    w.blank_line().unwrap();
288
289    // Start the first request attempt for this logical call.
290    let ret_type = swift_type_client_return(method.return_shape);
291    let _ = ret_type;
292    cw_writeln!(
293        w,
294        "let response = try await connection.call(methodId: {}, metadata: [], payload: Data(prepared.payload), retry: {retry_policy}, timeout: timeout, prepareRetry: prepareRetry, finalizeChannels: {{ finalizeBoundChannels(schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args, args: [{}]) }}, schemaInfo: schemaInfo)",
295        hex_u64(method_id),
296        arg_names.join(", ")
297    )
298    .unwrap();
299    generate_response_decode(w, method, cursor_var, "response");
300}
301
302fn unique_decode_cursor_name(args: &[vox_types::ArgDescriptor]) -> String {
303    let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
304    let mut candidate = String::from("cursor");
305    while arg_names.iter().any(|name| name == &candidate) {
306        candidate.push('_');
307    }
308    candidate
309}
310
311/// Generate code to decode the wire response payload for one request attempt:
312/// `Result<T, VoxError<E>>`.
313fn generate_response_decode(
314    w: &mut CodeWriter<&mut String>,
315    method: &MethodDescriptor,
316    cursor_var: &str,
317    response_var: &str,
318) {
319    let ret_type = swift_type_client_return(method.return_shape);
320    let result_disc_var = format!("_{cursor_var}_resultDisc");
321    let error_code_var = format!("_{cursor_var}_errorCode");
322    let is_fallible = matches!(
323        classify_shape(method.return_shape),
324        ShapeKind::Result { .. }
325    );
326
327    cw_writeln!(w, "var {cursor_var} = 0").unwrap();
328    cw_writeln!(
329        w,
330        "let {result_disc_var} = try decodeVarint(from: {response_var}, offset: &{cursor_var})"
331    )
332    .unwrap();
333    cw_writeln!(w, "switch {result_disc_var} {{").unwrap();
334
335    w.writeln("case 0:").unwrap();
336    {
337        let _indent = w.indent();
338        if is_fallible {
339            let ShapeKind::Result { ok, .. } = classify_shape(method.return_shape) else {
340                unreachable!()
341            };
342            let decode_ok =
343                generate_decode_stmt_from_with_cursor(ok, "value", "", response_var, cursor_var);
344            for line in decode_ok.lines() {
345                w.writeln(line).unwrap();
346            }
347            w.writeln("return .success(value)").unwrap();
348        } else if ret_type == "Void" {
349            w.writeln("return").unwrap();
350        } else {
351            let decode_stmt = generate_decode_stmt_from_with_cursor(
352                method.return_shape,
353                "result",
354                "",
355                response_var,
356                cursor_var,
357            );
358            for line in decode_stmt.lines() {
359                w.writeln(line).unwrap();
360            }
361            w.writeln("return result").unwrap();
362        }
363    }
364
365    w.writeln("case 1:").unwrap();
366    {
367        let _indent = w.indent();
368        cw_writeln!(
369            w,
370            "let {error_code_var} = try decodeU8(from: {response_var}, offset: &{cursor_var})"
371        )
372        .unwrap();
373        cw_writeln!(w, "switch {error_code_var} {{").unwrap();
374
375        w.writeln("case 0:").unwrap();
376        {
377            let _indent = w.indent();
378            if is_fallible {
379                let ShapeKind::Result { err, .. } = classify_shape(method.return_shape) else {
380                    unreachable!()
381                };
382                let decode_err = generate_decode_stmt_from_with_cursor(
383                    err,
384                    "userError",
385                    "",
386                    response_var,
387                    cursor_var,
388                );
389                for line in decode_err.lines() {
390                    w.writeln(line).unwrap();
391                }
392                w.writeln("return .failure(userError)").unwrap();
393            } else {
394                w.writeln(
395                    "throw VoxError.decodeError(\"unexpected user error for infallible method\")",
396                )
397                .unwrap();
398            }
399        }
400        w.writeln("case 1:").unwrap();
401        w.writeln("    throw VoxError.unknownMethod").unwrap();
402        w.writeln("case 2:").unwrap();
403        w.writeln("    throw VoxError.decodeError(\"invalid payload\")")
404            .unwrap();
405        w.writeln("case 3:").unwrap();
406        w.writeln("    throw VoxError.cancelled").unwrap();
407        w.writeln("case 4:").unwrap();
408        w.writeln("    throw VoxError.indeterminate").unwrap();
409        w.writeln("default:").unwrap();
410        cw_writeln!(
411            w,
412            "    throw VoxError.decodeError(\"invalid VoxError discriminant: \\({error_code_var})\")"
413        )
414        .unwrap();
415        w.writeln("}").unwrap();
416    }
417
418    w.writeln("default:").unwrap();
419    cw_writeln!(
420        w,
421        "    throw VoxError.decodeError(\"invalid Result discriminant: \\({result_disc_var})\")"
422    )
423    .unwrap();
424    w.writeln("}").unwrap();
425}