Skip to main content

vox_codegen/targets/swift/
server.rs

1//! Swift server/handler generation.
2//!
3//! Generates handler protocol and dispatcher for routing incoming calls.
4
5use facet_core::Shape;
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
8
9use super::decode::{generate_decode_stmt_with_cursor, generate_inline_decode};
10use super::encode::generate_encode_closure;
11use super::types::{format_doc, is_channel, swift_type_server_arg, swift_type_server_return};
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14use crate::render::hex_u64;
15
16fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
17    match (method.retry.persist, method.retry.idem) {
18        (false, false) => ".volatile",
19        (false, true) => ".idem",
20        (true, false) => ".persist",
21        (true, true) => ".persistIdem",
22    }
23}
24
25fn dispatch_helper_name(method_name: &str) -> String {
26    format!("dispatch_{method_name}")
27}
28
29/// Generate complete server code (handler protocol + dispatchers).
30pub fn generate_server(service: &ServiceDescriptor) -> String {
31    let mut out = String::new();
32    out.push_str(&generate_handler_protocol(service));
33    // Emit only the channel-capable dispatcher.
34    out.push_str(&generate_channeling_dispatcher(service));
35    out
36}
37
38/// Generate handler protocol (for handling incoming calls).
39fn generate_handler_protocol(service: &ServiceDescriptor) -> String {
40    let mut out = String::new();
41    let service_name = service.service_name.to_upper_camel_case();
42
43    if let Some(doc) = &service.doc {
44        out.push_str(&format_doc(doc, ""));
45    }
46    out.push_str(&format!("public protocol {service_name}Handler {{\n"));
47
48    for method in service.methods {
49        let method_name = method.method_name.to_lower_camel_case();
50
51        if let Some(doc) = &method.doc {
52            out.push_str(&format_doc(doc, "    "));
53        }
54
55        // Server perspective
56        let args: Vec<String> = method
57            .args
58            .iter()
59            .map(|a| {
60                format!(
61                    "{}: {}",
62                    a.name.to_lower_camel_case(),
63                    swift_type_server_arg(a.shape)
64                )
65            })
66            .collect();
67
68        let ret_type = swift_type_server_return(method.return_shape);
69
70        if ret_type == "Void" {
71            out.push_str(&format!(
72                "    func {method_name}({}) async throws\n",
73                args.join(", ")
74            ));
75        } else {
76            out.push_str(&format!(
77                "    func {method_name}({}) async throws -> {ret_type}\n",
78                args.join(", ")
79            ));
80        }
81    }
82
83    out.push_str("}\n\n");
84    out
85}
86
87/// Generate channeling dispatcher for handling incoming calls with channel support.
88fn generate_channeling_dispatcher(service: &ServiceDescriptor) -> String {
89    let mut out = String::new();
90    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
91    let service_name = service.service_name.to_upper_camel_case();
92
93    let service_name_lower = service.service_name.to_lower_camel_case();
94
95    cw_writeln!(
96        w,
97        "public final class {service_name}ChannelingDispatcher {{"
98    )
99    .unwrap();
100    {
101        let _indent = w.indent();
102        cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
103        w.writeln("private let registry: IncomingChannelRegistry")
104            .unwrap();
105        w.writeln("private let taskSender: TaskSender").unwrap();
106        cw_writeln!(w, "private let schemaRegistry: [UInt64: Schema]").unwrap();
107        cw_writeln!(w, "private let methodSchemas: [UInt64: MethodSchemaInfo]").unwrap();
108        w.blank_line().unwrap();
109
110        cw_writeln!(
111            w,
112            "public init(handler: {service_name}Handler, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender, schemaSendTracker _: SchemaSendTracker, schemaRegistry: [UInt64: Schema] = {service_name_lower}_schema_registry, methodSchemas: [UInt64: MethodSchemaInfo] = {service_name_lower}_method_schemas) {{"
113        )
114        .unwrap();
115        {
116            let _indent = w.indent();
117            w.writeln("self.handler = handler").unwrap();
118            w.writeln("self.registry = registry").unwrap();
119            w.writeln("self.taskSender = taskSender").unwrap();
120            w.writeln("self.schemaRegistry = schemaRegistry").unwrap();
121            w.writeln("self.methodSchemas = methodSchemas").unwrap();
122        }
123        w.writeln("}").unwrap();
124        w.blank_line().unwrap();
125
126        // Main dispatch method
127        w.writeln(
128            "public func dispatch(methodId: UInt64, requestId: UInt64, payload: Data) async {",
129        )
130        .unwrap();
131        {
132            let _indent = w.indent();
133            w.writeln("switch methodId {").unwrap();
134            for method in service.methods {
135                let method_name = method.method_name.to_lower_camel_case();
136                let method_id = crate::method_id(method);
137                let dispatch_name = dispatch_helper_name(&method_name);
138                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
139                cw_writeln!(
140                    w,
141                    "    await {dispatch_name}(methodId: methodId, requestId: requestId, payload: payload)"
142                )
143                .unwrap();
144            }
145            w.writeln("default:").unwrap();
146            w.writeln(
147                "    taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
148            )
149            .unwrap();
150            w.writeln("}").unwrap();
151        }
152        w.writeln("}").unwrap();
153        w.blank_line().unwrap();
154
155        w.writeln("public static func retryPolicy(methodId: UInt64) -> RetryPolicy {")
156            .unwrap();
157        {
158            let _indent = w.indent();
159            w.writeln("switch methodId {").unwrap();
160            for method in service.methods {
161                let method_id = crate::method_id(method);
162                let retry_policy = swift_retry_policy_literal(method);
163                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
164                cw_writeln!(w, "    return {retry_policy}").unwrap();
165            }
166            w.writeln("default:").unwrap();
167            w.writeln("    return .volatile").unwrap();
168            w.writeln("}").unwrap();
169        }
170        w.writeln("}").unwrap();
171        w.blank_line().unwrap();
172
173        // Generate preregisterChannels method
174        generate_preregister_channels(&mut w, service);
175        w.blank_line().unwrap();
176
177        // Individual dispatch methods
178        for method in service.methods {
179            generate_channeling_dispatch_method(&mut w, method);
180            w.blank_line().unwrap();
181        }
182    }
183    w.writeln("}").unwrap();
184    w.blank_line().unwrap();
185
186    out
187}
188
189/// Generate preregisterChannels method.
190fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDescriptor) {
191    w.writeln("/// Pre-register Rx channel IDs from request payloads.")
192        .unwrap();
193    w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
194        .unwrap();
195    w.writeln("/// race conditions where Data arrives before channels are registered.")
196        .unwrap();
197    w.writeln(
198        "public static func preregisterChannels(methodId: UInt64, payload: Data, registry: ChannelRegistry) async {",
199    )
200        .unwrap();
201    {
202        let _indent = w.indent();
203        w.writeln("switch methodId {").unwrap();
204
205        for method in service.methods {
206            let method_id = crate::method_id(method);
207            let has_channel_args = method.args.iter().any(|a| is_rx(a.shape) || is_tx(a.shape));
208
209            if has_channel_args {
210                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
211                w.writeln("    do {").unwrap();
212                {
213                    let _indent = w.indent();
214                    if !method.args.is_empty() {
215                        w.writeln("var preregisterCursor = 0").unwrap();
216                    }
217                    for arg in method.args {
218                        let arg_name = arg.name.to_lower_camel_case();
219                        if is_rx(arg.shape) {
220                            cw_writeln!(
221                                w,
222                                "let {arg_name}ChannelId = try decodeVarint(from: payload, offset: &preregisterCursor)"
223                            )
224                            .unwrap();
225                            cw_writeln!(w, "await registry.markKnown({arg_name}ChannelId)")
226                                .unwrap();
227                        } else if is_tx(arg.shape) {
228                            w.writeln(
229                                "_ = try decodeVarint(from: payload, offset: &preregisterCursor)",
230                            )
231                            .unwrap();
232                        } else {
233                            let discard_name = format!("_discard_{arg_name}");
234                            let decode_stmt = generate_decode_stmt_with_cursor(
235                                arg.shape,
236                                &discard_name,
237                                "",
238                                "preregisterCursor",
239                            );
240                            for line in decode_stmt.lines() {
241                                w.writeln(line).unwrap();
242                            }
243                        }
244                    }
245                }
246                w.writeln("    } catch {").unwrap();
247                w.writeln("        return").unwrap();
248                w.writeln("    }").unwrap();
249            }
250        }
251
252        w.writeln("default:").unwrap();
253        w.writeln("    break").unwrap();
254        w.writeln("}").unwrap();
255    }
256    w.writeln("}").unwrap();
257}
258
259/// Generate a single channeling dispatch method.
260fn generate_channeling_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDescriptor) {
261    let method_name = method.method_name.to_lower_camel_case();
262    let dispatch_name = dispatch_helper_name(&method_name);
263    let has_channeling = method.args.iter().any(|a| is_channel(a.shape));
264    let handler_error_payload = if method.retry.persist {
265        "encodeIndeterminateError()"
266    } else {
267        "encodeInvalidPayloadError()"
268    };
269
270    cw_writeln!(
271        w,
272        "private func {dispatch_name}(methodId: UInt64, requestId: UInt64, payload: Data) async {{"
273    )
274    .unwrap();
275    {
276        let _indent = w.indent();
277        // Build response schema payload for this method.
278        w.writeln("guard let methodInfo = methodSchemas[methodId] else {")
279            .unwrap();
280        w.writeln(
281            "    taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
282        )
283        .unwrap();
284        w.writeln("    return").unwrap();
285        w.writeln("}").unwrap();
286        w.writeln(
287            "let responseSchemaPayload = methodInfo.buildPayload(direction: .response, registry: schemaRegistry)",
288        )
289            .unwrap();
290        w.writeln("do {").unwrap();
291        {
292            let _indent = w.indent();
293            let cursor_var = if !method.args.is_empty() {
294                let name = unique_decode_cursor_name(method.args);
295                cw_writeln!(w, "var {name} = 0").unwrap();
296                Some(name)
297            } else {
298                None
299            };
300
301            for arg in method.args {
302                let arg_name = arg.name.to_lower_camel_case();
303                generate_channeling_decode_arg(w, &arg_name, arg.shape, cursor_var.as_deref());
304            }
305            let arg_names: Vec<String> = method
306                .args
307                .iter()
308                .map(|a| {
309                    let name = a.name.to_lower_camel_case();
310                    format!("{name}: {name}")
311                })
312                .collect();
313
314            let ret_type = swift_type_server_return(method.return_shape);
315
316            w.writeln("do {").unwrap();
317            {
318                let _indent = w.indent();
319                if has_channeling {
320                    if ret_type == "Void" {
321                        cw_writeln!(
322                            w,
323                            "try await handler.{method_name}({})",
324                            arg_names.join(", ")
325                        )
326                        .unwrap();
327                    } else {
328                        cw_writeln!(
329                            w,
330                            "let result = try await handler.{method_name}({})",
331                            arg_names.join(", ")
332                        )
333                        .unwrap();
334                    }
335
336                    for arg in method.args {
337                        if is_tx(arg.shape) {
338                            let arg_name = arg.name.to_lower_camel_case();
339                            cw_writeln!(w, "{arg_name}.close()").unwrap();
340                        }
341                    }
342
343                    if ret_type == "Void" {
344                        w.writeln(
345                            "taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] }), methodId: methodId, schemaPayload: responseSchemaPayload))",
346                        )
347                        .unwrap();
348                    } else {
349                        let encode_closure = generate_encode_closure(method.return_shape);
350                        cw_writeln!(
351                            w,
352                            "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure}), methodId: methodId, schemaPayload: responseSchemaPayload))"
353                        )
354                        .unwrap();
355                    }
356                } else if ret_type == "Void" {
357                    cw_writeln!(
358                        w,
359                        "try await handler.{method_name}({})",
360                        arg_names.join(", ")
361                    )
362                    .unwrap();
363                    w.writeln(
364                        "taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] }), methodId: methodId, schemaPayload: responseSchemaPayload))",
365                    )
366                    .unwrap();
367                } else {
368                    cw_writeln!(
369                        w,
370                        "let result = try await handler.{method_name}({})",
371                        arg_names.join(", ")
372                    )
373                    .unwrap();
374                    if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
375                        let ok_encode = generate_encode_closure(ok);
376                        let err_encode = generate_encode_closure(err);
377                        cw_writeln!(
378                            w,
379                            "taskSender(.response(requestId: requestId, payload: {{ switch result {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1), UInt8(0)] + {err_encode}(e) }} }}(), methodId: methodId, schemaPayload: responseSchemaPayload))"
380                        )
381                        .unwrap();
382                    } else {
383                        let encode_closure = generate_encode_closure(method.return_shape);
384                        cw_writeln!(
385                            w,
386                            "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure}), methodId: methodId, schemaPayload: responseSchemaPayload))"
387                        )
388                        .unwrap();
389                    }
390                }
391            }
392            w.writeln("} catch {").unwrap();
393            {
394                let _indent = w.indent();
395                cw_writeln!(
396                    w,
397                    "taskSender(.response(requestId: requestId, payload: {handler_error_payload}, methodId: methodId, schemaPayload: responseSchemaPayload))"
398                )
399                .unwrap();
400            }
401            w.writeln("}").unwrap();
402        }
403        w.writeln("} catch {").unwrap();
404        {
405            let _indent = w.indent();
406            w.writeln(
407                "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError(), methodId: methodId, schemaPayload: responseSchemaPayload))",
408            )
409            .unwrap();
410        }
411        w.writeln("}").unwrap();
412    }
413    w.writeln("}").unwrap();
414}
415
416/// Generate code to decode a single argument for channeling dispatch.
417fn generate_channeling_decode_arg(
418    w: &mut CodeWriter<&mut String>,
419    name: &str,
420    shape: &'static Shape,
421    cursor_var: Option<&str>,
422) {
423    match classify_shape(shape) {
424        ShapeKind::Rx { inner } => {
425            // Schema Rx = client passes Rx to method, sends via paired Tx
426            // Server needs to receive → create server Rx
427            let inline_decode = generate_inline_decode(inner, "Data(bytes)", "off");
428            let cursor_var = cursor_var.expect("payload cursor required for channeling args");
429            cw_writeln!(
430                w,
431                "let {name}ChannelId = try decodeVarint(from: payload, offset: &{cursor_var})"
432            )
433            .unwrap();
434            cw_writeln!(
435                w,
436                "let {name}Receiver = await registry.register({name}ChannelId, initialCredit: 16, onConsumed: {{ [taskSender = self.taskSender] additional in taskSender(.grantCredit(channelId: {name}ChannelId, bytes: additional)) }})"
437            )
438            .unwrap();
439            cw_writeln!(
440                w,
441                "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {{ bytes in"
442            )
443            .unwrap();
444            cw_writeln!(w, "    var off = 0").unwrap();
445            cw_writeln!(w, "    return try {inline_decode}").unwrap();
446            w.writeln("})").unwrap();
447        }
448        ShapeKind::Tx { inner } => {
449            // Schema Tx = client passes Tx to method, receives via paired Rx
450            // Server needs to send → create server Tx
451            let encode_closure = generate_encode_closure(inner);
452            let cursor_var = cursor_var.expect("payload cursor required for channeling args");
453            cw_writeln!(
454                w,
455                "let {name}ChannelId = try decodeVarint(from: payload, offset: &{cursor_var})"
456            )
457            .unwrap();
458            cw_writeln!(
459                w,
460                "let {name} = await createServerTx(channelId: {name}ChannelId, taskSender: taskSender, registry: registry, initialCredit: 16, serialize: ({encode_closure}))"
461            )
462            .unwrap();
463        }
464        _ => {
465            // Non-channeling argument - use standard decode
466            let cursor_var = cursor_var.expect("payload cursor required for non-channel args");
467            let decode_stmt = generate_decode_stmt_with_cursor(shape, name, "", cursor_var);
468            for line in decode_stmt.lines() {
469                w.writeln(line).unwrap();
470            }
471        }
472    }
473}
474
475fn unique_decode_cursor_name(args: &[vox_types::ArgDescriptor]) -> String {
476    let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
477    let mut candidate = String::from("cursor");
478    while arg_names.iter().any(|name| name == &candidate) {
479        candidate.push('_');
480    }
481    candidate
482}