Skip to main content

vox_codegen/targets/swift/
client.rs

1//! Swift client generation.
2//!
3//! Generates the caller protocol and client implementation for making
4//! application-level RPC calls. Each generated client method represents one
5//! logical call; the runtime may realize that call with one request attempt or
6//! multiple request attempts if retry/session recovery creates later attempts
7//! for the same operation.
8
9use heck::{ToLowerCamelCase, ToUpperCamelCase};
10use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape};
11
12use super::decode::generate_decode_stmt_with_buf;
13use super::encode::generate_encode_stmt;
14use super::types::{format_doc, is_channel, swift_type_client_arg, swift_type_client_return};
15use crate::code_writer::CodeWriter;
16use crate::cw_writeln;
17use crate::render::hex_u64;
18
19fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
20    match (method.retry.persist, method.retry.idem) {
21        (false, false) => ".volatile",
22        (false, true) => ".idem",
23        (true, false) => ".persist",
24        (true, true) => ".persistIdem",
25    }
26}
27
28/// Generate complete client code (caller protocol + client implementation).
29///
30/// The generated API speaks in terms of application-level calls, while the
31/// runtime beneath it sends wire-level request attempts and receives responses.
32pub fn generate_client(service: &ServiceDescriptor) -> String {
33    let mut out = String::new();
34    out.push_str(&generate_caller_protocol(service));
35    out.push_str(&generate_client_impl(service));
36    out
37}
38
39/// Generate caller protocol for making application-level service calls.
40fn generate_caller_protocol(service: &ServiceDescriptor) -> String {
41    let mut out = String::new();
42    let service_name = service.service_name.to_upper_camel_case();
43
44    if let Some(doc) = &service.doc {
45        out.push_str(&format_doc(doc, ""));
46    }
47    out.push_str(&format!("public protocol {service_name}Caller {{\n"));
48
49    for method in service.methods {
50        let method_name = method.method_name.to_lower_camel_case();
51
52        if let Some(doc) = &method.doc {
53            out.push_str(&format_doc(doc, "    "));
54        }
55
56        let args: Vec<String> = method
57            .args
58            .iter()
59            .map(|a| {
60                format!(
61                    "{}: {}",
62                    a.name.to_lower_camel_case(),
63                    swift_type_client_arg(a.shape)
64                )
65            })
66            .collect();
67
68        let ret_type = swift_type_client_return(method.return_shape);
69
70        if ret_type == "Void" {
71            out.push_str(&format!(
72                "    func {method_name}({}) async throws\n",
73                args.join(", ")
74            ));
75        } else {
76            out.push_str(&format!(
77                "    func {method_name}({}) async throws -> {ret_type}\n",
78                args.join(", ")
79            ));
80        }
81    }
82
83    out.push_str("}\n\n");
84    out
85}
86
87/// Generate client implementation for making application-level service calls.
88fn generate_client_impl(service: &ServiceDescriptor) -> String {
89    let mut out = String::new();
90    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
91    let service_name = service.service_name.to_upper_camel_case();
92
93    w.writeln(&format!(
94        "public final class {service_name}Client: {service_name}Caller, Sendable {{"
95    ))
96    .unwrap();
97    {
98        let _indent = w.indent();
99        w.writeln("private let connection: VoxConnection").unwrap();
100        w.writeln("private let timeout: TimeInterval?").unwrap();
101        w.blank_line().unwrap();
102        w.writeln("public init(connection: VoxConnection, timeout: TimeInterval? = 30.0) {")
103            .unwrap();
104        {
105            let _indent = w.indent();
106            w.writeln("self.connection = connection").unwrap();
107            w.writeln("self.timeout = timeout").unwrap();
108        }
109        w.writeln("}").unwrap();
110
111        for method in service.methods {
112            w.blank_line().unwrap();
113            generate_client_method(&mut w, method, &service_name);
114        }
115    }
116    w.writeln("}").unwrap();
117    w.blank_line().unwrap();
118
119    out
120}
121
122/// Generate a single client method implementation.
123///
124/// One generated method corresponds to one logical call. At runtime, the
125/// underlying connection sends one request attempt immediately, and may later
126/// send additional request attempts for the same logical operation if retry or
127/// session recovery requires it.
128fn generate_client_method(
129    w: &mut CodeWriter<&mut String>,
130    method: &MethodDescriptor,
131    service_name: &str,
132) {
133    let method_name = method.method_name.to_lower_camel_case();
134    let method_id_name = method.method_name.to_lower_camel_case();
135
136    let args: Vec<String> = method
137        .args
138        .iter()
139        .map(|a| {
140            format!(
141                "{}: {}",
142                a.name.to_lower_camel_case(),
143                swift_type_client_arg(a.shape)
144            )
145        })
146        .collect();
147
148    let ret_type = swift_type_client_return(method.return_shape);
149    let has_streaming = method.args.iter().any(|a| is_channel(a.shape));
150    let retry_policy = swift_retry_policy_literal(method);
151
152    // Method signature
153    if ret_type == "Void" {
154        cw_writeln!(
155            w,
156            "public func {method_name}({}) async throws {{",
157            args.join(", ")
158        )
159        .unwrap();
160    } else {
161        cw_writeln!(
162            w,
163            "public func {method_name}({}) async throws -> {ret_type} {{",
164            args.join(", ")
165        )
166        .unwrap();
167    }
168
169    {
170        let _indent = w.indent();
171        let cursor_var = unique_decode_cursor_name(method.args);
172
173        let service_name_lower = service_name.to_lower_camel_case();
174        let method_id = crate::method_id(method);
175
176        if has_streaming {
177            generate_streaming_client_body(
178                w,
179                method,
180                service_name,
181                &method_id_name,
182                &cursor_var,
183                retry_policy,
184            );
185        } else {
186            // Encode arguments
187            generate_encode_args(w, method.args);
188
189            // Build schema info for the request
190            cw_writeln!(
191                w,
192                "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
193                hex_u64(method_id)
194            )
195            .unwrap();
196
197            // Start the first request attempt for this logical call.
198            cw_writeln!(
199                w,
200                "let response = try await connection.call(methodId: {}, metadata: [], payload: payload, retry: {retry_policy}, timeout: timeout, prepareRetry: nil, finalizeChannels: nil, schemaInfo: schemaInfo)",
201                hex_u64(method_id),
202            )
203            .unwrap();
204            generate_response_decode(w, method, &cursor_var, "response");
205        }
206    }
207    w.writeln("}").unwrap();
208}
209
210/// Generate code to encode method arguments (for client).
211fn generate_encode_args(w: &mut CodeWriter<&mut String>, args: &[vox_types::ArgDescriptor]) {
212    if args.is_empty() {
213        w.writeln("let payload: [UInt8] = []").unwrap();
214        return;
215    }
216
217    w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: 64)")
218        .unwrap();
219    for arg in args {
220        let arg_name = arg.name.to_lower_camel_case();
221        let stmt = generate_encode_stmt(arg.shape, &arg_name);
222        for line in stmt.lines() {
223            w.writeln(line).unwrap();
224        }
225    }
226    w.writeln("let payload = buffer.readBytes(length: buffer.readableBytes) ?? []")
227        .unwrap();
228}
229
230/// Generate client body for channel-bearing methods.
231///
232/// These methods still represent one logical call at the API level, but the
233/// request payload and channel bindings may need to be rebuilt for later
234/// request attempts if retry/session recovery triggers another attempt for the
235/// same operation.
236fn generate_streaming_client_body(
237    w: &mut CodeWriter<&mut String>,
238    method: &MethodDescriptor,
239    service_name: &str,
240    method_id_name: &str,
241    cursor_var: &str,
242    retry_policy: &str,
243) {
244    let service_name_lower = service_name.to_lower_camel_case();
245
246    let arg_names: Vec<String> = method
247        .args
248        .iter()
249        .map(|a| a.name.to_lower_camel_case())
250        .collect();
251
252    let method_id = crate::method_id(method);
253
254    // Build schema info for the request
255    cw_writeln!(
256        w,
257        "let schemaInfo = ClientSchemaInfo(methodInfo: {service_name_lower}_method_schemas[{}]!, schemaRegistry: {service_name_lower}_schema_registry)",
258        hex_u64(method_id)
259    )
260    .unwrap();
261
262    w.writeln("let prepareRetry: @Sendable () async -> PreparedRetryRequest = { [connection] in")
263        .unwrap();
264    {
265        let _indent = w.indent();
266        w.writeln("await bindChannels(").unwrap();
267        {
268            let _indent = w.indent();
269            cw_writeln!(
270                w,
271                "schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args,"
272            )
273            .unwrap();
274            cw_writeln!(w, "args: [{}],", arg_names.join(", ")).unwrap();
275            w.writeln("allocator: connection.channelAllocator,")
276                .unwrap();
277            w.writeln("incomingRegistry: connection.incomingChannelRegistry,")
278                .unwrap();
279            w.writeln("taskSender: connection.taskSender,").unwrap();
280            cw_writeln!(w, "serializers: {service_name}Serializers()").unwrap();
281        }
282        w.writeln(")").unwrap();
283        w.blank_line().unwrap();
284        generate_encode_args(w, method.args);
285        w.writeln("return PreparedRetryRequest(payload: payload)")
286            .unwrap();
287    }
288    w.writeln("}").unwrap();
289    w.writeln("let prepared = await prepareRetry()").unwrap();
290    w.blank_line().unwrap();
291
292    // Start the first request attempt for this logical call.
293    let ret_type = swift_type_client_return(method.return_shape);
294    let _ = ret_type;
295    cw_writeln!(
296        w,
297        "let response = try await connection.call(methodId: {}, metadata: [], payload: prepared.payload, retry: {retry_policy}, timeout: timeout, prepareRetry: prepareRetry, finalizeChannels: {{ finalizeBoundChannels(schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args, args: [{}]) }}, schemaInfo: schemaInfo)",
298        hex_u64(method_id),
299        arg_names.join(", ")
300    )
301    .unwrap();
302    generate_response_decode(w, method, cursor_var, "response");
303}
304
305fn unique_decode_cursor_name(args: &[vox_types::ArgDescriptor]) -> String {
306    let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
307    let mut candidate = String::from("cursor");
308    while arg_names.iter().any(|name| name == &candidate) {
309        candidate.push('_');
310    }
311    candidate
312}
313
314/// Generate code to decode the wire response payload for one request attempt
315/// by delegating to a VoxRuntime helper (`decodeInfallibleResponse` or
316/// `decodeFallibleResponse`).
317fn generate_response_decode(
318    w: &mut CodeWriter<&mut String>,
319    method: &MethodDescriptor,
320    _cursor_var: &str,
321    response_var: &str,
322) {
323    let ret_type = swift_type_client_return(method.return_shape);
324
325    if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
326        // Case 1: Fallible – delegate to decodeFallibleResponse
327        let decode_ok = generate_decode_stmt_with_buf(ok, "value", "", "buf");
328        let decode_err = generate_decode_stmt_with_buf(err, "userError", "", "buf");
329
330        cw_writeln!(w, "return try decodeFallibleResponse({response_var},").unwrap();
331        {
332            let _indent = w.indent();
333            w.writeln("decodeOk: { buf in").unwrap();
334            {
335                let _indent = w.indent();
336                for line in decode_ok.lines() {
337                    w.writeln(line).unwrap();
338                }
339                w.writeln("return value").unwrap();
340            }
341            w.writeln("},").unwrap();
342            w.writeln("decodeErr: { buf in").unwrap();
343            {
344                let _indent = w.indent();
345                for line in decode_err.lines() {
346                    w.writeln(line).unwrap();
347                }
348                w.writeln("return userError").unwrap();
349            }
350            w.writeln("})").unwrap();
351        }
352    } else if ret_type == "Void" {
353        // Case 2: Infallible Void
354        cw_writeln!(w, "try decodeInfallibleResponse({response_var}) {{ _ in }}").unwrap();
355    } else {
356        // Case 3: Infallible non-Void
357        let decode_stmt = generate_decode_stmt_with_buf(method.return_shape, "result", "", "buf");
358        cw_writeln!(
359            w,
360            "return try decodeInfallibleResponse({response_var}) {{ buf in"
361        )
362        .unwrap();
363        {
364            let _indent = w.indent();
365            for line in decode_stmt.lines() {
366                w.writeln(line).unwrap();
367            }
368            w.writeln("return result").unwrap();
369        }
370        w.writeln("}").unwrap();
371    }
372}