Skip to main content

vox_codegen/targets/swift/
server.rs

1//! Swift server/handler generation.
2//!
3//! Generates handler protocol and dispatcher for routing incoming calls.
4
5use facet_core::Shape;
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
8
9use super::decode::generate_decode_stmt_with_cursor;
10use super::encode::{generate_encode_closure, generate_encode_stmt};
11use super::types::{format_doc, is_channel, swift_type_server_arg, swift_type_server_return};
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14use crate::render::hex_u64;
15
16fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
17    match (method.retry.persist, method.retry.idem) {
18        (false, false) => ".volatile",
19        (false, true) => ".idem",
20        (true, false) => ".persist",
21        (true, true) => ".persistIdem",
22    }
23}
24
25fn dispatch_helper_name(method_name: &str) -> String {
26    format!("dispatch_{method_name}")
27}
28
29/// Generate complete server code (handler protocol + dispatchers).
30pub fn generate_server(service: &ServiceDescriptor) -> String {
31    let mut out = String::new();
32    out.push_str(&generate_handler_protocol(service));
33    // Emit only the channel-capable dispatcher.
34    out.push_str(&generate_dispatcher(service));
35    out
36}
37
38/// Generate handler protocol (for handling incoming calls).
39fn generate_handler_protocol(service: &ServiceDescriptor) -> String {
40    let mut out = String::new();
41    let service_name = service.service_name.to_upper_camel_case();
42
43    if let Some(doc) = &service.doc {
44        out.push_str(&format_doc(doc, ""));
45    }
46    out.push_str(&format!(
47        "public protocol {service_name}Handler: Sendable {{\n"
48    ));
49
50    for method in service.methods {
51        let method_name = method.method_name.to_lower_camel_case();
52
53        if let Some(doc) = &method.doc {
54            out.push_str(&format_doc(doc, "    "));
55        }
56
57        // Server perspective
58        let args: Vec<String> = method
59            .args
60            .iter()
61            .map(|a| {
62                format!(
63                    "{}: {}",
64                    a.name.to_lower_camel_case(),
65                    swift_type_server_arg(a.shape)
66                )
67            })
68            .collect();
69
70        let ret_type = swift_type_server_return(method.return_shape);
71
72        if ret_type == "Void" {
73            out.push_str(&format!(
74                "    func {method_name}({}) async throws\n",
75                args.join(", ")
76            ));
77        } else {
78            out.push_str(&format!(
79                "    func {method_name}({}) async throws -> {ret_type}\n",
80                args.join(", ")
81            ));
82        }
83    }
84
85    out.push_str("}\n\n");
86    out
87}
88
89/// Generate dispatcher for handling incoming calls with channel support.
90fn generate_dispatcher(service: &ServiceDescriptor) -> String {
91    let mut out = String::new();
92    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
93    let service_name = service.service_name.to_upper_camel_case();
94
95    let service_name_lower = service.service_name.to_lower_camel_case();
96
97    cw_writeln!(
98        w,
99        "public final class {service_name}Dispatcher: ServiceDispatcher {{"
100    )
101    .unwrap();
102    {
103        let _indent = w.indent();
104        cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
105        cw_writeln!(w, "private let schemaRegistry: [UInt64: Schema]").unwrap();
106        cw_writeln!(w, "private let methodSchemas: [UInt64: MethodSchemaInfo]").unwrap();
107        w.blank_line().unwrap();
108
109        cw_writeln!(
110            w,
111            "public init(handler: {service_name}Handler, schemaRegistry: [UInt64: Schema] = {service_name_lower}_schema_registry, methodSchemas: [UInt64: MethodSchemaInfo] = {service_name_lower}_method_schemas) {{"
112        )
113        .unwrap();
114        {
115            let _indent = w.indent();
116            w.writeln("self.handler = handler").unwrap();
117            w.writeln("self.schemaRegistry = schemaRegistry").unwrap();
118            w.writeln("self.methodSchemas = methodSchemas").unwrap();
119        }
120        w.writeln("}").unwrap();
121        w.blank_line().unwrap();
122
123        // Main dispatch method matching ServiceDispatcher protocol
124        w.writeln(
125            "public func dispatch(methodId: UInt64, payload: [UInt8], requestId: UInt64, registry: ChannelRegistry, schemaSendTracker _: SchemaSendTracker, taskTx: @escaping @Sendable (TaskMessage) -> Void) async {",
126        )
127        .unwrap();
128        {
129            let _indent = w.indent();
130            w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: payload.count)")
131                .unwrap();
132            w.writeln("buffer.writeBytes(payload)").unwrap();
133            w.writeln("let taskSender: TaskSender = taskTx").unwrap();
134            w.writeln("switch methodId {").unwrap();
135            for method in service.methods {
136                let method_name = method.method_name.to_lower_camel_case();
137                let method_id = crate::method_id(method);
138                let dispatch_name = dispatch_helper_name(&method_name);
139                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
140                cw_writeln!(
141                    w,
142                    "    await {dispatch_name}(methodId: methodId, requestId: requestId, buffer: &buffer, registry: registry, taskSender: taskSender)"
143                )
144                .unwrap();
145            }
146            w.writeln("default:").unwrap();
147            w.writeln(
148                "    taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
149            )
150            .unwrap();
151            w.writeln("}").unwrap();
152        }
153        w.writeln("}").unwrap();
154        w.blank_line().unwrap();
155
156        w.writeln("public func retryPolicy(methodId: UInt64) -> RetryPolicy {")
157            .unwrap();
158        {
159            let _indent = w.indent();
160            w.writeln("switch methodId {").unwrap();
161            for method in service.methods {
162                let method_id = crate::method_id(method);
163                let retry_policy = swift_retry_policy_literal(method);
164                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
165                cw_writeln!(w, "    return {retry_policy}").unwrap();
166            }
167            w.writeln("default:").unwrap();
168            w.writeln("    return .volatile").unwrap();
169            w.writeln("}").unwrap();
170        }
171        w.writeln("}").unwrap();
172        w.blank_line().unwrap();
173
174        // Generate preregisterChannels method
175        generate_preregister_channels(&mut w, service);
176        w.blank_line().unwrap();
177
178        // Individual dispatch methods
179        for method in service.methods {
180            generate_channeling_dispatch_method(&mut w, method);
181            w.blank_line().unwrap();
182        }
183    }
184    w.writeln("}").unwrap();
185    w.blank_line().unwrap();
186
187    out
188}
189
190/// Generate preregisterChannels method.
191fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDescriptor) {
192    w.writeln("/// Pre-register Rx channel IDs from request payloads.")
193        .unwrap();
194    w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
195        .unwrap();
196    w.writeln("/// race conditions where Data arrives before channels are registered.")
197        .unwrap();
198    w.writeln(
199        "public func preregister(methodId: UInt64, payload: [UInt8], registry: ChannelRegistry) async {",
200    )
201        .unwrap();
202    {
203        let _indent = w.indent();
204        w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: payload.count)")
205            .unwrap();
206        w.writeln("buffer.writeBytes(payload)").unwrap();
207        w.writeln("switch methodId {").unwrap();
208
209        for method in service.methods {
210            let method_id = crate::method_id(method);
211            let has_channel_args = method.args.iter().any(|a| is_rx(a.shape) || is_tx(a.shape));
212
213            if has_channel_args {
214                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
215                w.writeln("    do {").unwrap();
216                {
217                    let _indent = w.indent();
218                    for arg in method.args {
219                        let arg_name = arg.name.to_lower_camel_case();
220                        if is_rx(arg.shape) {
221                            cw_writeln!(
222                                w,
223                                "let {arg_name}ChannelId = try decodeVarint(from: &buffer)"
224                            )
225                            .unwrap();
226                            cw_writeln!(w, "await registry.markKnown({arg_name}ChannelId)")
227                                .unwrap();
228                        } else if is_tx(arg.shape) {
229                            w.writeln("_ = try decodeVarint(from: &buffer)").unwrap();
230                        } else {
231                            let discard_name = format!("_discard_{arg_name}");
232                            let decode_stmt = generate_decode_stmt_with_cursor(
233                                arg.shape,
234                                &discard_name,
235                                "",
236                                "buffer",
237                            );
238                            for line in decode_stmt.lines() {
239                                w.writeln(line).unwrap();
240                            }
241                        }
242                    }
243                }
244                w.writeln("    } catch {").unwrap();
245                w.writeln("        return").unwrap();
246                w.writeln("    }").unwrap();
247            }
248        }
249
250        w.writeln("default:").unwrap();
251        w.writeln("    break").unwrap();
252        w.writeln("}").unwrap();
253    }
254    w.writeln("}").unwrap();
255}
256
257/// Generate a single dispatch method.
258fn generate_channeling_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDescriptor) {
259    let method_name = method.method_name.to_lower_camel_case();
260    let dispatch_name = dispatch_helper_name(&method_name);
261    let has_channeling = method.args.iter().any(|a| is_channel(a.shape));
262    cw_writeln!(
263        w,
264        "private func {dispatch_name}(methodId: UInt64, requestId: UInt64, buffer: inout ByteBuffer, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender) async {{"
265    )
266    .unwrap();
267    {
268        let _indent = w.indent();
269        // Build response schema payload for this method.
270        w.writeln("guard let methodInfo = methodSchemas[methodId] else {")
271            .unwrap();
272        w.writeln(
273            "    taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
274        )
275        .unwrap();
276        w.writeln("    return").unwrap();
277        w.writeln("}").unwrap();
278        w.writeln(
279            "let responseSchemaPayload = methodInfo.buildPayload(direction: .response, registry: schemaRegistry)",
280        )
281            .unwrap();
282        w.writeln("do {").unwrap();
283        {
284            let _indent = w.indent();
285            for arg in method.args {
286                let arg_name = arg.name.to_lower_camel_case();
287                generate_channeling_decode_arg(w, &arg_name, arg.shape);
288            }
289            let arg_names: Vec<String> = method
290                .args
291                .iter()
292                .map(|a| {
293                    let name = a.name.to_lower_camel_case();
294                    format!("{name}: {name}")
295                })
296                .collect();
297
298            let ret_type = swift_type_server_return(method.return_shape);
299
300            w.writeln("do {").unwrap();
301            {
302                let _indent = w.indent();
303                if has_channeling {
304                    if ret_type == "Void" {
305                        cw_writeln!(
306                            w,
307                            "try await handler.{method_name}({})",
308                            arg_names.join(", ")
309                        )
310                        .unwrap();
311                    } else {
312                        cw_writeln!(
313                            w,
314                            "let result = try await handler.{method_name}({})",
315                            arg_names.join(", ")
316                        )
317                        .unwrap();
318                    }
319
320                    for arg in method.args {
321                        if is_tx(arg.shape) {
322                            let arg_name = arg.name.to_lower_camel_case();
323                            cw_writeln!(w, "{arg_name}.close()").unwrap();
324                        }
325                    }
326
327                    if ret_type == "Void" {
328                        w.writeln(
329                            "taskSender(.response(requestId: requestId, payload: encodeResultOkUnit(), methodId: methodId, schemaPayload: responseSchemaPayload))",
330                        )
331                        .unwrap();
332                    } else {
333                        let encode_closure = generate_encode_closure(method.return_shape);
334                        cw_writeln!(
335                            w,
336                            "let _encoded = encodeResultOk(result, encoder: {encode_closure})"
337                        )
338                        .unwrap();
339                        w.writeln(
340                            "taskSender(.response(requestId: requestId, payload: _encoded, methodId: methodId, schemaPayload: responseSchemaPayload))",
341                        )
342                        .unwrap();
343                    }
344                } else if ret_type == "Void" {
345                    cw_writeln!(
346                        w,
347                        "try await handler.{method_name}({})",
348                        arg_names.join(", ")
349                    )
350                    .unwrap();
351                    w.writeln(
352                        "taskSender(.response(requestId: requestId, payload: encodeResultOkUnit(), methodId: methodId, schemaPayload: responseSchemaPayload))",
353                    )
354                    .unwrap();
355                } else {
356                    cw_writeln!(
357                        w,
358                        "let result = try await handler.{method_name}({})",
359                        arg_names.join(", ")
360                    )
361                    .unwrap();
362                    if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
363                        let ok_stmt = generate_encode_stmt(ok, "v");
364                        let err_stmt = generate_encode_stmt(err, "e");
365                        w.writeln("let _encoded: [UInt8] = {").unwrap();
366                        {
367                            let _indent = w.indent();
368                            w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: 64)")
369                                .unwrap();
370                            w.writeln("switch result {").unwrap();
371                            w.writeln("case .success(let v):").unwrap();
372                            {
373                                let _indent = w.indent();
374                                w.writeln("encodeVarint(UInt64(0), into: &buffer)").unwrap();
375                                for line in ok_stmt.lines() {
376                                    w.writeln(line).unwrap();
377                                }
378                            }
379                            w.writeln("case .failure(let e):").unwrap();
380                            {
381                                let _indent = w.indent();
382                                w.writeln("encodeVarint(UInt64(1), into: &buffer)").unwrap();
383                                w.writeln("encodeU8(0, into: &buffer)").unwrap();
384                                for line in err_stmt.lines() {
385                                    w.writeln(line).unwrap();
386                                }
387                            }
388                            w.writeln("}").unwrap();
389                            w.writeln(
390                                "return buffer.readBytes(length: buffer.readableBytes) ?? []",
391                            )
392                            .unwrap();
393                        }
394                        w.writeln("}()").unwrap();
395                        w.writeln(
396                            "taskSender(.response(requestId: requestId, payload: _encoded, methodId: methodId, schemaPayload: responseSchemaPayload))",
397                        )
398                        .unwrap();
399                    } else {
400                        let encode_closure = generate_encode_closure(method.return_shape);
401                        cw_writeln!(
402                            w,
403                            "let _encoded = encodeResultOk(result, encoder: {encode_closure})"
404                        )
405                        .unwrap();
406                        w.writeln(
407                            "taskSender(.response(requestId: requestId, payload: _encoded, methodId: methodId, schemaPayload: responseSchemaPayload))",
408                        )
409                        .unwrap();
410                    }
411                }
412            }
413            w.writeln("} catch {").unwrap();
414            {
415                let _indent = w.indent();
416                if method.retry.persist {
417                    w.writeln(
418                        "taskSender(.response(requestId: requestId, payload: encodeIndeterminateError(), methodId: methodId, schemaPayload: responseSchemaPayload))",
419                    )
420                    .unwrap();
421                } else {
422                    w.writeln(
423                        "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError(reason: String(describing: error)), methodId: methodId, schemaPayload: responseSchemaPayload))",
424                    )
425                    .unwrap();
426                }
427            }
428            w.writeln("}").unwrap();
429        }
430        w.writeln("} catch {").unwrap();
431        {
432            let _indent = w.indent();
433            w.writeln(
434                "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError(reason: String(describing: error)), methodId: methodId, schemaPayload: responseSchemaPayload))",
435            )
436            .unwrap();
437        }
438        w.writeln("}").unwrap();
439    }
440    w.writeln("}").unwrap();
441}
442
443/// Generate code to decode a single argument for dispatch.
444/// All decodes read from `buffer: inout ByteBuffer` in scope.
445fn generate_channeling_decode_arg(
446    w: &mut CodeWriter<&mut String>,
447    name: &str,
448    shape: &'static Shape,
449) {
450    match classify_shape(shape) {
451        ShapeKind::Rx { inner } => {
452            let decode_closure = generate_decode_closure_for_channel(inner);
453            cw_writeln!(w, "let {name}ChannelId = try decodeVarint(from: &buffer)").unwrap();
454            cw_writeln!(
455                w,
456                "let {name}Receiver = await registry.register({name}ChannelId, initialCredit: 16, onConsumed: {{ [taskSender] additional in taskSender(.grantCredit(channelId: {name}ChannelId, bytes: additional)) }})"
457            )
458            .unwrap();
459            cw_writeln!(
460                w,
461                "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {decode_closure})"
462            )
463            .unwrap();
464        }
465        ShapeKind::Tx { inner } => {
466            let encode_closure = generate_encode_closure(inner);
467            cw_writeln!(w, "let {name}ChannelId = try decodeVarint(from: &buffer)").unwrap();
468            cw_writeln!(
469                w,
470                "let {name} = await createServerTx(channelId: {name}ChannelId, taskSender: taskSender, registry: registry, initialCredit: 16, serialize: {encode_closure})"
471            )
472            .unwrap();
473        }
474        _ => {
475            let decode_stmt = generate_decode_stmt_with_cursor(shape, name, "", "buffer");
476            for line in decode_stmt.lines() {
477                w.writeln(line).unwrap();
478            }
479        }
480    }
481}
482
483/// Generate a deserialize closure for use with createServerRx.
484/// The closure takes `inout ByteBuffer` and returns the decoded value.
485fn generate_decode_closure_for_channel(inner: &'static Shape) -> String {
486    use super::decode::generate_decode_closure;
487    generate_decode_closure(inner)
488}