salsa/
runtime.rs

1use crate::durability::Durability;
2use crate::plumbing::CycleDetected;
3use crate::revision::{AtomicRevision, Revision};
4use crate::{CycleError, Database, DatabaseKeyIndex, Event, EventKind};
5use log::debug;
6use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
7use parking_lot::{Mutex, RwLock};
8use rustc_hash::{FxHashMap, FxHasher};
9use smallvec::SmallVec;
10use std::hash::{BuildHasherDefault, Hash};
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::Arc;
13
14pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, BuildHasherDefault<FxHasher>>;
15pub(crate) type FxIndexMap<K, V> = indexmap::IndexMap<K, V, BuildHasherDefault<FxHasher>>;
16
17mod local_state;
18use local_state::LocalState;
19
20/// The salsa runtime stores the storage for all queries as well as
21/// tracking the query stack and dependencies between cycles.
22///
23/// Each new runtime you create (e.g., via `Runtime::new` or
24/// `Runtime::default`) will have an independent set of query storage
25/// associated with it. Normally, therefore, you only do this once, at
26/// the start of your application.
27pub struct Runtime {
28    /// Our unique runtime id.
29    id: RuntimeId,
30
31    /// If this is a "forked" runtime, then the `revision_guard` will
32    /// be `Some`; this guard holds a read-lock on the global query
33    /// lock.
34    revision_guard: Option<RevisionGuard>,
35
36    /// Local state that is specific to this runtime (thread).
37    local_state: LocalState,
38
39    /// Shared state that is accessible via all runtimes.
40    shared_state: Arc<SharedState>,
41}
42
43impl Default for Runtime {
44    fn default() -> Self {
45        Runtime {
46            id: RuntimeId { counter: 0 },
47            revision_guard: None,
48            shared_state: Default::default(),
49            local_state: Default::default(),
50        }
51    }
52}
53
54impl std::fmt::Debug for Runtime {
55    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        fmt.debug_struct("Runtime")
57            .field("id", &self.id())
58            .field("forked", &self.revision_guard.is_some())
59            .field("shared_state", &self.shared_state)
60            .finish()
61    }
62}
63
64impl Runtime {
65    /// Create a new runtime; equivalent to `Self::default`. This is
66    /// used when creating a new database.
67    pub fn new() -> Self {
68        Self::default()
69    }
70
71    /// See [`crate::storage::Storage::snapshot`].
72    pub(crate) fn snapshot(&self) -> Self {
73        if self.local_state.query_in_progress() {
74            panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
75        }
76
77        let revision_guard = RevisionGuard::new(&self.shared_state);
78
79        let id = RuntimeId {
80            counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst),
81        };
82
83        Runtime {
84            id,
85            revision_guard: Some(revision_guard),
86            shared_state: self.shared_state.clone(),
87            local_state: Default::default(),
88        }
89    }
90
91    /// A "synthetic write" causes the system to act *as though* some
92    /// input of durability `durability` has changed. This is mostly
93    /// useful for profiling scenarios, but it also has interactions
94    /// with garbage collection. In general, a synthetic write to
95    /// durability level D will cause the system to fully trace all
96    /// queries of durability level D and below. When running a GC, then:
97    ///
98    /// - Synthetic writes will cause more derived values to be
99    ///   *retained*.  This is because derived values are only
100    ///   retained if they are traced, and a synthetic write can cause
101    ///   more things to be traced.
102    /// - Synthetic writes can cause more interned values to be
103    ///   *collected*. This is because interned values can only be
104    ///   collected if they were not yet traced in the current
105    ///   revision. Therefore, if you issue a synthetic write, execute
106    ///   some query Q, and then start collecting interned values, you
107    ///   will be able to recycle interned values not used in Q.
108    ///
109    /// In general, then, one can do a "full GC" that retains only
110    /// those things that are used by some query Q by (a) doing a
111    /// synthetic write at `Durability::HIGH`, (b) executing the query
112    /// Q and then (c) doing a sweep.
113    ///
114    /// **WARNING:** Just like an ordinary write, this method triggers
115    /// cancellation. If you invoke it while a snapshot exists, it
116    /// will block until that snapshot is dropped -- if that snapshot
117    /// is owned by the current thread, this could trigger deadlock.
118    pub fn synthetic_write(&mut self, durability: Durability) {
119        self.with_incremented_revision(&mut |_next_revision| Some(durability));
120    }
121
122    /// The unique identifier attached to this `SalsaRuntime`. Each
123    /// snapshotted runtime has a distinct identifier.
124    #[inline]
125    pub fn id(&self) -> RuntimeId {
126        self.id
127    }
128
129    /// Returns the database-key for the query that this thread is
130    /// actively executing (if any).
131    pub fn active_query(&self) -> Option<DatabaseKeyIndex> {
132        self.local_state.active_query()
133    }
134
135    /// Read current value of the revision counter.
136    #[inline]
137    pub(crate) fn current_revision(&self) -> Revision {
138        self.shared_state.revisions[0].load()
139    }
140
141    /// The revision in which values with durability `d` may have last
142    /// changed.  For D0, this is just the current revision. But for
143    /// higher levels of durability, this value may lag behind the
144    /// current revision. If we encounter a value of durability Di,
145    /// then, we can check this function to get a "bound" on when the
146    /// value may have changed, which allows us to skip walking its
147    /// dependencies.
148    #[inline]
149    pub(crate) fn last_changed_revision(&self, d: Durability) -> Revision {
150        self.shared_state.revisions[d.index()].load()
151    }
152
153    /// Read current value of the revision counter.
154    #[inline]
155    fn pending_revision(&self) -> Revision {
156        self.shared_state.pending_revision.load()
157    }
158
159    /// Check if the current revision is canceled. If this method ever
160    /// returns true, the currently executing query is also marked as
161    /// having an *untracked read* -- this means that, in the next
162    /// revision, we will always recompute its value "as if" some
163    /// input had changed. This means that, if your revision is
164    /// canceled (which indicates that current query results will be
165    /// ignored) your query is free to shortcircuit and return
166    /// whatever it likes.
167    ///
168    /// This method is useful for implementing cancellation of queries.
169    /// You can do it in one of two ways, via `Result`s or via unwinding.
170    ///
171    /// The `Result` approach looks like this:
172    ///
173    ///   * Some queries invoke `is_current_revision_canceled` and
174    ///     return a special value, like `Err(Canceled)`, if it returns
175    ///     `true`.
176    ///   * Other queries propagate the special value using `?` operator.
177    ///   * API around top-level queries checks if the result is `Ok` or
178    ///     `Err(Canceled)`.
179    ///
180    /// The `panic` approach works in a similar way:
181    ///
182    ///   * Some queries invoke `is_current_revision_canceled` and
183    ///     panic with a special value, like `Canceled`, if it returns
184    ///     true.
185    ///   * The implementation of `Database` trait overrides
186    ///     `on_propagated_panic` to throw this special value as well.
187    ///     This way, panic gets propagated naturally through dependant
188    ///     queries, even across the threads.
189    ///   * API around top-level queries converts a `panic` into `Result` by
190    ///     catching the panic (using either `std::panic::catch_unwind` or
191    ///     threads) and downcasting the payload to `Canceled` (re-raising
192    ///     panic if downcast fails).
193    ///
194    /// Note that salsa is explicitly designed to be panic-safe, so cancellation
195    /// via unwinding is 100% valid approach to cancellation.
196    #[inline]
197    pub fn is_current_revision_canceled(&self) -> bool {
198        let current_revision = self.current_revision();
199        let pending_revision = self.pending_revision();
200        debug!(
201            "is_current_revision_canceled: current_revision={:?}, pending_revision={:?}",
202            current_revision, pending_revision
203        );
204        if pending_revision > current_revision {
205            self.report_untracked_read();
206            true
207        } else {
208            // Subtle: If the current revision is not canceled, we
209            // still report an **anonymous** read, which will bump up
210            // the revision number to be at least the last
211            // non-canceled revision. This is needed to ensure
212            // deterministic reads and avoid salsa-rs/salsa#66. The
213            // specific scenario we are trying to avoid is tested by
214            // `no_back_dating_in_cancellation`; it works like
215            // this. Imagine we have 3 queries, where Query3 invokes
216            // Query2 which invokes Query1. Then:
217            //
218            // - In Revision R1:
219            //   - Query1: Observes cancelation and returns sentinel S.
220            //     - Recorded inputs: Untracked, because we observed cancelation.
221            //   - Query2: Reads Query1 and propagates sentinel S.
222            //     - Recorded inputs: Query1, changed-at=R1
223            //   - Query3: Reads Query2 and propagates sentinel S. (Inputs = Query2, ChangedAt R1)
224            //     - Recorded inputs: Query2, changed-at=R1
225            // - In Revision R2:
226            //   - Query1: Observes no cancelation. All of its inputs last changed in R0,
227            //     so it returns a valid value with "changed at" of R0.
228            //     - Recorded inputs: ..., changed-at=R0
229            //   - Query2: Recomputes its value and returns correct result.
230            //     - Recorded inputs: Query1, changed-at=R0 <-- key problem!
231            //   - Query3: sees that Query2's result last changed in R0, so it thinks it
232            //     can re-use its value from R1 (which is the sentinel value).
233            //
234            // The anonymous read here prevents that scenario: Query1
235            // winds up with a changed-at setting of R2, which is the
236            // "pending revision", and hence Query2 and Query3
237            // are recomputed.
238            assert_eq!(pending_revision, current_revision);
239            self.report_anon_read(pending_revision);
240            false
241        }
242    }
243
244    /// Acquires the **global query write lock** (ensuring that no queries are
245    /// executing) and then increments the current revision counter; invokes
246    /// `op` with the global query write lock still held.
247    ///
248    /// While we wait to acquire the global query write lock, this method will
249    /// also increment `pending_revision_increments`, thus signalling to queries
250    /// that their results are "canceled" and they should abort as expeditiously
251    /// as possible.
252    ///
253    /// The `op` closure should actually perform the writes needed. It is given
254    /// the new revision as an argument, and its return value indicates whether
255    /// any pre-existing value was modified:
256    ///
257    /// - returning `None` means that no pre-existing value was modified (this
258    ///   could occur e.g. when setting some key on an input that was never set
259    ///   before)
260    /// - returning `Some(d)` indicates that a pre-existing value was modified
261    ///   and it had the durability `d`. This will update the records for when
262    ///   values with each durability were modified.
263    ///
264    /// Note that, given our writer model, we can assume that only one thread is
265    /// attempting to increment the global revision at a time.
266    pub(crate) fn with_incremented_revision(
267        &mut self,
268        op: &mut dyn FnMut(Revision) -> Option<Durability>,
269    ) {
270        log::debug!("increment_revision()");
271
272        if !self.permits_increment() {
273            panic!("increment_revision invoked during a query computation");
274        }
275
276        // Set the `pending_revision` field so that people
277        // know current revision is canceled.
278        let current_revision = self.shared_state.pending_revision.fetch_then_increment();
279
280        // To modify the revision, we need the lock.
281        let shared_state = self.shared_state.clone();
282        let _lock = shared_state.query_lock.write();
283
284        let old_revision = self.shared_state.revisions[0].fetch_then_increment();
285        assert_eq!(current_revision, old_revision);
286
287        let new_revision = current_revision.next();
288
289        debug!("increment_revision: incremented to {:?}", new_revision);
290
291        if let Some(d) = op(new_revision) {
292            for rev in &self.shared_state.revisions[1..=d.index()] {
293                rev.store(new_revision);
294            }
295        }
296    }
297
298    pub(crate) fn permits_increment(&self) -> bool {
299        self.revision_guard.is_none() && !self.local_state.query_in_progress()
300    }
301
302    pub(crate) fn execute_query_implementation<DB, V>(
303        &self,
304        db: &DB,
305        database_key_index: DatabaseKeyIndex,
306        execute: impl FnOnce() -> V,
307    ) -> ComputedQueryResult<V>
308    where
309        DB: ?Sized + Database,
310    {
311        debug!(
312            "{:?}: execute_query_implementation invoked",
313            database_key_index
314        );
315
316        db.salsa_event(Event {
317            runtime_id: self.id(),
318            kind: EventKind::WillExecute {
319                database_key: database_key_index,
320            },
321        });
322
323        // Push the active query onto the stack.
324        let max_durability = Durability::MAX;
325        let active_query = self
326            .local_state
327            .push_query(database_key_index, max_durability);
328
329        // Execute user's code, accumulating inputs etc.
330        let value = execute();
331
332        // Extract accumulated inputs.
333        let ActiveQuery {
334            dependencies,
335            changed_at,
336            durability,
337            cycle,
338            ..
339        } = active_query.complete();
340
341        ComputedQueryResult {
342            value,
343            durability,
344            changed_at,
345            dependencies,
346            cycle,
347        }
348    }
349
350    /// Reports that the currently active query read the result from
351    /// another query.
352    ///
353    /// # Parameters
354    ///
355    /// - `database_key`: the query whose result was read
356    /// - `changed_revision`: the last revision in which the result of that
357    ///   query had changed
358    pub(crate) fn report_query_read<'hack>(
359        &self,
360        input: DatabaseKeyIndex,
361        durability: Durability,
362        changed_at: Revision,
363    ) {
364        self.local_state
365            .report_query_read(input, durability, changed_at);
366    }
367
368    /// Reports that the query depends on some state unknown to salsa.
369    ///
370    /// Queries which report untracked reads will be re-executed in the next
371    /// revision.
372    pub fn report_untracked_read(&self) {
373        self.local_state
374            .report_untracked_read(self.current_revision());
375    }
376
377    /// Acts as though the current query had read an input with the given durability; this will force the current query's durability to be at most `durability`.
378    ///
379    /// This is mostly useful to control the durability level for [on-demand inputs](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html).
380    pub fn report_synthetic_read(&self, durability: Durability) {
381        self.local_state.report_synthetic_read(durability);
382    }
383
384    /// An "anonymous" read is a read that doesn't come from executing
385    /// a query, but from some other internal operation. It just
386    /// modifies the "changed at" to be at least the given revision.
387    /// (It also does not disqualify a query from being considered
388    /// constant, since it is used for queries that don't give back
389    /// actual *data*.)
390    ///
391    /// This is used when queries check if they have been canceled.
392    fn report_anon_read(&self, revision: Revision) {
393        self.local_state.report_anon_read(revision)
394    }
395
396    /// Obviously, this should be user configurable at some point.
397    pub(crate) fn report_unexpected_cycle(
398        &self,
399        database_key_index: DatabaseKeyIndex,
400        error: CycleDetected,
401        changed_at: Revision,
402    ) -> crate::CycleError<DatabaseKeyIndex> {
403        debug!(
404            "report_unexpected_cycle(database_key={:?})",
405            database_key_index
406        );
407
408        let mut query_stack = self.local_state.borrow_query_stack_mut();
409
410        if error.from == error.to {
411            // All queries in the cycle is local
412            let start_index = query_stack
413                .iter()
414                .rposition(|active_query| active_query.database_key_index == database_key_index)
415                .unwrap();
416            let mut cycle = Vec::new();
417            let cycle_participants = &mut query_stack[start_index..];
418            for active_query in &mut *cycle_participants {
419                cycle.push(active_query.database_key_index);
420            }
421
422            assert!(!cycle.is_empty());
423
424            for active_query in cycle_participants {
425                active_query.cycle = cycle.clone();
426            }
427
428            crate::CycleError {
429                cycle,
430                changed_at,
431                durability: Durability::MAX,
432            }
433        } else {
434            // Part of the cycle is on another thread so we need to lock and inspect the shared
435            // state
436            let dependency_graph = self.shared_state.dependency_graph.lock();
437
438            let mut cycle = Vec::new();
439            dependency_graph.push_cycle_path(
440                database_key_index,
441                error.to,
442                query_stack.iter().map(|query| query.database_key_index),
443                &mut cycle,
444            );
445            cycle.push(database_key_index);
446
447            assert!(!cycle.is_empty());
448
449            for active_query in query_stack
450                .iter_mut()
451                .filter(|query| cycle.iter().any(|key| *key == query.database_key_index))
452            {
453                active_query.cycle = cycle.clone();
454            }
455
456            crate::CycleError {
457                cycle,
458                changed_at,
459                durability: Durability::MAX,
460            }
461        }
462    }
463
464    pub(crate) fn mark_cycle_participants(&self, err: &CycleError<DatabaseKeyIndex>) {
465        for active_query in self
466            .local_state
467            .borrow_query_stack_mut()
468            .iter_mut()
469            .rev()
470            .take_while(|active_query| {
471                err.cycle
472                    .iter()
473                    .any(|e| *e == active_query.database_key_index)
474            })
475        {
476            active_query.cycle = err.cycle.clone();
477        }
478    }
479
480    /// Try to make this runtime blocked on `other_id`. Returns true
481    /// upon success or false if `other_id` is already blocked on us.
482    pub(crate) fn try_block_on(&self, database_key: DatabaseKeyIndex, other_id: RuntimeId) -> bool {
483        self.shared_state.dependency_graph.lock().add_edge(
484            self.id(),
485            database_key,
486            other_id,
487            self.local_state
488                .borrow_query_stack()
489                .iter()
490                .map(|query| query.database_key_index),
491        )
492    }
493
494    pub(crate) fn unblock_queries_blocked_on_self(&self, database_key_index: DatabaseKeyIndex) {
495        self.shared_state
496            .dependency_graph
497            .lock()
498            .remove_edge(database_key_index, self.id())
499    }
500}
501
502/// State that will be common to all threads (when we support multiple threads)
503struct SharedState {
504    /// Stores the next id to use for a snapshotted runtime (starts at 1).
505    next_id: AtomicUsize,
506
507    /// Whenever derived queries are executing, they acquire this lock
508    /// in read mode. Mutating inputs (and thus creating a new
509    /// revision) requires a write lock (thus guaranteeing that no
510    /// derived queries are in progress). Note that this is not needed
511    /// to prevent **race conditions** -- the revision counter itself
512    /// is stored in an `AtomicUsize` so it can be cheaply read
513    /// without acquiring the lock.  Rather, the `query_lock` is used
514    /// to ensure a higher-level consistency property.
515    query_lock: RwLock<()>,
516
517    /// This is typically equal to `revision` -- set to `revision+1`
518    /// when a new revision is pending (which implies that the current
519    /// revision is canceled).
520    pending_revision: AtomicRevision,
521
522    /// Stores the "last change" revision for values of each duration.
523    /// This vector is always of length at least 1 (for Durability 0)
524    /// but its total length depends on the number of durations. The
525    /// element at index 0 is special as it represents the "current
526    /// revision".  In general, we have the invariant that revisions
527    /// in here are *declining* -- that is, `revisions[i] >=
528    /// revisions[i + 1]`, for all `i`. This is because when you
529    /// modify a value with durability D, that implies that values
530    /// with durability less than D may have changed too.
531    revisions: Vec<AtomicRevision>,
532
533    /// The dependency graph tracks which runtimes are blocked on one
534    /// another, waiting for queries to terminate.
535    dependency_graph: Mutex<DependencyGraph<DatabaseKeyIndex>>,
536}
537
538impl SharedState {
539    fn with_durabilities(durabilities: usize) -> Self {
540        SharedState {
541            next_id: AtomicUsize::new(1),
542            query_lock: Default::default(),
543            revisions: (0..durabilities).map(|_| AtomicRevision::start()).collect(),
544            pending_revision: AtomicRevision::start(),
545            dependency_graph: Default::default(),
546        }
547    }
548}
549
550impl std::panic::RefUnwindSafe for SharedState {}
551
552impl Default for SharedState {
553    fn default() -> Self {
554        Self::with_durabilities(Durability::LEN)
555    }
556}
557
558impl std::fmt::Debug for SharedState {
559    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560        let query_lock = if self.query_lock.try_write().is_some() {
561            "<unlocked>"
562        } else if self.query_lock.try_read().is_some() {
563            "<rlocked>"
564        } else {
565            "<wlocked>"
566        };
567        fmt.debug_struct("SharedState")
568            .field("query_lock", &query_lock)
569            .field("revisions", &self.revisions)
570            .field("pending_revision", &self.pending_revision)
571            .finish()
572    }
573}
574
575struct ActiveQuery {
576    /// What query is executing
577    database_key_index: DatabaseKeyIndex,
578
579    /// Minimum durability of inputs observed so far.
580    durability: Durability,
581
582    /// Maximum revision of all inputs observed. If we observe an
583    /// untracked read, this will be set to the most recent revision.
584    changed_at: Revision,
585
586    /// Set of subqueries that were accessed thus far, or `None` if
587    /// there was an untracked the read.
588    dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
589
590    /// Stores the entire cycle, if one is found and this query is part of it.
591    cycle: Vec<DatabaseKeyIndex>,
592}
593
594pub(crate) struct ComputedQueryResult<V> {
595    /// Final value produced
596    pub(crate) value: V,
597
598    /// Minimum durability of inputs observed so far.
599    pub(crate) durability: Durability,
600
601    /// Maximum revision of all inputs observed. If we observe an
602    /// untracked read, this will be set to the most recent revision.
603    pub(crate) changed_at: Revision,
604
605    /// Complete set of subqueries that were accessed, or `None` if
606    /// there was an untracked read.
607    pub(crate) dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
608
609    /// The cycle if one occured while computing this value
610    pub(crate) cycle: Vec<DatabaseKeyIndex>,
611}
612
613impl ActiveQuery {
614    fn new(database_key_index: DatabaseKeyIndex, max_durability: Durability) -> Self {
615        ActiveQuery {
616            database_key_index,
617            durability: max_durability,
618            changed_at: Revision::start(),
619            dependencies: Some(FxIndexSet::default()),
620            cycle: Vec::new(),
621        }
622    }
623
624    fn add_read(&mut self, input: DatabaseKeyIndex, durability: Durability, revision: Revision) {
625        if let Some(set) = &mut self.dependencies {
626            set.insert(input);
627        }
628
629        self.durability = self.durability.min(durability);
630        self.changed_at = self.changed_at.max(revision);
631    }
632
633    fn add_untracked_read(&mut self, changed_at: Revision) {
634        self.dependencies = None;
635        self.durability = Durability::LOW;
636        self.changed_at = changed_at;
637    }
638
639    fn add_synthetic_read(&mut self, durability: Durability) {
640        self.durability = self.durability.min(durability);
641    }
642
643    fn add_anon_read(&mut self, changed_at: Revision) {
644        self.changed_at = self.changed_at.max(changed_at);
645    }
646}
647
648/// A unique identifier for a particular runtime. Each time you create
649/// a snapshot, a fresh `RuntimeId` is generated. Once a snapshot is
650/// complete, its `RuntimeId` may potentially be re-used.
651#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
652pub struct RuntimeId {
653    counter: usize,
654}
655
656#[derive(Clone, Debug)]
657pub(crate) struct StampedValue<V> {
658    pub(crate) value: V,
659    pub(crate) durability: Durability,
660    pub(crate) changed_at: Revision,
661}
662
663#[derive(Debug)]
664struct Edge<K> {
665    id: RuntimeId,
666    path: Vec<K>,
667}
668
669#[derive(Debug)]
670struct DependencyGraph<K: Hash + Eq> {
671    /// A `(K -> V)` pair in this map indicates that the the runtime
672    /// `K` is blocked on some query executing in the runtime `V`.
673    /// This encodes a graph that must be acyclic (or else deadlock
674    /// will result).
675    edges: FxHashMap<RuntimeId, Edge<K>>,
676    labels: FxHashMap<K, SmallVec<[RuntimeId; 4]>>,
677}
678
679impl<K> Default for DependencyGraph<K>
680where
681    K: Hash + Eq,
682{
683    fn default() -> Self {
684        DependencyGraph {
685            edges: Default::default(),
686            labels: Default::default(),
687        }
688    }
689}
690
691impl<K> DependencyGraph<K>
692where
693    K: Hash + Eq + Clone,
694{
695    /// Attempt to add an edge `from_id -> to_id` into the result graph.
696    fn add_edge(
697        &mut self,
698        from_id: RuntimeId,
699        database_key: K,
700        to_id: RuntimeId,
701        path: impl IntoIterator<Item = K>,
702    ) -> bool {
703        assert_ne!(from_id, to_id);
704        debug_assert!(!self.edges.contains_key(&from_id));
705
706        // First: walk the chain of things that `to_id` depends on,
707        // looking for us.
708        let mut p = to_id;
709        while let Some(q) = self.edges.get(&p).map(|edge| edge.id) {
710            if q == from_id {
711                return false;
712            }
713
714            p = q;
715        }
716
717        self.edges.insert(
718            from_id,
719            Edge {
720                id: to_id,
721                path: path.into_iter().chain(Some(database_key.clone())).collect(),
722            },
723        );
724        self.labels
725            .entry(database_key.clone())
726            .or_default()
727            .push(from_id);
728        true
729    }
730
731    fn remove_edge(&mut self, database_key: K, to_id: RuntimeId) {
732        let vec = self.labels.remove(&database_key).unwrap_or_default();
733
734        for from_id in &vec {
735            let to_id1 = self.edges.remove(from_id).map(|edge| edge.id);
736            assert_eq!(Some(to_id), to_id1);
737        }
738    }
739
740    fn push_cycle_path<'a>(
741        &'a self,
742        database_key: K,
743        to: RuntimeId,
744        local_path: impl IntoIterator<Item = K>,
745        output: &mut Vec<K>,
746    ) where
747        K: std::fmt::Debug,
748    {
749        let mut current = Some((to, std::slice::from_ref(&database_key)));
750        let mut last = None;
751        let mut local_path = Some(local_path);
752
753        loop {
754            match current.take() {
755                Some((id, path)) => {
756                    let link_key = path.last().unwrap();
757
758                    output.extend(path.iter().cloned());
759
760                    current = self.edges.get(&id).map(|edge| {
761                        let i = edge.path.iter().rposition(|p| p == link_key).unwrap();
762                        (edge.id, &edge.path[i + 1..])
763                    });
764
765                    if current.is_none() {
766                        last = local_path.take().map(|local_path| {
767                            local_path
768                                .into_iter()
769                                .skip_while(move |p| *p != *link_key)
770                                .skip(1)
771                        });
772                    }
773                }
774                None => break,
775            }
776        }
777
778        if let Some(iter) = &mut last {
779            output.extend(iter);
780        }
781    }
782}
783
784struct RevisionGuard {
785    shared_state: Arc<SharedState>,
786}
787
788impl RevisionGuard {
789    fn new(shared_state: &Arc<SharedState>) -> Self {
790        // Subtle: we use a "recursive" lock here so that it is not an
791        // error to acquire a read-lock when one is already held (this
792        // happens when a query uses `snapshot` to spawn off parallel
793        // workers, for example).
794        //
795        // This has the side-effect that we are responsible to ensure
796        // that people contending for the write lock do not starve,
797        // but this is what we achieve via the cancellation mechanism.
798        //
799        // (In particular, since we only ever have one "mutating
800        // handle" to the database, the only contention for the global
801        // query lock occurs when there are "futures" evaluating
802        // queries in parallel, and those futures hold a read-lock
803        // already, so the starvation problem is more about them bring
804        // themselves to a close, versus preventing other people from
805        // *starting* work).
806        unsafe {
807            shared_state.query_lock.raw().lock_shared_recursive();
808        }
809
810        Self {
811            shared_state: shared_state.clone(),
812        }
813    }
814}
815
816impl Drop for RevisionGuard {
817    fn drop(&mut self) {
818        // Release our read-lock without using RAII. As documented in
819        // `Snapshot::new` above, this requires the unsafe keyword.
820        unsafe {
821            self.shared_state.query_lock.raw().unlock_shared();
822        }
823    }
824}
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829
830    #[test]
831    fn dependency_graph_path1() {
832        let mut graph = DependencyGraph::default();
833        let a = RuntimeId { counter: 0 };
834        let b = RuntimeId { counter: 1 };
835        assert!(graph.add_edge(a, 2, b, vec![1]));
836        // assert!(graph.add_edge(b, &1, a, vec![3, 2]));
837        let mut v = vec![];
838        graph.push_cycle_path(1, a, vec![3, 2], &mut v);
839        assert_eq!(v, vec![1, 2]);
840    }
841
842    #[test]
843    fn dependency_graph_path2() {
844        let mut graph = DependencyGraph::default();
845        let a = RuntimeId { counter: 0 };
846        let b = RuntimeId { counter: 1 };
847        let c = RuntimeId { counter: 2 };
848        assert!(graph.add_edge(a, 3, b, vec![1]));
849        assert!(graph.add_edge(b, 4, c, vec![2, 3]));
850        // assert!(graph.add_edge(c, &1, a, vec![5, 6, 4, 7]));
851        let mut v = vec![];
852        graph.push_cycle_path(1, a, vec![5, 6, 4, 7], &mut v);
853        assert_eq!(v, vec![1, 3, 4, 7]);
854    }
855}