salsa/runtime.rs
1use crate::durability::Durability;
2use crate::plumbing::CycleDetected;
3use crate::revision::{AtomicRevision, Revision};
4use crate::{CycleError, Database, DatabaseKeyIndex, Event, EventKind};
5use log::debug;
6use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
7use parking_lot::{Mutex, RwLock};
8use rustc_hash::{FxHashMap, FxHasher};
9use smallvec::SmallVec;
10use std::hash::{BuildHasherDefault, Hash};
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::Arc;
13
14pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, BuildHasherDefault<FxHasher>>;
15pub(crate) type FxIndexMap<K, V> = indexmap::IndexMap<K, V, BuildHasherDefault<FxHasher>>;
16
17mod local_state;
18use local_state::LocalState;
19
20/// The salsa runtime stores the storage for all queries as well as
21/// tracking the query stack and dependencies between cycles.
22///
23/// Each new runtime you create (e.g., via `Runtime::new` or
24/// `Runtime::default`) will have an independent set of query storage
25/// associated with it. Normally, therefore, you only do this once, at
26/// the start of your application.
27pub struct Runtime {
28 /// Our unique runtime id.
29 id: RuntimeId,
30
31 /// If this is a "forked" runtime, then the `revision_guard` will
32 /// be `Some`; this guard holds a read-lock on the global query
33 /// lock.
34 revision_guard: Option<RevisionGuard>,
35
36 /// Local state that is specific to this runtime (thread).
37 local_state: LocalState,
38
39 /// Shared state that is accessible via all runtimes.
40 shared_state: Arc<SharedState>,
41}
42
43impl Default for Runtime {
44 fn default() -> Self {
45 Runtime {
46 id: RuntimeId { counter: 0 },
47 revision_guard: None,
48 shared_state: Default::default(),
49 local_state: Default::default(),
50 }
51 }
52}
53
54impl std::fmt::Debug for Runtime {
55 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 fmt.debug_struct("Runtime")
57 .field("id", &self.id())
58 .field("forked", &self.revision_guard.is_some())
59 .field("shared_state", &self.shared_state)
60 .finish()
61 }
62}
63
64impl Runtime {
65 /// Create a new runtime; equivalent to `Self::default`. This is
66 /// used when creating a new database.
67 pub fn new() -> Self {
68 Self::default()
69 }
70
71 /// See [`crate::storage::Storage::snapshot`].
72 pub(crate) fn snapshot(&self) -> Self {
73 if self.local_state.query_in_progress() {
74 panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
75 }
76
77 let revision_guard = RevisionGuard::new(&self.shared_state);
78
79 let id = RuntimeId {
80 counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst),
81 };
82
83 Runtime {
84 id,
85 revision_guard: Some(revision_guard),
86 shared_state: self.shared_state.clone(),
87 local_state: Default::default(),
88 }
89 }
90
91 /// A "synthetic write" causes the system to act *as though* some
92 /// input of durability `durability` has changed. This is mostly
93 /// useful for profiling scenarios, but it also has interactions
94 /// with garbage collection. In general, a synthetic write to
95 /// durability level D will cause the system to fully trace all
96 /// queries of durability level D and below. When running a GC, then:
97 ///
98 /// - Synthetic writes will cause more derived values to be
99 /// *retained*. This is because derived values are only
100 /// retained if they are traced, and a synthetic write can cause
101 /// more things to be traced.
102 /// - Synthetic writes can cause more interned values to be
103 /// *collected*. This is because interned values can only be
104 /// collected if they were not yet traced in the current
105 /// revision. Therefore, if you issue a synthetic write, execute
106 /// some query Q, and then start collecting interned values, you
107 /// will be able to recycle interned values not used in Q.
108 ///
109 /// In general, then, one can do a "full GC" that retains only
110 /// those things that are used by some query Q by (a) doing a
111 /// synthetic write at `Durability::HIGH`, (b) executing the query
112 /// Q and then (c) doing a sweep.
113 ///
114 /// **WARNING:** Just like an ordinary write, this method triggers
115 /// cancellation. If you invoke it while a snapshot exists, it
116 /// will block until that snapshot is dropped -- if that snapshot
117 /// is owned by the current thread, this could trigger deadlock.
118 pub fn synthetic_write(&mut self, durability: Durability) {
119 self.with_incremented_revision(&mut |_next_revision| Some(durability));
120 }
121
122 /// The unique identifier attached to this `SalsaRuntime`. Each
123 /// snapshotted runtime has a distinct identifier.
124 #[inline]
125 pub fn id(&self) -> RuntimeId {
126 self.id
127 }
128
129 /// Returns the database-key for the query that this thread is
130 /// actively executing (if any).
131 pub fn active_query(&self) -> Option<DatabaseKeyIndex> {
132 self.local_state.active_query()
133 }
134
135 /// Read current value of the revision counter.
136 #[inline]
137 pub(crate) fn current_revision(&self) -> Revision {
138 self.shared_state.revisions[0].load()
139 }
140
141 /// The revision in which values with durability `d` may have last
142 /// changed. For D0, this is just the current revision. But for
143 /// higher levels of durability, this value may lag behind the
144 /// current revision. If we encounter a value of durability Di,
145 /// then, we can check this function to get a "bound" on when the
146 /// value may have changed, which allows us to skip walking its
147 /// dependencies.
148 #[inline]
149 pub(crate) fn last_changed_revision(&self, d: Durability) -> Revision {
150 self.shared_state.revisions[d.index()].load()
151 }
152
153 /// Read current value of the revision counter.
154 #[inline]
155 fn pending_revision(&self) -> Revision {
156 self.shared_state.pending_revision.load()
157 }
158
159 /// Check if the current revision is canceled. If this method ever
160 /// returns true, the currently executing query is also marked as
161 /// having an *untracked read* -- this means that, in the next
162 /// revision, we will always recompute its value "as if" some
163 /// input had changed. This means that, if your revision is
164 /// canceled (which indicates that current query results will be
165 /// ignored) your query is free to shortcircuit and return
166 /// whatever it likes.
167 ///
168 /// This method is useful for implementing cancellation of queries.
169 /// You can do it in one of two ways, via `Result`s or via unwinding.
170 ///
171 /// The `Result` approach looks like this:
172 ///
173 /// * Some queries invoke `is_current_revision_canceled` and
174 /// return a special value, like `Err(Canceled)`, if it returns
175 /// `true`.
176 /// * Other queries propagate the special value using `?` operator.
177 /// * API around top-level queries checks if the result is `Ok` or
178 /// `Err(Canceled)`.
179 ///
180 /// The `panic` approach works in a similar way:
181 ///
182 /// * Some queries invoke `is_current_revision_canceled` and
183 /// panic with a special value, like `Canceled`, if it returns
184 /// true.
185 /// * The implementation of `Database` trait overrides
186 /// `on_propagated_panic` to throw this special value as well.
187 /// This way, panic gets propagated naturally through dependant
188 /// queries, even across the threads.
189 /// * API around top-level queries converts a `panic` into `Result` by
190 /// catching the panic (using either `std::panic::catch_unwind` or
191 /// threads) and downcasting the payload to `Canceled` (re-raising
192 /// panic if downcast fails).
193 ///
194 /// Note that salsa is explicitly designed to be panic-safe, so cancellation
195 /// via unwinding is 100% valid approach to cancellation.
196 #[inline]
197 pub fn is_current_revision_canceled(&self) -> bool {
198 let current_revision = self.current_revision();
199 let pending_revision = self.pending_revision();
200 debug!(
201 "is_current_revision_canceled: current_revision={:?}, pending_revision={:?}",
202 current_revision, pending_revision
203 );
204 if pending_revision > current_revision {
205 self.report_untracked_read();
206 true
207 } else {
208 // Subtle: If the current revision is not canceled, we
209 // still report an **anonymous** read, which will bump up
210 // the revision number to be at least the last
211 // non-canceled revision. This is needed to ensure
212 // deterministic reads and avoid salsa-rs/salsa#66. The
213 // specific scenario we are trying to avoid is tested by
214 // `no_back_dating_in_cancellation`; it works like
215 // this. Imagine we have 3 queries, where Query3 invokes
216 // Query2 which invokes Query1. Then:
217 //
218 // - In Revision R1:
219 // - Query1: Observes cancelation and returns sentinel S.
220 // - Recorded inputs: Untracked, because we observed cancelation.
221 // - Query2: Reads Query1 and propagates sentinel S.
222 // - Recorded inputs: Query1, changed-at=R1
223 // - Query3: Reads Query2 and propagates sentinel S. (Inputs = Query2, ChangedAt R1)
224 // - Recorded inputs: Query2, changed-at=R1
225 // - In Revision R2:
226 // - Query1: Observes no cancelation. All of its inputs last changed in R0,
227 // so it returns a valid value with "changed at" of R0.
228 // - Recorded inputs: ..., changed-at=R0
229 // - Query2: Recomputes its value and returns correct result.
230 // - Recorded inputs: Query1, changed-at=R0 <-- key problem!
231 // - Query3: sees that Query2's result last changed in R0, so it thinks it
232 // can re-use its value from R1 (which is the sentinel value).
233 //
234 // The anonymous read here prevents that scenario: Query1
235 // winds up with a changed-at setting of R2, which is the
236 // "pending revision", and hence Query2 and Query3
237 // are recomputed.
238 assert_eq!(pending_revision, current_revision);
239 self.report_anon_read(pending_revision);
240 false
241 }
242 }
243
244 /// Acquires the **global query write lock** (ensuring that no queries are
245 /// executing) and then increments the current revision counter; invokes
246 /// `op` with the global query write lock still held.
247 ///
248 /// While we wait to acquire the global query write lock, this method will
249 /// also increment `pending_revision_increments`, thus signalling to queries
250 /// that their results are "canceled" and they should abort as expeditiously
251 /// as possible.
252 ///
253 /// The `op` closure should actually perform the writes needed. It is given
254 /// the new revision as an argument, and its return value indicates whether
255 /// any pre-existing value was modified:
256 ///
257 /// - returning `None` means that no pre-existing value was modified (this
258 /// could occur e.g. when setting some key on an input that was never set
259 /// before)
260 /// - returning `Some(d)` indicates that a pre-existing value was modified
261 /// and it had the durability `d`. This will update the records for when
262 /// values with each durability were modified.
263 ///
264 /// Note that, given our writer model, we can assume that only one thread is
265 /// attempting to increment the global revision at a time.
266 pub(crate) fn with_incremented_revision(
267 &mut self,
268 op: &mut dyn FnMut(Revision) -> Option<Durability>,
269 ) {
270 log::debug!("increment_revision()");
271
272 if !self.permits_increment() {
273 panic!("increment_revision invoked during a query computation");
274 }
275
276 // Set the `pending_revision` field so that people
277 // know current revision is canceled.
278 let current_revision = self.shared_state.pending_revision.fetch_then_increment();
279
280 // To modify the revision, we need the lock.
281 let shared_state = self.shared_state.clone();
282 let _lock = shared_state.query_lock.write();
283
284 let old_revision = self.shared_state.revisions[0].fetch_then_increment();
285 assert_eq!(current_revision, old_revision);
286
287 let new_revision = current_revision.next();
288
289 debug!("increment_revision: incremented to {:?}", new_revision);
290
291 if let Some(d) = op(new_revision) {
292 for rev in &self.shared_state.revisions[1..=d.index()] {
293 rev.store(new_revision);
294 }
295 }
296 }
297
298 pub(crate) fn permits_increment(&self) -> bool {
299 self.revision_guard.is_none() && !self.local_state.query_in_progress()
300 }
301
302 pub(crate) fn execute_query_implementation<DB, V>(
303 &self,
304 db: &DB,
305 database_key_index: DatabaseKeyIndex,
306 execute: impl FnOnce() -> V,
307 ) -> ComputedQueryResult<V>
308 where
309 DB: ?Sized + Database,
310 {
311 debug!(
312 "{:?}: execute_query_implementation invoked",
313 database_key_index
314 );
315
316 db.salsa_event(Event {
317 runtime_id: self.id(),
318 kind: EventKind::WillExecute {
319 database_key: database_key_index,
320 },
321 });
322
323 // Push the active query onto the stack.
324 let max_durability = Durability::MAX;
325 let active_query = self
326 .local_state
327 .push_query(database_key_index, max_durability);
328
329 // Execute user's code, accumulating inputs etc.
330 let value = execute();
331
332 // Extract accumulated inputs.
333 let ActiveQuery {
334 dependencies,
335 changed_at,
336 durability,
337 cycle,
338 ..
339 } = active_query.complete();
340
341 ComputedQueryResult {
342 value,
343 durability,
344 changed_at,
345 dependencies,
346 cycle,
347 }
348 }
349
350 /// Reports that the currently active query read the result from
351 /// another query.
352 ///
353 /// # Parameters
354 ///
355 /// - `database_key`: the query whose result was read
356 /// - `changed_revision`: the last revision in which the result of that
357 /// query had changed
358 pub(crate) fn report_query_read<'hack>(
359 &self,
360 input: DatabaseKeyIndex,
361 durability: Durability,
362 changed_at: Revision,
363 ) {
364 self.local_state
365 .report_query_read(input, durability, changed_at);
366 }
367
368 /// Reports that the query depends on some state unknown to salsa.
369 ///
370 /// Queries which report untracked reads will be re-executed in the next
371 /// revision.
372 pub fn report_untracked_read(&self) {
373 self.local_state
374 .report_untracked_read(self.current_revision());
375 }
376
377 /// Acts as though the current query had read an input with the given durability; this will force the current query's durability to be at most `durability`.
378 ///
379 /// This is mostly useful to control the durability level for [on-demand inputs](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html).
380 pub fn report_synthetic_read(&self, durability: Durability) {
381 self.local_state.report_synthetic_read(durability);
382 }
383
384 /// An "anonymous" read is a read that doesn't come from executing
385 /// a query, but from some other internal operation. It just
386 /// modifies the "changed at" to be at least the given revision.
387 /// (It also does not disqualify a query from being considered
388 /// constant, since it is used for queries that don't give back
389 /// actual *data*.)
390 ///
391 /// This is used when queries check if they have been canceled.
392 fn report_anon_read(&self, revision: Revision) {
393 self.local_state.report_anon_read(revision)
394 }
395
396 /// Obviously, this should be user configurable at some point.
397 pub(crate) fn report_unexpected_cycle(
398 &self,
399 database_key_index: DatabaseKeyIndex,
400 error: CycleDetected,
401 changed_at: Revision,
402 ) -> crate::CycleError<DatabaseKeyIndex> {
403 debug!(
404 "report_unexpected_cycle(database_key={:?})",
405 database_key_index
406 );
407
408 let mut query_stack = self.local_state.borrow_query_stack_mut();
409
410 if error.from == error.to {
411 // All queries in the cycle is local
412 let start_index = query_stack
413 .iter()
414 .rposition(|active_query| active_query.database_key_index == database_key_index)
415 .unwrap();
416 let mut cycle = Vec::new();
417 let cycle_participants = &mut query_stack[start_index..];
418 for active_query in &mut *cycle_participants {
419 cycle.push(active_query.database_key_index);
420 }
421
422 assert!(!cycle.is_empty());
423
424 for active_query in cycle_participants {
425 active_query.cycle = cycle.clone();
426 }
427
428 crate::CycleError {
429 cycle,
430 changed_at,
431 durability: Durability::MAX,
432 }
433 } else {
434 // Part of the cycle is on another thread so we need to lock and inspect the shared
435 // state
436 let dependency_graph = self.shared_state.dependency_graph.lock();
437
438 let mut cycle = Vec::new();
439 dependency_graph.push_cycle_path(
440 database_key_index,
441 error.to,
442 query_stack.iter().map(|query| query.database_key_index),
443 &mut cycle,
444 );
445 cycle.push(database_key_index);
446
447 assert!(!cycle.is_empty());
448
449 for active_query in query_stack
450 .iter_mut()
451 .filter(|query| cycle.iter().any(|key| *key == query.database_key_index))
452 {
453 active_query.cycle = cycle.clone();
454 }
455
456 crate::CycleError {
457 cycle,
458 changed_at,
459 durability: Durability::MAX,
460 }
461 }
462 }
463
464 pub(crate) fn mark_cycle_participants(&self, err: &CycleError<DatabaseKeyIndex>) {
465 for active_query in self
466 .local_state
467 .borrow_query_stack_mut()
468 .iter_mut()
469 .rev()
470 .take_while(|active_query| {
471 err.cycle
472 .iter()
473 .any(|e| *e == active_query.database_key_index)
474 })
475 {
476 active_query.cycle = err.cycle.clone();
477 }
478 }
479
480 /// Try to make this runtime blocked on `other_id`. Returns true
481 /// upon success or false if `other_id` is already blocked on us.
482 pub(crate) fn try_block_on(&self, database_key: DatabaseKeyIndex, other_id: RuntimeId) -> bool {
483 self.shared_state.dependency_graph.lock().add_edge(
484 self.id(),
485 database_key,
486 other_id,
487 self.local_state
488 .borrow_query_stack()
489 .iter()
490 .map(|query| query.database_key_index),
491 )
492 }
493
494 pub(crate) fn unblock_queries_blocked_on_self(&self, database_key_index: DatabaseKeyIndex) {
495 self.shared_state
496 .dependency_graph
497 .lock()
498 .remove_edge(database_key_index, self.id())
499 }
500}
501
502/// State that will be common to all threads (when we support multiple threads)
503struct SharedState {
504 /// Stores the next id to use for a snapshotted runtime (starts at 1).
505 next_id: AtomicUsize,
506
507 /// Whenever derived queries are executing, they acquire this lock
508 /// in read mode. Mutating inputs (and thus creating a new
509 /// revision) requires a write lock (thus guaranteeing that no
510 /// derived queries are in progress). Note that this is not needed
511 /// to prevent **race conditions** -- the revision counter itself
512 /// is stored in an `AtomicUsize` so it can be cheaply read
513 /// without acquiring the lock. Rather, the `query_lock` is used
514 /// to ensure a higher-level consistency property.
515 query_lock: RwLock<()>,
516
517 /// This is typically equal to `revision` -- set to `revision+1`
518 /// when a new revision is pending (which implies that the current
519 /// revision is canceled).
520 pending_revision: AtomicRevision,
521
522 /// Stores the "last change" revision for values of each duration.
523 /// This vector is always of length at least 1 (for Durability 0)
524 /// but its total length depends on the number of durations. The
525 /// element at index 0 is special as it represents the "current
526 /// revision". In general, we have the invariant that revisions
527 /// in here are *declining* -- that is, `revisions[i] >=
528 /// revisions[i + 1]`, for all `i`. This is because when you
529 /// modify a value with durability D, that implies that values
530 /// with durability less than D may have changed too.
531 revisions: Vec<AtomicRevision>,
532
533 /// The dependency graph tracks which runtimes are blocked on one
534 /// another, waiting for queries to terminate.
535 dependency_graph: Mutex<DependencyGraph<DatabaseKeyIndex>>,
536}
537
538impl SharedState {
539 fn with_durabilities(durabilities: usize) -> Self {
540 SharedState {
541 next_id: AtomicUsize::new(1),
542 query_lock: Default::default(),
543 revisions: (0..durabilities).map(|_| AtomicRevision::start()).collect(),
544 pending_revision: AtomicRevision::start(),
545 dependency_graph: Default::default(),
546 }
547 }
548}
549
550impl std::panic::RefUnwindSafe for SharedState {}
551
552impl Default for SharedState {
553 fn default() -> Self {
554 Self::with_durabilities(Durability::LEN)
555 }
556}
557
558impl std::fmt::Debug for SharedState {
559 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560 let query_lock = if self.query_lock.try_write().is_some() {
561 "<unlocked>"
562 } else if self.query_lock.try_read().is_some() {
563 "<rlocked>"
564 } else {
565 "<wlocked>"
566 };
567 fmt.debug_struct("SharedState")
568 .field("query_lock", &query_lock)
569 .field("revisions", &self.revisions)
570 .field("pending_revision", &self.pending_revision)
571 .finish()
572 }
573}
574
575struct ActiveQuery {
576 /// What query is executing
577 database_key_index: DatabaseKeyIndex,
578
579 /// Minimum durability of inputs observed so far.
580 durability: Durability,
581
582 /// Maximum revision of all inputs observed. If we observe an
583 /// untracked read, this will be set to the most recent revision.
584 changed_at: Revision,
585
586 /// Set of subqueries that were accessed thus far, or `None` if
587 /// there was an untracked the read.
588 dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
589
590 /// Stores the entire cycle, if one is found and this query is part of it.
591 cycle: Vec<DatabaseKeyIndex>,
592}
593
594pub(crate) struct ComputedQueryResult<V> {
595 /// Final value produced
596 pub(crate) value: V,
597
598 /// Minimum durability of inputs observed so far.
599 pub(crate) durability: Durability,
600
601 /// Maximum revision of all inputs observed. If we observe an
602 /// untracked read, this will be set to the most recent revision.
603 pub(crate) changed_at: Revision,
604
605 /// Complete set of subqueries that were accessed, or `None` if
606 /// there was an untracked read.
607 pub(crate) dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
608
609 /// The cycle if one occured while computing this value
610 pub(crate) cycle: Vec<DatabaseKeyIndex>,
611}
612
613impl ActiveQuery {
614 fn new(database_key_index: DatabaseKeyIndex, max_durability: Durability) -> Self {
615 ActiveQuery {
616 database_key_index,
617 durability: max_durability,
618 changed_at: Revision::start(),
619 dependencies: Some(FxIndexSet::default()),
620 cycle: Vec::new(),
621 }
622 }
623
624 fn add_read(&mut self, input: DatabaseKeyIndex, durability: Durability, revision: Revision) {
625 if let Some(set) = &mut self.dependencies {
626 set.insert(input);
627 }
628
629 self.durability = self.durability.min(durability);
630 self.changed_at = self.changed_at.max(revision);
631 }
632
633 fn add_untracked_read(&mut self, changed_at: Revision) {
634 self.dependencies = None;
635 self.durability = Durability::LOW;
636 self.changed_at = changed_at;
637 }
638
639 fn add_synthetic_read(&mut self, durability: Durability) {
640 self.durability = self.durability.min(durability);
641 }
642
643 fn add_anon_read(&mut self, changed_at: Revision) {
644 self.changed_at = self.changed_at.max(changed_at);
645 }
646}
647
648/// A unique identifier for a particular runtime. Each time you create
649/// a snapshot, a fresh `RuntimeId` is generated. Once a snapshot is
650/// complete, its `RuntimeId` may potentially be re-used.
651#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
652pub struct RuntimeId {
653 counter: usize,
654}
655
656#[derive(Clone, Debug)]
657pub(crate) struct StampedValue<V> {
658 pub(crate) value: V,
659 pub(crate) durability: Durability,
660 pub(crate) changed_at: Revision,
661}
662
663#[derive(Debug)]
664struct Edge<K> {
665 id: RuntimeId,
666 path: Vec<K>,
667}
668
669#[derive(Debug)]
670struct DependencyGraph<K: Hash + Eq> {
671 /// A `(K -> V)` pair in this map indicates that the the runtime
672 /// `K` is blocked on some query executing in the runtime `V`.
673 /// This encodes a graph that must be acyclic (or else deadlock
674 /// will result).
675 edges: FxHashMap<RuntimeId, Edge<K>>,
676 labels: FxHashMap<K, SmallVec<[RuntimeId; 4]>>,
677}
678
679impl<K> Default for DependencyGraph<K>
680where
681 K: Hash + Eq,
682{
683 fn default() -> Self {
684 DependencyGraph {
685 edges: Default::default(),
686 labels: Default::default(),
687 }
688 }
689}
690
691impl<K> DependencyGraph<K>
692where
693 K: Hash + Eq + Clone,
694{
695 /// Attempt to add an edge `from_id -> to_id` into the result graph.
696 fn add_edge(
697 &mut self,
698 from_id: RuntimeId,
699 database_key: K,
700 to_id: RuntimeId,
701 path: impl IntoIterator<Item = K>,
702 ) -> bool {
703 assert_ne!(from_id, to_id);
704 debug_assert!(!self.edges.contains_key(&from_id));
705
706 // First: walk the chain of things that `to_id` depends on,
707 // looking for us.
708 let mut p = to_id;
709 while let Some(q) = self.edges.get(&p).map(|edge| edge.id) {
710 if q == from_id {
711 return false;
712 }
713
714 p = q;
715 }
716
717 self.edges.insert(
718 from_id,
719 Edge {
720 id: to_id,
721 path: path.into_iter().chain(Some(database_key.clone())).collect(),
722 },
723 );
724 self.labels
725 .entry(database_key.clone())
726 .or_default()
727 .push(from_id);
728 true
729 }
730
731 fn remove_edge(&mut self, database_key: K, to_id: RuntimeId) {
732 let vec = self.labels.remove(&database_key).unwrap_or_default();
733
734 for from_id in &vec {
735 let to_id1 = self.edges.remove(from_id).map(|edge| edge.id);
736 assert_eq!(Some(to_id), to_id1);
737 }
738 }
739
740 fn push_cycle_path<'a>(
741 &'a self,
742 database_key: K,
743 to: RuntimeId,
744 local_path: impl IntoIterator<Item = K>,
745 output: &mut Vec<K>,
746 ) where
747 K: std::fmt::Debug,
748 {
749 let mut current = Some((to, std::slice::from_ref(&database_key)));
750 let mut last = None;
751 let mut local_path = Some(local_path);
752
753 loop {
754 match current.take() {
755 Some((id, path)) => {
756 let link_key = path.last().unwrap();
757
758 output.extend(path.iter().cloned());
759
760 current = self.edges.get(&id).map(|edge| {
761 let i = edge.path.iter().rposition(|p| p == link_key).unwrap();
762 (edge.id, &edge.path[i + 1..])
763 });
764
765 if current.is_none() {
766 last = local_path.take().map(|local_path| {
767 local_path
768 .into_iter()
769 .skip_while(move |p| *p != *link_key)
770 .skip(1)
771 });
772 }
773 }
774 None => break,
775 }
776 }
777
778 if let Some(iter) = &mut last {
779 output.extend(iter);
780 }
781 }
782}
783
784struct RevisionGuard {
785 shared_state: Arc<SharedState>,
786}
787
788impl RevisionGuard {
789 fn new(shared_state: &Arc<SharedState>) -> Self {
790 // Subtle: we use a "recursive" lock here so that it is not an
791 // error to acquire a read-lock when one is already held (this
792 // happens when a query uses `snapshot` to spawn off parallel
793 // workers, for example).
794 //
795 // This has the side-effect that we are responsible to ensure
796 // that people contending for the write lock do not starve,
797 // but this is what we achieve via the cancellation mechanism.
798 //
799 // (In particular, since we only ever have one "mutating
800 // handle" to the database, the only contention for the global
801 // query lock occurs when there are "futures" evaluating
802 // queries in parallel, and those futures hold a read-lock
803 // already, so the starvation problem is more about them bring
804 // themselves to a close, versus preventing other people from
805 // *starting* work).
806 unsafe {
807 shared_state.query_lock.raw().lock_shared_recursive();
808 }
809
810 Self {
811 shared_state: shared_state.clone(),
812 }
813 }
814}
815
816impl Drop for RevisionGuard {
817 fn drop(&mut self) {
818 // Release our read-lock without using RAII. As documented in
819 // `Snapshot::new` above, this requires the unsafe keyword.
820 unsafe {
821 self.shared_state.query_lock.raw().unlock_shared();
822 }
823 }
824}
825
826#[cfg(test)]
827mod tests {
828 use super::*;
829
830 #[test]
831 fn dependency_graph_path1() {
832 let mut graph = DependencyGraph::default();
833 let a = RuntimeId { counter: 0 };
834 let b = RuntimeId { counter: 1 };
835 assert!(graph.add_edge(a, 2, b, vec![1]));
836 // assert!(graph.add_edge(b, &1, a, vec![3, 2]));
837 let mut v = vec![];
838 graph.push_cycle_path(1, a, vec![3, 2], &mut v);
839 assert_eq!(v, vec![1, 2]);
840 }
841
842 #[test]
843 fn dependency_graph_path2() {
844 let mut graph = DependencyGraph::default();
845 let a = RuntimeId { counter: 0 };
846 let b = RuntimeId { counter: 1 };
847 let c = RuntimeId { counter: 2 };
848 assert!(graph.add_edge(a, 3, b, vec![1]));
849 assert!(graph.add_edge(b, 4, c, vec![2, 3]));
850 // assert!(graph.add_edge(c, &1, a, vec![5, 6, 4, 7]));
851 let mut v = vec![];
852 graph.push_cycle_path(1, a, vec![5, 6, 4, 7], &mut v);
853 assert_eq!(v, vec![1, 3, 4, 7]);
854 }
855}