1use super::common::{GENERATED_HEADER, is_void_input};
2use super::typescript::{emit_jsdoc, rust_type_to_ts};
3use crate::model::{Manifest, ProcedureKind};
4
5const ERROR_CLASS: &str = r#"export class RpcError extends Error {
7 readonly status: number;
8 readonly data: unknown;
9
10 constructor(status: number, message: string, data?: unknown) {
11 super(message);
12 this.name = "RpcError";
13 this.status = status;
14 this.data = data;
15 }
16}"#;
17
18const REQUEST_CONTEXT_INTERFACE: &str = r#"export interface RequestContext {
20 procedure: string;
21 method: "GET" | "POST";
22 url: string;
23 headers: Record<string, string>;
24 input?: unknown;
25}"#;
26
27const RESPONSE_CONTEXT_INTERFACE: &str = r#"export interface ResponseContext {
29 procedure: string;
30 method: "GET" | "POST";
31 url: string;
32 response: Response;
33 data: unknown;
34 duration: number;
35}"#;
36
37const ERROR_CONTEXT_INTERFACE: &str = r#"export interface ErrorContext {
39 procedure: string;
40 method: "GET" | "POST";
41 url: string;
42 error: unknown;
43 attempt: number;
44 willRetry: boolean;
45}"#;
46
47const RETRY_POLICY_INTERFACE: &str = r#"export interface RetryPolicy {
49 attempts: number;
50 delay: number | ((attempt: number) => number);
51 retryOn?: number[];
52}"#;
53
54const CONFIG_INTERFACE: &str = r#"export interface RpcClientConfig {
56 baseUrl: string;
57 fetch?: typeof globalThis.fetch;
58 headers?:
59 | Record<string, string>
60 | (() => Record<string, string> | Promise<Record<string, string>>);
61 onRequest?: (ctx: RequestContext) => void | Promise<void>;
62 onResponse?: (ctx: ResponseContext) => void | Promise<void>;
63 onError?: (ctx: ErrorContext) => void | Promise<void>;
64 retry?: RetryPolicy;
65 timeout?: number;
66 serialize?: (input: unknown) => string;
67 deserialize?: (text: string) => unknown;
68 // AbortSignal for cancelling all requests made by this client.
69 signal?: AbortSignal;
70 dedupe?: boolean;
71}"#;
72
73const CALL_OPTIONS_INTERFACE: &str = r#"export interface CallOptions {
75 headers?: Record<string, string>;
76 timeout?: number;
77 signal?: AbortSignal;
78 dedupe?: boolean;
79}"#;
80
81const DEDUP_KEY_FN: &str = r#"function dedupKey(procedure: string, input: unknown, config: RpcClientConfig): string {
83 const serialized = input === undefined
84 ? ""
85 : config.serialize
86 ? config.serialize(input)
87 : JSON.stringify(input);
88 return procedure + ":" + serialized;
89}"#;
90
91const WRAP_WITH_SIGNAL_FN: &str = r#"function wrapWithSignal<T>(promise: Promise<T>, signal?: AbortSignal): Promise<T> {
93 if (!signal) return promise;
94 if (signal.aborted) return Promise.reject(signal.reason);
95 return new Promise<T>((resolve, reject) => {
96 const onAbort = () => reject(signal.reason);
97 signal.addEventListener("abort", onAbort, { once: true });
98 promise.then(
99 (value) => { signal.removeEventListener("abort", onAbort); resolve(value); },
100 (error) => { signal.removeEventListener("abort", onAbort); reject(error); },
101 );
102 });
103}"#;
104
105const FETCH_HELPER: &str = r#"const DEFAULT_RETRY_ON = [408, 429, 500, 502, 503, 504];
107
108async function rpcFetch(
109 config: RpcClientConfig,
110 method: "GET" | "POST",
111 procedure: string,
112 input?: unknown,
113 callOptions?: CallOptions,
114): Promise<unknown> {
115 let url = `${config.baseUrl}/${procedure}`;
116 const customHeaders = typeof config.headers === "function"
117 ? await config.headers()
118 : config.headers;
119 const baseHeaders: Record<string, string> = { ...customHeaders, ...callOptions?.headers };
120
121 if (method === "GET" && input !== undefined) {
122 const serialized = config.serialize ? config.serialize(input) : JSON.stringify(input);
123 url += `?input=${encodeURIComponent(serialized)}`;
124 } else if (method === "POST" && input !== undefined) {
125 baseHeaders["Content-Type"] = "application/json";
126 }
127
128 const fetchFn = config.fetch ?? globalThis.fetch;
129 const maxAttempts = 1 + (config.retry?.attempts ?? 0);
130 const retryOn = config.retry?.retryOn ?? DEFAULT_RETRY_ON;
131 const effectiveTimeout = callOptions?.timeout ?? config.timeout;
132 const start = Date.now();
133
134 for (let attempt = 1; attempt <= maxAttempts; attempt++) {
135 const reqCtx: RequestContext = { procedure, method, url, headers: { ...baseHeaders }, input };
136 await config.onRequest?.(reqCtx);
137
138 const init: RequestInit = { method, headers: reqCtx.headers };
139 if (method === "POST" && input !== undefined) {
140 init.body = config.serialize ? config.serialize(input) : JSON.stringify(input);
141 }
142
143 let timeoutId: ReturnType<typeof setTimeout> | undefined;
144 const signals: AbortSignal[] = [];
145 if (config.signal) signals.push(config.signal);
146 if (callOptions?.signal) signals.push(callOptions.signal);
147 if (effectiveTimeout) {
148 const controller = new AbortController();
149 timeoutId = setTimeout(() => controller.abort(), effectiveTimeout);
150 signals.push(controller.signal);
151 }
152 if (signals.length > 0) {
153 init.signal = signals.length === 1 ? signals[0] : AbortSignal.any(signals);
154 }
155
156 try {
157 const res = await fetchFn(url, init);
158
159 if (!res.ok) {
160 let data: unknown;
161 try {
162 data = await res.json();
163 } catch {
164 data = await res.text().catch(() => null);
165 }
166 const rpcError = new RpcError(
167 res.status,
168 `RPC error on "${procedure}": ${res.status} ${res.statusText}`,
169 data,
170 );
171 const canRetry = retryOn.includes(res.status) && attempt < maxAttempts;
172 await config.onError?.({ procedure, method, url, error: rpcError, attempt, willRetry: canRetry });
173 if (!canRetry) throw rpcError;
174 } else {
175 const json = config.deserialize ? config.deserialize(await res.text()) : await res.json();
176 const result = json?.result?.data ?? json;
177 const duration = Date.now() - start;
178 await config.onResponse?.({ procedure, method, url, response: res, data: result, duration });
179 return result;
180 }
181 } catch (err) {
182 if (err instanceof RpcError) throw err;
183 const willRetry = attempt < maxAttempts;
184 await config.onError?.({ procedure, method, url, error: err, attempt, willRetry });
185 if (!willRetry) throw err;
186 } finally {
187 if (timeoutId !== undefined) clearTimeout(timeoutId);
188 }
189
190 if (config.retry) {
191 const d = typeof config.retry.delay === "function"
192 ? config.retry.delay(attempt) : config.retry.delay;
193 await new Promise(r => setTimeout(r, d));
194 }
195 }
196}"#;
197
198pub fn generate_client_file(
207 manifest: &Manifest,
208 types_import_path: &str,
209 preserve_docs: bool,
210) -> String {
211 let mut out = String::with_capacity(2048);
212
213 out.push_str(GENERATED_HEADER);
215 out.push('\n');
216
217 let type_names: Vec<&str> = manifest
219 .structs
220 .iter()
221 .map(|s| s.name.as_str())
222 .chain(manifest.enums.iter().map(|e| e.name.as_str()))
223 .collect();
224
225 if type_names.is_empty() {
227 emit!(
228 out,
229 "import type {{ Procedures }} from \"{types_import_path}\";\n"
230 );
231 emit!(out, "export type {{ Procedures }};\n");
232 } else {
233 let types_csv = type_names.join(", ");
234 emit!(
235 out,
236 "import type {{ Procedures, {types_csv} }} from \"{types_import_path}\";\n"
237 );
238 emit!(out, "export type {{ Procedures, {types_csv} }};\n");
239 }
240
241 emit!(out, "{ERROR_CLASS}\n");
243
244 emit!(out, "{REQUEST_CONTEXT_INTERFACE}\n");
246 emit!(out, "{RESPONSE_CONTEXT_INTERFACE}\n");
247 emit!(out, "{ERROR_CONTEXT_INTERFACE}\n");
248
249 emit!(out, "{RETRY_POLICY_INTERFACE}\n");
251
252 emit!(out, "{CONFIG_INTERFACE}\n");
254
255 emit!(out, "{CALL_OPTIONS_INTERFACE}\n");
257
258 emit!(out, "{FETCH_HELPER}\n");
260
261 let has_queries = manifest
263 .procedures
264 .iter()
265 .any(|p| p.kind == ProcedureKind::Query);
266 if has_queries {
267 emit!(out, "{DEDUP_KEY_FN}\n");
268 emit!(out, "{WRAP_WITH_SIGNAL_FN}\n");
269 }
270
271 generate_type_helpers(&mut out);
273 out.push('\n');
274
275 generate_client_factory(manifest, preserve_docs, &mut out);
277
278 out
279}
280
281fn generate_type_helpers(out: &mut String) {
283 emit!(out, "type QueryKey = keyof Procedures[\"queries\"];");
284 emit!(out, "type MutationKey = keyof Procedures[\"mutations\"];");
285 emit!(
286 out,
287 "type QueryInput<K extends QueryKey> = Procedures[\"queries\"][K][\"input\"];"
288 );
289 emit!(
290 out,
291 "type QueryOutput<K extends QueryKey> = Procedures[\"queries\"][K][\"output\"];"
292 );
293 emit!(
294 out,
295 "type MutationInput<K extends MutationKey> = Procedures[\"mutations\"][K][\"input\"];"
296 );
297 emit!(
298 out,
299 "type MutationOutput<K extends MutationKey> = Procedures[\"mutations\"][K][\"output\"];"
300 );
301}
302
303fn generate_client_factory(manifest: &Manifest, preserve_docs: bool, out: &mut String) {
305 let queries: Vec<_> = manifest
306 .procedures
307 .iter()
308 .filter(|p| p.kind == ProcedureKind::Query)
309 .collect();
310 let mutations: Vec<_> = manifest
311 .procedures
312 .iter()
313 .filter(|p| p.kind == ProcedureKind::Mutation)
314 .collect();
315 let has_queries = !queries.is_empty();
316 let has_mutations = !mutations.is_empty();
317
318 let void_queries: Vec<_> = queries.iter().filter(|p| is_void_input(p)).collect();
320 let non_void_queries: Vec<_> = queries.iter().filter(|p| !is_void_input(p)).collect();
321 let void_mutations: Vec<_> = mutations.iter().filter(|p| is_void_input(p)).collect();
322 let non_void_mutations: Vec<_> = mutations.iter().filter(|p| !is_void_input(p)).collect();
323
324 let query_mixed = !void_queries.is_empty() && !non_void_queries.is_empty();
325 let mutation_mixed = !void_mutations.is_empty() && !non_void_mutations.is_empty();
326
327 if query_mixed {
329 let names: Vec<_> = void_queries
330 .iter()
331 .map(|p| format!("\"{}\"", p.name))
332 .collect();
333 emit!(
334 out,
335 "const VOID_QUERIES: Set<string> = new Set([{}]);",
336 names.join(", ")
337 );
338 out.push('\n');
339 }
340 if mutation_mixed {
341 let names: Vec<_> = void_mutations
342 .iter()
343 .map(|p| format!("\"{}\"", p.name))
344 .collect();
345 emit!(
346 out,
347 "const VOID_MUTATIONS: Set<string> = new Set([{}]);",
348 names.join(", ")
349 );
350 out.push('\n');
351 }
352
353 emit!(out, "export interface RpcClient {{");
355
356 if has_queries {
357 generate_query_overloads(manifest, preserve_docs, out);
358 }
359
360 if has_mutations {
361 if has_queries {
362 out.push('\n');
363 }
364 generate_mutation_overloads(manifest, preserve_docs, out);
365 }
366
367 emit!(out, "}}");
368 out.push('\n');
369
370 emit!(
372 out,
373 "export function createRpcClient(config: RpcClientConfig): RpcClient {{"
374 );
375
376 if has_queries {
377 emit!(
378 out,
379 " const inflight = new Map<string, Promise<unknown>>();\n"
380 );
381 }
382
383 emit!(out, " return {{");
384
385 if has_queries {
386 emit!(
387 out,
388 " query(key: QueryKey, ...args: unknown[]): Promise<unknown> {{"
389 );
390
391 if query_mixed {
393 emit!(out, " let input: unknown;");
394 emit!(out, " let callOptions: CallOptions | undefined;");
395 emit!(out, " if (VOID_QUERIES.has(key)) {{");
396 emit!(out, " input = undefined;");
397 emit!(
398 out,
399 " callOptions = args[0] as CallOptions | undefined;"
400 );
401 emit!(out, " }} else {{");
402 emit!(out, " input = args[0];");
403 emit!(
404 out,
405 " callOptions = args[1] as CallOptions | undefined;"
406 );
407 emit!(out, " }}");
408 } else if !void_queries.is_empty() {
409 emit!(out, " const input = undefined;");
410 emit!(
411 out,
412 " const callOptions = args[0] as CallOptions | undefined;"
413 );
414 } else {
415 emit!(out, " const input = args[0];");
416 emit!(
417 out,
418 " const callOptions = args[1] as CallOptions | undefined;"
419 );
420 }
421
422 emit!(
424 out,
425 " const shouldDedupe = callOptions?.dedupe ?? config.dedupe ?? true;"
426 );
427 emit!(out, " if (shouldDedupe) {{");
428 emit!(out, " const k = dedupKey(key, input, config);");
429 emit!(out, " const existing = inflight.get(k);");
430 emit!(
431 out,
432 " if (existing) return wrapWithSignal(existing, callOptions?.signal);"
433 );
434 emit!(
435 out,
436 " const promise = rpcFetch(config, \"GET\", key, input, callOptions)"
437 );
438 emit!(out, " .finally(() => inflight.delete(k));");
439 emit!(out, " inflight.set(k, promise);");
440 emit!(
441 out,
442 " return wrapWithSignal(promise, callOptions?.signal);"
443 );
444 emit!(out, " }}");
445 emit!(
446 out,
447 " return rpcFetch(config, \"GET\", key, input, callOptions);"
448 );
449 emit!(out, " }},");
450 }
451
452 if has_mutations {
453 emit!(
454 out,
455 " mutate(key: MutationKey, ...args: unknown[]): Promise<unknown> {{"
456 );
457 if mutation_mixed {
458 emit!(out, " if (VOID_MUTATIONS.has(key)) {{");
460 emit!(
461 out,
462 " return rpcFetch(config, \"POST\", key, undefined, args[0] as CallOptions | undefined);"
463 );
464 emit!(out, " }}");
465 emit!(
466 out,
467 " return rpcFetch(config, \"POST\", key, args[0], args[1] as CallOptions | undefined);"
468 );
469 } else if !void_mutations.is_empty() {
470 emit!(
472 out,
473 " return rpcFetch(config, \"POST\", key, undefined, args[0] as CallOptions | undefined);"
474 );
475 } else {
476 emit!(
478 out,
479 " return rpcFetch(config, \"POST\", key, args[0], args[1] as CallOptions | undefined);"
480 );
481 }
482 emit!(out, " }},");
483 }
484
485 emit!(out, " }} as RpcClient;");
486 emit!(out, "}}");
487}
488
489fn generate_query_overloads(manifest: &Manifest, preserve_docs: bool, out: &mut String) {
491 let (void_queries, non_void_queries): (Vec<_>, Vec<_>) = manifest
492 .procedures
493 .iter()
494 .filter(|p| p.kind == ProcedureKind::Query)
495 .partition(|p| is_void_input(p));
496
497 for proc in &void_queries {
499 if preserve_docs && let Some(doc) = &proc.docs {
500 emit_jsdoc(doc, " ", out);
501 }
502 let output_ts = proc
503 .output
504 .as_ref()
505 .map(rust_type_to_ts)
506 .unwrap_or_else(|| "void".to_string());
507 emit!(
508 out,
509 " query(key: \"{}\"): Promise<{}>;",
510 proc.name,
511 output_ts,
512 );
513 emit!(
514 out,
515 " query(key: \"{}\", options: CallOptions): Promise<{}>;",
516 proc.name,
517 output_ts,
518 );
519 }
520
521 for proc in &non_void_queries {
523 if preserve_docs && let Some(doc) = &proc.docs {
524 emit_jsdoc(doc, " ", out);
525 }
526 let input_ts = proc
527 .input
528 .as_ref()
529 .map(rust_type_to_ts)
530 .unwrap_or_else(|| "void".to_string());
531 let output_ts = proc
532 .output
533 .as_ref()
534 .map(rust_type_to_ts)
535 .unwrap_or_else(|| "void".to_string());
536 emit!(
537 out,
538 " query(key: \"{}\", input: {}): Promise<{}>;",
539 proc.name,
540 input_ts,
541 output_ts,
542 );
543 emit!(
544 out,
545 " query(key: \"{}\", input: {}, options: CallOptions): Promise<{}>;",
546 proc.name,
547 input_ts,
548 output_ts,
549 );
550 }
551}
552
553fn generate_mutation_overloads(manifest: &Manifest, preserve_docs: bool, out: &mut String) {
555 let (void_mutations, non_void_mutations): (Vec<_>, Vec<_>) = manifest
556 .procedures
557 .iter()
558 .filter(|p| p.kind == ProcedureKind::Mutation)
559 .partition(|p| is_void_input(p));
560
561 for proc in &void_mutations {
563 if preserve_docs && let Some(doc) = &proc.docs {
564 emit_jsdoc(doc, " ", out);
565 }
566 let output_ts = proc
567 .output
568 .as_ref()
569 .map(rust_type_to_ts)
570 .unwrap_or_else(|| "void".to_string());
571 emit!(
572 out,
573 " mutate(key: \"{}\"): Promise<{}>;",
574 proc.name,
575 output_ts,
576 );
577 emit!(
578 out,
579 " mutate(key: \"{}\", options: CallOptions): Promise<{}>;",
580 proc.name,
581 output_ts,
582 );
583 }
584
585 for proc in &non_void_mutations {
587 if preserve_docs && let Some(doc) = &proc.docs {
588 emit_jsdoc(doc, " ", out);
589 }
590 let input_ts = proc
591 .input
592 .as_ref()
593 .map(rust_type_to_ts)
594 .unwrap_or_else(|| "void".to_string());
595 let output_ts = proc
596 .output
597 .as_ref()
598 .map(rust_type_to_ts)
599 .unwrap_or_else(|| "void".to_string());
600 emit!(
601 out,
602 " mutate(key: \"{}\", input: {}): Promise<{}>;",
603 proc.name,
604 input_ts,
605 output_ts,
606 );
607 emit!(
608 out,
609 " mutate(key: \"{}\", input: {}, options: CallOptions): Promise<{}>;",
610 proc.name,
611 input_ts,
612 output_ts,
613 );
614 }
615}