spacetimedb_cli/subcommands/generate/
rust.rs

1use super::code_indenter::{CodeIndenter, Indenter};
2use super::util::{collect_case, iter_reducers, print_lines, type_ref_name};
3use super::Lang;
4use crate::detect::{has_rust_fmt, has_rust_up};
5use crate::generate::util::{iter_tables, iter_types, iter_unique_cols, print_auto_generated_file_comment};
6use anyhow::Context;
7use convert_case::{Case, Casing};
8use duct::cmd;
9use spacetimedb_lib::sats::AlgebraicTypeRef;
10use spacetimedb_schema::def::{ModuleDef, ReducerDef, ScopedTypeName, TableDef, TypeDef};
11use spacetimedb_schema::identifier::Identifier;
12use spacetimedb_schema::schema::{Schema, TableSchema};
13use spacetimedb_schema::type_for_generate::{AlgebraicTypeDef, AlgebraicTypeUse, PrimitiveType};
14use std::collections::BTreeSet;
15use std::fmt::{self, Write};
16use std::ops::Deref;
17use std::path::PathBuf;
18
19/// Pairs of (module_name, TypeName).
20type Imports = BTreeSet<AlgebraicTypeRef>;
21
22const INDENT: &str = "    ";
23
24pub struct Rust;
25
26impl Lang for Rust {
27    fn table_filename(
28        &self,
29        _module: &spacetimedb_schema::def::ModuleDef,
30        table: &spacetimedb_schema::def::TableDef,
31    ) -> String {
32        table_module_name(&table.name) + ".rs"
33    }
34
35    fn type_filename(&self, type_name: &ScopedTypeName) -> String {
36        type_module_name(type_name) + ".rs"
37    }
38
39    fn reducer_filename(&self, reducer_name: &Identifier) -> String {
40        reducer_module_name(reducer_name) + ".rs"
41    }
42
43    fn format_files(&self, generated_files: BTreeSet<PathBuf>) -> anyhow::Result<()> {
44        if !has_rust_fmt() {
45            if has_rust_up() {
46                cmd!("rustup", "component", "add", "rustfmt")
47                    .run()
48                    .context("Failed to install rustfmt with Rustup")?;
49            } else {
50                anyhow::bail!("rustfmt is not installed. Please install it.");
51            }
52        }
53        cmd!("rustfmt", "--edition", "2021")
54            .before_spawn(move |cmd| {
55                cmd.args(&generated_files);
56                Ok(())
57            })
58            .run()?;
59        Ok(())
60    }
61
62    fn generate_type(&self, module: &ModuleDef, typ: &TypeDef) -> String {
63        let type_name = collect_case(Case::Pascal, typ.name.name_segments());
64
65        let mut output = CodeIndenter::new(String::new(), INDENT);
66        let out = &mut output;
67
68        print_file_header(out);
69        out.newline();
70
71        match &module.typespace_for_generate()[typ.ty] {
72            AlgebraicTypeDef::Product(product) => {
73                gen_and_print_imports(module, out, &product.elements, &[typ.ty]);
74                out.newline();
75                define_struct_for_product(module, out, &type_name, &product.elements, "pub");
76            }
77            AlgebraicTypeDef::Sum(sum) => {
78                gen_and_print_imports(module, out, &sum.variants, &[typ.ty]);
79                out.newline();
80                define_enum_for_sum(module, out, &type_name, &sum.variants);
81            }
82            AlgebraicTypeDef::PlainEnum(plain_enum) => {
83                let variants = plain_enum
84                    .variants
85                    .iter()
86                    .cloned()
87                    .map(|var| (var, AlgebraicTypeUse::Unit))
88                    .collect::<Vec<_>>();
89                define_enum_for_sum(module, out, &type_name, &variants);
90            }
91        }
92        out.newline();
93
94        writeln!(
95            out,
96            "
97impl __sdk::InModule for {type_name} {{
98    type Module = super::RemoteModule;
99}}
100",
101        );
102
103        output.into_inner()
104    }
105    fn generate_table(&self, module: &ModuleDef, table: &TableDef) -> String {
106        let schema = TableSchema::from_module_def(module, table, (), 0.into())
107            .validated()
108            .expect("Failed to generate table due to validation errors");
109
110        let type_ref = table.product_type_ref;
111
112        let mut output = CodeIndenter::new(String::new(), INDENT);
113        let out = &mut output;
114
115        print_file_header(out);
116
117        let row_type = type_ref_name(module, type_ref);
118        let row_type_module = type_ref_module_name(module, type_ref);
119
120        writeln!(out, "use super::{row_type_module}::{row_type};");
121
122        let product_def = module.typespace_for_generate()[type_ref].as_product().unwrap();
123
124        // Import the types of all fields.
125        // We only need to import fields which have indices or unique constraints,
126        // but it's easier to just import all of 'em, since we have `#![allow(unused)]` anyway.
127        gen_and_print_imports(
128            module,
129            out,
130            &product_def.elements,
131            &[], // No need to skip any imports; we're not defining a type, so there's no chance of circular imports.
132        );
133
134        let table_name = table.name.deref();
135        let table_name_pascalcase = table.name.deref().to_case(Case::Pascal);
136        let table_handle = table_name_pascalcase.clone() + "TableHandle";
137        let insert_callback_id = table_name_pascalcase.clone() + "InsertCallbackId";
138        let delete_callback_id = table_name_pascalcase.clone() + "DeleteCallbackId";
139        let accessor_trait = table_access_trait_name(&table.name);
140        let accessor_method = table_method_name(&table.name);
141
142        write!(
143            out,
144            "
145/// Table handle for the table `{table_name}`.
146///
147/// Obtain a handle from the [`{accessor_trait}::{accessor_method}`] method on [`super::RemoteTables`],
148/// like `ctx.db.{accessor_method}()`.
149///
150/// Users are encouraged not to explicitly reference this type,
151/// but to directly chain method calls,
152/// like `ctx.db.{accessor_method}().on_insert(...)`.
153pub struct {table_handle}<'ctx> {{
154    imp: __sdk::TableHandle<{row_type}>,
155    ctx: std::marker::PhantomData<&'ctx super::RemoteTables>,
156}}
157
158#[allow(non_camel_case_types)]
159/// Extension trait for access to the table `{table_name}`.
160///
161/// Implemented for [`super::RemoteTables`].
162pub trait {accessor_trait} {{
163    #[allow(non_snake_case)]
164    /// Obtain a [`{table_handle}`], which mediates access to the table `{table_name}`.
165    fn {accessor_method}(&self) -> {table_handle}<'_>;
166}}
167
168impl {accessor_trait} for super::RemoteTables {{
169    fn {accessor_method}(&self) -> {table_handle}<'_> {{
170        {table_handle} {{
171            imp: self.imp.get_table::<{row_type}>({table_name:?}),
172            ctx: std::marker::PhantomData,
173        }}
174    }}
175}}
176
177pub struct {insert_callback_id}(__sdk::CallbackId);
178pub struct {delete_callback_id}(__sdk::CallbackId);
179
180impl<'ctx> __sdk::Table for {table_handle}<'ctx> {{
181    type Row = {row_type};
182    type EventContext = super::EventContext;
183
184    fn count(&self) -> u64 {{ self.imp.count() }}
185    fn iter(&self) -> impl Iterator<Item = {row_type}> + '_ {{ self.imp.iter() }}
186
187    type InsertCallbackId = {insert_callback_id};
188
189    fn on_insert(
190        &self,
191        callback: impl FnMut(&Self::EventContext, &Self::Row) + Send + 'static,
192    ) -> {insert_callback_id} {{
193        {insert_callback_id}(self.imp.on_insert(Box::new(callback)))
194    }}
195
196    fn remove_on_insert(&self, callback: {insert_callback_id}) {{
197        self.imp.remove_on_insert(callback.0)
198    }}
199
200    type DeleteCallbackId = {delete_callback_id};
201
202    fn on_delete(
203        &self,
204        callback: impl FnMut(&Self::EventContext, &Self::Row) + Send + 'static,
205    ) -> {delete_callback_id} {{
206        {delete_callback_id}(self.imp.on_delete(Box::new(callback)))
207    }}
208
209    fn remove_on_delete(&self, callback: {delete_callback_id}) {{
210        self.imp.remove_on_delete(callback.0)
211    }}
212}}
213"
214        );
215
216        out.delimited_block(
217            "
218#[doc(hidden)]
219pub(super) fn register_table(client_cache: &mut __sdk::ClientCache<super::RemoteModule>) {
220",
221            |out| {
222                writeln!(out, "let _table = client_cache.get_or_make_table::<{row_type}>({table_name:?});");
223                for (unique_field_ident, unique_field_type_use) in iter_unique_cols(&schema, product_def) {
224                    let unique_field_name = unique_field_ident.deref().to_case(Case::Snake);
225                    let unique_field_type = type_name(module, unique_field_type_use);
226                    writeln!(
227                        out,
228                        "_table.add_unique_constraint::<{unique_field_type}>({unique_field_name:?}, |row| &row.{unique_field_name});",
229                    );
230                }
231            },
232            "}",
233        );
234
235        if schema.pk().is_some() {
236            let update_callback_id = table_name_pascalcase.clone() + "UpdateCallbackId";
237            write!(
238                out,
239                "
240pub struct {update_callback_id}(__sdk::CallbackId);
241
242impl<'ctx> __sdk::TableWithPrimaryKey for {table_handle}<'ctx> {{
243    type UpdateCallbackId = {update_callback_id};
244
245    fn on_update(
246        &self,
247        callback: impl FnMut(&Self::EventContext, &Self::Row, &Self::Row) + Send + 'static,
248    ) -> {update_callback_id} {{
249        {update_callback_id}(self.imp.on_update(Box::new(callback)))
250    }}
251
252    fn remove_on_update(&self, callback: {update_callback_id}) {{
253        self.imp.remove_on_update(callback.0)
254    }}
255}}
256"
257            );
258        }
259
260        out.newline();
261
262        write!(
263            out,
264            "
265#[doc(hidden)]
266pub(super) fn parse_table_update(
267    raw_updates: __ws::TableUpdate<__ws::BsatnFormat>,
268) -> __sdk::Result<__sdk::TableUpdate<{row_type}>> {{
269    __sdk::TableUpdate::parse_table_update(raw_updates).map_err(|e| {{
270        __sdk::InternalError::failed_parse(
271            \"TableUpdate<{row_type}>\",
272            \"TableUpdate\",
273        ).with_cause(e).into()
274    }})
275}}
276"
277        );
278
279        for (unique_field_ident, unique_field_type_use) in iter_unique_cols(&schema, product_def) {
280            let unique_field_name = unique_field_ident.deref().to_case(Case::Snake);
281            let unique_field_name_pascalcase = unique_field_name.to_case(Case::Pascal);
282
283            let unique_constraint = table_name_pascalcase.clone() + &unique_field_name_pascalcase + "Unique";
284            let unique_field_type = type_name(module, unique_field_type_use);
285
286            write!(
287                out,
288                "
289        /// Access to the `{unique_field_name}` unique index on the table `{table_name}`,
290        /// which allows point queries on the field of the same name
291        /// via the [`{unique_constraint}::find`] method.
292        ///
293        /// Users are encouraged not to explicitly reference this type,
294        /// but to directly chain method calls,
295        /// like `ctx.db.{accessor_method}().{unique_field_name}().find(...)`.
296        pub struct {unique_constraint}<'ctx> {{
297            imp: __sdk::UniqueConstraintHandle<{row_type}, {unique_field_type}>,
298            phantom: std::marker::PhantomData<&'ctx super::RemoteTables>,
299        }}
300
301        impl<'ctx> {table_handle}<'ctx> {{
302            /// Get a handle on the `{unique_field_name}` unique index on the table `{table_name}`.
303            pub fn {unique_field_name}(&self) -> {unique_constraint}<'ctx> {{
304                {unique_constraint} {{
305                    imp: self.imp.get_unique_constraint::<{unique_field_type}>({unique_field_name:?}),
306                    phantom: std::marker::PhantomData,
307                }}
308            }}
309        }}
310
311        impl<'ctx> {unique_constraint}<'ctx> {{
312            /// Find the subscribed row whose `{unique_field_name}` column value is equal to `col_val`,
313            /// if such a row is present in the client cache.
314            pub fn find(&self, col_val: &{unique_field_type}) -> Option<{row_type}> {{
315                self.imp.find(col_val)
316            }}
317        }}
318        "
319            );
320        }
321
322        // TODO: expose non-unique indices.
323
324        output.into_inner()
325    }
326    fn generate_reducer(&self, module: &ModuleDef, reducer: &ReducerDef) -> String {
327        let mut output = CodeIndenter::new(String::new(), INDENT);
328        let out = &mut output;
329
330        print_file_header(out);
331
332        out.newline();
333
334        gen_and_print_imports(
335            module,
336            out,
337            &reducer.params_for_generate.elements,
338            // No need to skip any imports; we're not emitting a type that other modules can import.
339            &[],
340        );
341
342        out.newline();
343
344        let reducer_name = reducer.name.deref();
345        let func_name = reducer_function_name(reducer);
346        let set_reducer_flags_trait = reducer_flags_trait_name(reducer);
347        let args_type = reducer_args_type_name(&reducer.name);
348        let enum_variant_name = reducer_variant_name(&reducer.name);
349
350        // Define an "args struct" for the reducer.
351        // This is not user-facing (note the `pub(super)` visibility);
352        // it is an internal helper for serialization and deserialization.
353        // We actually want to ser/de instances of `enum Reducer`, but:
354        // - `Reducer` will have struct-like variants, which SATS ser/de does not support.
355        // - The WS format does not contain a BSATN-serialized `Reducer` instance;
356        //   it holds the reducer name or ID separately from the argument bytes.
357        //   We could work up some magic with `DeserializeSeed`
358        //   and/or custom `Serializer` and `Deserializer` types
359        //   to account for this, but it's much easier to just use an intermediate struct per reducer.
360        define_struct_for_product(
361            module,
362            out,
363            &args_type,
364            &reducer.params_for_generate.elements,
365            "pub(super)",
366        );
367
368        out.newline();
369
370        let callback_id = reducer_callback_id_name(&reducer.name);
371
372        // The reducer arguments as `ident: ty, ident: ty, ident: ty,`,
373        // like an argument list.
374        let mut arglist = String::new();
375        write_arglist_no_delimiters(module, &mut arglist, &reducer.params_for_generate.elements, None).unwrap();
376
377        // The reducer argument types as `&ty, &ty, &ty`,
378        // for use as the params in a `FnMut` closure type.
379        let mut arg_types_ref_list = String::new();
380        // The reducer argument names as `ident, ident, ident`,
381        // for passing to function call and struct literal expressions.
382        let mut arg_names_list = String::new();
383        for (arg_ident, arg_ty) in &reducer.params_for_generate.elements[..] {
384            arg_types_ref_list += "&";
385            write_type(module, &mut arg_types_ref_list, arg_ty).unwrap();
386            arg_types_ref_list += ", ";
387
388            let arg_name = arg_ident.deref().to_case(Case::Snake);
389            arg_names_list += &arg_name;
390            arg_names_list += ", ";
391        }
392
393        write!(out, "impl From<{args_type}> for super::Reducer ");
394        out.delimited_block(
395            "{",
396            |out| {
397                write!(out, "fn from(args: {args_type}) -> Self ");
398                out.delimited_block(
399                    "{",
400                    |out| {
401                        write!(out, "Self::{enum_variant_name}");
402                        if !reducer.params_for_generate.elements.is_empty() {
403                            // We generate "struct variants" for reducers with arguments,
404                            // but "unit variants" for reducers of no arguments.
405                            // These use different constructor syntax.
406                            out.delimited_block(
407                                " {",
408                                |out| {
409                                    for (arg_ident, _ty) in &reducer.params_for_generate.elements[..] {
410                                        let arg_name = arg_ident.deref().to_case(Case::Snake);
411                                        writeln!(out, "{arg_name}: args.{arg_name},");
412                                    }
413                                },
414                                "}",
415                            );
416                        }
417                        out.newline();
418                    },
419                    "}\n",
420                );
421            },
422            "}\n",
423        );
424
425        // TODO: check for lifecycle reducers and do not generate the invoke method.
426
427        writeln!(
428            out,
429            "
430impl __sdk::InModule for {args_type} {{
431    type Module = super::RemoteModule;
432}}
433
434pub struct {callback_id}(__sdk::CallbackId);
435
436#[allow(non_camel_case_types)]
437/// Extension trait for access to the reducer `{reducer_name}`.
438///
439/// Implemented for [`super::RemoteReducers`].
440pub trait {func_name} {{
441    /// Request that the remote module invoke the reducer `{reducer_name}` to run as soon as possible.
442    ///
443    /// This method returns immediately, and errors only if we are unable to send the request.
444    /// The reducer will run asynchronously in the future,
445    ///  and its status can be observed by listening for [`Self::on_{func_name}`] callbacks.
446    fn {func_name}(&self, {arglist}) -> __sdk::Result<()>;
447    /// Register a callback to run whenever we are notified of an invocation of the reducer `{reducer_name}`.
448    ///
449    /// Callbacks should inspect the [`__sdk::ReducerEvent`] contained in the [`super::ReducerEventContext`]
450    /// to determine the reducer's status.
451    ///
452    /// The returned [`{callback_id}`] can be passed to [`Self::remove_on_{func_name}`]
453    /// to cancel the callback.
454    fn on_{func_name}(&self, callback: impl FnMut(&super::ReducerEventContext, {arg_types_ref_list}) + Send + 'static) -> {callback_id};
455    /// Cancel a callback previously registered by [`Self::on_{func_name}`],
456    /// causing it not to run in the future.
457    fn remove_on_{func_name}(&self, callback: {callback_id});
458}}
459
460impl {func_name} for super::RemoteReducers {{
461    fn {func_name}(&self, {arglist}) -> __sdk::Result<()> {{
462        self.imp.call_reducer({reducer_name:?}, {args_type} {{ {arg_names_list} }})
463    }}
464    fn on_{func_name}(
465        &self,
466        mut callback: impl FnMut(&super::ReducerEventContext, {arg_types_ref_list}) + Send + 'static,
467    ) -> {callback_id} {{
468        {callback_id}(self.imp.on_reducer(
469            {reducer_name:?},
470            Box::new(move |ctx: &super::ReducerEventContext| {{
471                let super::ReducerEventContext {{
472                    event: __sdk::ReducerEvent {{
473                        reducer: super::Reducer::{enum_variant_name} {{
474                            {arg_names_list}
475                        }},
476                        ..
477                    }},
478                    ..
479                }} = ctx else {{ unreachable!() }};
480                callback(ctx, {arg_names_list})
481            }}),
482        ))
483    }}
484    fn remove_on_{func_name}(&self, callback: {callback_id}) {{
485        self.imp.remove_on_reducer({reducer_name:?}, callback.0)
486    }}
487}}
488
489#[allow(non_camel_case_types)]
490#[doc(hidden)]
491/// Extension trait for setting the call-flags for the reducer `{reducer_name}`.
492///
493/// Implemented for [`super::SetReducerFlags`].
494///
495/// This type is currently unstable and may be removed without a major version bump.
496pub trait {set_reducer_flags_trait} {{
497    /// Set the call-reducer flags for the reducer `{reducer_name}` to `flags`.
498    ///
499    /// This type is currently unstable and may be removed without a major version bump.
500    fn {func_name}(&self, flags: __ws::CallReducerFlags);
501}}
502
503impl {set_reducer_flags_trait} for super::SetReducerFlags {{
504    fn {func_name}(&self, flags: __ws::CallReducerFlags) {{
505        self.imp.set_call_reducer_flags({reducer_name:?}, flags);
506    }}
507}}
508"
509        );
510
511        output.into_inner()
512    }
513
514    fn generate_globals(&self, module: &ModuleDef) -> Vec<(String, String)> {
515        let mut output = CodeIndenter::new(String::new(), INDENT);
516        let out = &mut output;
517
518        print_file_header(out);
519
520        out.newline();
521
522        // Declare `pub mod` for each of the files generated.
523        print_module_decls(module, out);
524
525        out.newline();
526
527        // Re-export all the modules for the generated files.
528        print_module_reexports(module, out);
529
530        out.newline();
531
532        // Define `enum Reducer`.
533        print_reducer_enum_defn(module, out);
534
535        out.newline();
536
537        // Define `DbUpdate`.
538        print_db_update_defn(module, out);
539
540        out.newline();
541
542        // Define `AppliedDiff`.
543        print_applied_diff_defn(module, out);
544
545        out.newline();
546
547        // Define `RemoteModule`, `DbConnection`, `EventContext`, `RemoteTables`, `RemoteReducers` and `SubscriptionHandle`.
548        // Note that these do not change based on the module.
549        print_const_db_context_types(out);
550
551        out.newline();
552
553        // Implement `SpacetimeModule` for `RemoteModule`.
554        // This includes a method for initializing the tables in the client cache.
555        print_impl_spacetime_module(module, out);
556
557        vec![("mod.rs".to_string(), (output.into_inner()))]
558    }
559
560    fn clap_value() -> clap::builder::PossibleValue {
561        clap::builder::PossibleValue::new("rust").aliases(["rs", "RS"])
562    }
563}
564
565pub fn write_type<W: Write>(module: &ModuleDef, out: &mut W, ty: &AlgebraicTypeUse) -> fmt::Result {
566    match ty {
567        AlgebraicTypeUse::Unit => write!(out, "()")?,
568        AlgebraicTypeUse::Never => write!(out, "std::convert::Infallible")?,
569        AlgebraicTypeUse::Identity => write!(out, "__sdk::Identity")?,
570        AlgebraicTypeUse::ConnectionId => write!(out, "__sdk::ConnectionId")?,
571        AlgebraicTypeUse::Timestamp => write!(out, "__sdk::Timestamp")?,
572        AlgebraicTypeUse::TimeDuration => write!(out, "__sdk::TimeDuration")?,
573        AlgebraicTypeUse::ScheduleAt => write!(out, "__sdk::ScheduleAt")?,
574        AlgebraicTypeUse::Option(inner_ty) => {
575            write!(out, "Option::<")?;
576            write_type(module, out, inner_ty)?;
577            write!(out, ">")?;
578        }
579        AlgebraicTypeUse::Primitive(prim) => match prim {
580            PrimitiveType::Bool => write!(out, "bool")?,
581            PrimitiveType::I8 => write!(out, "i8")?,
582            PrimitiveType::U8 => write!(out, "u8")?,
583            PrimitiveType::I16 => write!(out, "i16")?,
584            PrimitiveType::U16 => write!(out, "u16")?,
585            PrimitiveType::I32 => write!(out, "i32")?,
586            PrimitiveType::U32 => write!(out, "u32")?,
587            PrimitiveType::I64 => write!(out, "i64")?,
588            PrimitiveType::U64 => write!(out, "u64")?,
589            PrimitiveType::I128 => write!(out, "i128")?,
590            PrimitiveType::U128 => write!(out, "u128")?,
591            PrimitiveType::I256 => write!(out, "__sats::i256")?,
592            PrimitiveType::U256 => write!(out, "__sats::u256")?,
593            PrimitiveType::F32 => write!(out, "f32")?,
594            PrimitiveType::F64 => write!(out, "f64")?,
595        },
596        AlgebraicTypeUse::String => write!(out, "String")?,
597        AlgebraicTypeUse::Array(elem_ty) => {
598            write!(out, "Vec::<")?;
599            write_type(module, out, elem_ty)?;
600            write!(out, ">")?;
601        }
602        AlgebraicTypeUse::Ref(r) => {
603            write!(out, "{}", type_ref_name(module, *r))?;
604        }
605    }
606    Ok(())
607}
608
609pub fn type_name(module: &ModuleDef, ty: &AlgebraicTypeUse) -> String {
610    let mut s = String::new();
611    write_type(module, &mut s, ty).unwrap();
612    s
613}
614
615const ALLOW_LINTS: &str = "#![allow(unused, clippy::all)]";
616
617const SPACETIMEDB_IMPORTS: &[&str] = &[
618    "use spacetimedb_sdk::__codegen::{",
619    "\tself as __sdk,",
620    "\t__lib,",
621    "\t__sats,",
622    "\t__ws,",
623    "};",
624];
625
626fn print_spacetimedb_imports(output: &mut Indenter) {
627    print_lines(output, SPACETIMEDB_IMPORTS);
628}
629
630fn print_file_header(output: &mut Indenter) {
631    print_auto_generated_file_comment(output);
632    writeln!(output, "{ALLOW_LINTS}");
633    print_spacetimedb_imports(output);
634}
635
636// TODO: figure out if/when sum types should derive:
637// - Clone
638// - Debug
639// - Copy
640// - PartialEq, Eq
641// - Hash
642//    - Complicated because `HashMap` is not `Hash`.
643// - others?
644
645const ENUM_DERIVES: &[&str] = &[
646    "#[derive(__lib::ser::Serialize, __lib::de::Deserialize, Clone, PartialEq, Debug)]",
647    "#[sats(crate = __lib)]",
648];
649
650fn print_enum_derives(output: &mut Indenter) {
651    print_lines(output, ENUM_DERIVES);
652}
653
654/// Generate a file which defines an `enum` corresponding to the `sum_type`.
655pub fn define_enum_for_sum(
656    module: &ModuleDef,
657    out: &mut Indenter,
658    name: &str,
659    variants: &[(Identifier, AlgebraicTypeUse)],
660) {
661    print_enum_derives(out);
662    write!(out, "pub enum {name} ");
663
664    out.delimited_block(
665        "{",
666        |out| {
667            for (ident, ty) in variants {
668                write_enum_variant(module, out, ident, ty);
669                out.newline();
670            }
671        },
672        "}\n",
673    );
674
675    out.newline()
676}
677
678fn write_enum_variant(module: &ModuleDef, out: &mut Indenter, ident: &Identifier, ty: &AlgebraicTypeUse) {
679    let name = ident.deref().to_case(Case::Pascal);
680    write!(out, "{name}");
681
682    // If the contained type is the unit type, i.e. this variant has no members,
683    // write it without parens or braces, like
684    // ```
685    // Foo,
686    // ```
687    if !matches!(ty, AlgebraicTypeUse::Unit) {
688        // If the contained type is not a product, i.e. this variant has a single
689        // member, write it tuple-style, with parens.
690        write!(out, "(");
691        write_type(module, out, ty).unwrap();
692        write!(out, ")");
693    }
694    writeln!(out, ",");
695}
696
697fn write_struct_type_fields_in_braces(
698    module: &ModuleDef,
699    out: &mut Indenter,
700    elements: &[(Identifier, AlgebraicTypeUse)],
701
702    // Whether to print a `pub` qualifier on the fields. Necessary for `struct` defns,
703    // disallowed for `enum` defns.
704    pub_qualifier: bool,
705) {
706    out.delimited_block(
707        "{",
708        |out| write_arglist_no_delimiters(module, out, elements, pub_qualifier.then_some("pub")).unwrap(),
709        "}",
710    );
711}
712
713fn write_arglist_no_delimiters(
714    module: &ModuleDef,
715    out: &mut impl Write,
716    elements: &[(Identifier, AlgebraicTypeUse)],
717
718    // Written before each line. Useful for `pub`.
719    prefix: Option<&str>,
720) -> anyhow::Result<()> {
721    for (ident, ty) in elements {
722        if let Some(prefix) = prefix {
723            write!(out, "{prefix} ")?;
724        }
725
726        let name = ident.deref().to_case(Case::Snake);
727
728        write!(out, "{name}: ")?;
729        write_type(module, out, ty)?;
730        writeln!(out, ",")?;
731    }
732
733    Ok(())
734}
735
736// TODO: figure out if/when product types should derive:
737// - Clone
738// - Debug
739// - Copy
740// - PartialEq, Eq
741// - Hash
742//    - Complicated because `HashMap` is not `Hash`.
743// - others?
744
745const STRUCT_DERIVES: &[&str] = &[
746    "#[derive(__lib::ser::Serialize, __lib::de::Deserialize, Clone, PartialEq, Debug)]",
747    "#[sats(crate = __lib)]",
748];
749
750fn print_struct_derives(output: &mut Indenter) {
751    print_lines(output, STRUCT_DERIVES);
752}
753
754fn define_struct_for_product(
755    module: &ModuleDef,
756    out: &mut Indenter,
757    name: &str,
758    elements: &[(Identifier, AlgebraicTypeUse)],
759    vis: &str,
760) {
761    print_struct_derives(out);
762
763    write!(out, "{vis} struct {name} ");
764
765    // TODO: if elements is empty, define a unit struct with no brace-delimited list of fields.
766    write_struct_type_fields_in_braces(
767        module, out, elements, true, // `pub`-qualify fields.
768    );
769
770    out.newline();
771}
772
773fn type_ref_module_name(module: &ModuleDef, type_ref: AlgebraicTypeRef) -> String {
774    let (name, _) = module.type_def_from_ref(type_ref).unwrap();
775    type_module_name(name)
776}
777
778fn type_module_name(type_name: &ScopedTypeName) -> String {
779    collect_case(Case::Snake, type_name.name_segments()) + "_type"
780}
781
782fn table_module_name(table_name: &Identifier) -> String {
783    table_name.deref().to_case(Case::Snake) + "_table"
784}
785
786fn table_method_name(table_name: &Identifier) -> String {
787    table_name.deref().to_case(Case::Snake)
788}
789
790fn table_access_trait_name(table_name: &Identifier) -> String {
791    table_name.deref().to_case(Case::Pascal) + "TableAccess"
792}
793
794fn reducer_args_type_name(reducer_name: &Identifier) -> String {
795    reducer_name.deref().to_case(Case::Pascal) + "Args"
796}
797
798fn reducer_variant_name(reducer_name: &Identifier) -> String {
799    reducer_name.deref().to_case(Case::Pascal)
800}
801
802fn reducer_callback_id_name(reducer_name: &Identifier) -> String {
803    reducer_name.deref().to_case(Case::Pascal) + "CallbackId"
804}
805
806fn reducer_module_name(reducer_name: &Identifier) -> String {
807    reducer_name.deref().to_case(Case::Snake) + "_reducer"
808}
809
810fn reducer_function_name(reducer: &ReducerDef) -> String {
811    reducer.name.deref().to_case(Case::Snake)
812}
813
814fn reducer_flags_trait_name(reducer: &ReducerDef) -> String {
815    format!("set_flags_for_{}", reducer_function_name(reducer))
816}
817
818/// Iterate over all of the Rust `mod`s for types, reducers and tables in the `module`.
819fn iter_module_names(module: &ModuleDef) -> impl Iterator<Item = String> + '_ {
820    itertools::chain!(
821        iter_types(module).map(|ty| type_module_name(&ty.name)),
822        iter_reducers(module).map(|r| reducer_module_name(&r.name)),
823        iter_tables(module).map(|tbl| table_module_name(&tbl.name)),
824    )
825}
826
827/// Print `pub mod` declarations for all the files that will be generated for `items`.
828fn print_module_decls(module: &ModuleDef, out: &mut Indenter) {
829    for module_name in iter_module_names(module) {
830        writeln!(out, "pub mod {module_name};");
831    }
832}
833
834/// Print appropriate reexports for all the files that will be generated for `items`.
835fn print_module_reexports(module: &ModuleDef, out: &mut Indenter) {
836    for ty in iter_types(module) {
837        let mod_name = type_module_name(&ty.name);
838        let type_name = collect_case(Case::Pascal, ty.name.name_segments());
839        writeln!(out, "pub use {mod_name}::{type_name};")
840    }
841    for table in iter_tables(module) {
842        let mod_name = table_module_name(&table.name);
843        // TODO: More precise reexport: we want:
844        // - The trait name.
845        // - The insert, delete and possibly update callback ids.
846        // We do not want:
847        // - The table handle.
848        writeln!(out, "pub use {mod_name}::*;");
849    }
850    for reducer in iter_reducers(module) {
851        let mod_name = reducer_module_name(&reducer.name);
852        let reducer_trait_name = reducer_function_name(reducer);
853        let flags_trait_name = reducer_flags_trait_name(reducer);
854        let callback_id_name = reducer_callback_id_name(&reducer.name);
855        writeln!(
856            out,
857            "pub use {mod_name}::{{{reducer_trait_name}, {flags_trait_name}, {callback_id_name}}};"
858        );
859    }
860}
861
862fn print_reducer_enum_defn(module: &ModuleDef, out: &mut Indenter) {
863    // Don't derive ser/de on this enum;
864    // it's not a proper SATS enum and the derive will fail.
865    writeln!(out, "#[derive(Clone, PartialEq, Debug)]");
866    writeln!(
867        out,
868        "
869/// One of the reducers defined by this module.
870///
871/// Contained within a [`__sdk::ReducerEvent`] in [`EventContext`]s for reducer events
872/// to indicate which reducer caused the event.
873",
874    );
875    out.delimited_block(
876        "pub enum Reducer {",
877        |out| {
878            for reducer in iter_reducers(module) {
879                write!(out, "{} ", reducer_variant_name(&reducer.name));
880                if !reducer.params_for_generate.elements.is_empty() {
881                    // If the reducer has any arguments, generate a "struct variant,"
882                    // like `Foo { bar: Baz, }`.
883                    // If it doesn't, generate a "unit variant" instead,
884                    // like `Foo,`.
885                    write_struct_type_fields_in_braces(module, out, &reducer.params_for_generate.elements, false);
886                }
887                writeln!(out, ",");
888            }
889        },
890        "}\n",
891    );
892    out.newline();
893    writeln!(
894        out,
895        "
896impl __sdk::InModule for Reducer {{
897    type Module = RemoteModule;
898}}
899",
900    );
901
902    out.delimited_block(
903        "impl __sdk::Reducer for Reducer {",
904        |out| {
905            out.delimited_block(
906                "fn reducer_name(&self) -> &'static str {",
907                |out| {
908                    out.delimited_block(
909                        "match self {",
910                        |out| {
911                            for reducer in iter_reducers(module) {
912                                write!(out, "Reducer::{}", reducer_variant_name(&reducer.name));
913                                if !reducer.params_for_generate.elements.is_empty() {
914                                    // Because we're emitting unit variants when the payload is empty,
915                                    // we will emit different patterns for empty vs non-empty variants.
916                                    // This is not strictly required;
917                                    // Rust allows matching a struct-like pattern
918                                    // against a unit-like enum variant,
919                                    // but we prefer the clarity of not including the braces for unit variants.
920                                    write!(out, " {{ .. }}");
921                                }
922                                writeln!(out, " => {:?},", reducer.name.deref());
923                            }
924                        },
925                        "}\n",
926                    );
927                },
928                "}\n",
929            );
930        },
931        "}\n",
932    );
933
934    out.delimited_block(
935        "impl TryFrom<__ws::ReducerCallInfo<__ws::BsatnFormat>> for Reducer {",
936        |out| {
937            writeln!(out, "type Error = __sdk::Error;");
938            // We define an "args struct" for each reducer in `generate_reducer`.
939            // This is not user-facing, and is not exported past the "root" `mod.rs`;
940            // it is an internal helper for serialization and deserialization.
941            // We actually want to ser/de instances of `enum Reducer`, but:
942            //
943            // - `Reducer` will have struct-like variants, which SATS ser/de does not support.
944            // - The WS format does not contain a BSATN-serialized `Reducer` instance;
945            //   it holds the reducer name or ID separately from the argument bytes.
946            //   We could work up some magic with `DeserializeSeed`
947            //   and/or custom `Serializer` and `Deserializer` types
948            //   to account for this, but it's much easier to just use an intermediate struct per reducer.
949            //
950            // As such, we deserialize from the `value.args` bytes into that "args struct,"
951            // then convert it into a `Reducer` variant via `Into::into`,
952            // which we also implement in `generate_reducer`.
953            out.delimited_block(
954                "fn try_from(value: __ws::ReducerCallInfo<__ws::BsatnFormat>) -> __sdk::Result<Self> {",
955                |out| {
956                    out.delimited_block(
957                        "match &value.reducer_name[..] {",
958                        |out| {
959                            for reducer in iter_reducers(module) {
960                                writeln!(
961                                    out,
962                                    "{:?} => Ok(__sdk::parse_reducer_args::<{}::{}>({:?}, &value.args)?.into()),",
963                                    reducer.name.deref(),
964                                    reducer_module_name(&reducer.name),
965                                    reducer_args_type_name(&reducer.name),
966                                    reducer.name.deref(),
967                                );
968                            }
969                            writeln!(
970                                out,
971                                "unknown => Err(__sdk::InternalError::unknown_name(\"reducer\", unknown, \"ReducerCallInfo\").into()),",
972                            );
973                        },
974                        "}\n",
975                    )
976                },
977                "}\n",
978            );
979        },
980        "}\n",
981    )
982}
983
984fn print_db_update_defn(module: &ModuleDef, out: &mut Indenter) {
985    writeln!(out, "#[derive(Default)]");
986    writeln!(out, "#[allow(non_snake_case)]");
987    writeln!(out, "#[doc(hidden)]");
988    out.delimited_block(
989        "pub struct DbUpdate {",
990        |out| {
991            for table in iter_tables(module) {
992                writeln!(
993                    out,
994                    "{}: __sdk::TableUpdate<{}>,",
995                    table_method_name(&table.name),
996                    type_ref_name(module, table.product_type_ref),
997                );
998            }
999        },
1000        "}\n",
1001    );
1002
1003    out.newline();
1004
1005    out.delimited_block(
1006        "
1007impl TryFrom<__ws::DatabaseUpdate<__ws::BsatnFormat>> for DbUpdate {
1008    type Error = __sdk::Error;
1009    fn try_from(raw: __ws::DatabaseUpdate<__ws::BsatnFormat>) -> Result<Self, Self::Error> {
1010        let mut db_update = DbUpdate::default();
1011        for table_update in raw.tables {
1012            match &table_update.table_name[..] {
1013",
1014        |out| {
1015            for table in iter_tables(module) {
1016                writeln!(
1017                    out,
1018                    "{:?} => db_update.{} = {}::parse_table_update(table_update)?,",
1019                    table.name.deref(),
1020                    table_method_name(&table.name),
1021                    table_module_name(&table.name),
1022                );
1023            }
1024        },
1025        "
1026                unknown => {
1027                    return Err(__sdk::InternalError::unknown_name(
1028                        \"table\",
1029                        unknown,
1030                        \"DatabaseUpdate\",
1031                    ).into());
1032                }
1033            }
1034        }
1035        Ok(db_update)
1036    }
1037}",
1038    );
1039
1040    out.newline();
1041
1042    writeln!(
1043        out,
1044        "
1045impl __sdk::InModule for DbUpdate {{
1046    type Module = RemoteModule;
1047}}
1048",
1049    );
1050
1051    out.delimited_block(
1052        "impl __sdk::DbUpdate for DbUpdate {",
1053        |out| {
1054            out.delimited_block(
1055                "fn apply_to_client_cache(&self, cache: &mut __sdk::ClientCache<RemoteModule>) -> AppliedDiff<'_> {
1056                    let mut diff = AppliedDiff::default();
1057                ",
1058                |out| {
1059                    for table in iter_tables(module) {
1060                        let with_updates = table
1061                            .primary_key
1062                            .map(|col| {
1063                                let pk_field = table.get_column(col).unwrap().name.deref().to_case(Case::Snake);
1064                                format!(".with_updates_by_pk(|row| &row.{pk_field})")
1065                            })
1066                            .unwrap_or_default();
1067
1068                        let field_name = table_method_name(&table.name);
1069                        writeln!(
1070                            out,
1071                            "diff.{field_name} = cache.apply_diff_to_table::<{}>({:?}, &self.{field_name}){with_updates};",
1072                            type_ref_name(module, table.product_type_ref),
1073                            table.name.deref(),
1074                        );
1075                    }
1076                },
1077                "
1078                    diff
1079                }\n",
1080            );
1081        },
1082        "}\n",
1083    );
1084}
1085
1086fn print_applied_diff_defn(module: &ModuleDef, out: &mut Indenter) {
1087    writeln!(out, "#[derive(Default)]");
1088    writeln!(out, "#[allow(non_snake_case)]");
1089    writeln!(out, "#[doc(hidden)]");
1090    out.delimited_block(
1091        "pub struct AppliedDiff<'r> {",
1092        |out| {
1093            for table in iter_tables(module) {
1094                writeln!(
1095                    out,
1096                    "{}: __sdk::TableAppliedDiff<'r, {}>,",
1097                    table_method_name(&table.name),
1098                    type_ref_name(module, table.product_type_ref),
1099                );
1100            }
1101        },
1102        "}\n",
1103    );
1104
1105    out.newline();
1106
1107    writeln!(
1108        out,
1109        "
1110impl __sdk::InModule for AppliedDiff<'_> {{
1111    type Module = RemoteModule;
1112}}
1113",
1114    );
1115
1116    out.delimited_block(
1117        "impl<'r> __sdk::AppliedDiff<'r> for AppliedDiff<'r> {",
1118        |out| {
1119            out.delimited_block(
1120                "fn invoke_row_callbacks(&self, event: &EventContext, callbacks: &mut __sdk::DbCallbacks<RemoteModule>) {",
1121                |out| {
1122                    for table in iter_tables(module) {
1123                        writeln!(
1124                            out,
1125                            "callbacks.invoke_table_row_callbacks::<{}>({:?}, &self.{}, event);",
1126                            type_ref_name(module, table.product_type_ref),
1127                            table.name.deref(),
1128                            table_method_name(&table.name),
1129                        );
1130                    }
1131                },
1132                "}\n",
1133            );
1134        },
1135        "}\n",
1136    );
1137}
1138
1139fn print_impl_spacetime_module(module: &ModuleDef, out: &mut Indenter) {
1140    out.delimited_block(
1141        "impl __sdk::SpacetimeModule for RemoteModule {",
1142        |out| {
1143            writeln!(
1144                out,
1145                "
1146type DbConnection = DbConnection;
1147type EventContext = EventContext;
1148type ReducerEventContext = ReducerEventContext;
1149type SubscriptionEventContext = SubscriptionEventContext;
1150type ErrorContext = ErrorContext;
1151type Reducer = Reducer;
1152type DbView = RemoteTables;
1153type Reducers = RemoteReducers;
1154type SetReducerFlags = SetReducerFlags;
1155type DbUpdate = DbUpdate;
1156type AppliedDiff<'r> = AppliedDiff<'r>;
1157type SubscriptionHandle = SubscriptionHandle;
1158"
1159            );
1160            out.delimited_block(
1161                "fn register_tables(client_cache: &mut __sdk::ClientCache<Self>) {",
1162                |out| {
1163                    for table in iter_tables(module) {
1164                        writeln!(out, "{}::register_table(client_cache);", table_module_name(&table.name));
1165                    }
1166                },
1167                "}\n",
1168            );
1169        },
1170        "}\n",
1171    );
1172}
1173
1174fn print_const_db_context_types(out: &mut Indenter) {
1175    writeln!(
1176        out,
1177        "
1178#[doc(hidden)]
1179pub struct RemoteModule;
1180
1181impl __sdk::InModule for RemoteModule {{
1182    type Module = Self;
1183}}
1184
1185/// The `reducers` field of [`EventContext`] and [`DbConnection`],
1186/// with methods provided by extension traits for each reducer defined by the module.
1187pub struct RemoteReducers {{
1188    imp: __sdk::DbContextImpl<RemoteModule>,
1189}}
1190
1191impl __sdk::InModule for RemoteReducers {{
1192    type Module = RemoteModule;
1193}}
1194
1195#[doc(hidden)]
1196/// The `set_reducer_flags` field of [`DbConnection`],
1197/// with methods provided by extension traits for each reducer defined by the module.
1198/// Each method sets the flags for the reducer with the same name.
1199///
1200/// This type is currently unstable and may be removed without a major version bump.
1201pub struct SetReducerFlags {{
1202    imp: __sdk::DbContextImpl<RemoteModule>,
1203}}
1204
1205impl __sdk::InModule for SetReducerFlags {{
1206    type Module = RemoteModule;
1207}}
1208
1209/// The `db` field of [`EventContext`] and [`DbConnection`],
1210/// with methods provided by extension traits for each table defined by the module.
1211pub struct RemoteTables {{
1212    imp: __sdk::DbContextImpl<RemoteModule>,
1213}}
1214
1215impl __sdk::InModule for RemoteTables {{
1216    type Module = RemoteModule;
1217}}
1218
1219/// A connection to a remote module, including a materialized view of a subset of the database.
1220///
1221/// Connect to a remote module by calling [`DbConnection::builder`]
1222/// and using the [`__sdk::DbConnectionBuilder`] builder-pattern constructor.
1223///
1224/// You must explicitly advance the connection by calling any one of:
1225///
1226/// - [`DbConnection::frame_tick`].
1227/// - [`DbConnection::run_threaded`].
1228/// - [`DbConnection::run_async`].
1229/// - [`DbConnection::advance_one_message`].
1230/// - [`DbConnection::advance_one_message_blocking`].
1231/// - [`DbConnection::advance_one_message_async`].
1232///
1233/// Which of these methods you should call depends on the specific needs of your application,
1234/// but you must call one of them, or else the connection will never progress.
1235pub struct DbConnection {{
1236    /// Access to tables defined by the module via extension traits implemented for [`RemoteTables`].
1237    pub db: RemoteTables,
1238    /// Access to reducers defined by the module via extension traits implemented for [`RemoteReducers`].
1239    pub reducers: RemoteReducers,
1240    #[doc(hidden)]
1241    /// Access to setting the call-flags of each reducer defined for each reducer defined by the module
1242    /// via extension traits implemented for [`SetReducerFlags`].
1243    ///
1244    /// This type is currently unstable and may be removed without a major version bump.
1245    pub set_reducer_flags: SetReducerFlags,
1246
1247    imp: __sdk::DbContextImpl<RemoteModule>,
1248}}
1249
1250impl __sdk::InModule for DbConnection {{
1251    type Module = RemoteModule;
1252}}
1253
1254impl __sdk::DbContext for DbConnection {{
1255    type DbView = RemoteTables;
1256    type Reducers = RemoteReducers;
1257    type SetReducerFlags = SetReducerFlags;
1258
1259    fn db(&self) -> &Self::DbView {{
1260        &self.db
1261    }}
1262    fn reducers(&self) -> &Self::Reducers {{
1263        &self.reducers
1264    }}
1265    fn set_reducer_flags(&self) -> &Self::SetReducerFlags {{
1266        &self.set_reducer_flags
1267    }}
1268
1269    fn is_active(&self) -> bool {{
1270        self.imp.is_active()
1271    }}
1272
1273    fn disconnect(&self) -> __sdk::Result<()> {{
1274        self.imp.disconnect()
1275    }}
1276
1277    type SubscriptionBuilder = __sdk::SubscriptionBuilder<RemoteModule>;
1278
1279    fn subscription_builder(&self) -> Self::SubscriptionBuilder {{
1280        __sdk::SubscriptionBuilder::new(&self.imp)
1281    }}
1282
1283    fn try_identity(&self) -> Option<__sdk::Identity> {{
1284        self.imp.try_identity()
1285    }}
1286    fn connection_id(&self) -> __sdk::ConnectionId {{
1287        self.imp.connection_id()
1288    }}
1289}}
1290
1291impl DbConnection {{
1292    /// Builder-pattern constructor for a connection to a remote module.
1293    ///
1294    /// See [`__sdk::DbConnectionBuilder`] for required and optional configuration for the new connection.
1295    pub fn builder() -> __sdk::DbConnectionBuilder<RemoteModule> {{
1296        __sdk::DbConnectionBuilder::new()
1297    }}
1298
1299    /// If any WebSocket messages are waiting, process one of them.
1300    ///
1301    /// Returns `true` if a message was processed, or `false` if the queue is empty.
1302    /// Callers should invoke this message in a loop until it returns `false`
1303    /// or for as much time is available to process messages.
1304    ///
1305    /// Returns an error if the connection is disconnected.
1306    /// If the disconnection in question was normal,
1307    ///  i.e. the result of a call to [`__sdk::DbContext::disconnect`],
1308    /// the returned error will be downcastable to [`__sdk::DisconnectedError`].
1309    ///
1310    /// This is a low-level primitive exposed for power users who need significant control over scheduling.
1311    /// Most applications should call [`Self::frame_tick`] each frame
1312    /// to fully exhaust the queue whenever time is available.
1313    pub fn advance_one_message(&self) -> __sdk::Result<bool> {{
1314        self.imp.advance_one_message()
1315    }}
1316
1317    /// Process one WebSocket message, potentially blocking the current thread until one is received.
1318    ///
1319    /// Returns an error if the connection is disconnected.
1320    /// If the disconnection in question was normal,
1321    ///  i.e. the result of a call to [`__sdk::DbContext::disconnect`],
1322    /// the returned error will be downcastable to [`__sdk::DisconnectedError`].
1323    ///
1324    /// This is a low-level primitive exposed for power users who need significant control over scheduling.
1325    /// Most applications should call [`Self::run_threaded`] to spawn a thread
1326    /// which advances the connection automatically.
1327    pub fn advance_one_message_blocking(&self) -> __sdk::Result<()> {{
1328        self.imp.advance_one_message_blocking()
1329    }}
1330
1331    /// Process one WebSocket message, `await`ing until one is received.
1332    ///
1333    /// Returns an error if the connection is disconnected.
1334    /// If the disconnection in question was normal,
1335    ///  i.e. the result of a call to [`__sdk::DbContext::disconnect`],
1336    /// the returned error will be downcastable to [`__sdk::DisconnectedError`].
1337    ///
1338    /// This is a low-level primitive exposed for power users who need significant control over scheduling.
1339    /// Most applications should call [`Self::run_async`] to run an `async` loop
1340    /// which advances the connection when polled.
1341    pub async fn advance_one_message_async(&self) -> __sdk::Result<()> {{
1342        self.imp.advance_one_message_async().await
1343    }}
1344
1345    /// Process all WebSocket messages waiting in the queue,
1346    /// then return without `await`ing or blocking the current thread.
1347    pub fn frame_tick(&self) -> __sdk::Result<()> {{
1348        self.imp.frame_tick()
1349    }}
1350
1351    /// Spawn a thread which processes WebSocket messages as they are received.
1352    pub fn run_threaded(&self) -> std::thread::JoinHandle<()> {{
1353        self.imp.run_threaded()
1354    }}
1355
1356    /// Run an `async` loop which processes WebSocket messages when polled.
1357    pub async fn run_async(&self) -> __sdk::Result<()> {{
1358        self.imp.run_async().await
1359    }}
1360}}
1361
1362impl __sdk::DbConnection for DbConnection {{
1363    fn new(imp: __sdk::DbContextImpl<RemoteModule>) -> Self {{
1364        Self {{
1365            db: RemoteTables {{ imp: imp.clone() }},
1366            reducers: RemoteReducers {{ imp: imp.clone() }},
1367            set_reducer_flags: SetReducerFlags {{ imp: imp.clone() }},
1368            imp,
1369        }}
1370    }}
1371}}
1372
1373/// A handle on a subscribed query.
1374// TODO: Document this better after implementing the new subscription API.
1375#[derive(Clone)]
1376pub struct SubscriptionHandle {{
1377    imp: __sdk::SubscriptionHandleImpl<RemoteModule>,
1378}}
1379
1380impl __sdk::InModule for SubscriptionHandle {{
1381    type Module = RemoteModule;
1382}}
1383
1384impl __sdk::SubscriptionHandle for SubscriptionHandle {{
1385    fn new(imp: __sdk::SubscriptionHandleImpl<RemoteModule>) -> Self {{
1386        Self {{ imp }}
1387    }}
1388
1389    /// Returns true if this subscription has been terminated due to an unsubscribe call or an error.
1390    fn is_ended(&self) -> bool {{
1391        self.imp.is_ended()
1392    }}
1393
1394    /// Returns true if this subscription has been applied and has not yet been unsubscribed.
1395    fn is_active(&self) -> bool {{
1396        self.imp.is_active()
1397    }}
1398
1399    /// Unsubscribe from the query controlled by this `SubscriptionHandle`,
1400    /// then run `on_end` when its rows are removed from the client cache.
1401    fn unsubscribe_then(self, on_end: __sdk::OnEndedCallback<RemoteModule>) -> __sdk::Result<()> {{
1402        self.imp.unsubscribe_then(Some(on_end))
1403    }}
1404
1405    fn unsubscribe(self) -> __sdk::Result<()> {{
1406        self.imp.unsubscribe_then(None)
1407    }}
1408
1409}}
1410
1411/// Alias trait for a [`__sdk::DbContext`] connected to this module,
1412/// with that trait's associated types bounded to this module's concrete types.
1413///
1414/// Users can use this trait as a boundary on definitions which should accept
1415/// either a [`DbConnection`] or an [`EventContext`] and operate on either.
1416pub trait RemoteDbContext: __sdk::DbContext<
1417    DbView = RemoteTables,
1418    Reducers = RemoteReducers,
1419    SetReducerFlags = SetReducerFlags,
1420    SubscriptionBuilder = __sdk::SubscriptionBuilder<RemoteModule>,
1421> {{}}
1422impl<Ctx: __sdk::DbContext<
1423    DbView = RemoteTables,
1424    Reducers = RemoteReducers,
1425    SetReducerFlags = SetReducerFlags,
1426    SubscriptionBuilder = __sdk::SubscriptionBuilder<RemoteModule>,
1427>> RemoteDbContext for Ctx {{}}
1428",
1429    );
1430
1431    define_event_context(
1432        out,
1433        "EventContext",
1434        Some("__sdk::Event<Reducer>"),
1435        "[`__sdk::Table::on_insert`], [`__sdk::Table::on_delete`] and [`__sdk::TableWithPrimaryKey::on_update`] callbacks",
1436        Some("[`__sdk::Event`]"),
1437    );
1438
1439    define_event_context(
1440        out,
1441        "ReducerEventContext",
1442        Some("__sdk::ReducerEvent<Reducer>"),
1443        "on-reducer callbacks", // There's no single trait or method for reducer callbacks, so we can't usefully link to them.
1444        Some("[`__sdk::ReducerEvent`]"),
1445    );
1446
1447    define_event_context(
1448        out,
1449        "SubscriptionEventContext",
1450        None, // SubscriptionEventContexts have no additional `event` info, so they don't even get that field.
1451        "[`__sdk::SubscriptionBuilder::on_applied`] and [`SubscriptionHandle::unsubscribe_then`] callbacks",
1452        None,
1453    );
1454
1455    define_event_context(
1456        out,
1457        "ErrorContext",
1458        Some("Option<__sdk::Error>"),
1459        "[`__sdk::DbConnectionBuilder::on_disconnect`], [`__sdk::DbConnectionBuilder::on_connect_error`] and [`__sdk::SubscriptionBuilder::on_error`] callbacks",
1460        Some("[`__sdk::Error`]"),
1461    );
1462}
1463
1464/// Define a type that implements `AbstractEventContext` and one of its concrete subtraits.
1465///
1466/// `struct_and_trait_name` should be the name of an event context trait,
1467/// and will also be used as the new struct's name.
1468///
1469/// `event_type`, if `Some`, should be a Rust type which will be the type of the new struct's `event` field.
1470/// If `None`, the new struct will not have such a field.
1471/// The `SubscriptionEventContext` will pass `None`, since there is no useful information to add.
1472///
1473/// `passed_to_callbacks_doc_link` should be a rustdoc-formatted phrase
1474/// which links to the callback-registering functions for the callbacks which accept this event context type.
1475/// It should be of the form "foo callbacks" or "foo, bar and baz callbacks",
1476/// with link formatting where appropriate, and no trailing punctuation.
1477///
1478/// If `event_type` is `Some`, `event_type_doc_link` should be as well.
1479/// It should be a rustdoc-formatted link (including square brackets and all) to the `event_type`.
1480/// This may differ (in the `strcmp` sense) from `event_type` because it should not inlcude generic parameters.
1481fn define_event_context(
1482    out: &mut Indenter,
1483    struct_and_trait_name: &str,
1484    event_type: Option<&str>,
1485    passed_to_callbacks_doc_link: &str,
1486    event_type_doc_link: Option<&str>,
1487) {
1488    if let (Some(event_type), Some(event_type_doc_link)) = (event_type, event_type_doc_link) {
1489        write!(
1490            out,
1491            "
1492/// An [`__sdk::DbContext`] augmented with a {event_type_doc_link},
1493/// passed to {passed_to_callbacks_doc_link}.
1494pub struct {struct_and_trait_name} {{
1495    /// Access to tables defined by the module via extension traits implemented for [`RemoteTables`].
1496    pub db: RemoteTables,
1497    /// Access to reducers defined by the module via extension traits implemented for [`RemoteReducers`].
1498    pub reducers: RemoteReducers,
1499    /// Access to setting the call-flags of each reducer defined for each reducer defined by the module
1500    /// via extension traits implemented for [`SetReducerFlags`].
1501    ///
1502    /// This type is currently unstable and may be removed without a major version bump.
1503    pub set_reducer_flags: SetReducerFlags,
1504    /// The event which caused these callbacks to run.
1505    pub event: {event_type},
1506    imp: __sdk::DbContextImpl<RemoteModule>,
1507}}
1508
1509impl __sdk::AbstractEventContext for {struct_and_trait_name} {{
1510    type Event = {event_type};
1511    fn event(&self) -> &Self::Event {{
1512        &self.event
1513    }}
1514    fn new(imp: __sdk::DbContextImpl<RemoteModule>, event: Self::Event) -> Self {{
1515        Self {{
1516            db: RemoteTables {{ imp: imp.clone() }},
1517            reducers: RemoteReducers {{ imp: imp.clone() }},
1518            set_reducer_flags: SetReducerFlags {{ imp: imp.clone() }},
1519            event,
1520            imp,
1521        }}
1522    }}
1523}}
1524",
1525        );
1526    } else {
1527        debug_assert!(event_type.is_none() && event_type_doc_link.is_none());
1528        write!(
1529            out,
1530            "
1531/// An [`__sdk::DbContext`] passed to {passed_to_callbacks_doc_link}.
1532pub struct {struct_and_trait_name} {{
1533    /// Access to tables defined by the module via extension traits implemented for [`RemoteTables`].
1534    pub db: RemoteTables,
1535    /// Access to reducers defined by the module via extension traits implemented for [`RemoteReducers`].
1536    pub reducers: RemoteReducers,
1537    /// Access to setting the call-flags of each reducer defined for each reducer defined by the module
1538    /// via extension traits implemented for [`SetReducerFlags`].
1539    ///
1540    /// This type is currently unstable and may be removed without a major version bump.
1541    pub set_reducer_flags: SetReducerFlags,
1542    imp: __sdk::DbContextImpl<RemoteModule>,
1543}}
1544
1545impl __sdk::AbstractEventContext for {struct_and_trait_name} {{
1546    type Event = ();
1547    fn event(&self) -> &Self::Event {{
1548        &()
1549    }}
1550    fn new(imp: __sdk::DbContextImpl<RemoteModule>, _event: Self::Event) -> Self {{
1551        Self {{
1552            db: RemoteTables {{ imp: imp.clone() }},
1553            reducers: RemoteReducers {{ imp: imp.clone() }},
1554            set_reducer_flags: SetReducerFlags {{ imp: imp.clone() }},
1555            imp,
1556        }}
1557    }}
1558}}
1559",
1560        );
1561    }
1562
1563    write!(
1564        out,
1565        "
1566impl __sdk::InModule for {struct_and_trait_name} {{
1567    type Module = RemoteModule;
1568}}
1569
1570impl __sdk::DbContext for {struct_and_trait_name} {{
1571    type DbView = RemoteTables;
1572    type Reducers = RemoteReducers;
1573    type SetReducerFlags = SetReducerFlags;
1574
1575    fn db(&self) -> &Self::DbView {{
1576        &self.db
1577    }}
1578    fn reducers(&self) -> &Self::Reducers {{
1579        &self.reducers
1580    }}
1581    fn set_reducer_flags(&self) -> &Self::SetReducerFlags {{
1582        &self.set_reducer_flags
1583    }}
1584
1585    fn is_active(&self) -> bool {{
1586        self.imp.is_active()
1587    }}
1588
1589    fn disconnect(&self) -> __sdk::Result<()> {{
1590        self.imp.disconnect()
1591    }}
1592
1593    type SubscriptionBuilder = __sdk::SubscriptionBuilder<RemoteModule>;
1594
1595    fn subscription_builder(&self) -> Self::SubscriptionBuilder {{
1596        __sdk::SubscriptionBuilder::new(&self.imp)
1597    }}
1598
1599    fn try_identity(&self) -> Option<__sdk::Identity> {{
1600        self.imp.try_identity()
1601    }}
1602    fn connection_id(&self) -> __sdk::ConnectionId {{
1603        self.imp.connection_id()
1604    }}
1605}}
1606
1607impl __sdk::{struct_and_trait_name} for {struct_and_trait_name} {{}}
1608"
1609    );
1610}
1611
1612/// Print `use super::` imports for each of the `imports`.
1613fn print_imports(module: &ModuleDef, out: &mut Indenter, imports: Imports) {
1614    for typeref in imports {
1615        let module_name = type_ref_module_name(module, typeref);
1616        let type_name = type_ref_name(module, typeref);
1617        writeln!(out, "use super::{module_name}::{type_name};");
1618    }
1619}
1620
1621/// Use `search_function` on `roots` to detect required imports, then print them with `print_imports`.
1622///
1623/// `this_file` is passed and excluded for the case of recursive types:
1624/// without it, the definition for a type like `struct Foo { foos: Vec<Foo> }`
1625/// would attempt to include `import super::foo::Foo`, which fails to compile.
1626fn gen_and_print_imports(
1627    module: &ModuleDef,
1628    out: &mut Indenter,
1629    roots: &[(Identifier, AlgebraicTypeUse)],
1630    dont_import: &[AlgebraicTypeRef],
1631) {
1632    let mut imports = BTreeSet::new();
1633
1634    for (_, ty) in roots {
1635        ty.for_each_ref(|r| {
1636            imports.insert(r);
1637        });
1638    }
1639    for skip in dont_import {
1640        imports.remove(skip);
1641    }
1642
1643    print_imports(module, out, imports);
1644}