Skip to main content

vercel_rpc_cli/codegen/
client.rs

1use super::common::{GENERATED_HEADER, is_void_input};
2use super::typescript::{emit_jsdoc, rust_type_to_ts};
3use crate::model::{Manifest, ProcedureKind};
4
5/// Standard RPC error class with status code and structured error data.
6const ERROR_CLASS: &str = r#"export class RpcError extends Error {
7  readonly status: number;
8  readonly data: unknown;
9
10  constructor(status: number, message: string, data?: unknown) {
11    super(message);
12    this.name = "RpcError";
13    this.status = status;
14    this.data = data;
15  }
16}"#;
17
18/// Context passed to the `onRequest` lifecycle hook.
19const REQUEST_CONTEXT_INTERFACE: &str = r#"export interface RequestContext {
20  procedure: string;
21  method: "GET" | "POST";
22  url: string;
23  headers: Record<string, string>;
24  input?: unknown;
25}"#;
26
27/// Context passed to the `onResponse` lifecycle hook.
28const RESPONSE_CONTEXT_INTERFACE: &str = r#"export interface ResponseContext {
29  procedure: string;
30  method: "GET" | "POST";
31  url: string;
32  response: Response;
33  data: unknown;
34  duration: number;
35}"#;
36
37/// Context passed to the `onError` lifecycle hook.
38const ERROR_CONTEXT_INTERFACE: &str = r#"export interface ErrorContext {
39  procedure: string;
40  method: "GET" | "POST";
41  url: string;
42  error: unknown;
43  attempt: number;
44  willRetry: boolean;
45}"#;
46
47/// Retry policy configuration.
48const RETRY_POLICY_INTERFACE: &str = r#"export interface RetryPolicy {
49  attempts: number;
50  delay: number | ((attempt: number) => number);
51  retryOn?: number[];
52}"#;
53
54/// Configuration interface for the RPC client.
55const CONFIG_INTERFACE: &str = r#"export interface RpcClientConfig {
56  baseUrl: string;
57  fetch?: typeof globalThis.fetch;
58  headers?:
59    | Record<string, string>
60    | (() => Record<string, string> | Promise<Record<string, string>>);
61  onRequest?: (ctx: RequestContext) => void | Promise<void>;
62  onResponse?: (ctx: ResponseContext) => void | Promise<void>;
63  onError?: (ctx: ErrorContext) => void | Promise<void>;
64  retry?: RetryPolicy;
65  timeout?: number;
66  serialize?: (input: unknown) => string;
67  deserialize?: (text: string) => unknown;
68  // AbortSignal for cancelling all requests made by this client.
69  signal?: AbortSignal;
70  dedupe?: boolean;
71}"#;
72
73/// Per-call options that override client-level defaults for a single request.
74const CALL_OPTIONS_INTERFACE: &str = r#"export interface CallOptions {
75  headers?: Record<string, string>;
76  timeout?: number;
77  signal?: AbortSignal;
78  dedupe?: boolean;
79}"#;
80
81/// Computes a dedup map key from procedure name and serialized input.
82const DEDUP_KEY_FN: &str = r#"function dedupKey(procedure: string, input: unknown, config: RpcClientConfig): string {
83  const serialized = input === undefined
84    ? ""
85    : config.serialize
86      ? config.serialize(input)
87      : JSON.stringify(input);
88  return procedure + ":" + serialized;
89}"#;
90
91/// Wraps a shared promise so that a per-caller AbortSignal can reject independently.
92const WRAP_WITH_SIGNAL_FN: &str = r#"function wrapWithSignal<T>(promise: Promise<T>, signal?: AbortSignal): Promise<T> {
93  if (!signal) return promise;
94  if (signal.aborted) return Promise.reject(signal.reason);
95  return new Promise<T>((resolve, reject) => {
96    const onAbort = () => reject(signal.reason);
97    signal.addEventListener("abort", onAbort, { once: true });
98    promise.then(
99      (value) => { signal.removeEventListener("abort", onAbort); resolve(value); },
100      (error) => { signal.removeEventListener("abort", onAbort); reject(error); },
101    );
102  });
103}"#;
104
105/// Internal fetch helper shared by query and mutate methods.
106const FETCH_HELPER: &str = r#"const DEFAULT_RETRY_ON = [408, 429, 500, 502, 503, 504];
107
108async function rpcFetch(
109  config: RpcClientConfig,
110  method: "GET" | "POST",
111  procedure: string,
112  input?: unknown,
113  callOptions?: CallOptions,
114): Promise<unknown> {
115  let url = `${config.baseUrl}/${procedure}`;
116  const customHeaders = typeof config.headers === "function"
117    ? await config.headers()
118    : config.headers;
119  const baseHeaders: Record<string, string> = { ...customHeaders, ...callOptions?.headers };
120
121  if (method === "GET" && input !== undefined) {
122    const serialized = config.serialize ? config.serialize(input) : JSON.stringify(input);
123    url += `?input=${encodeURIComponent(serialized)}`;
124  } else if (method === "POST" && input !== undefined) {
125    baseHeaders["Content-Type"] = "application/json";
126  }
127
128  const fetchFn = config.fetch ?? globalThis.fetch;
129  const maxAttempts = 1 + (config.retry?.attempts ?? 0);
130  const retryOn = config.retry?.retryOn ?? DEFAULT_RETRY_ON;
131  const effectiveTimeout = callOptions?.timeout ?? config.timeout;
132  const start = Date.now();
133
134  for (let attempt = 1; attempt <= maxAttempts; attempt++) {
135    const reqCtx: RequestContext = { procedure, method, url, headers: { ...baseHeaders }, input };
136    await config.onRequest?.(reqCtx);
137
138    const init: RequestInit = { method, headers: reqCtx.headers };
139    if (method === "POST" && input !== undefined) {
140      init.body = config.serialize ? config.serialize(input) : JSON.stringify(input);
141    }
142
143    let timeoutId: ReturnType<typeof setTimeout> | undefined;
144    const signals: AbortSignal[] = [];
145    if (config.signal) signals.push(config.signal);
146    if (callOptions?.signal) signals.push(callOptions.signal);
147    if (effectiveTimeout) {
148      const controller = new AbortController();
149      timeoutId = setTimeout(() => controller.abort(), effectiveTimeout);
150      signals.push(controller.signal);
151    }
152    if (signals.length > 0) {
153      init.signal = signals.length === 1 ? signals[0] : AbortSignal.any(signals);
154    }
155
156    try {
157      const res = await fetchFn(url, init);
158
159      if (!res.ok) {
160        let data: unknown;
161        try {
162          data = await res.json();
163        } catch {
164          data = await res.text().catch(() => null);
165        }
166        const rpcError = new RpcError(
167          res.status,
168          `RPC error on "${procedure}": ${res.status} ${res.statusText}`,
169          data,
170        );
171        const canRetry = retryOn.includes(res.status) && attempt < maxAttempts;
172        await config.onError?.({ procedure, method, url, error: rpcError, attempt, willRetry: canRetry });
173        if (!canRetry) throw rpcError;
174      } else {
175        const json = config.deserialize ? config.deserialize(await res.text()) : await res.json();
176        const result = json?.result?.data ?? json;
177        const duration = Date.now() - start;
178        await config.onResponse?.({ procedure, method, url, response: res, data: result, duration });
179        return result;
180      }
181    } catch (err) {
182      if (err instanceof RpcError) throw err;
183      const willRetry = attempt < maxAttempts;
184      await config.onError?.({ procedure, method, url, error: err, attempt, willRetry });
185      if (!willRetry) throw err;
186    } finally {
187      if (timeoutId !== undefined) clearTimeout(timeoutId);
188    }
189
190    if (config.retry) {
191      const d = typeof config.retry.delay === "function"
192        ? config.retry.delay(attempt) : config.retry.delay;
193      await new Promise(r => setTimeout(r, d));
194    }
195  }
196}"#;
197
198/// Generates the complete `rpc-client.ts` file content from a manifest.
199///
200/// The output includes:
201/// 1. Auto-generation header
202/// 2. Re-export of `Procedures` type from the types file
203/// 3. `RpcError` class for structured error handling
204/// 4. Internal `rpcFetch` helper
205/// 5. `createRpcClient` factory function with fully typed `query` / `mutate` methods
206pub fn generate_client_file(
207    manifest: &Manifest,
208    types_import_path: &str,
209    preserve_docs: bool,
210) -> String {
211    let mut out = String::with_capacity(2048);
212
213    // Header
214    out.push_str(GENERATED_HEADER);
215    out.push('\n');
216
217    // Collect all user-defined type names (structs + enums) for import
218    let type_names: Vec<&str> = manifest
219        .structs
220        .iter()
221        .map(|s| s.name.as_str())
222        .chain(manifest.enums.iter().map(|e| e.name.as_str()))
223        .collect();
224
225    // Import Procedures type (and any referenced types) from the types file
226    if type_names.is_empty() {
227        emit!(
228            out,
229            "import type {{ Procedures }} from \"{types_import_path}\";\n"
230        );
231        emit!(out, "export type {{ Procedures }};\n");
232    } else {
233        let types_csv = type_names.join(", ");
234        emit!(
235            out,
236            "import type {{ Procedures, {types_csv} }} from \"{types_import_path}\";\n"
237        );
238        emit!(out, "export type {{ Procedures, {types_csv} }};\n");
239    }
240
241    // Error class
242    emit!(out, "{ERROR_CLASS}\n");
243
244    // Lifecycle hook context interfaces
245    emit!(out, "{REQUEST_CONTEXT_INTERFACE}\n");
246    emit!(out, "{RESPONSE_CONTEXT_INTERFACE}\n");
247    emit!(out, "{ERROR_CONTEXT_INTERFACE}\n");
248
249    // Retry policy interface
250    emit!(out, "{RETRY_POLICY_INTERFACE}\n");
251
252    // Client config interface
253    emit!(out, "{CONFIG_INTERFACE}\n");
254
255    // Per-call options interface
256    emit!(out, "{CALL_OPTIONS_INTERFACE}\n");
257
258    // Internal fetch helper
259    emit!(out, "{FETCH_HELPER}\n");
260
261    // Dedup helpers (only when the manifest has queries)
262    let has_queries = manifest
263        .procedures
264        .iter()
265        .any(|p| p.kind == ProcedureKind::Query);
266    if has_queries {
267        emit!(out, "{DEDUP_KEY_FN}\n");
268        emit!(out, "{WRAP_WITH_SIGNAL_FN}\n");
269    }
270
271    // Type helpers for ergonomic API
272    generate_type_helpers(&mut out);
273    out.push('\n');
274
275    // Client factory
276    generate_client_factory(manifest, preserve_docs, &mut out);
277
278    out
279}
280
281/// Emits utility types that power the typed client API.
282fn generate_type_helpers(out: &mut String) {
283    emit!(out, "type QueryKey = keyof Procedures[\"queries\"];");
284    emit!(out, "type MutationKey = keyof Procedures[\"mutations\"];");
285    emit!(
286        out,
287        "type QueryInput<K extends QueryKey> = Procedures[\"queries\"][K][\"input\"];"
288    );
289    emit!(
290        out,
291        "type QueryOutput<K extends QueryKey> = Procedures[\"queries\"][K][\"output\"];"
292    );
293    emit!(
294        out,
295        "type MutationInput<K extends MutationKey> = Procedures[\"mutations\"][K][\"input\"];"
296    );
297    emit!(
298        out,
299        "type MutationOutput<K extends MutationKey> = Procedures[\"mutations\"][K][\"output\"];"
300    );
301}
302
303/// Generates the `createRpcClient` factory using an interface for typed overloads.
304fn generate_client_factory(manifest: &Manifest, preserve_docs: bool, out: &mut String) {
305    let queries: Vec<_> = manifest
306        .procedures
307        .iter()
308        .filter(|p| p.kind == ProcedureKind::Query)
309        .collect();
310    let mutations: Vec<_> = manifest
311        .procedures
312        .iter()
313        .filter(|p| p.kind == ProcedureKind::Mutation)
314        .collect();
315    let has_queries = !queries.is_empty();
316    let has_mutations = !mutations.is_empty();
317
318    // Partition queries and mutations by void/non-void input
319    let void_queries: Vec<_> = queries.iter().filter(|p| is_void_input(p)).collect();
320    let non_void_queries: Vec<_> = queries.iter().filter(|p| !is_void_input(p)).collect();
321    let void_mutations: Vec<_> = mutations.iter().filter(|p| is_void_input(p)).collect();
322    let non_void_mutations: Vec<_> = mutations.iter().filter(|p| !is_void_input(p)).collect();
323
324    let query_mixed = !void_queries.is_empty() && !non_void_queries.is_empty();
325    let mutation_mixed = !void_mutations.is_empty() && !non_void_mutations.is_empty();
326
327    // Emit VOID_QUERIES/VOID_MUTATIONS sets when mixed void/non-void exists
328    if query_mixed {
329        let names: Vec<_> = void_queries
330            .iter()
331            .map(|p| format!("\"{}\"", p.name))
332            .collect();
333        emit!(
334            out,
335            "const VOID_QUERIES: Set<string> = new Set([{}]);",
336            names.join(", ")
337        );
338        out.push('\n');
339    }
340    if mutation_mixed {
341        let names: Vec<_> = void_mutations
342            .iter()
343            .map(|p| format!("\"{}\"", p.name))
344            .collect();
345        emit!(
346            out,
347            "const VOID_MUTATIONS: Set<string> = new Set([{}]);",
348            names.join(", ")
349        );
350        out.push('\n');
351    }
352
353    // Emit the RpcClient interface with overloaded method signatures
354    emit!(out, "export interface RpcClient {{");
355
356    if has_queries {
357        generate_query_overloads(manifest, preserve_docs, out);
358    }
359
360    if has_mutations {
361        if has_queries {
362            out.push('\n');
363        }
364        generate_mutation_overloads(manifest, preserve_docs, out);
365    }
366
367    emit!(out, "}}");
368    out.push('\n');
369
370    // Emit the factory function
371    emit!(
372        out,
373        "export function createRpcClient(config: RpcClientConfig): RpcClient {{"
374    );
375
376    if has_queries {
377        emit!(
378            out,
379            "  const inflight = new Map<string, Promise<unknown>>();\n"
380        );
381    }
382
383    emit!(out, "  return {{");
384
385    if has_queries {
386        emit!(
387            out,
388            "    query(key: QueryKey, ...args: unknown[]): Promise<unknown> {{"
389        );
390
391        // Extract input and callOptions into locals based on void/non-void branching
392        if query_mixed {
393            emit!(out, "      let input: unknown;");
394            emit!(out, "      let callOptions: CallOptions | undefined;");
395            emit!(out, "      if (VOID_QUERIES.has(key)) {{");
396            emit!(out, "        input = undefined;");
397            emit!(
398                out,
399                "        callOptions = args[0] as CallOptions | undefined;"
400            );
401            emit!(out, "      }} else {{");
402            emit!(out, "        input = args[0];");
403            emit!(
404                out,
405                "        callOptions = args[1] as CallOptions | undefined;"
406            );
407            emit!(out, "      }}");
408        } else if !void_queries.is_empty() {
409            emit!(out, "      const input = undefined;");
410            emit!(
411                out,
412                "      const callOptions = args[0] as CallOptions | undefined;"
413            );
414        } else {
415            emit!(out, "      const input = args[0];");
416            emit!(
417                out,
418                "      const callOptions = args[1] as CallOptions | undefined;"
419            );
420        }
421
422        // Dedup logic
423        emit!(
424            out,
425            "      const shouldDedupe = callOptions?.dedupe ?? config.dedupe ?? true;"
426        );
427        emit!(out, "      if (shouldDedupe) {{");
428        emit!(out, "        const k = dedupKey(key, input, config);");
429        emit!(out, "        const existing = inflight.get(k);");
430        emit!(
431            out,
432            "        if (existing) return wrapWithSignal(existing, callOptions?.signal);"
433        );
434        emit!(
435            out,
436            "        const promise = rpcFetch(config, \"GET\", key, input, callOptions)"
437        );
438        emit!(out, "          .finally(() => inflight.delete(k));");
439        emit!(out, "        inflight.set(k, promise);");
440        emit!(
441            out,
442            "        return wrapWithSignal(promise, callOptions?.signal);"
443        );
444        emit!(out, "      }}");
445        emit!(
446            out,
447            "      return rpcFetch(config, \"GET\", key, input, callOptions);"
448        );
449        emit!(out, "    }},");
450    }
451
452    if has_mutations {
453        emit!(
454            out,
455            "    mutate(key: MutationKey, ...args: unknown[]): Promise<unknown> {{"
456        );
457        if mutation_mixed {
458            // Mixed: use VOID_MUTATIONS set to branch at runtime
459            emit!(out, "      if (VOID_MUTATIONS.has(key)) {{");
460            emit!(
461                out,
462                "        return rpcFetch(config, \"POST\", key, undefined, args[0] as CallOptions | undefined);"
463            );
464            emit!(out, "      }}");
465            emit!(
466                out,
467                "      return rpcFetch(config, \"POST\", key, args[0], args[1] as CallOptions | undefined);"
468            );
469        } else if !void_mutations.is_empty() {
470            // All void: args[0] is always CallOptions
471            emit!(
472                out,
473                "      return rpcFetch(config, \"POST\", key, undefined, args[0] as CallOptions | undefined);"
474            );
475        } else {
476            // All non-void: args[0] is input, args[1] is CallOptions
477            emit!(
478                out,
479                "      return rpcFetch(config, \"POST\", key, args[0], args[1] as CallOptions | undefined);"
480            );
481        }
482        emit!(out, "    }},");
483    }
484
485    emit!(out, "  }} as RpcClient;");
486    emit!(out, "}}");
487}
488
489/// Generates query overload signatures for the RpcClient interface.
490fn generate_query_overloads(manifest: &Manifest, preserve_docs: bool, out: &mut String) {
491    let (void_queries, non_void_queries): (Vec<_>, Vec<_>) = manifest
492        .procedures
493        .iter()
494        .filter(|p| p.kind == ProcedureKind::Query)
495        .partition(|p| is_void_input(p));
496
497    // Overload signatures for void-input queries (no input argument required)
498    for proc in &void_queries {
499        if preserve_docs && let Some(doc) = &proc.docs {
500            emit_jsdoc(doc, "  ", out);
501        }
502        let output_ts = proc
503            .output
504            .as_ref()
505            .map(rust_type_to_ts)
506            .unwrap_or_else(|| "void".to_string());
507        emit!(
508            out,
509            "  query(key: \"{}\"): Promise<{}>;",
510            proc.name,
511            output_ts,
512        );
513        emit!(
514            out,
515            "  query(key: \"{}\", options: CallOptions): Promise<{}>;",
516            proc.name,
517            output_ts,
518        );
519    }
520
521    // Overload signatures for non-void-input queries
522    for proc in &non_void_queries {
523        if preserve_docs && let Some(doc) = &proc.docs {
524            emit_jsdoc(doc, "  ", out);
525        }
526        let input_ts = proc
527            .input
528            .as_ref()
529            .map(rust_type_to_ts)
530            .unwrap_or_else(|| "void".to_string());
531        let output_ts = proc
532            .output
533            .as_ref()
534            .map(rust_type_to_ts)
535            .unwrap_or_else(|| "void".to_string());
536        emit!(
537            out,
538            "  query(key: \"{}\", input: {}): Promise<{}>;",
539            proc.name,
540            input_ts,
541            output_ts,
542        );
543        emit!(
544            out,
545            "  query(key: \"{}\", input: {}, options: CallOptions): Promise<{}>;",
546            proc.name,
547            input_ts,
548            output_ts,
549        );
550    }
551}
552
553/// Generates mutation overload signatures for the RpcClient interface.
554fn generate_mutation_overloads(manifest: &Manifest, preserve_docs: bool, out: &mut String) {
555    let (void_mutations, non_void_mutations): (Vec<_>, Vec<_>) = manifest
556        .procedures
557        .iter()
558        .filter(|p| p.kind == ProcedureKind::Mutation)
559        .partition(|p| is_void_input(p));
560
561    // Overload signatures for void-input mutations
562    for proc in &void_mutations {
563        if preserve_docs && let Some(doc) = &proc.docs {
564            emit_jsdoc(doc, "  ", out);
565        }
566        let output_ts = proc
567            .output
568            .as_ref()
569            .map(rust_type_to_ts)
570            .unwrap_or_else(|| "void".to_string());
571        emit!(
572            out,
573            "  mutate(key: \"{}\"): Promise<{}>;",
574            proc.name,
575            output_ts,
576        );
577        emit!(
578            out,
579            "  mutate(key: \"{}\", options: CallOptions): Promise<{}>;",
580            proc.name,
581            output_ts,
582        );
583    }
584
585    // Overload signatures for non-void-input mutations
586    for proc in &non_void_mutations {
587        if preserve_docs && let Some(doc) = &proc.docs {
588            emit_jsdoc(doc, "  ", out);
589        }
590        let input_ts = proc
591            .input
592            .as_ref()
593            .map(rust_type_to_ts)
594            .unwrap_or_else(|| "void".to_string());
595        let output_ts = proc
596            .output
597            .as_ref()
598            .map(rust_type_to_ts)
599            .unwrap_or_else(|| "void".to_string());
600        emit!(
601            out,
602            "  mutate(key: \"{}\", input: {}): Promise<{}>;",
603            proc.name,
604            input_ts,
605            output_ts,
606        );
607        emit!(
608            out,
609            "  mutate(key: \"{}\", input: {}, options: CallOptions): Promise<{}>;",
610            proc.name,
611            input_ts,
612            output_ts,
613        );
614    }
615}