Skip to main content

vox_codegen/targets/swift/
server.rs

1//! Swift server/handler generation.
2//!
3//! Generates handler protocol and dispatcher for routing incoming calls.
4
5use facet_core::Shape;
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
8
9use super::decode::generate_decode_stmt_with_cursor;
10use super::encode::{generate_encode_closure, generate_encode_stmt};
11use super::types::{format_doc, is_channel, swift_type_server_arg, swift_type_server_return};
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14use crate::render::hex_u64;
15
16fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
17    match (method.retry.persist, method.retry.idem) {
18        (false, false) => ".volatile",
19        (false, true) => ".idem",
20        (true, false) => ".persist",
21        (true, true) => ".persistIdem",
22    }
23}
24
25fn dispatch_helper_name(method_name: &str) -> String {
26    format!("dispatch_{method_name}")
27}
28
29/// Generate complete server code (handler protocol + dispatchers).
30pub fn generate_server(service: &ServiceDescriptor) -> String {
31    let mut out = String::new();
32    out.push_str(&generate_handler_protocol(service));
33    // Emit only the channel-capable dispatcher.
34    out.push_str(&generate_dispatcher(service));
35    out
36}
37
38/// Generate handler protocol (for handling incoming calls).
39fn generate_handler_protocol(service: &ServiceDescriptor) -> String {
40    let mut out = String::new();
41    let service_name = service.service_name.to_upper_camel_case();
42
43    if let Some(doc) = &service.doc {
44        out.push_str(&format_doc(doc, ""));
45    }
46    out.push_str(&format!(
47        "public protocol {service_name}Handler: Sendable {{\n"
48    ));
49
50    for method in service.methods {
51        let method_name = method.method_name.to_lower_camel_case();
52
53        if let Some(doc) = &method.doc {
54            out.push_str(&format_doc(doc, "    "));
55        }
56
57        // Server perspective
58        let args: Vec<String> = method
59            .args
60            .iter()
61            .map(|a| {
62                format!(
63                    "{}: {}",
64                    a.name.to_lower_camel_case(),
65                    swift_type_server_arg(a.shape)
66                )
67            })
68            .collect();
69
70        let ret_type = swift_type_server_return(method.return_shape);
71
72        if ret_type == "Void" {
73            out.push_str(&format!(
74                "    func {method_name}({}) async throws\n",
75                args.join(", ")
76            ));
77        } else {
78            out.push_str(&format!(
79                "    func {method_name}({}) async throws -> {ret_type}\n",
80                args.join(", ")
81            ));
82        }
83    }
84
85    out.push_str("}\n\n");
86    out
87}
88
89/// Generate dispatcher for handling incoming calls with channel support.
90fn generate_dispatcher(service: &ServiceDescriptor) -> String {
91    let mut out = String::new();
92    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
93    let service_name = service.service_name.to_upper_camel_case();
94
95    let service_name_lower = service.service_name.to_lower_camel_case();
96
97    cw_writeln!(
98        w,
99        "public final class {service_name}Dispatcher: ServiceDispatcher {{"
100    )
101    .unwrap();
102    {
103        let _indent = w.indent();
104        cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
105        cw_writeln!(w, "private let schemaRegistry: [UInt64: Schema]").unwrap();
106        cw_writeln!(w, "private let methodSchemas: [UInt64: MethodSchemaInfo]").unwrap();
107        w.blank_line().unwrap();
108
109        cw_writeln!(
110            w,
111            "public init(handler: {service_name}Handler, schemaRegistry: [UInt64: Schema] = {service_name_lower}_schema_registry, methodSchemas: [UInt64: MethodSchemaInfo] = {service_name_lower}_method_schemas) {{"
112        )
113        .unwrap();
114        {
115            let _indent = w.indent();
116            w.writeln("self.handler = handler").unwrap();
117            w.writeln("self.schemaRegistry = schemaRegistry").unwrap();
118            w.writeln("self.methodSchemas = methodSchemas").unwrap();
119        }
120        w.writeln("}").unwrap();
121        w.blank_line().unwrap();
122
123        // Main dispatch method matching ServiceDispatcher protocol
124        w.writeln(
125            "public func dispatch(methodId: UInt64, payload: [UInt8], requestId: UInt64, registry: ChannelRegistry, schemaSendTracker _: SchemaSendTracker, taskTx: @escaping @Sendable (TaskMessage) -> Void) async {",
126        )
127        .unwrap();
128        {
129            let _indent = w.indent();
130            w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: payload.count)")
131                .unwrap();
132            w.writeln("buffer.writeBytes(payload)").unwrap();
133            w.writeln("let taskSender: TaskSender = taskTx").unwrap();
134            w.writeln("switch methodId {").unwrap();
135            for method in service.methods {
136                let method_name = method.method_name.to_lower_camel_case();
137                let method_id = crate::method_id(method);
138                let dispatch_name = dispatch_helper_name(&method_name);
139                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
140                cw_writeln!(
141                    w,
142                    "    await {dispatch_name}(methodId: methodId, requestId: requestId, buffer: &buffer, registry: registry, taskSender: taskSender)"
143                )
144                .unwrap();
145            }
146            w.writeln("default:").unwrap();
147            w.writeln(
148                "    taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
149            )
150            .unwrap();
151            w.writeln("}").unwrap();
152        }
153        w.writeln("}").unwrap();
154        w.blank_line().unwrap();
155
156        w.writeln("public func retryPolicy(methodId: UInt64) -> RetryPolicy {")
157            .unwrap();
158        {
159            let _indent = w.indent();
160            w.writeln("switch methodId {").unwrap();
161            for method in service.methods {
162                let method_id = crate::method_id(method);
163                let retry_policy = swift_retry_policy_literal(method);
164                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
165                cw_writeln!(w, "    return {retry_policy}").unwrap();
166            }
167            w.writeln("default:").unwrap();
168            w.writeln("    return .volatile").unwrap();
169            w.writeln("}").unwrap();
170        }
171        w.writeln("}").unwrap();
172        w.blank_line().unwrap();
173
174        // Generate preregisterChannels method
175        generate_preregister_channels(&mut w, service);
176        w.blank_line().unwrap();
177
178        // Individual dispatch methods
179        for method in service.methods {
180            generate_channeling_dispatch_method(&mut w, method);
181            w.blank_line().unwrap();
182        }
183    }
184    w.writeln("}").unwrap();
185    w.blank_line().unwrap();
186
187    out
188}
189
190/// Generate preregisterChannels method.
191fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDescriptor) {
192    w.writeln("/// Pre-register Rx channel IDs from request payloads.")
193        .unwrap();
194    w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
195        .unwrap();
196    w.writeln("/// race conditions where Data arrives before channels are registered.")
197        .unwrap();
198    w.writeln(
199        "public func preregister(methodId: UInt64, payload: [UInt8], registry: ChannelRegistry) async {",
200    )
201        .unwrap();
202    {
203        let _indent = w.indent();
204        w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: payload.count)")
205            .unwrap();
206        w.writeln("buffer.writeBytes(payload)").unwrap();
207        w.writeln("switch methodId {").unwrap();
208
209        for method in service.methods {
210            let method_id = crate::method_id(method);
211            let has_channel_args = method.args.iter().any(|a| is_rx(a.shape) || is_tx(a.shape));
212
213            if has_channel_args {
214                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
215                w.writeln("    do {").unwrap();
216                {
217                    let _indent = w.indent();
218                    for arg in method.args {
219                        let arg_name = arg.name.to_lower_camel_case();
220                        if is_rx(arg.shape) {
221                            cw_writeln!(
222                                w,
223                                "let {arg_name}ChannelId = try decodeVarint(from: &buffer)"
224                            )
225                            .unwrap();
226                            cw_writeln!(w, "await registry.markKnown({arg_name}ChannelId)")
227                                .unwrap();
228                        } else if is_tx(arg.shape) {
229                            w.writeln("_ = try decodeVarint(from: &buffer)").unwrap();
230                        } else {
231                            let discard_name = format!("_discard_{arg_name}");
232                            let decode_stmt = generate_decode_stmt_with_cursor(
233                                arg.shape,
234                                &discard_name,
235                                "",
236                                "buffer",
237                            );
238                            for line in decode_stmt.lines() {
239                                w.writeln(line).unwrap();
240                            }
241                        }
242                    }
243                }
244                w.writeln("    } catch {").unwrap();
245                w.writeln("        return").unwrap();
246                w.writeln("    }").unwrap();
247            }
248        }
249
250        w.writeln("default:").unwrap();
251        w.writeln("    break").unwrap();
252        w.writeln("}").unwrap();
253    }
254    w.writeln("}").unwrap();
255}
256
257/// Generate a single dispatch method.
258fn generate_channeling_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDescriptor) {
259    let method_name = method.method_name.to_lower_camel_case();
260    let dispatch_name = dispatch_helper_name(&method_name);
261    let has_channeling = method.args.iter().any(|a| is_channel(a.shape));
262    cw_writeln!(
263        w,
264        "private func {dispatch_name}(methodId: UInt64, requestId: UInt64, buffer: inout ByteBuffer, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender) async {{"
265    )
266    .unwrap();
267    {
268        let _indent = w.indent();
269        // Build response schema payload for this method.
270        w.writeln("guard let methodInfo = methodSchemas[methodId] else {")
271            .unwrap();
272        w.writeln(
273            "    taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
274        )
275        .unwrap();
276        w.writeln("    return").unwrap();
277        w.writeln("}").unwrap();
278        w.writeln(
279            "let responseSchemaPayload = methodInfo.buildPayload(direction: .response, registry: schemaRegistry)",
280        )
281            .unwrap();
282        w.writeln("do {").unwrap();
283        {
284            let _indent = w.indent();
285            for arg in method.args {
286                let arg_name = arg.name.to_lower_camel_case();
287                generate_channeling_decode_arg(w, &arg_name, arg.shape);
288            }
289            let arg_names: Vec<String> = method
290                .args
291                .iter()
292                .map(|a| {
293                    let name = a.name.to_lower_camel_case();
294                    format!("{name}: {name}")
295                })
296                .collect();
297
298            let ret_type = swift_type_server_return(method.return_shape);
299
300            w.writeln("do {").unwrap();
301            {
302                let _indent = w.indent();
303                if has_channeling {
304                    if ret_type == "Void" {
305                        cw_writeln!(
306                            w,
307                            "try await handler.{method_name}({})",
308                            arg_names.join(", ")
309                        )
310                        .unwrap();
311                    } else {
312                        cw_writeln!(
313                            w,
314                            "let result = try await handler.{method_name}({})",
315                            arg_names.join(", ")
316                        )
317                        .unwrap();
318                    }
319
320                    if ret_type == "Void" {
321                        w.writeln(
322                            "taskSender(.response(requestId: requestId, payload: encodeResultOkUnit(), methodId: methodId, schemaPayload: responseSchemaPayload))",
323                        )
324                        .unwrap();
325                    } else {
326                        let encode_closure = generate_encode_closure(method.return_shape);
327                        cw_writeln!(
328                            w,
329                            "let _encoded = encodeResultOk(result, encoder: {encode_closure})"
330                        )
331                        .unwrap();
332                        w.writeln(
333                            "taskSender(.response(requestId: requestId, payload: _encoded, methodId: methodId, schemaPayload: responseSchemaPayload))",
334                        )
335                        .unwrap();
336                    }
337                } else if ret_type == "Void" {
338                    cw_writeln!(
339                        w,
340                        "try await handler.{method_name}({})",
341                        arg_names.join(", ")
342                    )
343                    .unwrap();
344                    w.writeln(
345                        "taskSender(.response(requestId: requestId, payload: encodeResultOkUnit(), methodId: methodId, schemaPayload: responseSchemaPayload))",
346                    )
347                    .unwrap();
348                } else {
349                    cw_writeln!(
350                        w,
351                        "let result = try await handler.{method_name}({})",
352                        arg_names.join(", ")
353                    )
354                    .unwrap();
355                    if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
356                        let ok_stmt = generate_encode_stmt(ok, "v");
357                        let err_stmt = generate_encode_stmt(err, "e");
358                        w.writeln("let _encoded: [UInt8] = {").unwrap();
359                        {
360                            let _indent = w.indent();
361                            w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: 64)")
362                                .unwrap();
363                            w.writeln("switch result {").unwrap();
364                            w.writeln("case .success(let v):").unwrap();
365                            {
366                                let _indent = w.indent();
367                                w.writeln("encodeVarint(UInt64(0), into: &buffer)").unwrap();
368                                for line in ok_stmt.lines() {
369                                    w.writeln(line).unwrap();
370                                }
371                            }
372                            w.writeln("case .failure(let e):").unwrap();
373                            {
374                                let _indent = w.indent();
375                                w.writeln("encodeVarint(UInt64(1), into: &buffer)").unwrap();
376                                w.writeln("encodeU8(0, into: &buffer)").unwrap();
377                                for line in err_stmt.lines() {
378                                    w.writeln(line).unwrap();
379                                }
380                            }
381                            w.writeln("}").unwrap();
382                            w.writeln(
383                                "return buffer.readBytes(length: buffer.readableBytes) ?? []",
384                            )
385                            .unwrap();
386                        }
387                        w.writeln("}()").unwrap();
388                        w.writeln(
389                            "taskSender(.response(requestId: requestId, payload: _encoded, methodId: methodId, schemaPayload: responseSchemaPayload))",
390                        )
391                        .unwrap();
392                    } else {
393                        let encode_closure = generate_encode_closure(method.return_shape);
394                        cw_writeln!(
395                            w,
396                            "let _encoded = encodeResultOk(result, encoder: {encode_closure})"
397                        )
398                        .unwrap();
399                        w.writeln(
400                            "taskSender(.response(requestId: requestId, payload: _encoded, methodId: methodId, schemaPayload: responseSchemaPayload))",
401                        )
402                        .unwrap();
403                    }
404                }
405            }
406            w.writeln("} catch {").unwrap();
407            {
408                let _indent = w.indent();
409                if method.retry.persist {
410                    w.writeln(
411                        "taskSender(.response(requestId: requestId, payload: encodeIndeterminateError(), methodId: methodId, schemaPayload: responseSchemaPayload))",
412                    )
413                    .unwrap();
414                } else {
415                    w.writeln(
416                        "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError(reason: String(describing: error)), methodId: methodId, schemaPayload: responseSchemaPayload))",
417                    )
418                    .unwrap();
419                }
420            }
421            w.writeln("}").unwrap();
422        }
423        w.writeln("} catch {").unwrap();
424        {
425            let _indent = w.indent();
426            w.writeln(
427                "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError(reason: String(describing: error)), methodId: methodId, schemaPayload: responseSchemaPayload))",
428            )
429            .unwrap();
430        }
431        w.writeln("}").unwrap();
432    }
433    w.writeln("}").unwrap();
434}
435
436/// Generate code to decode a single argument for dispatch.
437/// All decodes read from `buffer: inout ByteBuffer` in scope.
438fn generate_channeling_decode_arg(
439    w: &mut CodeWriter<&mut String>,
440    name: &str,
441    shape: &'static Shape,
442) {
443    match classify_shape(shape) {
444        ShapeKind::Rx { inner } => {
445            let decode_closure = generate_decode_closure_for_channel(inner);
446            cw_writeln!(w, "let {name}ChannelId = try decodeVarint(from: &buffer)").unwrap();
447            cw_writeln!(
448                w,
449                "let {name}Receiver = await registry.register({name}ChannelId, initialCredit: 16, onConsumed: {{ [taskSender] additional in taskSender(.grantCredit(channelId: {name}ChannelId, bytes: additional)) }})"
450            )
451            .unwrap();
452            cw_writeln!(
453                w,
454                "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {decode_closure})"
455            )
456            .unwrap();
457        }
458        ShapeKind::Tx { inner } => {
459            let encode_closure = generate_encode_closure(inner);
460            cw_writeln!(w, "let {name}ChannelId = try decodeVarint(from: &buffer)").unwrap();
461            cw_writeln!(
462                w,
463                "let {name} = await createServerTx(channelId: {name}ChannelId, taskSender: taskSender, registry: registry, initialCredit: 16, serialize: {encode_closure})"
464            )
465            .unwrap();
466        }
467        _ => {
468            let decode_stmt = generate_decode_stmt_with_cursor(shape, name, "", "buffer");
469            for line in decode_stmt.lines() {
470                w.writeln(line).unwrap();
471            }
472        }
473    }
474}
475
476/// Generate a deserialize closure for use with createServerRx.
477/// The closure takes `inout ByteBuffer` and returns the decoded value.
478fn generate_decode_closure_for_channel(inner: &'static Shape) -> String {
479    use super::decode::generate_decode_closure;
480    generate_decode_closure(inner)
481}