salsa/
database.rs

1use std::borrow::Cow;
2use std::ptr::NonNull;
3
4use crate::views::DatabaseDownCaster;
5use crate::zalsa::{IngredientIndex, ZalsaDatabase};
6use crate::{Durability, Revision};
7
8#[derive(Copy, Clone)]
9pub struct RawDatabase<'db> {
10    pub(crate) ptr: NonNull<()>,
11    _marker: std::marker::PhantomData<&'db dyn Database>,
12}
13
14impl<'db, Db: Database + ?Sized> From<&'db Db> for RawDatabase<'db> {
15    #[inline]
16    fn from(db: &'db Db) -> Self {
17        RawDatabase {
18            ptr: NonNull::from(db).cast(),
19            _marker: std::marker::PhantomData,
20        }
21    }
22}
23
24impl<'db, Db: Database + ?Sized> From<&'db mut Db> for RawDatabase<'db> {
25    #[inline]
26    fn from(db: &'db mut Db) -> Self {
27        RawDatabase {
28            ptr: NonNull::from(db).cast(),
29            _marker: std::marker::PhantomData,
30        }
31    }
32}
33
34/// The trait implemented by all Salsa databases.
35/// You can create your own subtraits of this trait using the `#[salsa::db]`(`crate::db`) procedural macro.
36pub trait Database: Send + ZalsaDatabase + AsDynDatabase {
37    /// Enforces current LRU limits, evicting entries if necessary.
38    ///
39    /// **WARNING:** Just like an ordinary write, this method triggers
40    /// cancellation. If you invoke it while a snapshot exists, it
41    /// will block until that snapshot is dropped -- if that snapshot
42    /// is owned by the current thread, this could trigger deadlock.
43    fn trigger_lru_eviction(&mut self) {
44        let zalsa_mut = self.zalsa_mut();
45        zalsa_mut.evict_lru();
46    }
47
48    /// A "synthetic write" causes the system to act *as though* some
49    /// input of durability `durability` has changed, triggering a new revision.
50    /// This is mostly useful for profiling scenarios.
51    ///
52    /// **WARNING:** Just like an ordinary write, this method triggers
53    /// cancellation. If you invoke it while a snapshot exists, it
54    /// will block until that snapshot is dropped -- if that snapshot
55    /// is owned by the current thread, this could trigger deadlock.
56    fn synthetic_write(&mut self, durability: Durability) {
57        let zalsa_mut = self.zalsa_mut();
58        zalsa_mut.new_revision();
59        zalsa_mut.runtime_mut().report_tracked_write(durability);
60    }
61
62    /// This method triggers cancellation.
63    /// If you invoke it while a snapshot exists, it
64    /// will block until that snapshot is dropped -- if that snapshot
65    /// is owned by the current thread, this could trigger deadlock.
66    fn trigger_cancellation(&mut self) {
67        let _ = self.zalsa_mut();
68    }
69
70    /// Reports that the query depends on some state unknown to salsa.
71    ///
72    /// Queries which report untracked reads will be re-executed in the next
73    /// revision.
74    fn report_untracked_read(&self) {
75        let (zalsa, zalsa_local) = self.zalsas();
76        zalsa_local.report_untracked_read(zalsa.current_revision())
77    }
78
79    /// Return the "debug name" (i.e., the struct name, etc) for an "ingredient",
80    /// which are the fine-grained components we use to track data. This is intended
81    /// for debugging and the contents of the returned string are not semver-guaranteed.
82    ///
83    /// Ingredient indices can be extracted from [`DatabaseKeyIndex`](`crate::DatabaseKeyIndex`) values.
84    fn ingredient_debug_name(&self, ingredient_index: IngredientIndex) -> Cow<'_, str> {
85        Cow::Borrowed(
86            self.zalsa()
87                .lookup_ingredient(ingredient_index)
88                .debug_name(),
89        )
90    }
91
92    /// Starts unwinding the stack if the current revision is cancelled.
93    ///
94    /// This method can be called by query implementations that perform
95    /// potentially expensive computations, in order to speed up propagation of
96    /// cancellation.
97    ///
98    /// Cancellation will automatically be triggered by salsa on any query
99    /// invocation.
100    ///
101    /// This method should not be overridden by `Database` implementors. A
102    /// `salsa_event` is emitted when this method is called, so that should be
103    /// used instead.
104    fn unwind_if_revision_cancelled(&self) {
105        let (zalsa, zalsa_local) = self.zalsas();
106        zalsa.unwind_if_revision_cancelled(zalsa_local);
107    }
108
109    /// Execute `op` with the database in thread-local storage for debug print-outs.
110    #[inline(always)]
111    fn attach<R>(&self, op: impl FnOnce(&Self) -> R) -> R
112    where
113        Self: Sized,
114    {
115        crate::attach::attach(self, || op(self))
116    }
117
118    #[cold]
119    #[inline(never)]
120    #[doc(hidden)]
121    fn zalsa_register_downcaster(&self) -> &DatabaseDownCaster<dyn Database> {
122        self.zalsa().views().downcaster_for::<dyn Database>()
123        // The no-op downcaster is special cased in view caster construction.
124    }
125
126    #[doc(hidden)]
127    #[inline(always)]
128    fn downcast(&self) -> &dyn Database
129    where
130        Self: Sized,
131    {
132        // No-op
133        self
134    }
135}
136
137/// Upcast to a `dyn Database`.
138///
139/// Only required because upcasting does not work for unsized generic parameters.
140pub trait AsDynDatabase {
141    fn as_dyn_database(&self) -> &dyn Database;
142}
143
144impl<T: Database> AsDynDatabase for T {
145    #[inline(always)]
146    fn as_dyn_database(&self) -> &dyn Database {
147        self
148    }
149}
150
151pub fn current_revision<Db: ?Sized + Database>(db: &Db) -> Revision {
152    db.zalsa().current_revision()
153}
154
155#[cfg(feature = "persistence")]
156mod persistence {
157    use crate::plumbing::Ingredient;
158    use crate::zalsa::Zalsa;
159    use crate::{Database, IngredientIndex, Runtime};
160
161    use std::fmt;
162
163    use serde::de::{self, DeserializeSeed, SeqAccess};
164    use serde::ser::SerializeMap;
165
166    impl dyn Database {
167        /// Returns a type implementing [`serde::Serialize`], that can be used to serialize the
168        /// current state of the database.
169        pub fn as_serialize(&mut self) -> impl serde::Serialize + '_ {
170            SerializeDatabase {
171                runtime: self.zalsa().runtime(),
172                ingredients: SerializeIngredients(self.zalsa()),
173            }
174        }
175
176        /// Deserialize the database using a [`serde::Deserializer`].
177        ///
178        /// This method will modify the database in-place based on the serialized data.
179        pub fn deserialize<'db, D>(&mut self, deserializer: D) -> Result<(), D::Error>
180        where
181            D: serde::Deserializer<'db>,
182        {
183            DeserializeDatabase(self.zalsa_mut()).deserialize(deserializer)
184        }
185    }
186
187    #[derive(serde::Serialize)]
188    #[serde(rename = "Database")]
189    pub struct SerializeDatabase<'db> {
190        pub runtime: &'db Runtime,
191        pub ingredients: SerializeIngredients<'db>,
192    }
193
194    pub struct SerializeIngredients<'db>(pub &'db Zalsa);
195
196    impl serde::Serialize for SerializeIngredients<'_> {
197        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
198        where
199            S: serde::Serializer,
200        {
201            let SerializeIngredients(zalsa) = self;
202
203            let mut ingredients = zalsa
204                .ingredients()
205                .filter(|ingredient| ingredient.should_serialize(zalsa))
206                .collect::<Vec<_>>();
207
208            // Ensure structs are serialized before tracked functions, as deserializing a
209            // memo requires its input struct to have been deserialized.
210            ingredients.sort_by_key(|ingredient| ingredient.jar_kind());
211
212            let mut map = serializer.serialize_map(Some(ingredients.len()))?;
213            for ingredient in ingredients {
214                map.serialize_entry(
215                    &ingredient.ingredient_index().as_u32(),
216                    &SerializeIngredient(ingredient, zalsa),
217                )?;
218            }
219
220            map.end()
221        }
222    }
223
224    struct SerializeIngredient<'db>(&'db dyn Ingredient, &'db Zalsa);
225
226    impl serde::Serialize for SerializeIngredient<'_> {
227        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
228        where
229            S: serde::Serializer,
230        {
231            let mut result = None;
232            let mut serializer = Some(serializer);
233
234            // SAFETY: `<dyn Database>::as_serialize` take `&mut self`.
235            unsafe {
236                self.0.serialize(self.1, &mut |serialize| {
237                    let serializer = serializer.take().expect(
238                        "`Ingredient::serialize` must invoke the serialization callback only once",
239                    );
240
241                    result = Some(erased_serde::serialize(&serialize, serializer))
242                })
243            };
244
245            result.expect("`Ingredient::serialize` must invoke the serialization callback")
246        }
247    }
248
249    #[derive(serde::Deserialize)]
250    #[serde(field_identifier, rename_all = "lowercase")]
251    enum DatabaseField {
252        Runtime,
253        Ingredients,
254    }
255
256    pub struct DeserializeDatabase<'db>(pub &'db mut Zalsa);
257
258    impl<'de> de::DeserializeSeed<'de> for DeserializeDatabase<'_> {
259        type Value = ();
260
261        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
262        where
263            D: de::Deserializer<'de>,
264        {
265            // Note that we have to deserialize using a manual visitor here because the
266            // `Deserialize` derive does not support fields that use `DeserializeSeed`.
267            deserializer.deserialize_struct("Database", &["runtime", "ingredients"], self)
268        }
269    }
270
271    impl<'de> serde::de::Visitor<'de> for DeserializeDatabase<'_> {
272        type Value = ();
273
274        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
275            formatter.write_str("struct Database")
276        }
277
278        fn visit_seq<V>(self, mut seq: V) -> Result<(), V::Error>
279        where
280            V: SeqAccess<'de>,
281        {
282            let mut runtime = seq
283                .next_element()?
284                .ok_or_else(|| de::Error::invalid_length(0, &self))?;
285            let () = seq
286                .next_element_seed(DeserializeIngredients(self.0))?
287                .ok_or_else(|| de::Error::invalid_length(1, &self))?;
288
289            self.0.runtime_mut().deserialize_from(&mut runtime);
290            Ok(())
291        }
292
293        fn visit_map<V>(self, mut map: V) -> Result<(), V::Error>
294        where
295            V: serde::de::MapAccess<'de>,
296        {
297            let mut runtime = None;
298            let mut ingredients = None;
299
300            while let Some(key) = map.next_key()? {
301                match key {
302                    DatabaseField::Runtime => {
303                        if runtime.is_some() {
304                            return Err(serde::de::Error::duplicate_field("runtime"));
305                        }
306
307                        runtime = Some(map.next_value()?);
308                    }
309                    DatabaseField::Ingredients => {
310                        if ingredients.is_some() {
311                            return Err(serde::de::Error::duplicate_field("ingredients"));
312                        }
313
314                        ingredients = Some(map.next_value_seed(DeserializeIngredients(self.0))?);
315                    }
316                }
317            }
318
319            let mut runtime = runtime.ok_or_else(|| serde::de::Error::missing_field("runtime"))?;
320            let () = ingredients.ok_or_else(|| serde::de::Error::missing_field("ingredients"))?;
321
322            self.0.runtime_mut().deserialize_from(&mut runtime);
323
324            Ok(())
325        }
326    }
327
328    struct DeserializeIngredients<'db>(&'db mut Zalsa);
329
330    impl<'de> serde::de::Visitor<'de> for DeserializeIngredients<'_> {
331        type Value = ();
332
333        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
334            formatter.write_str("a map")
335        }
336
337        fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
338        where
339            M: serde::de::MapAccess<'de>,
340        {
341            let DeserializeIngredients(zalsa) = self;
342
343            while let Some(index) = access.next_key::<u32>()? {
344                let index = IngredientIndex::new(index);
345
346                // Remove the ingredient temporarily, to avoid holding an overlapping mutable borrow
347                // to the ingredient as well as the database.
348                let mut ingredient = zalsa.take_ingredient(index);
349
350                // Deserialize the ingredient.
351                access.next_value_seed(DeserializeIngredient(&mut *ingredient, zalsa))?;
352
353                zalsa.replace_ingredient(index, ingredient);
354            }
355
356            Ok(())
357        }
358    }
359
360    impl<'de> serde::de::DeserializeSeed<'de> for DeserializeIngredients<'_> {
361        type Value = ();
362
363        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
364        where
365            D: serde::Deserializer<'de>,
366        {
367            deserializer.deserialize_map(self)
368        }
369    }
370
371    struct DeserializeIngredient<'db>(&'db mut dyn Ingredient, &'db mut Zalsa);
372
373    impl<'de> serde::de::DeserializeSeed<'de> for DeserializeIngredient<'_> {
374        type Value = ();
375
376        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
377        where
378            D: serde::Deserializer<'de>,
379        {
380            let deserializer = &mut <dyn erased_serde::Deserializer>::erase(deserializer);
381
382            self.0
383                .deserialize(self.1, deserializer)
384                .map_err(serde::de::Error::custom)
385        }
386    }
387}
388
389#[cfg(feature = "salsa_unstable")]
390pub use memory_usage::IngredientInfo;
391
392#[cfg(feature = "salsa_unstable")]
393pub(crate) use memory_usage::{MemoInfo, SlotInfo};
394
395#[cfg(feature = "salsa_unstable")]
396mod memory_usage {
397    use hashbrown::HashMap;
398
399    use crate::Database;
400
401    impl dyn Database {
402        /// Returns memory usage information about ingredients in the database.
403        pub fn memory_usage(&self) -> DatabaseInfo {
404            let mut queries = HashMap::new();
405            let mut structs = Vec::new();
406
407            for input_ingredient in self.zalsa().ingredients() {
408                let Some(input_info) = input_ingredient.memory_usage(self) else {
409                    continue;
410                };
411
412                let mut size_of_fields = 0;
413                let mut size_of_metadata = 0;
414                let mut count = 0;
415                let mut heap_size_of_fields = None;
416
417                for input_slot in input_info {
418                    count += 1;
419                    size_of_fields += input_slot.size_of_fields;
420                    size_of_metadata += input_slot.size_of_metadata;
421
422                    if let Some(slot_heap_size) = input_slot.heap_size_of_fields {
423                        heap_size_of_fields =
424                            Some(heap_size_of_fields.unwrap_or_default() + slot_heap_size);
425                    }
426
427                    for memo in input_slot.memos {
428                        let info = queries.entry(memo.debug_name).or_insert(IngredientInfo {
429                            debug_name: memo.output.debug_name,
430                            ..Default::default()
431                        });
432
433                        info.count += 1;
434                        info.size_of_fields += memo.output.size_of_fields;
435                        info.size_of_metadata += memo.output.size_of_metadata;
436
437                        if let Some(memo_heap_size) = memo.output.heap_size_of_fields {
438                            info.heap_size_of_fields =
439                                Some(info.heap_size_of_fields.unwrap_or_default() + memo_heap_size);
440                        }
441                    }
442                }
443
444                structs.push(IngredientInfo {
445                    count,
446                    size_of_fields,
447                    size_of_metadata,
448                    heap_size_of_fields,
449                    debug_name: input_ingredient.debug_name(),
450                });
451            }
452
453            DatabaseInfo { structs, queries }
454        }
455    }
456
457    /// Memory usage information about ingredients in the Salsa database.
458    pub struct DatabaseInfo {
459        /// Information about any Salsa structs.
460        pub structs: Vec<IngredientInfo>,
461
462        /// Memory usage information for memoized values of a given query, keyed
463        /// by the query function name.
464        pub queries: HashMap<&'static str, IngredientInfo>,
465    }
466
467    /// Information about instances of a particular Salsa ingredient.
468    #[derive(Default, Debug, PartialEq, Eq, PartialOrd, Ord)]
469    pub struct IngredientInfo {
470        debug_name: &'static str,
471        count: usize,
472        size_of_metadata: usize,
473        size_of_fields: usize,
474        heap_size_of_fields: Option<usize>,
475    }
476
477    impl IngredientInfo {
478        /// Returns the debug name of the ingredient.
479        pub fn debug_name(&self) -> &'static str {
480            self.debug_name
481        }
482
483        /// Returns the total stack size of the fields of any instances of this ingredient, in bytes.
484        pub fn size_of_fields(&self) -> usize {
485            self.size_of_fields
486        }
487
488        /// Returns the total heap size of the fields of any instances of this ingredient, in bytes.
489        ///
490        /// Returns `None` if the ingredient doesn't specify a `heap_size` function.
491        pub fn heap_size_of_fields(&self) -> Option<usize> {
492            self.heap_size_of_fields
493        }
494
495        /// Returns the total size of Salsa metadata of any instances of this ingredient, in bytes.
496        pub fn size_of_metadata(&self) -> usize {
497            self.size_of_metadata
498        }
499
500        /// Returns the number of instances of this ingredient.
501        pub fn count(&self) -> usize {
502            self.count
503        }
504    }
505
506    /// Memory usage information about a particular instance of struct, input or output.
507    pub struct SlotInfo {
508        pub(crate) debug_name: &'static str,
509        pub(crate) size_of_metadata: usize,
510        pub(crate) size_of_fields: usize,
511        pub(crate) heap_size_of_fields: Option<usize>,
512        pub(crate) memos: Vec<MemoInfo>,
513    }
514
515    /// Memory usage information about a particular memo.
516    pub struct MemoInfo {
517        pub(crate) debug_name: &'static str,
518        pub(crate) output: SlotInfo,
519    }
520}