Skip to main content

salsa/
database.rs

1use std::borrow::Cow;
2use std::ptr::NonNull;
3
4use crate::views::DatabaseDownCaster;
5use crate::zalsa::{IngredientIndex, ZalsaDatabase};
6use crate::zalsa_local::CancellationToken;
7use crate::{Durability, Revision};
8
9#[derive(Copy, Clone)]
10pub struct RawDatabase<'db> {
11    pub(crate) ptr: NonNull<()>,
12    _marker: std::marker::PhantomData<&'db dyn Database>,
13}
14
15impl<'db, Db: Database + ?Sized> From<&'db Db> for RawDatabase<'db> {
16    #[inline]
17    fn from(db: &'db Db) -> Self {
18        RawDatabase {
19            ptr: NonNull::from(db).cast(),
20            _marker: std::marker::PhantomData,
21        }
22    }
23}
24
25impl<'db, Db: Database + ?Sized> From<&'db mut Db> for RawDatabase<'db> {
26    #[inline]
27    fn from(db: &'db mut Db) -> Self {
28        RawDatabase {
29            ptr: NonNull::from(db).cast(),
30            _marker: std::marker::PhantomData,
31        }
32    }
33}
34
35/// The trait implemented by all Salsa databases.
36/// You can create your own subtraits of this trait using the `#[salsa::db]`(`crate::db`) procedural macro.
37pub trait Database: Send + ZalsaDatabase + AsDynDatabase {
38    /// Enforces current LRU limits, evicting entries if necessary.
39    ///
40    /// **WARNING:** Just like an ordinary write, this method triggers
41    /// cancellation. If you invoke it while a snapshot exists, it
42    /// will block until that snapshot is dropped -- if that snapshot
43    /// is owned by the current thread, this could trigger deadlock.
44    fn trigger_lru_eviction(&mut self) {
45        let zalsa_mut = self.zalsa_mut();
46        zalsa_mut.evict_lru();
47    }
48
49    /// A "synthetic write" causes the system to act *as though* some
50    /// input of durability `durability` has changed, triggering a new revision.
51    /// This is mostly useful for profiling scenarios.
52    ///
53    /// **WARNING:** Just like an ordinary write, this method triggers
54    /// cancellation. If you invoke it while a snapshot exists, it
55    /// will block until that snapshot is dropped -- if that snapshot
56    /// is owned by the current thread, this could trigger deadlock.
57    fn synthetic_write(&mut self, durability: Durability) {
58        let zalsa_mut = self.zalsa_mut();
59        zalsa_mut.new_revision();
60        zalsa_mut.runtime_mut().report_tracked_write(durability);
61    }
62
63    /// This method cancels all outstanding computations.
64    /// If you invoke it while a snapshot exists, it
65    /// will block until that snapshot is dropped -- if that snapshot
66    /// is owned by the current thread, this could trigger deadlock.
67    fn trigger_cancellation(&mut self) {
68        let _ = self.zalsa_mut();
69    }
70
71    /// Retrieves a [`CancellationToken`] for the current database handle.
72    fn cancellation_token(&self) -> CancellationToken {
73        self.zalsa_local().cancellation_token()
74    }
75
76    /// Reports that the query depends on some state unknown to salsa.
77    ///
78    /// Queries which report untracked reads will be re-executed in the next
79    /// revision.
80    fn report_untracked_read(&self) {
81        let (zalsa, zalsa_local) = self.zalsas();
82        zalsa_local.report_untracked_read(zalsa.current_revision())
83    }
84
85    /// Return the "debug name" (i.e., the struct name, etc) for an "ingredient",
86    /// which are the fine-grained components we use to track data. This is intended
87    /// for debugging and the contents of the returned string are not semver-guaranteed.
88    ///
89    /// Ingredient indices can be extracted from [`DatabaseKeyIndex`](`crate::DatabaseKeyIndex`) values.
90    fn ingredient_debug_name(&self, ingredient_index: IngredientIndex) -> Cow<'_, str> {
91        Cow::Borrowed(
92            self.zalsa()
93                .lookup_ingredient(ingredient_index)
94                .debug_name(),
95        )
96    }
97
98    /// Starts unwinding the stack if the current revision is cancelled.
99    ///
100    /// This method can be called by query implementations that perform
101    /// potentially expensive computations, in order to speed up propagation of
102    /// cancellation.
103    ///
104    /// Cancellation will automatically be triggered by salsa on any query
105    /// invocation.
106    ///
107    /// This method should not be overridden by `Database` implementors. A
108    /// `salsa_event` is emitted when this method is called, so that should be
109    /// used instead.
110    fn unwind_if_revision_cancelled(&self) {
111        let (zalsa, zalsa_local) = self.zalsas();
112        zalsa.unwind_if_revision_cancelled(zalsa_local);
113    }
114
115    /// Execute `op` with the database in thread-local storage for debug print-outs.
116    #[inline(always)]
117    fn attach<R>(&self, op: impl FnOnce(&Self) -> R) -> R
118    where
119        Self: Sized,
120    {
121        crate::attach::attach(self, || op(self))
122    }
123
124    #[cold]
125    #[inline(never)]
126    #[doc(hidden)]
127    fn zalsa_register_downcaster(&self) -> &DatabaseDownCaster<dyn Database> {
128        self.zalsa().views().downcaster_for::<dyn Database>()
129        // The no-op downcaster is special cased in view caster construction.
130    }
131
132    #[doc(hidden)]
133    #[inline(always)]
134    fn downcast(&self) -> &dyn Database
135    where
136        Self: Sized,
137    {
138        // No-op
139        self
140    }
141}
142
143/// Upcast to a `dyn Database`.
144///
145/// Only required because upcasting does not work for unsized generic parameters.
146pub trait AsDynDatabase {
147    fn as_dyn_database(&self) -> &dyn Database;
148}
149
150impl<T: Database> AsDynDatabase for T {
151    #[inline(always)]
152    fn as_dyn_database(&self) -> &dyn Database {
153        self
154    }
155}
156
157pub fn current_revision<Db: ?Sized + Database>(db: &Db) -> Revision {
158    db.zalsa().current_revision()
159}
160
161#[cfg(feature = "persistence")]
162mod persistence {
163    use crate::plumbing::Ingredient;
164    use crate::zalsa::Zalsa;
165    use crate::{Database, IngredientIndex, Runtime};
166
167    use std::fmt;
168
169    use serde::de::{self, DeserializeSeed, SeqAccess};
170    use serde::ser::SerializeMap;
171
172    impl dyn Database {
173        /// Returns a type implementing [`serde::Serialize`], that can be used to serialize the
174        /// current state of the database.
175        pub fn as_serialize(&mut self) -> impl serde::Serialize + '_ {
176            SerializeDatabase {
177                runtime: self.zalsa().runtime(),
178                ingredients: SerializeIngredients(self.zalsa()),
179            }
180        }
181
182        /// Deserialize the database using a [`serde::Deserializer`].
183        ///
184        /// This method will modify the database in-place based on the serialized data.
185        pub fn deserialize<'db, D>(&mut self, deserializer: D) -> Result<(), D::Error>
186        where
187            D: serde::Deserializer<'db>,
188        {
189            DeserializeDatabase(self.zalsa_mut()).deserialize(deserializer)
190        }
191    }
192
193    #[derive(serde::Serialize)]
194    #[serde(rename = "Database")]
195    pub struct SerializeDatabase<'db> {
196        pub runtime: &'db Runtime,
197        pub ingredients: SerializeIngredients<'db>,
198    }
199
200    pub struct SerializeIngredients<'db>(pub &'db Zalsa);
201
202    impl serde::Serialize for SerializeIngredients<'_> {
203        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
204        where
205            S: serde::Serializer,
206        {
207            let SerializeIngredients(zalsa) = self;
208
209            let mut ingredients = zalsa
210                .ingredients()
211                .filter(|ingredient| ingredient.should_serialize(zalsa))
212                .collect::<Vec<_>>();
213
214            // Ensure structs are serialized before tracked functions, as deserializing a
215            // memo requires its input struct to have been deserialized.
216            ingredients.sort_by_key(|ingredient| ingredient.jar_kind());
217
218            let mut map = serializer.serialize_map(Some(ingredients.len()))?;
219            for ingredient in ingredients {
220                map.serialize_entry(
221                    &ingredient.ingredient_index().as_u32(),
222                    &SerializeIngredient(ingredient, zalsa),
223                )?;
224            }
225
226            map.end()
227        }
228    }
229
230    struct SerializeIngredient<'db>(&'db dyn Ingredient, &'db Zalsa);
231
232    impl serde::Serialize for SerializeIngredient<'_> {
233        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
234        where
235            S: serde::Serializer,
236        {
237            let mut result = None;
238            let mut serializer = Some(serializer);
239
240            // SAFETY: `<dyn Database>::as_serialize` take `&mut self`.
241            unsafe {
242                self.0.serialize(self.1, &mut |serialize| {
243                    let serializer = serializer.take().expect(
244                        "`Ingredient::serialize` must invoke the serialization callback only once",
245                    );
246
247                    result = Some(erased_serde::serialize(&serialize, serializer))
248                })
249            };
250
251            result.expect("`Ingredient::serialize` must invoke the serialization callback")
252        }
253    }
254
255    #[derive(serde::Deserialize)]
256    #[serde(field_identifier, rename_all = "lowercase")]
257    enum DatabaseField {
258        Runtime,
259        Ingredients,
260    }
261
262    pub struct DeserializeDatabase<'db>(pub &'db mut Zalsa);
263
264    impl<'de> de::DeserializeSeed<'de> for DeserializeDatabase<'_> {
265        type Value = ();
266
267        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
268        where
269            D: de::Deserializer<'de>,
270        {
271            // Note that we have to deserialize using a manual visitor here because the
272            // `Deserialize` derive does not support fields that use `DeserializeSeed`.
273            deserializer.deserialize_struct("Database", &["runtime", "ingredients"], self)
274        }
275    }
276
277    impl<'de> serde::de::Visitor<'de> for DeserializeDatabase<'_> {
278        type Value = ();
279
280        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
281            formatter.write_str("struct Database")
282        }
283
284        fn visit_seq<V>(self, mut seq: V) -> Result<(), V::Error>
285        where
286            V: SeqAccess<'de>,
287        {
288            let mut runtime = seq
289                .next_element()?
290                .ok_or_else(|| de::Error::invalid_length(0, &self))?;
291            let () = seq
292                .next_element_seed(DeserializeIngredients(self.0))?
293                .ok_or_else(|| de::Error::invalid_length(1, &self))?;
294
295            self.0.runtime_mut().deserialize_from(&mut runtime);
296            Ok(())
297        }
298
299        fn visit_map<V>(self, mut map: V) -> Result<(), V::Error>
300        where
301            V: serde::de::MapAccess<'de>,
302        {
303            let mut runtime = None;
304            let mut ingredients = None;
305
306            while let Some(key) = map.next_key()? {
307                match key {
308                    DatabaseField::Runtime => {
309                        if runtime.is_some() {
310                            return Err(serde::de::Error::duplicate_field("runtime"));
311                        }
312
313                        runtime = Some(map.next_value()?);
314                    }
315                    DatabaseField::Ingredients => {
316                        if ingredients.is_some() {
317                            return Err(serde::de::Error::duplicate_field("ingredients"));
318                        }
319
320                        ingredients = Some(map.next_value_seed(DeserializeIngredients(self.0))?);
321                    }
322                }
323            }
324
325            let mut runtime = runtime.ok_or_else(|| serde::de::Error::missing_field("runtime"))?;
326            let () = ingredients.ok_or_else(|| serde::de::Error::missing_field("ingredients"))?;
327
328            self.0.runtime_mut().deserialize_from(&mut runtime);
329
330            Ok(())
331        }
332    }
333
334    struct DeserializeIngredients<'db>(&'db mut Zalsa);
335
336    impl<'de> serde::de::Visitor<'de> for DeserializeIngredients<'_> {
337        type Value = ();
338
339        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
340            formatter.write_str("a map")
341        }
342
343        fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
344        where
345            M: serde::de::MapAccess<'de>,
346        {
347            let DeserializeIngredients(zalsa) = self;
348
349            while let Some(index) = access.next_key::<u32>()? {
350                let index = IngredientIndex::new(index);
351
352                // Remove the ingredient temporarily, to avoid holding an overlapping mutable borrow
353                // to the ingredient as well as the database.
354                let mut ingredient = zalsa.take_ingredient(index);
355
356                // Deserialize the ingredient.
357                access.next_value_seed(DeserializeIngredient(&mut *ingredient, zalsa))?;
358
359                zalsa.replace_ingredient(index, ingredient);
360            }
361
362            Ok(())
363        }
364    }
365
366    impl<'de> serde::de::DeserializeSeed<'de> for DeserializeIngredients<'_> {
367        type Value = ();
368
369        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
370        where
371            D: serde::Deserializer<'de>,
372        {
373            deserializer.deserialize_map(self)
374        }
375    }
376
377    struct DeserializeIngredient<'db>(&'db mut dyn Ingredient, &'db mut Zalsa);
378
379    impl<'de> serde::de::DeserializeSeed<'de> for DeserializeIngredient<'_> {
380        type Value = ();
381
382        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
383        where
384            D: serde::Deserializer<'de>,
385        {
386            let deserializer = &mut <dyn erased_serde::Deserializer>::erase(deserializer);
387
388            self.0
389                .deserialize(self.1, deserializer)
390                .map_err(serde::de::Error::custom)
391        }
392    }
393}
394
395#[cfg(feature = "salsa_unstable")]
396pub use memory_usage::IngredientInfo;
397
398#[cfg(feature = "salsa_unstable")]
399pub(crate) use memory_usage::{MemoInfo, SlotInfo};
400
401#[cfg(feature = "salsa_unstable")]
402mod memory_usage {
403    use hashbrown::HashMap;
404
405    use crate::Database;
406
407    impl dyn Database {
408        /// Returns memory usage information about ingredients in the database.
409        pub fn memory_usage(&self) -> DatabaseInfo {
410            let mut queries = HashMap::new();
411            let mut structs = Vec::new();
412
413            for input_ingredient in self.zalsa().ingredients() {
414                let Some(input_info) = input_ingredient.memory_usage(self) else {
415                    continue;
416                };
417
418                let mut size_of_fields = 0;
419                let mut size_of_metadata = 0;
420                let mut count = 0;
421                let mut heap_size_of_fields = None;
422
423                for input_slot in input_info {
424                    count += 1;
425                    size_of_fields += input_slot.size_of_fields;
426                    size_of_metadata += input_slot.size_of_metadata;
427
428                    if let Some(slot_heap_size) = input_slot.heap_size_of_fields {
429                        heap_size_of_fields =
430                            Some(heap_size_of_fields.unwrap_or_default() + slot_heap_size);
431                    }
432
433                    for memo in input_slot.memos {
434                        let info = queries.entry(memo.debug_name).or_insert(IngredientInfo {
435                            debug_name: memo.output.debug_name,
436                            ..Default::default()
437                        });
438
439                        info.count += 1;
440                        info.size_of_fields += memo.output.size_of_fields;
441                        info.size_of_metadata += memo.output.size_of_metadata;
442
443                        if let Some(memo_heap_size) = memo.output.heap_size_of_fields {
444                            info.heap_size_of_fields =
445                                Some(info.heap_size_of_fields.unwrap_or_default() + memo_heap_size);
446                        }
447                    }
448                }
449
450                structs.push(IngredientInfo {
451                    count,
452                    size_of_fields,
453                    size_of_metadata,
454                    heap_size_of_fields,
455                    debug_name: input_ingredient.debug_name(),
456                });
457            }
458
459            DatabaseInfo { structs, queries }
460        }
461    }
462
463    /// Memory usage information about ingredients in the Salsa database.
464    pub struct DatabaseInfo {
465        /// Information about any Salsa structs.
466        pub structs: Vec<IngredientInfo>,
467
468        /// Memory usage information for memoized values of a given query, keyed
469        /// by the query function name.
470        pub queries: HashMap<&'static str, IngredientInfo>,
471    }
472
473    /// Information about instances of a particular Salsa ingredient.
474    #[derive(Default, Debug, PartialEq, Eq, PartialOrd, Ord)]
475    pub struct IngredientInfo {
476        debug_name: &'static str,
477        count: usize,
478        size_of_metadata: usize,
479        size_of_fields: usize,
480        heap_size_of_fields: Option<usize>,
481    }
482
483    impl IngredientInfo {
484        /// Returns the debug name of the ingredient.
485        pub fn debug_name(&self) -> &'static str {
486            self.debug_name
487        }
488
489        /// Returns the total stack size of the fields of any instances of this ingredient, in bytes.
490        pub fn size_of_fields(&self) -> usize {
491            self.size_of_fields
492        }
493
494        /// Returns the total heap size of the fields of any instances of this ingredient, in bytes.
495        ///
496        /// Returns `None` if the ingredient doesn't specify a `heap_size` function.
497        pub fn heap_size_of_fields(&self) -> Option<usize> {
498            self.heap_size_of_fields
499        }
500
501        /// Returns the total size of Salsa metadata of any instances of this ingredient, in bytes.
502        pub fn size_of_metadata(&self) -> usize {
503            self.size_of_metadata
504        }
505
506        /// Returns the number of instances of this ingredient.
507        pub fn count(&self) -> usize {
508            self.count
509        }
510    }
511
512    /// Memory usage information about a particular instance of struct, input or output.
513    pub struct SlotInfo {
514        pub(crate) debug_name: &'static str,
515        pub(crate) size_of_metadata: usize,
516        pub(crate) size_of_fields: usize,
517        pub(crate) heap_size_of_fields: Option<usize>,
518        pub(crate) memos: Vec<MemoInfo>,
519    }
520
521    /// Memory usage information about a particular memo.
522    pub struct MemoInfo {
523        pub(crate) debug_name: &'static str,
524        pub(crate) output: SlotInfo,
525    }
526}