spacetimedb_cli/subcommands/generate/
typescript.rs

1use crate::generate::util::{is_reducer_invokable, iter_reducers, iter_tables, iter_types, iter_unique_cols};
2use crate::indent_scope;
3
4use super::util::{collect_case, print_auto_generated_file_comment, type_ref_name};
5
6use std::collections::BTreeSet;
7use std::fmt::{self, Write};
8use std::ops::Deref;
9
10use convert_case::{Case, Casing};
11use spacetimedb_lib::sats::AlgebraicTypeRef;
12use spacetimedb_schema::def::{ModuleDef, ReducerDef, ScopedTypeName, TableDef, TypeDef};
13use spacetimedb_schema::identifier::Identifier;
14use spacetimedb_schema::schema::{Schema, TableSchema};
15use spacetimedb_schema::type_for_generate::{AlgebraicTypeDef, AlgebraicTypeUse, PrimitiveType};
16
17use super::code_indenter::{CodeIndenter, Indenter};
18use super::Lang;
19use std::path::PathBuf;
20
21type Imports = BTreeSet<AlgebraicTypeRef>;
22
23const INDENT: &str = "  ";
24
25pub struct TypeScript;
26
27impl Lang for TypeScript {
28    fn table_filename(
29        &self,
30        _module: &spacetimedb_schema::def::ModuleDef,
31        table: &spacetimedb_schema::def::TableDef,
32    ) -> String {
33        table_module_name(&table.name) + ".ts"
34    }
35
36    fn type_filename(&self, type_name: &ScopedTypeName) -> String {
37        type_module_name(type_name) + ".ts"
38    }
39
40    fn reducer_filename(&self, reducer_name: &Identifier) -> String {
41        reducer_module_name(reducer_name) + ".ts"
42    }
43
44    fn format_files(&self, _generated_files: BTreeSet<PathBuf>) -> anyhow::Result<()> {
45        // TODO: implement formatting.
46        Ok(())
47    }
48
49    fn generate_type(&self, module: &ModuleDef, typ: &TypeDef) -> String {
50        // TODO(cloutiertyler): I do think TypeScript does support namespaces:
51        // https://www.typescriptlang.org/docs/handbook/namespaces.html
52        let type_name = collect_case(Case::Pascal, typ.name.name_segments());
53
54        let mut output = CodeIndenter::new(String::new(), INDENT);
55        let out = &mut output;
56
57        print_file_header(out);
58
59        match &module.typespace_for_generate()[typ.ty] {
60            AlgebraicTypeDef::Product(product) => {
61                gen_and_print_imports(module, out, &product.elements, &[typ.ty]);
62                define_namespace_and_object_type_for_product(module, out, &type_name, &product.elements);
63            }
64            AlgebraicTypeDef::Sum(sum) => {
65                gen_and_print_imports(module, out, &sum.variants, &[typ.ty]);
66                define_namespace_and_types_for_sum(module, out, &type_name, &sum.variants);
67            }
68            AlgebraicTypeDef::PlainEnum(plain_enum) => {
69                let variants = plain_enum
70                    .variants
71                    .iter()
72                    .cloned()
73                    .map(|var| (var, AlgebraicTypeUse::Unit))
74                    .collect::<Vec<_>>();
75                define_namespace_and_types_for_sum(module, out, &type_name, &variants);
76            }
77        }
78        out.newline();
79
80        output.into_inner()
81    }
82
83    fn generate_table(&self, module: &ModuleDef, table: &TableDef) -> String {
84        let schema = TableSchema::from_module_def(module, table, (), 0.into())
85            .validated()
86            .expect("Failed to generate table due to validation errors");
87
88        let mut output = CodeIndenter::new(String::new(), INDENT);
89        let out = &mut output;
90
91        print_file_header(out);
92
93        let type_ref = table.product_type_ref;
94        let row_type = type_ref_name(module, type_ref);
95        let row_type_module = type_ref_module_name(module, type_ref);
96
97        writeln!(out, "import {{ {row_type} }} from \"./{row_type_module}\";");
98
99        let product_def = module.typespace_for_generate()[type_ref].as_product().unwrap();
100
101        // Import the types of all fields.
102        // We only need to import fields which have indices or unique constraints,
103        // but it's easier to just import all of 'em, since we have `// @ts-nocheck` anyway.
104        gen_and_print_imports(
105            module,
106            out,
107            &product_def.elements,
108            &[], // No need to skip any imports; we're not defining a type, so there's no chance of circular imports.
109        );
110
111        writeln!(
112            out,
113            "import {{ EventContext, Reducer, RemoteReducers, RemoteTables }} from \".\";"
114        );
115
116        let table_name = table.name.deref();
117        let table_name_pascalcase = table.name.deref().to_case(Case::Pascal);
118        let table_handle = table_name_pascalcase.clone() + "TableHandle";
119        let accessor_method = table_method_name(&table.name);
120
121        writeln!(out);
122
123        write!(
124            out,
125            "/**
126 * Table handle for the table `{table_name}`.
127 *
128 * Obtain a handle from the [`{accessor_method}`] property on [`RemoteTables`],
129 * like `ctx.db.{accessor_method}`.
130 *
131 * Users are encouraged not to explicitly reference this type,
132 * but to directly chain method calls,
133 * like `ctx.db.{accessor_method}.on_insert(...)`.
134 */
135export class {table_handle} {{
136"
137        );
138        out.indent(1);
139        writeln!(out, "tableCache: TableCache<{row_type}>;");
140        writeln!(out);
141        writeln!(out, "constructor(tableCache: TableCache<{row_type}>) {{");
142        out.with_indent(|out| writeln!(out, "this.tableCache = tableCache;"));
143        writeln!(out, "}}");
144        writeln!(out);
145        writeln!(out, "count(): number {{");
146        out.with_indent(|out| {
147            writeln!(out, "return this.tableCache.count();");
148        });
149        writeln!(out, "}}");
150        writeln!(out);
151        writeln!(out, "iter(): Iterable<{row_type}> {{");
152        out.with_indent(|out| {
153            writeln!(out, "return this.tableCache.iter();");
154        });
155        writeln!(out, "}}");
156
157        for (unique_field_ident, unique_field_type_use) in
158            iter_unique_cols(module.typespace_for_generate(), &schema, product_def)
159        {
160            let unique_field_name = unique_field_ident.deref().to_case(Case::Camel);
161            let unique_field_name_pascalcase = unique_field_name.to_case(Case::Pascal);
162
163            let unique_constraint = table_name_pascalcase.clone() + &unique_field_name_pascalcase + "Unique";
164            let unique_field_type = type_name(module, unique_field_type_use);
165
166            writeln!(
167                out,
168                "/**
169 * Access to the `{unique_field_name}` unique index on the table `{table_name}`,
170 * which allows point queries on the field of the same name
171 * via the [`{unique_constraint}.find`] method.
172 *
173 * Users are encouraged not to explicitly reference this type,
174 * but to directly chain method calls,
175 * like `ctx.db.{accessor_method}.{unique_field_name}().find(...)`.
176 *
177 * Get a handle on the `{unique_field_name}` unique index on the table `{table_name}`.
178 */"
179            );
180            writeln!(out, "{unique_field_name} = {{");
181            out.with_indent(|out| {
182                writeln!(
183                    out,
184                    "// Find the subscribed row whose `{unique_field_name}` column value is equal to `col_val`,"
185                );
186                writeln!(out, "// if such a row is present in the client cache.");
187                writeln!(
188                    out,
189                    "find: (col_val: {unique_field_type}): {row_type} | undefined => {{"
190                );
191                out.with_indent(|out| {
192                    writeln!(out, "for (let row of this.tableCache.iter()) {{");
193                    out.with_indent(|out| {
194                        writeln!(out, "if (deepEqual(row.{unique_field_name}, col_val)) {{");
195                        out.with_indent(|out| {
196                            writeln!(out, "return row;");
197                        });
198                        writeln!(out, "}}");
199                    });
200                    writeln!(out, "}}");
201                });
202                writeln!(out, "}},");
203            });
204            writeln!(out, "}};");
205        }
206
207        writeln!(out);
208
209        // TODO: expose non-unique indices.
210
211        writeln!(
212            out,
213            "onInsert = (cb: (ctx: EventContext, row: {row_type}) => void) => {{
214{INDENT}return this.tableCache.onInsert(cb);
215}}
216
217removeOnInsert = (cb: (ctx: EventContext, row: {row_type}) => void) => {{
218{INDENT}return this.tableCache.removeOnInsert(cb);
219}}
220
221onDelete = (cb: (ctx: EventContext, row: {row_type}) => void) => {{
222{INDENT}return this.tableCache.onDelete(cb);
223}}
224
225removeOnDelete = (cb: (ctx: EventContext, row: {row_type}) => void) => {{
226{INDENT}return this.tableCache.removeOnDelete(cb);
227}}"
228        );
229
230        if schema.pk().is_some() {
231            write!(
232                out,
233                "
234// Updates are only defined for tables with primary keys.
235onUpdate = (cb: (ctx: EventContext, oldRow: {row_type}, newRow: {row_type}) => void) => {{
236{INDENT}return this.tableCache.onUpdate(cb);
237}}
238
239removeOnUpdate = (cb: (ctx: EventContext, onRow: {row_type}, newRow: {row_type}) => void) => {{
240{INDENT}return this.tableCache.removeOnUpdate(cb);
241}}"
242            );
243        }
244        out.dedent(1);
245
246        writeln!(out, "}}");
247        output.into_inner()
248    }
249
250    fn generate_reducer(&self, module: &ModuleDef, reducer: &ReducerDef) -> String {
251        let mut output = CodeIndenter::new(String::new(), INDENT);
252        let out = &mut output;
253
254        print_file_header(out);
255
256        out.newline();
257
258        gen_and_print_imports(
259            module,
260            out,
261            &reducer.params_for_generate.elements,
262            // No need to skip any imports; we're not emitting a type that other modules can import.
263            &[],
264        );
265
266        let args_type = reducer_args_type_name(&reducer.name);
267
268        define_namespace_and_object_type_for_product(module, out, &args_type, &reducer.params_for_generate.elements);
269
270        output.into_inner()
271    }
272
273    fn generate_globals(&self, module: &ModuleDef) -> Vec<(String, String)> {
274        let mut output = CodeIndenter::new(String::new(), INDENT);
275        let out = &mut output;
276
277        print_file_header(out);
278
279        out.newline();
280
281        writeln!(out, "// Import and reexport all reducer arg types");
282        for reducer in iter_reducers(module) {
283            let reducer_name = &reducer.name;
284            let reducer_module_name = reducer_module_name(reducer_name) + ".ts";
285            let args_type = reducer_args_type_name(&reducer.name);
286            writeln!(out, "import {{ {args_type} }} from \"./{reducer_module_name}\";");
287            writeln!(out, "export {{ {args_type} }};");
288        }
289
290        writeln!(out);
291        writeln!(out, "// Import and reexport all table handle types");
292        for table in iter_tables(module) {
293            let table_name = &table.name;
294            let table_module_name = table_module_name(table_name) + ".ts";
295            let table_name_pascalcase = table.name.deref().to_case(Case::Pascal);
296            let table_handle = table_name_pascalcase.clone() + "TableHandle";
297            writeln!(out, "import {{ {table_handle} }} from \"./{table_module_name}\";");
298            writeln!(out, "export {{ {table_handle} }};");
299        }
300
301        writeln!(out);
302        writeln!(out, "// Import and reexport all types");
303        for ty in iter_types(module) {
304            let type_name = collect_case(Case::Pascal, ty.name.name_segments());
305            let type_module_name = type_module_name(&ty.name) + ".ts";
306            writeln!(out, "import {{ {type_name} }} from \"./{type_module_name}\";");
307            writeln!(out, "export {{ {type_name} }};");
308        }
309
310        out.newline();
311
312        // Define SpacetimeModule
313        writeln!(out, "const REMOTE_MODULE = {{");
314        out.indent(1);
315        writeln!(out, "tables: {{");
316        out.indent(1);
317        for table in iter_tables(module) {
318            let type_ref = table.product_type_ref;
319            let row_type = type_ref_name(module, type_ref);
320            let schema = TableSchema::from_module_def(module, table, (), 0.into())
321                .validated()
322                .expect("Failed to generate table due to validation errors");
323            writeln!(out, "{}: {{", table.name);
324            out.indent(1);
325            writeln!(out, "tableName: \"{}\",", table.name);
326            writeln!(out, "rowType: {row_type}.getTypeScriptAlgebraicType(),");
327            if let Some(pk) = schema.pk() {
328                writeln!(out, "primaryKey: \"{}\",", pk.col_name.to_string().to_case(Case::Camel));
329            }
330            out.dedent(1);
331            writeln!(out, "}},");
332        }
333        out.dedent(1);
334        writeln!(out, "}},");
335        writeln!(out, "reducers: {{");
336        out.indent(1);
337        for reducer in iter_reducers(module) {
338            writeln!(out, "{}: {{", reducer.name);
339            out.indent(1);
340            writeln!(out, "reducerName: \"{}\",", reducer.name);
341            writeln!(
342                out,
343                "argsType: {args_type}.getTypeScriptAlgebraicType(),",
344                args_type = reducer_args_type_name(&reducer.name)
345            );
346            out.dedent(1);
347            writeln!(out, "}},");
348        }
349        out.dedent(1);
350        writeln!(out, "}},");
351        writeln!(
352            out,
353            "// Constructors which are used by the DbConnectionImpl to
354// extract type information from the generated RemoteModule.
355//
356// NOTE: This is not strictly necessary for `eventContextConstructor` because
357// all we do is build a TypeScript object which we could have done inside the
358// SDK, but if in the future we wanted to create a class this would be
359// necessary because classes have methods, so we'll keep it.
360eventContextConstructor: (imp: DbConnectionImpl, event: Event<Reducer>) => {{
361  return {{
362    ...(imp as DbConnection),
363    event
364  }}
365}},
366dbViewConstructor: (imp: DbConnectionImpl) => {{
367  return new RemoteTables(imp);
368}},
369reducersConstructor: (imp: DbConnectionImpl, setReducerFlags: SetReducerFlags) => {{
370  return new RemoteReducers(imp, setReducerFlags);
371}},
372setReducerFlagsConstructor: () => {{
373  return new SetReducerFlags();
374}}"
375        );
376        out.dedent(1);
377        writeln!(out, "}}");
378
379        // Define `type Reducer` enum.
380        writeln!(out);
381        print_reducer_enum_defn(module, out);
382
383        out.newline();
384
385        print_remote_reducers(module, out);
386
387        out.newline();
388
389        print_set_reducer_flags(module, out);
390
391        out.newline();
392
393        print_remote_tables(module, out);
394
395        out.newline();
396
397        print_subscription_builder(module, out);
398
399        out.newline();
400
401        print_db_connection(module, out);
402
403        out.newline();
404
405        writeln!(
406            out,
407            "export type EventContext = EventContextInterface<RemoteTables, RemoteReducers, SetReducerFlags, Reducer>;"
408        );
409
410        writeln!(
411            out,
412            "export type ReducerEventContext = ReducerEventContextInterface<RemoteTables, RemoteReducers, SetReducerFlags, Reducer>;"
413        );
414
415        writeln!(
416            out,
417            "export type SubscriptionEventContext = SubscriptionEventContextInterface<RemoteTables, RemoteReducers, SetReducerFlags>;"
418        );
419
420        writeln!(
421            out,
422            "export type ErrorContext = ErrorContextInterface<RemoteTables, RemoteReducers, SetReducerFlags>;"
423        );
424
425        vec![("index.ts".to_string(), (output.into_inner()))]
426    }
427
428    fn clap_value() -> clap::builder::PossibleValue {
429        clap::builder::PossibleValue::new("typescript").aliases(["ts", "TS"])
430    }
431}
432
433fn print_remote_reducers(module: &ModuleDef, out: &mut Indenter) {
434    writeln!(out, "export class RemoteReducers {{");
435    out.indent(1);
436    writeln!(
437        out,
438        "constructor(private connection: DbConnectionImpl, private setCallReducerFlags: SetReducerFlags) {{}}"
439    );
440    out.newline();
441
442    for reducer in iter_reducers(module) {
443        // The reducer argument names and types as `ident: ty, ident: ty, ident: ty`,
444        // and the argument names as `ident, ident, ident`
445        // for passing to function call and struct literal expressions.
446        let mut arg_list = "".to_string();
447        let mut arg_name_list = "".to_string();
448        for (arg_ident, arg_ty) in &reducer.params_for_generate.elements[..] {
449            let arg_name = arg_ident.deref().to_case(Case::Camel);
450            arg_name_list += &arg_name;
451            arg_list += &arg_name;
452            arg_list += ": ";
453            write_type(module, &mut arg_list, arg_ty, None).unwrap();
454            arg_list += ", ";
455            arg_name_list += ", ";
456        }
457        let arg_list = arg_list.trim_end_matches(", ");
458        let arg_name_list = arg_name_list.trim_end_matches(", ");
459
460        let reducer_name = &reducer.name;
461
462        if is_reducer_invokable(reducer) {
463            let reducer_function_name = reducer_function_name(reducer);
464            let reducer_variant = reducer_variant_name(&reducer.name);
465            if reducer.params_for_generate.elements.is_empty() {
466                writeln!(out, "{reducer_function_name}() {{");
467                out.with_indent(|out| {
468                    writeln!(
469                        out,
470                        "this.connection.callReducer(\"{reducer_name}\", new Uint8Array(0), this.setCallReducerFlags.{reducer_function_name}Flags);"
471                    );
472                });
473            } else {
474                writeln!(out, "{reducer_function_name}({arg_list}) {{");
475                out.with_indent(|out| {
476                    writeln!(out, "const __args = {{ {arg_name_list} }};");
477                    writeln!(out, "let __writer = new BinaryWriter(1024);");
478                    writeln!(
479                        out,
480                        "{reducer_variant}.getTypeScriptAlgebraicType().serialize(__writer, __args);"
481                    );
482                    writeln!(out, "let __argsBuffer = __writer.getBuffer();");
483                    writeln!(out, "this.connection.callReducer(\"{reducer_name}\", __argsBuffer, this.setCallReducerFlags.{reducer_function_name}Flags);");
484                });
485            }
486            writeln!(out, "}}");
487            out.newline();
488        }
489
490        let arg_list_padded = if arg_list.is_empty() {
491            String::new()
492        } else {
493            format!(", {arg_list}")
494        };
495        let reducer_name_pascal = reducer_name.deref().to_case(Case::Pascal);
496        writeln!(
497            out,
498            "on{reducer_name_pascal}(callback: (ctx: ReducerEventContext{arg_list_padded}) => void) {{"
499        );
500        out.indent(1);
501        writeln!(out, "this.connection.onReducer(\"{reducer_name}\", callback);");
502        out.dedent(1);
503        writeln!(out, "}}");
504        out.newline();
505        writeln!(
506            out,
507            "removeOn{reducer_name_pascal}(callback: (ctx: ReducerEventContext{arg_list_padded}) => void) {{"
508        );
509        out.indent(1);
510        writeln!(out, "this.connection.offReducer(\"{reducer_name}\", callback);");
511        out.dedent(1);
512        writeln!(out, "}}");
513        out.newline();
514    }
515
516    out.dedent(1);
517    writeln!(out, "}}");
518}
519
520fn print_set_reducer_flags(module: &ModuleDef, out: &mut Indenter) {
521    writeln!(out, "export class SetReducerFlags {{");
522    out.indent(1);
523
524    for reducer in iter_reducers(module).filter(|r| is_reducer_invokable(r)) {
525        let reducer_function_name = reducer_function_name(reducer);
526        writeln!(out, "{reducer_function_name}Flags: CallReducerFlags = 'FullUpdate';");
527        writeln!(out, "{reducer_function_name}(flags: CallReducerFlags) {{");
528        out.with_indent(|out| {
529            writeln!(out, "this.{reducer_function_name}Flags = flags;");
530        });
531        writeln!(out, "}}");
532        out.newline();
533    }
534
535    out.dedent(1);
536    writeln!(out, "}}");
537}
538
539fn print_remote_tables(module: &ModuleDef, out: &mut Indenter) {
540    writeln!(out, "export class RemoteTables {{");
541    out.indent(1);
542    writeln!(out, "constructor(private connection: DbConnectionImpl) {{}}");
543
544    for table in iter_tables(module) {
545        writeln!(out);
546        let table_name = table.name.deref();
547        let table_name_pascalcase = table.name.deref().to_case(Case::Pascal);
548        let table_name_camelcase = table.name.deref().to_case(Case::Camel);
549        let table_handle = table_name_pascalcase.clone() + "TableHandle";
550        let type_ref = table.product_type_ref;
551        let row_type = type_ref_name(module, type_ref);
552        writeln!(out, "get {table_name_camelcase}(): {table_handle} {{");
553        out.with_indent(|out| {
554            writeln!(
555                out,
556                "return new {table_handle}(this.connection.clientCache.getOrCreateTable<{row_type}>(REMOTE_MODULE.tables.{table_name}));"
557            );
558        });
559        writeln!(out, "}}");
560    }
561
562    out.dedent(1);
563    writeln!(out, "}}");
564}
565
566fn print_subscription_builder(_module: &ModuleDef, out: &mut Indenter) {
567    writeln!(
568        out,
569        "export class SubscriptionBuilder extends SubscriptionBuilderImpl<RemoteTables, RemoteReducers, SetReducerFlags> {{ }}"
570    );
571}
572
573fn print_db_connection(_module: &ModuleDef, out: &mut Indenter) {
574    writeln!(
575        out,
576        "export class DbConnection extends DbConnectionImpl<RemoteTables, RemoteReducers, SetReducerFlags> {{"
577    );
578    out.indent(1);
579    writeln!(
580        out,
581        "static builder = (): DbConnectionBuilder<DbConnection, ErrorContext, SubscriptionEventContext> => {{"
582    );
583    out.indent(1);
584    writeln!(
585        out,
586        "return new DbConnectionBuilder<DbConnection, ErrorContext, SubscriptionEventContext>(REMOTE_MODULE, (imp: DbConnectionImpl) => imp as DbConnection);"
587    );
588    out.dedent(1);
589    writeln!(out, "}}");
590    writeln!(out, "subscriptionBuilder = (): SubscriptionBuilder => {{");
591    out.indent(1);
592    writeln!(out, "return new SubscriptionBuilder(this);");
593    out.dedent(1);
594    writeln!(out, "}}");
595    out.dedent(1);
596    writeln!(out, "}}");
597}
598
599fn print_reducer_enum_defn(module: &ModuleDef, out: &mut Indenter) {
600    writeln!(out, "// A type representing all the possible variants of a reducer.");
601    writeln!(out, "export type Reducer = never");
602    for reducer in iter_reducers(module) {
603        writeln!(
604            out,
605            "| {{ name: \"{}\", args: {} }}",
606            reducer_variant_name(&reducer.name),
607            reducer_args_type_name(&reducer.name)
608        );
609    }
610    writeln!(out, ";");
611}
612
613fn print_spacetimedb_imports(out: &mut Indenter) {
614    let mut types = [
615        "AlgebraicType",
616        "ProductType",
617        "ProductTypeElement",
618        "SumType",
619        "SumTypeVariant",
620        "AlgebraicValue",
621        "Identity",
622        "ConnectionId",
623        "Timestamp",
624        "TimeDuration",
625        "DbConnectionBuilder",
626        "TableCache",
627        "BinaryWriter",
628        "CallReducerFlags",
629        "EventContextInterface",
630        "ReducerEventContextInterface",
631        "SubscriptionEventContextInterface",
632        "ErrorContextInterface",
633        "SubscriptionBuilderImpl",
634        "BinaryReader",
635        "DbConnectionImpl",
636        "DbContext",
637        "Event",
638        "deepEqual",
639    ];
640    types.sort();
641    writeln!(out, "import {{");
642    out.indent(1);
643    for ty in &types {
644        writeln!(out, "{ty},");
645    }
646    out.dedent(1);
647    writeln!(out, "}} from \"@clockworklabs/spacetimedb-sdk\";");
648}
649
650fn print_file_header(output: &mut Indenter) {
651    print_auto_generated_file_comment(output);
652    print_lint_suppression(output);
653    print_spacetimedb_imports(output);
654}
655
656fn print_lint_suppression(output: &mut Indenter) {
657    writeln!(output, "/* eslint-disable */");
658    writeln!(output, "/* tslint:disable */");
659    writeln!(output, "// @ts-nocheck");
660}
661
662fn write_get_algebraic_type_for_product(
663    module: &ModuleDef,
664    out: &mut Indenter,
665    elements: &[(Identifier, AlgebraicTypeUse)],
666) {
667    writeln!(
668        out,
669        "/**
670* A function which returns this type represented as an AlgebraicType.
671* This function is derived from the AlgebraicType used to generate this type.
672*/"
673    );
674    writeln!(out, "export function getTypeScriptAlgebraicType(): AlgebraicType {{");
675    {
676        out.indent(1);
677        write!(out, "return ");
678        convert_product_type(module, out, elements, "__");
679        writeln!(out, ";");
680        out.dedent(1);
681    }
682    writeln!(out, "}}");
683}
684
685fn define_namespace_and_object_type_for_product(
686    module: &ModuleDef,
687    out: &mut Indenter,
688    name: &str,
689    elements: &[(Identifier, AlgebraicTypeUse)],
690) {
691    write!(out, "export type {name} = {{");
692    if elements.is_empty() {
693        writeln!(out, "}};");
694    } else {
695        writeln!(out);
696        out.with_indent(|out| write_arglist_no_delimiters(module, out, elements, None, true).unwrap());
697        writeln!(out, "}};");
698    }
699
700    out.newline();
701
702    writeln!(
703        out,
704        "/**
705 * A namespace for generated helper functions.
706 */"
707    );
708    writeln!(out, "export namespace {name} {{");
709    out.indent(1);
710    write_get_algebraic_type_for_product(module, out, elements);
711    writeln!(out);
712
713    writeln!(
714        out,
715        "export function serialize(writer: BinaryWriter, value: {name}): void {{"
716    );
717    out.indent(1);
718    writeln!(out, "{name}.getTypeScriptAlgebraicType().serialize(writer, value);");
719    out.dedent(1);
720    writeln!(out, "}}");
721    writeln!(out);
722
723    writeln!(out, "export function deserialize(reader: BinaryReader): {name} {{");
724    out.indent(1);
725    writeln!(out, "return {name}.getTypeScriptAlgebraicType().deserialize(reader);");
726    out.dedent(1);
727    writeln!(out, "}}");
728    writeln!(out);
729
730    out.dedent(1);
731    writeln!(out, "}}");
732
733    out.newline();
734}
735
736fn write_arglist_no_delimiters(
737    module: &ModuleDef,
738    out: &mut impl Write,
739    elements: &[(Identifier, AlgebraicTypeUse)],
740    prefix: Option<&str>,
741    convert_case: bool,
742) -> anyhow::Result<()> {
743    for (ident, ty) in elements {
744        if let Some(prefix) = prefix {
745            write!(out, "{prefix} ")?;
746        }
747
748        let name = if convert_case {
749            ident.deref().to_case(Case::Camel)
750        } else {
751            ident.deref().into()
752        };
753
754        write!(out, "{name}: ")?;
755        write_type(module, out, ty, Some("__"))?;
756        writeln!(out, ",")?;
757    }
758
759    Ok(())
760}
761
762fn write_sum_variant_type(module: &ModuleDef, out: &mut Indenter, ident: &Identifier, ty: &AlgebraicTypeUse) {
763    let name = ident.deref().to_case(Case::Pascal);
764    write!(out, "export type {name} = ");
765
766    // If the contained type is the unit type, i.e. this variant has no members,
767    // write only the tag.
768    // ```
769    // { tag: "Foo" }
770    // ```
771    write!(out, "{{ ");
772    write!(out, "tag: \"{name}\"");
773
774    // If the contained type is not the unit type, write the tag and the value.
775    // ```
776    // { tag: "Bar", value: Bar }
777    // { tag: "Bar", value: number }
778    // { tag: "Bar", value: string }
779    // ```
780    // Note you could alternatively do:
781    // ```
782    // { tag: "Bar" } & Bar
783    // ```
784    // for non-primitive types but that doesn't extend to primitives.
785    // Another alternative would be to name the value field the same as the tag field, but lowercased
786    // ```
787    // { tag: "Bar", bar: Bar }
788    // { tag: "Bar", bar: number }
789    // { tag: "Bar", bar: string }
790    // ```
791    // but this is a departure from our previous convention and is not much different.
792    if !matches!(ty, AlgebraicTypeUse::Unit) {
793        write!(out, ", value: ");
794        write_type(module, out, ty, Some("__")).unwrap();
795    }
796
797    writeln!(out, " }};");
798}
799
800fn write_variant_types(module: &ModuleDef, out: &mut Indenter, variants: &[(Identifier, AlgebraicTypeUse)]) {
801    // Write all the variant types.
802    for (ident, ty) in variants {
803        write_sum_variant_type(module, out, ident, ty);
804    }
805}
806
807fn write_variant_constructors(
808    module: &ModuleDef,
809    out: &mut Indenter,
810    name: &str,
811    variants: &[(Identifier, AlgebraicTypeUse)],
812) {
813    // Write all the variant constructors.
814    // Write all of the variant constructors.
815    for (ident, ty) in variants {
816        if matches!(ty, AlgebraicTypeUse::Unit) {
817            // If the variant has no members, we can export a simple object.
818            // ```
819            // export const Foo = { tag: "Foo" };
820            // ```
821            write!(out, "export const {ident} = ");
822            writeln!(out, "{{ tag: \"{ident}\" }};");
823            continue;
824        }
825        let variant_name = ident.deref().to_case(Case::Pascal);
826        write!(out, "export const {variant_name} = (value: ");
827        write_type(module, out, ty, Some("__")).unwrap();
828        writeln!(out, "): {name} => ({{ tag: \"{variant_name}\", value }});");
829    }
830}
831
832fn write_get_algebraic_type_for_sum(
833    module: &ModuleDef,
834    out: &mut Indenter,
835    variants: &[(Identifier, AlgebraicTypeUse)],
836) {
837    writeln!(out, "export function getTypeScriptAlgebraicType(): AlgebraicType {{");
838    {
839        indent_scope!(out);
840        write!(out, "return ");
841        convert_sum_type(module, &mut out, variants, "__");
842        writeln!(out, ";");
843    }
844    writeln!(out, "}}");
845}
846
847fn define_namespace_and_types_for_sum(
848    module: &ModuleDef,
849    out: &mut Indenter,
850    name: &str,
851    variants: &[(Identifier, AlgebraicTypeUse)],
852) {
853    writeln!(out, "// A namespace for generated variants and helper functions.");
854    writeln!(out, "export namespace {name} {{");
855    out.indent(1);
856
857    // Write all of the variant types.
858    writeln!(
859        out,
860        "// These are the generated variant types for each variant of the tagged union.
861// One type is generated per variant and will be used in the `value` field of
862// the tagged union."
863    );
864    write_variant_types(module, out, variants);
865    writeln!(out);
866
867    // Write all of the variant constructors.
868    writeln!(
869        out,
870        "// Helper functions for constructing each variant of the tagged union.
871// ```
872// const foo = Foo.A(42);
873// assert!(foo.tag === \"A\");
874// assert!(foo.value === 42);
875// ```"
876    );
877    write_variant_constructors(module, out, name, variants);
878    writeln!(out);
879
880    // Write the function that generates the algebraic type.
881    write_get_algebraic_type_for_sum(module, out, variants);
882    writeln!(out);
883
884    writeln!(
885        out,
886        "export function serialize(writer: BinaryWriter, value: {name}): void {{
887    {name}.getTypeScriptAlgebraicType().serialize(writer, value);
888}}"
889    );
890    writeln!(out);
891
892    writeln!(
893        out,
894        "export function deserialize(reader: BinaryReader): {name} {{
895    return {name}.getTypeScriptAlgebraicType().deserialize(reader);
896}}"
897    );
898    writeln!(out);
899
900    out.dedent(1);
901
902    writeln!(out, "}}");
903    out.newline();
904
905    writeln!(out, "// The tagged union or sum type for the algebraic type `{name}`.");
906    write!(out, "export type {name} = ");
907
908    let names = variants
909        .iter()
910        .map(|(ident, _)| format!("{name}.{}", ident.deref().to_case(Case::Pascal)))
911        .collect::<Vec<String>>()
912        .join(" | ");
913
914    writeln!(out, "{names};");
915    out.newline();
916
917    writeln!(out, "export default {name};");
918}
919
920fn type_ref_module_name(module: &ModuleDef, type_ref: AlgebraicTypeRef) -> String {
921    let (name, _) = module.type_def_from_ref(type_ref).unwrap();
922    type_module_name(name)
923}
924
925fn type_module_name(type_name: &ScopedTypeName) -> String {
926    collect_case(Case::Snake, type_name.name_segments()) + "_type"
927}
928
929fn table_module_name(table_name: &Identifier) -> String {
930    table_name.deref().to_case(Case::Snake) + "_table"
931}
932
933fn table_method_name(table_name: &Identifier) -> String {
934    table_name.deref().to_case(Case::Camel)
935}
936
937fn reducer_args_type_name(reducer_name: &Identifier) -> String {
938    reducer_name.deref().to_case(Case::Pascal)
939}
940
941fn reducer_variant_name(reducer_name: &Identifier) -> String {
942    reducer_name.deref().to_case(Case::Pascal)
943}
944
945fn reducer_module_name(reducer_name: &Identifier) -> String {
946    reducer_name.deref().to_case(Case::Snake) + "_reducer"
947}
948
949fn reducer_function_name(reducer: &ReducerDef) -> String {
950    reducer.name.deref().to_case(Case::Camel)
951}
952
953pub fn type_name(module: &ModuleDef, ty: &AlgebraicTypeUse) -> String {
954    let mut s = String::new();
955    write_type(module, &mut s, ty, None).unwrap();
956    s
957}
958
959// This should return true if we should wrap the type in parentheses when it is the element type of
960// an array. This is needed if the type has a `|` in it, e.g. `Option<T>` or `Foo | Bar`, since
961// without parens, `Foo | Bar[]` would be parsed as `Foo | (Bar[])`.
962fn needs_parens_within_array(ty: &AlgebraicTypeUse) -> bool {
963    match ty {
964        AlgebraicTypeUse::Unit
965        | AlgebraicTypeUse::Never
966        | AlgebraicTypeUse::Identity
967        | AlgebraicTypeUse::ConnectionId
968        | AlgebraicTypeUse::Timestamp
969        | AlgebraicTypeUse::TimeDuration
970        | AlgebraicTypeUse::Primitive(_)
971        | AlgebraicTypeUse::Array(_)
972        | AlgebraicTypeUse::Ref(_) // We use the type name for these.
973        | AlgebraicTypeUse::String => {
974            false
975        }
976        AlgebraicTypeUse::ScheduleAt | AlgebraicTypeUse::Option(_) => {
977            true
978        }
979    }
980}
981
982pub fn write_type<W: Write>(
983    module: &ModuleDef,
984    out: &mut W,
985    ty: &AlgebraicTypeUse,
986    ref_prefix: Option<&str>,
987) -> fmt::Result {
988    match ty {
989        AlgebraicTypeUse::Unit => write!(out, "void")?,
990        AlgebraicTypeUse::Never => write!(out, "never")?,
991        AlgebraicTypeUse::Identity => write!(out, "Identity")?,
992        AlgebraicTypeUse::ConnectionId => write!(out, "ConnectionId")?,
993        AlgebraicTypeUse::Timestamp => write!(out, "Timestamp")?,
994        AlgebraicTypeUse::TimeDuration => write!(out, "TimeDuration")?,
995        AlgebraicTypeUse::ScheduleAt => write!(
996            out,
997            "{{ tag: \"Interval\", value: TimeDuration }} | {{ tag: \"Time\", value: Timestamp }}"
998        )?,
999        AlgebraicTypeUse::Option(inner_ty) => {
1000            write_type(module, out, inner_ty, ref_prefix)?;
1001            write!(out, " | undefined")?;
1002        }
1003        AlgebraicTypeUse::Primitive(prim) => match prim {
1004            PrimitiveType::Bool => write!(out, "boolean")?,
1005            PrimitiveType::I8 => write!(out, "number")?,
1006            PrimitiveType::U8 => write!(out, "number")?,
1007            PrimitiveType::I16 => write!(out, "number")?,
1008            PrimitiveType::U16 => write!(out, "number")?,
1009            PrimitiveType::I32 => write!(out, "number")?,
1010            PrimitiveType::U32 => write!(out, "number")?,
1011            PrimitiveType::I64 => write!(out, "bigint")?,
1012            PrimitiveType::U64 => write!(out, "bigint")?,
1013            PrimitiveType::I128 => write!(out, "bigint")?,
1014            PrimitiveType::U128 => write!(out, "bigint")?,
1015            PrimitiveType::I256 => write!(out, "bigint")?,
1016            PrimitiveType::U256 => write!(out, "bigint")?,
1017            PrimitiveType::F32 => write!(out, "number")?,
1018            PrimitiveType::F64 => write!(out, "number")?,
1019        },
1020        AlgebraicTypeUse::String => write!(out, "string")?,
1021        AlgebraicTypeUse::Array(elem_ty) => {
1022            if matches!(&**elem_ty, AlgebraicTypeUse::Primitive(PrimitiveType::U8)) {
1023                return write!(out, "Uint8Array");
1024            }
1025            let needs_parens = needs_parens_within_array(elem_ty);
1026            // We wrap the inner type in parentheses to avoid ambiguity with the [] binding.
1027            if needs_parens {
1028                write!(out, "(")?;
1029            }
1030            write_type(module, out, elem_ty, ref_prefix)?;
1031            if needs_parens {
1032                write!(out, ")")?;
1033            }
1034            write!(out, "[]")?;
1035        }
1036        AlgebraicTypeUse::Ref(r) => {
1037            if let Some(prefix) = ref_prefix {
1038                write!(out, "{prefix}")?;
1039            }
1040            write!(out, "{}", type_ref_name(module, *r))?;
1041        }
1042    }
1043    Ok(())
1044}
1045
1046fn convert_algebraic_type<'a>(
1047    module: &'a ModuleDef,
1048    out: &mut Indenter,
1049    ty: &'a AlgebraicTypeUse,
1050    ref_prefix: &'a str,
1051) {
1052    match ty {
1053        AlgebraicTypeUse::ScheduleAt => write!(out, "AlgebraicType.createScheduleAtType()"),
1054        AlgebraicTypeUse::Identity => write!(out, "AlgebraicType.createIdentityType()"),
1055        AlgebraicTypeUse::ConnectionId => write!(out, "AlgebraicType.createConnectionIdType()"),
1056        AlgebraicTypeUse::Timestamp => write!(out, "AlgebraicType.createTimestampType()"),
1057        AlgebraicTypeUse::TimeDuration => write!(out, "AlgebraicType.createTimeDurationType()"),
1058        AlgebraicTypeUse::Option(inner_ty) => {
1059            write!(out, "AlgebraicType.createOptionType(");
1060            convert_algebraic_type(module, out, inner_ty, ref_prefix);
1061            write!(out, ")");
1062        }
1063        AlgebraicTypeUse::Array(ty) => {
1064            write!(out, "AlgebraicType.createArrayType(");
1065            convert_algebraic_type(module, out, ty, ref_prefix);
1066            write!(out, ")");
1067        }
1068        AlgebraicTypeUse::Ref(r) => write!(
1069            out,
1070            "{ref_prefix}{}.getTypeScriptAlgebraicType()",
1071            type_ref_name(module, *r)
1072        ),
1073        AlgebraicTypeUse::Primitive(prim) => {
1074            write!(out, "AlgebraicType.create{prim:?}Type()");
1075        }
1076        AlgebraicTypeUse::Unit => write!(out, "AlgebraicType.createProductType([])"),
1077        AlgebraicTypeUse::Never => unimplemented!(),
1078        AlgebraicTypeUse::String => write!(out, "AlgebraicType.createStringType()"),
1079    }
1080}
1081
1082fn convert_sum_type<'a>(
1083    module: &'a ModuleDef,
1084    out: &mut Indenter,
1085    variants: &'a [(Identifier, AlgebraicTypeUse)],
1086    ref_prefix: &'a str,
1087) {
1088    writeln!(out, "AlgebraicType.createSumType([");
1089    out.indent(1);
1090    for (ident, ty) in variants {
1091        write!(out, "new SumTypeVariant(\"{ident}\", ",);
1092        convert_algebraic_type(module, out, ty, ref_prefix);
1093        writeln!(out, "),");
1094    }
1095    out.dedent(1);
1096    write!(out, "])")
1097}
1098
1099fn convert_product_type<'a>(
1100    module: &'a ModuleDef,
1101    out: &mut Indenter,
1102    elements: &'a [(Identifier, AlgebraicTypeUse)],
1103    ref_prefix: &'a str,
1104) {
1105    writeln!(out, "AlgebraicType.createProductType([");
1106    out.indent(1);
1107    for (ident, ty) in elements {
1108        write!(
1109            out,
1110            "new ProductTypeElement(\"{}\", ",
1111            ident.deref().to_case(Case::Camel)
1112        );
1113        convert_algebraic_type(module, out, ty, ref_prefix);
1114        writeln!(out, "),");
1115    }
1116    out.dedent(1);
1117    write!(out, "])")
1118}
1119
1120/// Print imports for each of the `imports`.
1121fn print_imports(module: &ModuleDef, out: &mut Indenter, imports: Imports) {
1122    for typeref in imports {
1123        let module_name = type_ref_module_name(module, typeref);
1124        let type_name = type_ref_name(module, typeref);
1125        writeln!(
1126            out,
1127            "import {{ {type_name} as __{type_name} }} from \"./{module_name}\";"
1128        );
1129    }
1130}
1131
1132/// Use `search_function` on `roots` to detect required imports, then print them with `print_imports`.
1133///
1134/// `this_file` is passed and excluded for the case of recursive types:
1135/// without it, the definition for a type like `struct Foo { foos: Vec<Foo> }`
1136/// would attempt to include `import { Foo } from "./foo"`.
1137fn gen_and_print_imports(
1138    module: &ModuleDef,
1139    out: &mut Indenter,
1140    roots: &[(Identifier, AlgebraicTypeUse)],
1141    dont_import: &[AlgebraicTypeRef],
1142) {
1143    let mut imports = BTreeSet::new();
1144
1145    for (_, ty) in roots {
1146        ty.for_each_ref(|r| {
1147            imports.insert(r);
1148        });
1149    }
1150    for skip in dont_import {
1151        imports.remove(skip);
1152    }
1153    let len = imports.len();
1154
1155    print_imports(module, out, imports);
1156
1157    if len > 0 {
1158        out.newline();
1159    }
1160}
1161
1162// const RESERVED_KEYWORDS: [&str; 36] = [
1163//     "break",
1164//     "case",
1165//     "catch",
1166//     "class",
1167//     "const",
1168//     "continue",
1169//     "debugger",
1170//     "default",
1171//     "delete",
1172//     "do",
1173//     "else",
1174//     "enum",
1175//     "export",
1176//     "extends",
1177//     "false",
1178//     "finally",
1179//     "for",
1180//     "function",
1181//     "if",
1182//     "import",
1183//     "in",
1184//     "instanceof",
1185//     "new",
1186//     "null",
1187//     "return",
1188//     "super",
1189//     "switch",
1190//     "this",
1191//     "throw",
1192//     "true",
1193//     "try",
1194//     "typeof",
1195//     "var",
1196//     "void",
1197//     "while",
1198//     "with",
1199// ];
1200
1201// fn typescript_field_name(field_name: String) -> String {
1202//     if RESERVED_KEYWORDS
1203//         .into_iter()
1204//         .map(String::from)
1205//         .collect::<Vec<String>>()
1206//         .contains(&field_name)
1207//     {
1208//         return format!("_{field_name}");
1209//     }
1210
1211//     field_name
1212// }