vox_codegen/targets/swift/
server.rs1use facet_core::Shape;
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
8
9use super::decode::{generate_decode_stmt_with_cursor, generate_inline_decode};
10use super::encode::generate_encode_closure;
11use super::types::{format_doc, is_channel, swift_type_server_arg, swift_type_server_return};
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14use crate::render::hex_u64;
15
16fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
17 match (method.retry.persist, method.retry.idem) {
18 (false, false) => ".volatile",
19 (false, true) => ".idem",
20 (true, false) => ".persist",
21 (true, true) => ".persistIdem",
22 }
23}
24
25fn dispatch_helper_name(method_name: &str) -> String {
26 format!("dispatch_{method_name}")
27}
28
29pub fn generate_server(service: &ServiceDescriptor) -> String {
31 let mut out = String::new();
32 out.push_str(&generate_handler_protocol(service));
33 out.push_str(&generate_channeling_dispatcher(service));
35 out
36}
37
38fn generate_handler_protocol(service: &ServiceDescriptor) -> String {
40 let mut out = String::new();
41 let service_name = service.service_name.to_upper_camel_case();
42
43 if let Some(doc) = &service.doc {
44 out.push_str(&format_doc(doc, ""));
45 }
46 out.push_str(&format!("public protocol {service_name}Handler {{\n"));
47
48 for method in service.methods {
49 let method_name = method.method_name.to_lower_camel_case();
50
51 if let Some(doc) = &method.doc {
52 out.push_str(&format_doc(doc, " "));
53 }
54
55 let args: Vec<String> = method
57 .args
58 .iter()
59 .map(|a| {
60 format!(
61 "{}: {}",
62 a.name.to_lower_camel_case(),
63 swift_type_server_arg(a.shape)
64 )
65 })
66 .collect();
67
68 let ret_type = swift_type_server_return(method.return_shape);
69
70 if ret_type == "Void" {
71 out.push_str(&format!(
72 " func {method_name}({}) async throws\n",
73 args.join(", ")
74 ));
75 } else {
76 out.push_str(&format!(
77 " func {method_name}({}) async throws -> {ret_type}\n",
78 args.join(", ")
79 ));
80 }
81 }
82
83 out.push_str("}\n\n");
84 out
85}
86
87fn generate_channeling_dispatcher(service: &ServiceDescriptor) -> String {
89 let mut out = String::new();
90 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
91 let service_name = service.service_name.to_upper_camel_case();
92
93 let service_name_lower = service.service_name.to_lower_camel_case();
94
95 cw_writeln!(
96 w,
97 "public final class {service_name}ChannelingDispatcher {{"
98 )
99 .unwrap();
100 {
101 let _indent = w.indent();
102 cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
103 w.writeln("private let registry: IncomingChannelRegistry")
104 .unwrap();
105 w.writeln("private let taskSender: TaskSender").unwrap();
106 cw_writeln!(w, "private let schemaRegistry: [UInt64: Schema]").unwrap();
107 cw_writeln!(w, "private let methodSchemas: [UInt64: MethodSchemaInfo]").unwrap();
108 w.blank_line().unwrap();
109
110 cw_writeln!(
111 w,
112 "public init(handler: {service_name}Handler, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender, schemaSendTracker _: SchemaSendTracker, schemaRegistry: [UInt64: Schema] = {service_name_lower}_schema_registry, methodSchemas: [UInt64: MethodSchemaInfo] = {service_name_lower}_method_schemas) {{"
113 )
114 .unwrap();
115 {
116 let _indent = w.indent();
117 w.writeln("self.handler = handler").unwrap();
118 w.writeln("self.registry = registry").unwrap();
119 w.writeln("self.taskSender = taskSender").unwrap();
120 w.writeln("self.schemaRegistry = schemaRegistry").unwrap();
121 w.writeln("self.methodSchemas = methodSchemas").unwrap();
122 }
123 w.writeln("}").unwrap();
124 w.blank_line().unwrap();
125
126 w.writeln(
128 "public func dispatch(methodId: UInt64, requestId: UInt64, payload: Data) async {",
129 )
130 .unwrap();
131 {
132 let _indent = w.indent();
133 w.writeln("switch methodId {").unwrap();
134 for method in service.methods {
135 let method_name = method.method_name.to_lower_camel_case();
136 let method_id = crate::method_id(method);
137 let dispatch_name = dispatch_helper_name(&method_name);
138 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
139 cw_writeln!(
140 w,
141 " await {dispatch_name}(methodId: methodId, requestId: requestId, payload: payload)"
142 )
143 .unwrap();
144 }
145 w.writeln("default:").unwrap();
146 w.writeln(
147 " taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
148 )
149 .unwrap();
150 w.writeln("}").unwrap();
151 }
152 w.writeln("}").unwrap();
153 w.blank_line().unwrap();
154
155 w.writeln("public static func retryPolicy(methodId: UInt64) -> RetryPolicy {")
156 .unwrap();
157 {
158 let _indent = w.indent();
159 w.writeln("switch methodId {").unwrap();
160 for method in service.methods {
161 let method_id = crate::method_id(method);
162 let retry_policy = swift_retry_policy_literal(method);
163 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
164 cw_writeln!(w, " return {retry_policy}").unwrap();
165 }
166 w.writeln("default:").unwrap();
167 w.writeln(" return .volatile").unwrap();
168 w.writeln("}").unwrap();
169 }
170 w.writeln("}").unwrap();
171 w.blank_line().unwrap();
172
173 generate_preregister_channels(&mut w, service);
175 w.blank_line().unwrap();
176
177 for method in service.methods {
179 generate_channeling_dispatch_method(&mut w, method);
180 w.blank_line().unwrap();
181 }
182 }
183 w.writeln("}").unwrap();
184 w.blank_line().unwrap();
185
186 out
187}
188
189fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDescriptor) {
191 w.writeln("/// Pre-register Rx channel IDs from request payloads.")
192 .unwrap();
193 w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
194 .unwrap();
195 w.writeln("/// race conditions where Data arrives before channels are registered.")
196 .unwrap();
197 w.writeln(
198 "public static func preregisterChannels(methodId: UInt64, payload: Data, registry: ChannelRegistry) async {",
199 )
200 .unwrap();
201 {
202 let _indent = w.indent();
203 w.writeln("switch methodId {").unwrap();
204
205 for method in service.methods {
206 let method_id = crate::method_id(method);
207 let has_channel_args = method.args.iter().any(|a| is_rx(a.shape) || is_tx(a.shape));
208
209 if has_channel_args {
210 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
211 w.writeln(" do {").unwrap();
212 {
213 let _indent = w.indent();
214 if !method.args.is_empty() {
215 w.writeln("var preregisterCursor = 0").unwrap();
216 }
217 for arg in method.args {
218 let arg_name = arg.name.to_lower_camel_case();
219 if is_rx(arg.shape) {
220 cw_writeln!(
221 w,
222 "let {arg_name}ChannelId = try decodeVarint(from: payload, offset: &preregisterCursor)"
223 )
224 .unwrap();
225 cw_writeln!(w, "await registry.markKnown({arg_name}ChannelId)")
226 .unwrap();
227 } else if is_tx(arg.shape) {
228 w.writeln(
229 "_ = try decodeVarint(from: payload, offset: &preregisterCursor)",
230 )
231 .unwrap();
232 } else {
233 let discard_name = format!("_discard_{arg_name}");
234 let decode_stmt = generate_decode_stmt_with_cursor(
235 arg.shape,
236 &discard_name,
237 "",
238 "preregisterCursor",
239 );
240 for line in decode_stmt.lines() {
241 w.writeln(line).unwrap();
242 }
243 }
244 }
245 }
246 w.writeln(" } catch {").unwrap();
247 w.writeln(" return").unwrap();
248 w.writeln(" }").unwrap();
249 }
250 }
251
252 w.writeln("default:").unwrap();
253 w.writeln(" break").unwrap();
254 w.writeln("}").unwrap();
255 }
256 w.writeln("}").unwrap();
257}
258
259fn generate_channeling_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDescriptor) {
261 let method_name = method.method_name.to_lower_camel_case();
262 let dispatch_name = dispatch_helper_name(&method_name);
263 let has_channeling = method.args.iter().any(|a| is_channel(a.shape));
264 let handler_error_payload = if method.retry.persist {
265 "encodeIndeterminateError()"
266 } else {
267 "encodeInvalidPayloadError()"
268 };
269
270 cw_writeln!(
271 w,
272 "private func {dispatch_name}(methodId: UInt64, requestId: UInt64, payload: Data) async {{"
273 )
274 .unwrap();
275 {
276 let _indent = w.indent();
277 w.writeln("guard let methodInfo = methodSchemas[methodId] else {")
279 .unwrap();
280 w.writeln(
281 " taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
282 )
283 .unwrap();
284 w.writeln(" return").unwrap();
285 w.writeln("}").unwrap();
286 w.writeln(
287 "let responseSchemaPayload = methodInfo.buildPayload(direction: .response, registry: schemaRegistry)",
288 )
289 .unwrap();
290 w.writeln("do {").unwrap();
291 {
292 let _indent = w.indent();
293 let cursor_var = if !method.args.is_empty() {
294 let name = unique_decode_cursor_name(method.args);
295 cw_writeln!(w, "var {name} = 0").unwrap();
296 Some(name)
297 } else {
298 None
299 };
300
301 for arg in method.args {
302 let arg_name = arg.name.to_lower_camel_case();
303 generate_channeling_decode_arg(w, &arg_name, arg.shape, cursor_var.as_deref());
304 }
305 let arg_names: Vec<String> = method
306 .args
307 .iter()
308 .map(|a| {
309 let name = a.name.to_lower_camel_case();
310 format!("{name}: {name}")
311 })
312 .collect();
313
314 let ret_type = swift_type_server_return(method.return_shape);
315
316 w.writeln("do {").unwrap();
317 {
318 let _indent = w.indent();
319 if has_channeling {
320 if ret_type == "Void" {
321 cw_writeln!(
322 w,
323 "try await handler.{method_name}({})",
324 arg_names.join(", ")
325 )
326 .unwrap();
327 } else {
328 cw_writeln!(
329 w,
330 "let result = try await handler.{method_name}({})",
331 arg_names.join(", ")
332 )
333 .unwrap();
334 }
335
336 for arg in method.args {
337 if is_tx(arg.shape) {
338 let arg_name = arg.name.to_lower_camel_case();
339 cw_writeln!(w, "{arg_name}.close()").unwrap();
340 }
341 }
342
343 if ret_type == "Void" {
344 w.writeln(
345 "taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] }), methodId: methodId, schemaPayload: responseSchemaPayload))",
346 )
347 .unwrap();
348 } else {
349 let encode_closure = generate_encode_closure(method.return_shape);
350 cw_writeln!(
351 w,
352 "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure}), methodId: methodId, schemaPayload: responseSchemaPayload))"
353 )
354 .unwrap();
355 }
356 } else if ret_type == "Void" {
357 cw_writeln!(
358 w,
359 "try await handler.{method_name}({})",
360 arg_names.join(", ")
361 )
362 .unwrap();
363 w.writeln(
364 "taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] }), methodId: methodId, schemaPayload: responseSchemaPayload))",
365 )
366 .unwrap();
367 } else {
368 cw_writeln!(
369 w,
370 "let result = try await handler.{method_name}({})",
371 arg_names.join(", ")
372 )
373 .unwrap();
374 if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
375 let ok_encode = generate_encode_closure(ok);
376 let err_encode = generate_encode_closure(err);
377 cw_writeln!(
378 w,
379 "taskSender(.response(requestId: requestId, payload: {{ switch result {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1), UInt8(0)] + {err_encode}(e) }} }}(), methodId: methodId, schemaPayload: responseSchemaPayload))"
380 )
381 .unwrap();
382 } else {
383 let encode_closure = generate_encode_closure(method.return_shape);
384 cw_writeln!(
385 w,
386 "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure}), methodId: methodId, schemaPayload: responseSchemaPayload))"
387 )
388 .unwrap();
389 }
390 }
391 }
392 w.writeln("} catch {").unwrap();
393 {
394 let _indent = w.indent();
395 cw_writeln!(
396 w,
397 "taskSender(.response(requestId: requestId, payload: {handler_error_payload}, methodId: methodId, schemaPayload: responseSchemaPayload))"
398 )
399 .unwrap();
400 }
401 w.writeln("}").unwrap();
402 }
403 w.writeln("} catch {").unwrap();
404 {
405 let _indent = w.indent();
406 w.writeln(
407 "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError(), methodId: methodId, schemaPayload: responseSchemaPayload))",
408 )
409 .unwrap();
410 }
411 w.writeln("}").unwrap();
412 }
413 w.writeln("}").unwrap();
414}
415
416fn generate_channeling_decode_arg(
418 w: &mut CodeWriter<&mut String>,
419 name: &str,
420 shape: &'static Shape,
421 cursor_var: Option<&str>,
422) {
423 match classify_shape(shape) {
424 ShapeKind::Rx { inner } => {
425 let inline_decode = generate_inline_decode(inner, "Data(bytes)", "off");
428 let cursor_var = cursor_var.expect("payload cursor required for channeling args");
429 cw_writeln!(
430 w,
431 "let {name}ChannelId = try decodeVarint(from: payload, offset: &{cursor_var})"
432 )
433 .unwrap();
434 cw_writeln!(
435 w,
436 "let {name}Receiver = await registry.register({name}ChannelId, initialCredit: 16, onConsumed: {{ [taskSender = self.taskSender] additional in taskSender(.grantCredit(channelId: {name}ChannelId, bytes: additional)) }})"
437 )
438 .unwrap();
439 cw_writeln!(
440 w,
441 "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {{ bytes in"
442 )
443 .unwrap();
444 cw_writeln!(w, " var off = 0").unwrap();
445 cw_writeln!(w, " return try {inline_decode}").unwrap();
446 w.writeln("})").unwrap();
447 }
448 ShapeKind::Tx { inner } => {
449 let encode_closure = generate_encode_closure(inner);
452 let cursor_var = cursor_var.expect("payload cursor required for channeling args");
453 cw_writeln!(
454 w,
455 "let {name}ChannelId = try decodeVarint(from: payload, offset: &{cursor_var})"
456 )
457 .unwrap();
458 cw_writeln!(
459 w,
460 "let {name} = await createServerTx(channelId: {name}ChannelId, taskSender: taskSender, registry: registry, initialCredit: 16, serialize: ({encode_closure}))"
461 )
462 .unwrap();
463 }
464 _ => {
465 let cursor_var = cursor_var.expect("payload cursor required for non-channel args");
467 let decode_stmt = generate_decode_stmt_with_cursor(shape, name, "", cursor_var);
468 for line in decode_stmt.lines() {
469 w.writeln(line).unwrap();
470 }
471 }
472 }
473}
474
475fn unique_decode_cursor_name(args: &[vox_types::ArgDescriptor]) -> String {
476 let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
477 let mut candidate = String::from("cursor");
478 while arg_names.iter().any(|name| name == &candidate) {
479 candidate.push('_');
480 }
481 candidate
482}