Skip to main content

salsa_macro_rules/
setup_tracked_struct.rs

1/// Macro for setting up a function that must intern its arguments.
2#[macro_export]
3macro_rules! setup_tracked_struct {
4    (
5        // Attributes on the function.
6        attrs: [$(#[$attr:meta]),*],
7
8        // Visibility of the struct.
9        vis: $vis:vis,
10
11        // Name of the struct.
12        Struct: $Struct:ident,
13
14        // Name of the `'db` lifetime that the user gave.
15        db_lt: $db_lt:lifetime,
16
17        // Name user gave for `new`.
18        new_fn: $new_fn:ident,
19
20        // Field names.
21        field_ids: [$($field_id:ident),*],
22
23        // Tracked field names.
24        tracked_ids: [$($tracked_id:ident),*],
25
26        // Visibility and names of tracked fields.
27        tracked_getters: [$($tracked_getter_vis:vis $tracked_getter_id:ident),*],
28
29        // Visibility and names of untracked fields.
30        untracked_getters: [$($untracked_getter_vis:vis $untracked_getter_id:ident),*],
31
32        // Field types, may reference `db_lt`.
33        field_tys: [$($field_ty:ty),*],
34
35        // Tracked field types.
36        tracked_tys: [$($tracked_ty:ty),*],
37
38        // Untracked field types.
39        untracked_tys: [$($untracked_ty:ty),*],
40
41        // Indices for each field from 0..N -- must be unsuffixed (e.g., `0`, `1`).
42        field_indices: [$($field_index:tt),*],
43
44        // Absolute indices of any tracked fields, relative to all other fields of this struct.
45        absolute_tracked_indices: [$($absolute_tracked_index:tt),*],
46
47        // Indices of any tracked fields, relative to only tracked fields on this struct.
48        relative_tracked_indices: [$($relative_tracked_index:tt),*],
49
50        // Absolute indices of any untracked fields.
51        absolute_untracked_indices: [$($absolute_untracked_index:tt),*],
52
53        // Tracked field types.
54        tracked_maybe_updates: [$($tracked_maybe_update:tt),*],
55
56        // Untracked field types.
57        untracked_maybe_updates: [$($untracked_maybe_update:tt),*],
58
59        // A set of "field options" for each tracked field.
60        //
61        // Each field option is a tuple `(return_mode, maybe_default)` where:
62        //
63        // * `return_mode` is an identifier as specified in `salsa_macros::options::Option::returns`
64        // * `maybe_default` is either the identifier `default` or `required`
65        //
66        // These are used to drive conditional logic for each field via recursive macro invocation
67        // (see e.g. @return_mode below).
68        tracked_options: [$($tracked_option:tt),*],
69
70        // A set of "field options" for each untracked field.
71        // (see docs for `tracked_options`).
72        untracked_options: [$($untracked_option:tt),*],
73
74        // Attrs for each field.
75        tracked_field_attrs: [$([$(#[$tracked_field_attr:meta]),*]),*],
76        untracked_field_attrs: [$([$(#[$untracked_field_attr:meta]),*]),*],
77
78        // Number of tracked fields.
79        num_tracked_fields: $N:literal,
80
81        // If true, generate a debug impl.
82        generate_debug_impl: $generate_debug_impl:tt,
83
84        // The function used to implement `C::heap_size`.
85        heap_size_fn: $($heap_size_fn:path)?,
86
87        // If `true`, `serialize_fn` and `deserialize_fn` have been provided.
88        persist: $persist:tt,
89
90        // The path to the `serialize` function for the value's fields.
91        serialize_fn: $($serialize_fn:path)?,
92
93        // The path to the `serialize` function for the value's fields.
94        deserialize_fn: $($deserialize_fn:path)?,
95
96        // Annoyingly macro-rules hygiene does not extend to items defined in the macro.
97        // We have the procedural macro generate names for those items that are
98        // not used elsewhere in the user's code.
99        unused_names: [
100            $zalsa:ident,
101            $zalsa_struct:ident,
102            $Configuration:ident,
103            $CACHE:ident,
104            $Db:ident,
105            $Revision:ident,
106        ]
107    ) => {
108        $(#[$attr])*
109        #[derive(Copy, Clone, PartialEq, Eq, Hash)]
110        $vis struct $Struct<$db_lt>(
111            ::salsa::Id,
112            ::std::marker::PhantomData<fn() -> &$db_lt ()>
113        );
114
115        #[allow(dead_code)]
116        #[allow(clippy::all)]
117        const _: () = {
118            use ::salsa::plumbing as $zalsa;
119            use $zalsa::tracked_struct as $zalsa_struct;
120            use $zalsa::Revision as $Revision;
121
122            type $Configuration = $Struct<'static>;
123
124            impl<$db_lt> $zalsa::HasJar for $Struct<$db_lt> {
125                type Jar = $zalsa_struct::JarImpl<$Configuration>;
126                const KIND: $zalsa::JarKind = $zalsa::JarKind::Struct;
127            }
128
129            $zalsa::register_jar! {
130                $zalsa::ErasedJar::erase::<$Struct<'static>>()
131            }
132
133            impl $zalsa_struct::Configuration for $Configuration {
134                const LOCATION: $zalsa::Location = $zalsa::Location {
135                    file: file!(),
136                    line: line!(),
137                };
138                const DEBUG_NAME: &'static str = stringify!($Struct);
139
140                const TRACKED_FIELD_NAMES: &'static [&'static str] = &[
141                    $(stringify!($tracked_id),)*
142                ];
143
144                const TRACKED_FIELD_INDICES: &'static [usize] = &[
145                    $($relative_tracked_index,)*
146                ];
147
148                const PERSIST: bool = $persist;
149
150                type Fields<$db_lt> = ($($field_ty,)*);
151
152                type Revisions = [$zalsa::AtomicRevision; $N];
153
154                type Struct<$db_lt> = $Struct<$db_lt>;
155
156                fn untracked_fields(fields: &Self::Fields<'_>) -> impl ::std::hash::Hash {
157                    ( $( &fields.$absolute_untracked_index ),* )
158                }
159
160                fn new_revisions(current_revision: $Revision) -> Self::Revisions {
161                    std::array::from_fn(|_| $zalsa::AtomicRevision::new(current_revision))
162                }
163
164                unsafe fn update_fields<$db_lt>(
165                    current_revision: $Revision,
166                    revisions: &Self::Revisions,
167                    old_fields: *mut Self::Fields<$db_lt>,
168                    new_fields: Self::Fields<$db_lt>,
169                ) -> bool {
170                    use $zalsa::UpdateFallback as _;
171                    unsafe {
172                        $(
173                            if $tracked_maybe_update(std::ptr::addr_of_mut!((*old_fields).$absolute_tracked_index), new_fields.$absolute_tracked_index) {
174                                revisions[$relative_tracked_index].store(current_revision);
175                            }
176                        )*;
177
178                        // If any untracked field has changed, return `true`, indicating that the tracked struct
179                        // itself should be considered changed.
180                        $(
181                            $untracked_maybe_update(
182                                &mut (*old_fields).$absolute_untracked_index,
183                                new_fields.$absolute_untracked_index,
184                            )
185                            |
186                        )* false
187                    }
188                }
189
190                $(
191                    fn heap_size(value: &Self::Fields<'_>) -> Option<usize> {
192                        Some($heap_size_fn(value))
193                    }
194                )?
195
196                fn serialize<S: $zalsa::serde::Serializer>(
197                    fields: &Self::Fields<'_>,
198                    serializer: S,
199                ) -> ::std::result::Result<S::Ok, S::Error> {
200                    $zalsa::macro_if! {
201                        if $persist {
202                            $($serialize_fn(fields, serializer))?
203                        } else {
204                            panic!("attempted to serialize value not marked with `persist` attribute")
205                        }
206                    }
207                }
208
209                fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>(
210                    deserializer: D,
211                ) -> ::std::result::Result<Self::Fields<'static>, D::Error> {
212                    $zalsa::macro_if! {
213                        if $persist {
214                            $($deserialize_fn(deserializer))?
215                        } else {
216                            panic!("attempted to deserialize value not marked with `persist` attribute")
217                        }
218                    }
219                }
220            }
221
222            impl $Configuration {
223                pub fn ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl<Self> {
224                    Self::ingredient_(db.zalsa())
225                }
226
227                fn ingredient_(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl<Self> {
228                    static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> =
229                        $zalsa::IngredientCache::new();
230
231                    // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only
232                    // ingredient created by our jar is the struct ingredient.
233                    unsafe {
234                        CACHE.get_or_create(zalsa, || {
235                            zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>()
236                        })
237                    }
238                }
239            }
240
241            impl<$db_lt> $zalsa::FromId for $Struct<$db_lt> {
242                #[inline]
243                fn from_id(id: ::salsa::Id) -> Self {
244                    $Struct(id, ::std::marker::PhantomData)
245                }
246            }
247
248            impl $zalsa::AsId for $Struct<'_> {
249                #[inline]
250                fn as_id(&self) -> $zalsa::Id {
251                    self.0
252                }
253            }
254
255            impl $zalsa::SalsaStructInDb for $Struct<'_> {
256                type MemoIngredientMap = $zalsa::MemoIngredientSingletonIndex;
257
258                fn lookup_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices {
259                    aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into()
260                }
261
262                fn entries(
263                    zalsa: &$zalsa::Zalsa
264                ) -> impl Iterator<Item = $zalsa::DatabaseKeyIndex> + '_ {
265                    let ingredient_index = zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>();
266                    <$Configuration>::ingredient_(zalsa).entries(zalsa).map(|entry| entry.key())
267                }
268
269                #[inline]
270                fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option<Self> {
271                    if type_id == $zalsa::TypeId::of::<$Struct<'static>>() {
272                        $zalsa::Some(<$Struct<'static> as $zalsa::FromId>::from_id(id))
273                    } else {
274                        $zalsa::None
275                    }
276                }
277
278                #[inline]
279                unsafe fn memo_table(
280                    zalsa: &$zalsa::Zalsa,
281                    id: $zalsa::Id,
282                    current_revision: $zalsa::Revision,
283                ) -> $zalsa::MemoTableWithTypes<'_> {
284                    // SAFETY: Guaranteed by caller.
285                    unsafe { zalsa.table().memos::<$zalsa_struct::Value<$Configuration>>(id, current_revision) }
286                }
287            }
288
289            impl $zalsa::TrackedStructInDb for $Struct<'_> {
290                fn database_key_index(zalsa: &$zalsa::Zalsa, id: $zalsa::Id) -> $zalsa::DatabaseKeyIndex {
291                    $Configuration::ingredient_(zalsa).database_key_index(id)
292                }
293            }
294
295            $zalsa::macro_if! { $persist =>
296                impl $zalsa::serde::Serialize for $Struct<'_> {
297                    fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
298                    where
299                        S: $zalsa::serde::Serializer,
300                    {
301                        $zalsa::serde::Serialize::serialize(&$zalsa::AsId::as_id(self), serializer)
302                    }
303                }
304
305                impl<'de> $zalsa::serde::Deserialize<'de> for $Struct<'_> {
306                    fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
307                    where
308                        D: $zalsa::serde::Deserializer<'de>,
309                    {
310                        let id = $zalsa::Id::deserialize(deserializer)?;
311                        Ok($zalsa::FromId::from_id(id))
312                    }
313                }
314            }
315
316
317            unsafe impl Send for $Struct<'_> {}
318
319            unsafe impl Sync for $Struct<'_> {}
320
321            $zalsa::macro_if! { $generate_debug_impl =>
322                impl ::std::fmt::Debug for $Struct<'_> {
323                    fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
324                        Self::default_debug_fmt(*self, f)
325                    }
326                }
327            }
328
329            unsafe impl $zalsa::Update for $Struct<'_> {
330                unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
331                    if unsafe { *old_pointer } != new_value {
332                        unsafe { *old_pointer = new_value };
333                        true
334                    } else {
335                        false
336                    }
337                }
338            }
339
340            impl<$db_lt> $Struct<$db_lt> {
341                pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> Self
342                where
343                    // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
344                    $Db: ?Sized + $zalsa::Database,
345                {
346                    let (zalsa, zalsa_local) = db.zalsas();
347                    $Configuration::ingredient_(zalsa).new_struct(
348                        zalsa,zalsa_local,
349                        ($($field_id,)*)
350                    )
351                }
352
353                $(
354                    $(#[$tracked_field_attr])*
355                    $tracked_getter_vis fn $tracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($tracked_option, $db_lt, $tracked_ty)
356                    where
357                        // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
358                        $Db: ?Sized + $zalsa::Database,
359                    {
360                        let (zalsa, zalsa_local) = db.zalsas();
361                        let fields = $Configuration::ingredient_(zalsa).tracked_field(zalsa, zalsa_local, self, $relative_tracked_index);
362                        $crate::return_mode_expression!(
363                            $tracked_option,
364                            $tracked_ty,
365                            &fields.$absolute_tracked_index,
366                        )
367                    }
368                )*
369
370                $(
371                    $(#[$untracked_field_attr])*
372                    $untracked_getter_vis fn $untracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($untracked_option, $db_lt, $untracked_ty)
373                    where
374                        // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
375                        $Db: ?Sized + $zalsa::Database,
376                    {
377                        let zalsa = db.zalsa();
378                        let fields = $Configuration::ingredient_(zalsa).untracked_field(zalsa, self);
379                        $crate::return_mode_expression!(
380                            $untracked_option,
381                            $untracked_ty,
382                            &fields.$absolute_untracked_index,
383                        )
384                    }
385                )*
386            }
387
388            #[allow(unused_lifetimes)]
389            impl<'_db> $Struct<'_db> {
390                /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl)
391                pub fn default_debug_fmt(this: Self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result
392                where
393                    // `zalsa::with_attached_database` has a local lifetime for the database
394                    // so we need this function to be higher-ranked over the db lifetime
395                    // Thus the actual lifetime of `Self` does not matter here so we discard
396                    // it with the `'_db` lifetime name as we cannot shadow lifetimes.
397                    $(for<$db_lt> $field_ty: ::std::fmt::Debug),*
398                {
399                    $zalsa::with_attached_database(|db| {
400                        let zalsa = db.zalsa();
401                        let fields = $Configuration::ingredient_(zalsa).leak_fields(zalsa, this);
402                        let mut f = f.debug_struct(stringify!($Struct));
403                        let f = f.field("[salsa id]", &$zalsa::AsId::as_id(&this));
404                        $(
405                            let f = f.field(stringify!($field_id), &fields.$field_index);
406                        )*
407                        f.finish()
408                    }).unwrap_or_else(|| {
409                        f.debug_struct(stringify!($Struct))
410                            .field("[salsa id]", &$zalsa::AsId::as_id(&this))
411                            .finish()
412                    })
413                }
414            }
415        };
416    };
417}