vox_codegen/targets/swift/
server.rs1use facet_core::Shape;
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use vox_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
8
9use super::decode::generate_decode_stmt_with_cursor;
10use super::encode::{generate_encode_closure, generate_encode_stmt};
11use super::types::{format_doc, is_channel, swift_type_server_arg, swift_type_server_return};
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14use crate::render::hex_u64;
15
16fn swift_retry_policy_literal(method: &MethodDescriptor) -> &'static str {
17 match (method.retry.persist, method.retry.idem) {
18 (false, false) => ".volatile",
19 (false, true) => ".idem",
20 (true, false) => ".persist",
21 (true, true) => ".persistIdem",
22 }
23}
24
25fn dispatch_helper_name(method_name: &str) -> String {
26 format!("dispatch_{method_name}")
27}
28
29pub fn generate_server(service: &ServiceDescriptor) -> String {
31 let mut out = String::new();
32 out.push_str(&generate_handler_protocol(service));
33 out.push_str(&generate_dispatcher(service));
35 out
36}
37
38fn generate_handler_protocol(service: &ServiceDescriptor) -> String {
40 let mut out = String::new();
41 let service_name = service.service_name.to_upper_camel_case();
42
43 if let Some(doc) = &service.doc {
44 out.push_str(&format_doc(doc, ""));
45 }
46 out.push_str(&format!(
47 "public protocol {service_name}Handler: Sendable {{\n"
48 ));
49
50 for method in service.methods {
51 let method_name = method.method_name.to_lower_camel_case();
52
53 if let Some(doc) = &method.doc {
54 out.push_str(&format_doc(doc, " "));
55 }
56
57 let args: Vec<String> = method
59 .args
60 .iter()
61 .map(|a| {
62 format!(
63 "{}: {}",
64 a.name.to_lower_camel_case(),
65 swift_type_server_arg(a.shape)
66 )
67 })
68 .collect();
69
70 let ret_type = swift_type_server_return(method.return_shape);
71
72 if ret_type == "Void" {
73 out.push_str(&format!(
74 " func {method_name}({}) async throws\n",
75 args.join(", ")
76 ));
77 } else {
78 out.push_str(&format!(
79 " func {method_name}({}) async throws -> {ret_type}\n",
80 args.join(", ")
81 ));
82 }
83 }
84
85 out.push_str("}\n\n");
86 out
87}
88
89fn generate_dispatcher(service: &ServiceDescriptor) -> String {
91 let mut out = String::new();
92 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
93 let service_name = service.service_name.to_upper_camel_case();
94
95 let service_name_lower = service.service_name.to_lower_camel_case();
96
97 cw_writeln!(
98 w,
99 "public final class {service_name}Dispatcher: ServiceDispatcher {{"
100 )
101 .unwrap();
102 {
103 let _indent = w.indent();
104 cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
105 cw_writeln!(w, "private let schemaRegistry: [UInt64: Schema]").unwrap();
106 cw_writeln!(w, "private let methodSchemas: [UInt64: MethodSchemaInfo]").unwrap();
107 w.blank_line().unwrap();
108
109 cw_writeln!(
110 w,
111 "public init(handler: {service_name}Handler, schemaRegistry: [UInt64: Schema] = {service_name_lower}_schema_registry, methodSchemas: [UInt64: MethodSchemaInfo] = {service_name_lower}_method_schemas) {{"
112 )
113 .unwrap();
114 {
115 let _indent = w.indent();
116 w.writeln("self.handler = handler").unwrap();
117 w.writeln("self.schemaRegistry = schemaRegistry").unwrap();
118 w.writeln("self.methodSchemas = methodSchemas").unwrap();
119 }
120 w.writeln("}").unwrap();
121 w.blank_line().unwrap();
122
123 w.writeln(
125 "public func dispatch(methodId: UInt64, payload: [UInt8], requestId: UInt64, registry: ChannelRegistry, schemaSendTracker _: SchemaSendTracker, taskTx: @escaping @Sendable (TaskMessage) -> Void) async {",
126 )
127 .unwrap();
128 {
129 let _indent = w.indent();
130 w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: payload.count)")
131 .unwrap();
132 w.writeln("buffer.writeBytes(payload)").unwrap();
133 w.writeln("let taskSender: TaskSender = taskTx").unwrap();
134 w.writeln("switch methodId {").unwrap();
135 for method in service.methods {
136 let method_name = method.method_name.to_lower_camel_case();
137 let method_id = crate::method_id(method);
138 let dispatch_name = dispatch_helper_name(&method_name);
139 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
140 cw_writeln!(
141 w,
142 " await {dispatch_name}(methodId: methodId, requestId: requestId, buffer: &buffer, registry: registry, taskSender: taskSender)"
143 )
144 .unwrap();
145 }
146 w.writeln("default:").unwrap();
147 w.writeln(
148 " taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
149 )
150 .unwrap();
151 w.writeln("}").unwrap();
152 }
153 w.writeln("}").unwrap();
154 w.blank_line().unwrap();
155
156 w.writeln("public func retryPolicy(methodId: UInt64) -> RetryPolicy {")
157 .unwrap();
158 {
159 let _indent = w.indent();
160 w.writeln("switch methodId {").unwrap();
161 for method in service.methods {
162 let method_id = crate::method_id(method);
163 let retry_policy = swift_retry_policy_literal(method);
164 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
165 cw_writeln!(w, " return {retry_policy}").unwrap();
166 }
167 w.writeln("default:").unwrap();
168 w.writeln(" return .volatile").unwrap();
169 w.writeln("}").unwrap();
170 }
171 w.writeln("}").unwrap();
172 w.blank_line().unwrap();
173
174 generate_preregister_channels(&mut w, service);
176 w.blank_line().unwrap();
177
178 for method in service.methods {
180 generate_channeling_dispatch_method(&mut w, method);
181 w.blank_line().unwrap();
182 }
183 }
184 w.writeln("}").unwrap();
185 w.blank_line().unwrap();
186
187 out
188}
189
190fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDescriptor) {
192 w.writeln("/// Pre-register Rx channel IDs from request payloads.")
193 .unwrap();
194 w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
195 .unwrap();
196 w.writeln("/// race conditions where Data arrives before channels are registered.")
197 .unwrap();
198 w.writeln(
199 "public func preregister(methodId: UInt64, payload: [UInt8], registry: ChannelRegistry) async {",
200 )
201 .unwrap();
202 {
203 let _indent = w.indent();
204 w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: payload.count)")
205 .unwrap();
206 w.writeln("buffer.writeBytes(payload)").unwrap();
207 w.writeln("switch methodId {").unwrap();
208
209 for method in service.methods {
210 let method_id = crate::method_id(method);
211 let has_channel_args = method.args.iter().any(|a| is_rx(a.shape) || is_tx(a.shape));
212
213 if has_channel_args {
214 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
215 w.writeln(" do {").unwrap();
216 {
217 let _indent = w.indent();
218 for arg in method.args {
219 let arg_name = arg.name.to_lower_camel_case();
220 if is_rx(arg.shape) {
221 cw_writeln!(
222 w,
223 "let {arg_name}ChannelId = try decodeVarint(from: &buffer)"
224 )
225 .unwrap();
226 cw_writeln!(w, "await registry.markKnown({arg_name}ChannelId)")
227 .unwrap();
228 } else if is_tx(arg.shape) {
229 w.writeln("_ = try decodeVarint(from: &buffer)").unwrap();
230 } else {
231 let discard_name = format!("_discard_{arg_name}");
232 let decode_stmt = generate_decode_stmt_with_cursor(
233 arg.shape,
234 &discard_name,
235 "",
236 "buffer",
237 );
238 for line in decode_stmt.lines() {
239 w.writeln(line).unwrap();
240 }
241 }
242 }
243 }
244 w.writeln(" } catch {").unwrap();
245 w.writeln(" return").unwrap();
246 w.writeln(" }").unwrap();
247 }
248 }
249
250 w.writeln("default:").unwrap();
251 w.writeln(" break").unwrap();
252 w.writeln("}").unwrap();
253 }
254 w.writeln("}").unwrap();
255}
256
257fn generate_channeling_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDescriptor) {
259 let method_name = method.method_name.to_lower_camel_case();
260 let dispatch_name = dispatch_helper_name(&method_name);
261 let has_channeling = method.args.iter().any(|a| is_channel(a.shape));
262 cw_writeln!(
263 w,
264 "private func {dispatch_name}(methodId: UInt64, requestId: UInt64, buffer: inout ByteBuffer, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender) async {{"
265 )
266 .unwrap();
267 {
268 let _indent = w.indent();
269 w.writeln("guard let methodInfo = methodSchemas[methodId] else {")
271 .unwrap();
272 w.writeln(
273 " taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
274 )
275 .unwrap();
276 w.writeln(" return").unwrap();
277 w.writeln("}").unwrap();
278 w.writeln(
279 "let responseSchemaPayload = methodInfo.buildPayload(direction: .response, registry: schemaRegistry)",
280 )
281 .unwrap();
282 w.writeln("do {").unwrap();
283 {
284 let _indent = w.indent();
285 for arg in method.args {
286 let arg_name = arg.name.to_lower_camel_case();
287 generate_channeling_decode_arg(w, &arg_name, arg.shape);
288 }
289 let arg_names: Vec<String> = method
290 .args
291 .iter()
292 .map(|a| {
293 let name = a.name.to_lower_camel_case();
294 format!("{name}: {name}")
295 })
296 .collect();
297
298 let ret_type = swift_type_server_return(method.return_shape);
299
300 w.writeln("do {").unwrap();
301 {
302 let _indent = w.indent();
303 if has_channeling {
304 if ret_type == "Void" {
305 cw_writeln!(
306 w,
307 "try await handler.{method_name}({})",
308 arg_names.join(", ")
309 )
310 .unwrap();
311 } else {
312 cw_writeln!(
313 w,
314 "let result = try await handler.{method_name}({})",
315 arg_names.join(", ")
316 )
317 .unwrap();
318 }
319
320 if ret_type == "Void" {
321 w.writeln(
322 "taskSender(.response(requestId: requestId, payload: encodeResultOkUnit(), methodId: methodId, schemaPayload: responseSchemaPayload))",
323 )
324 .unwrap();
325 } else {
326 let encode_closure = generate_encode_closure(method.return_shape);
327 cw_writeln!(
328 w,
329 "let _encoded = encodeResultOk(result, encoder: {encode_closure})"
330 )
331 .unwrap();
332 w.writeln(
333 "taskSender(.response(requestId: requestId, payload: _encoded, methodId: methodId, schemaPayload: responseSchemaPayload))",
334 )
335 .unwrap();
336 }
337 } else if ret_type == "Void" {
338 cw_writeln!(
339 w,
340 "try await handler.{method_name}({})",
341 arg_names.join(", ")
342 )
343 .unwrap();
344 w.writeln(
345 "taskSender(.response(requestId: requestId, payload: encodeResultOkUnit(), methodId: methodId, schemaPayload: responseSchemaPayload))",
346 )
347 .unwrap();
348 } else {
349 cw_writeln!(
350 w,
351 "let result = try await handler.{method_name}({})",
352 arg_names.join(", ")
353 )
354 .unwrap();
355 if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
356 let ok_stmt = generate_encode_stmt(ok, "v");
357 let err_stmt = generate_encode_stmt(err, "e");
358 w.writeln("let _encoded: [UInt8] = {").unwrap();
359 {
360 let _indent = w.indent();
361 w.writeln("var buffer = ByteBufferAllocator().buffer(capacity: 64)")
362 .unwrap();
363 w.writeln("switch result {").unwrap();
364 w.writeln("case .success(let v):").unwrap();
365 {
366 let _indent = w.indent();
367 w.writeln("encodeVarint(UInt64(0), into: &buffer)").unwrap();
368 for line in ok_stmt.lines() {
369 w.writeln(line).unwrap();
370 }
371 }
372 w.writeln("case .failure(let e):").unwrap();
373 {
374 let _indent = w.indent();
375 w.writeln("encodeVarint(UInt64(1), into: &buffer)").unwrap();
376 w.writeln("encodeU8(0, into: &buffer)").unwrap();
377 for line in err_stmt.lines() {
378 w.writeln(line).unwrap();
379 }
380 }
381 w.writeln("}").unwrap();
382 w.writeln(
383 "return buffer.readBytes(length: buffer.readableBytes) ?? []",
384 )
385 .unwrap();
386 }
387 w.writeln("}()").unwrap();
388 w.writeln(
389 "taskSender(.response(requestId: requestId, payload: _encoded, methodId: methodId, schemaPayload: responseSchemaPayload))",
390 )
391 .unwrap();
392 } else {
393 let encode_closure = generate_encode_closure(method.return_shape);
394 cw_writeln!(
395 w,
396 "let _encoded = encodeResultOk(result, encoder: {encode_closure})"
397 )
398 .unwrap();
399 w.writeln(
400 "taskSender(.response(requestId: requestId, payload: _encoded, methodId: methodId, schemaPayload: responseSchemaPayload))",
401 )
402 .unwrap();
403 }
404 }
405 }
406 w.writeln("} catch {").unwrap();
407 {
408 let _indent = w.indent();
409 if method.retry.persist {
410 w.writeln(
411 "taskSender(.response(requestId: requestId, payload: encodeIndeterminateError(), methodId: methodId, schemaPayload: responseSchemaPayload))",
412 )
413 .unwrap();
414 } else {
415 w.writeln(
416 "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError(reason: String(describing: error)), methodId: methodId, schemaPayload: responseSchemaPayload))",
417 )
418 .unwrap();
419 }
420 }
421 w.writeln("}").unwrap();
422 }
423 w.writeln("} catch {").unwrap();
424 {
425 let _indent = w.indent();
426 w.writeln(
427 "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError(reason: String(describing: error)), methodId: methodId, schemaPayload: responseSchemaPayload))",
428 )
429 .unwrap();
430 }
431 w.writeln("}").unwrap();
432 }
433 w.writeln("}").unwrap();
434}
435
436fn generate_channeling_decode_arg(
439 w: &mut CodeWriter<&mut String>,
440 name: &str,
441 shape: &'static Shape,
442) {
443 match classify_shape(shape) {
444 ShapeKind::Rx { inner } => {
445 let decode_closure = generate_decode_closure_for_channel(inner);
446 cw_writeln!(w, "let {name}ChannelId = try decodeVarint(from: &buffer)").unwrap();
447 cw_writeln!(
448 w,
449 "let {name}Receiver = await registry.register({name}ChannelId, initialCredit: 16, onConsumed: {{ [taskSender] additional in taskSender(.grantCredit(channelId: {name}ChannelId, bytes: additional)) }})"
450 )
451 .unwrap();
452 cw_writeln!(
453 w,
454 "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {decode_closure})"
455 )
456 .unwrap();
457 }
458 ShapeKind::Tx { inner } => {
459 let encode_closure = generate_encode_closure(inner);
460 cw_writeln!(w, "let {name}ChannelId = try decodeVarint(from: &buffer)").unwrap();
461 cw_writeln!(
462 w,
463 "let {name} = await createServerTx(channelId: {name}ChannelId, taskSender: taskSender, registry: registry, initialCredit: 16, serialize: {encode_closure})"
464 )
465 .unwrap();
466 }
467 _ => {
468 let decode_stmt = generate_decode_stmt_with_cursor(shape, name, "", "buffer");
469 for line in decode_stmt.lines() {
470 w.writeln(line).unwrap();
471 }
472 }
473 }
474}
475
476fn generate_decode_closure_for_channel(inner: &'static Shape) -> String {
479 use super::decode::generate_decode_closure;
480 generate_decode_closure(inner)
481}