solana_runtime/
dependency_tracker.rs

1//! Utility to track dependent work.
2
3use std::sync::{atomic::AtomicU64, Condvar, Mutex};
4
5#[derive(Debug, Default)]
6pub struct DependencyTracker {
7    /// The current work id
8    work_id: AtomicU64,
9    /// The processed work id, if it is None, no work has been processed
10    processed_work_id: Mutex<Option<u64>>,
11    condvar: Condvar,
12}
13
14fn less_than(a: &Option<u64>, b: u64) -> bool {
15    a.is_none_or(|a| a < b)
16}
17
18impl DependencyTracker {
19    /// Acquire the next work id number.
20    /// The work id starts from 0 and increments by 1 each time it is called.
21    pub fn declare_work(&self) -> u64 {
22        self.work_id
23            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
24            + 1
25    }
26
27    /// Notify all waiting threads that a work has been processed with the given work id.
28    /// This function will update the processed work id and notify all waiting threads only if the work
29    /// id is greater than the procsessed work id. Notify a work of id number 's' will
30    /// implicitly imply that all work with id number less than 's' have been processed.
31    pub fn mark_this_and_all_previous_work_processed(&self, work_id: u64) {
32        let mut processed_work_id = self.processed_work_id.lock().unwrap();
33        if less_than(&processed_work_id, work_id) {
34            *processed_work_id = Some(work_id);
35            self.condvar.notify_all();
36        }
37    }
38
39    /// To wait for the dependency work with 'work_id' to be processed.
40    pub fn wait_for_dependency(&self, work_id: u64) {
41        if work_id == 0 {
42            return; // No need to wait for work id 0 as real work starts from 1.
43        }
44        let mut processed_work_id = self.processed_work_id.lock().unwrap();
45        while less_than(&processed_work_id, work_id) {
46            processed_work_id = self.condvar.wait(processed_work_id).unwrap();
47        }
48    }
49
50    /// Get the current work id number.
51    pub fn get_current_declared_work(&self) -> u64 {
52        self.work_id.load(std::sync::atomic::Ordering::SeqCst)
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use {
59        super::*,
60        std::{sync::Arc, thread},
61    };
62
63    #[test]
64    fn test_less_than() {
65        assert!(less_than(&None, 0));
66        assert!(less_than(&Some(0), 1));
67        assert!(!less_than(&Some(1), 1));
68        assert!(!less_than(&Some(2), 1));
69    }
70
71    #[test]
72    fn test_get_new_work_id() {
73        let dependency_tracker = DependencyTracker::default();
74        assert_eq!(dependency_tracker.declare_work(), 1);
75        assert_eq!(dependency_tracker.declare_work(), 2);
76        assert_eq!(dependency_tracker.get_current_declared_work(), 2);
77    }
78
79    #[test]
80    fn test_notify_work_processed() {
81        let dependency_tracker = DependencyTracker::default();
82        dependency_tracker.mark_this_and_all_previous_work_processed(1);
83
84        let processed_work_id = *dependency_tracker.processed_work_id.lock().unwrap();
85        assert_eq!(processed_work_id, Some(1));
86
87        // notify a smaller work id number, should not change the processed work id
88        dependency_tracker.mark_this_and_all_previous_work_processed(0);
89        let processed_work_id = *dependency_tracker.processed_work_id.lock().unwrap();
90        assert_eq!(processed_work_id, Some(1));
91        // notify a larger work id number, should change the processed work id
92        dependency_tracker.mark_this_and_all_previous_work_processed(2);
93        let processed_work_id = *dependency_tracker.processed_work_id.lock().unwrap();
94        assert_eq!(processed_work_id, Some(2));
95        // notify the same work id number, should not change the processed work id
96        dependency_tracker.mark_this_and_all_previous_work_processed(2);
97        let processed_work_id = *dependency_tracker.processed_work_id.lock().unwrap();
98        assert_eq!(processed_work_id, Some(2));
99    }
100
101    #[test]
102    fn test_wait_and_notify_work_processed() {
103        let dependency_tracker = Arc::new(DependencyTracker::default());
104        let tracker_clone = Arc::clone(&dependency_tracker);
105
106        let work = dependency_tracker.declare_work();
107        assert_eq!(work, 1);
108        let work = dependency_tracker.declare_work();
109        assert_eq!(work, 2);
110        let work_to_wait = dependency_tracker.get_current_declared_work();
111        let handle = thread::spawn(move || {
112            tracker_clone.wait_for_dependency(work_to_wait);
113        });
114
115        thread::sleep(std::time::Duration::from_millis(100));
116        dependency_tracker.mark_this_and_all_previous_work_processed(work);
117        handle.join().unwrap();
118
119        let processed_work_id = *dependency_tracker.processed_work_id.lock().unwrap();
120        assert_eq!(processed_work_id, Some(2));
121    }
122}