Skip to main content

vox_codegen/targets/swift/
client.rs

1//! Swift client generation.
2//!
3//! Generates the caller protocol and client implementation for making
4//! application-level RPC calls. Each generated client method represents one
5//! logical call; the runtime may realize that call with one request attempt or
6//! multiple request attempts if retry/session recovery creates later attempts
7//! for the same operation.
8
9use heck::{ToLowerCamelCase, ToUpperCamelCase};
10use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape};
11
12use super::decode::generate_decode_stmt_with_buf;
13use super::encode::generate_encode_stmt;
14use super::types::{format_doc, is_channel, swift_type_client_arg, swift_type_client_return};
15use crate::code_writer::CodeWriter;
16use crate::cw_writeln;
17use crate::render::hex_u64;
18
19fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
20    match (method.retry.persist, method.retry.idem) {
21        (false, false) => ".volatile",
22        (false, true) => ".idem",
23        (true, false) => ".persist",
24        (true, true) => ".persistIdem",
25    }
26}
27
28/// Generate complete client code (caller protocol + client implementation).
29///
30/// The generated API speaks in terms of application-level calls, while the
31/// runtime beneath it sends wire-level request attempts and receives responses.
32pub fn generate_client(service: &ServiceDescriptor) -> String {
33    let mut out = String::new();
34    out.push_str(&generate_caller_protocol(service));
35    out.push_str(&generate_client_impl(service));
36    out
37}
38
39/// Generate caller protocol for making application-level service calls.
40fn generate_caller_protocol(service: &ServiceDescriptor) -> String {
41    let mut out = String::new();
42    let service_name = service.service_name.to_upper_camel_case();
43
44    if let Some(doc) = &service.doc {
45        out.push_str(&format_doc(doc, ""));
46    }
47    out.push_str(&format!("public protocol {service_name}Caller {{\n"));
48
49    for method in service.methods {
50        let method_name = method.method_name.to_lower_camel_case();
51
52        if let Some(doc) = &method.doc {
53            out.push_str(&format_doc(doc, "    "));
54        }
55
56        let args: Vec<String> = method
57            .args
58            .iter()
59            .map(|a| {
60                format!(
61                    "{}: {}",
62                    a.name.to_lower_camel_case(),
63                    swift_type_client_arg(a.shape)
64                )
65            })
66            .collect();
67
68        let ret_type = swift_type_client_return(method.return_shape);
69
70        if ret_type == "Void" {
71            out.push_str(&format!(
72                "    func {method_name}({}) async throws\n",
73                args.join(", ")
74            ));
75        } else {
76            out.push_str(&format!(
77                "    func {method_name}({}) async throws -> {ret_type}\n",
78                args.join(", ")
79            ));
80        }
81    }
82
83    out.push_str("}\n\n");
84    out
85}
86
87/// Generate client implementation for making application-level service calls.
88fn generate_client_impl(service: &ServiceDescriptor) -> String {
89    let mut out = String::new();
90    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
91    let service_name = service.service_name.to_upper_camel_case();
92
93    w.writeln(&format!(
94        "public final class {service_name}Client: {service_name}Caller, Sendable {{"
95    ))
96    .unwrap();
97    {
98        let _indent = w.indent();
99        w.writeln("private let connection: VoxConnection").unwrap();
100        w.writeln("private let timeout: TimeInterval?").unwrap();
101        w.blank_line().unwrap();
102        w.writeln("public init(connection: VoxConnection, timeout: TimeInterval? = 30.0) {")
103            .unwrap();
104        {
105            let _indent = w.indent();
106            w.writeln("self.connection = connection").unwrap();
107            w.writeln("self.timeout = timeout").unwrap();
108        }
109        w.writeln("}").unwrap();
110
111        for method in service.methods {
112            w.blank_line().unwrap();
113            generate_client_method(&mut w, method, &service_name);
114        }
115    }
116    w.writeln("}").unwrap();
117    w.blank_line().unwrap();
118
119    // ExpectedRootClient conformance — lets `Session.initiator(...,
120    // expecting: <Service>Client.self, ...)` inject the right
121    // `vox-service` metadata so the server's per-connection router
122    // dispatches to this service.
123    w.writeln(&format!(
124        "extension {service_name}Client: ExpectedRootClient {{"
125    ))
126    .unwrap();
127    {
128        let _indent = w.indent();
129        w.writeln(&format!(
130            "public static let voxServiceName = \"{}\"",
131            service.service_name
132        ))
133        .unwrap();
134    }
135    w.writeln("}").unwrap();
136    w.blank_line().unwrap();
137
138    out
139}
140
141/// Generate a single client method implementation.
142///
143/// One generated method corresponds to one logical call. At runtime, the
144/// underlying connection sends one request attempt immediately, and may later
145/// send additional request attempts for the same logical operation if retry or
146/// session recovery requires it.
147fn generate_client_method(
148    w: &mut CodeWriter<&mut String>,
149    method: &MethodDescriptor,
150    service_name: &str,
151) {
152    let method_name = method.method_name.to_lower_camel_case();
153    let method_id_name = method.method_name.to_lower_camel_case();
154
155    let args: Vec<String> = method
156        .args
157        .iter()
158        .map(|a| {
159            format!(
160                "{}: {}",
161                a.name.to_lower_camel_case(),
162                swift_type_client_arg(a.shape)
163            )
164        })
165        .collect();
166
167    let ret_type = swift_type_client_return(method.return_shape);
168    let has_streaming = method.args.iter().any(|a| is_channel(a.shape));
169    let retry_policy = swift_retry_policy_literal(method);
170
171    // Method signature
172    if ret_type == "Void" {
173        cw_writeln!(
174            w,
175            "public func {method_name}({}) async throws {{",
176            args.join(", ")
177        )
178        .unwrap();
179    } else {
180        cw_writeln!(
181            w,
182            "public func {method_name}({}) async throws -> {ret_type} {{",
183            args.join(", ")
184        )
185        .unwrap();
186    }
187
188    {
189        let _indent = w.indent();
190        let cursor_var = unique_decode_cursor_name(method.args);
191
192        let service_name_lower = service_name.to_lower_camel_case();
193        let method_id = crate::method_id(method);
194
195        if has_streaming {
196            generate_streaming_client_body(
197                w,
198                method,
199                service_name,
200                &method_id_name,
201                &cursor_var,
202                retry_policy,
203            );
204        } else {
205            // Encode arguments
206            generate_encode_args(w, method.args);
207
208            // Build schema info for the request
209            cw_writeln!(
210                w,
211                "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
212                hex_u64(method_id)
213            )
214            .unwrap();
215
216            // Start the first request attempt for this logical call.
217            cw_writeln!(
218                w,
219                "let response = try await connection.call(methodId: {}, metadata: [], payload: payload, retry: {retry_policy}, timeout: timeout, prepareRetry: nil, finalizeChannels: nil, schemaInfo: schemaInfo)",
220                hex_u64(method_id),
221            )
222            .unwrap();
223            generate_response_decode(w, method, &cursor_var, "response");
224        }
225    }
226    w.writeln("}").unwrap();
227}
228
229/// Generate code to encode method arguments (for client).
230fn generate_encode_args(w: &mut CodeWriter<&mut String>, args: &[vox_types::ArgDescriptor]) {
231    if args.is_empty() {
232        w.writeln("let payload: [UInt8] = []").unwrap();
233        return;
234    }
235
236    w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: 64)")
237        .unwrap();
238    for arg in args {
239        let arg_name = arg.name.to_lower_camel_case();
240        let stmt = generate_encode_stmt(arg.shape, &arg_name);
241        for line in stmt.lines() {
242            w.writeln(line).unwrap();
243        }
244    }
245    w.writeln("let payload = buffer.readBytes(length: buffer.readableBytes) ?? []")
246        .unwrap();
247}
248
249/// Generate client body for channel-bearing methods.
250///
251/// These methods still represent one logical call at the API level, but the
252/// request payload and channel bindings may need to be rebuilt for later
253/// request attempts if retry/session recovery triggers another attempt for the
254/// same operation.
255fn generate_streaming_client_body(
256    w: &mut CodeWriter<&mut String>,
257    method: &MethodDescriptor,
258    service_name: &str,
259    _method_id_name: &str,
260    cursor_var: &str,
261    retry_policy: &str,
262) {
263    let service_name_lower = service_name.to_lower_camel_case();
264
265    let arg_names: Vec<String> = method
266        .args
267        .iter()
268        .map(|a| a.name.to_lower_camel_case())
269        .collect();
270
271    let method_id = crate::method_id(method);
272
273    // Build schema info for the request
274    cw_writeln!(
275        w,
276        "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
277        hex_u64(method_id)
278    )
279    .unwrap();
280
281    w.writeln("let prepareRetry: @Sendable () async -> PreparedRetryRequest = { [connection] in")
282        .unwrap();
283    {
284        let _indent = w.indent();
285        w.writeln("await bindChannels(").unwrap();
286        {
287            let _indent = w.indent();
288            cw_writeln!(
289                w,
290                "argsRoot: {service_name_lower}_method_schemas[{}]!.argsRoot,",
291                hex_u64(method_id)
292            )
293            .unwrap();
294            cw_writeln!(w, "schemaRegistry: {service_name_lower}_schema_registry,").unwrap();
295            cw_writeln!(w, "args: [{}],", arg_names.join(", ")).unwrap();
296            w.writeln("allocator: connection.channelAllocator,")
297                .unwrap();
298            w.writeln("incomingRegistry: connection.incomingChannelRegistry,")
299                .unwrap();
300            w.writeln("taskSender: connection.taskSender,").unwrap();
301        }
302        w.writeln(")").unwrap();
303        w.blank_line().unwrap();
304        generate_encode_args(w, method.args);
305        w.writeln("return PreparedRetryRequest(payload: payload)")
306            .unwrap();
307    }
308    w.writeln("}").unwrap();
309    w.writeln("let prepared = await prepareRetry()").unwrap();
310    w.blank_line().unwrap();
311
312    // Start the first request attempt for this logical call.
313    let ret_type = swift_type_client_return(method.return_shape);
314    let _ = ret_type;
315    cw_writeln!(
316        w,
317        "let response = try await connection.call(methodId: {}, metadata: [], payload: prepared.payload, retry: {retry_policy}, timeout: timeout, prepareRetry: prepareRetry, finalizeChannels: {{ finalizeBoundChannels(argsRoot: {service_name_lower}_method_schemas[{}]!.argsRoot, schemaRegistry: {service_name_lower}_schema_registry, args: [{}]) }}, schemaInfo: schemaInfo)",
318        hex_u64(method_id),
319        hex_u64(method_id),
320        arg_names.join(", ")
321    )
322    .unwrap();
323    generate_response_decode(w, method, cursor_var, "response");
324}
325
326fn unique_decode_cursor_name(args: &[vox_types::ArgDescriptor]) -> String {
327    let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
328    let mut candidate = String::from("cursor");
329    while arg_names.iter().any(|name| name == &candidate) {
330        candidate.push('_');
331    }
332    candidate
333}
334
335/// Generate code to decode the wire response payload for one request attempt
336/// by delegating to a VoxRuntime helper (`decodeInfallibleResponse` or
337/// `decodeFallibleResponse`).
338fn generate_response_decode(
339    w: &mut CodeWriter<&mut String>,
340    method: &MethodDescriptor,
341    _cursor_var: &str,
342    response_var: &str,
343) {
344    let ret_type = swift_type_client_return(method.return_shape);
345
346    if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
347        // Case 1: Fallible – delegate to decodeFallibleResponse
348        let decode_ok = generate_decode_stmt_with_buf(ok, "value", "", "buf");
349        let decode_err = generate_decode_stmt_with_buf(err, "userError", "", "buf");
350
351        cw_writeln!(w, "return try decodeFallibleResponse({response_var},").unwrap();
352        {
353            let _indent = w.indent();
354            w.writeln("decodeOk: { buf in").unwrap();
355            {
356                let _indent = w.indent();
357                for line in decode_ok.lines() {
358                    w.writeln(line).unwrap();
359                }
360                w.writeln("return value").unwrap();
361            }
362            w.writeln("},").unwrap();
363            w.writeln("decodeErr: { buf in").unwrap();
364            {
365                let _indent = w.indent();
366                for line in decode_err.lines() {
367                    w.writeln(line).unwrap();
368                }
369                w.writeln("return userError").unwrap();
370            }
371            w.writeln("})").unwrap();
372        }
373    } else if ret_type == "Void" {
374        // Case 2: Infallible Void
375        cw_writeln!(w, "try decodeInfallibleResponse({response_var}) {{ _ in }}").unwrap();
376    } else {
377        // Case 3: Infallible non-Void
378        let decode_stmt = generate_decode_stmt_with_buf(method.return_shape, "result", "", "buf");
379        cw_writeln!(
380            w,
381            "return try decodeInfallibleResponse({response_var}) {{ buf in"
382        )
383        .unwrap();
384        {
385            let _indent = w.indent();
386            for line in decode_stmt.lines() {
387                w.writeln(line).unwrap();
388            }
389            w.writeln("return result").unwrap();
390        }
391        w.writeln("}").unwrap();
392    }
393}