Skip to main content

timsrust_utils/
thread.rs

1use std::{
2    sync::{Arc, Mutex},
3    time::{Duration, Instant},
4};
5
6#[derive(Debug, Default)]
7pub struct SyncedTimer(Synced<Duration>);
8
9impl SyncedTimer {
10    pub fn update_with_func<F, T>(&self, func: F) -> T
11    where
12        F: FnOnce() -> T,
13    {
14        let start = Instant::now();
15        let result = func(); // Call the function with no arguments here
16        let elapsed = start.elapsed();
17        self.0.with_lock(|t| *t += elapsed).unwrap();
18        result
19    }
20
21    pub fn get(&self) -> Duration {
22        self.0.with_lock(|t| *t).unwrap()
23    }
24}
25
26#[derive(Debug)]
27pub struct Synced<T> {
28    internal: Arc<Mutex<T>>,
29}
30
31/// A thread-safe wrapper that provides synchronized access to an inner value.
32///
33/// `Synced<T>` wraps a value of type `T` in an `Arc<Mutex<T>>`, providing
34/// safe concurrent access across multiple threads.
35impl<T> Synced<T> {
36    /// Executes a function with exclusive access to the inner value.
37    ///
38    /// This method acquires a lock on the internal mutex, passes a mutable
39    /// reference to the inner value to the provided function, and returns
40    /// the result.
41    ///
42    /// # Arguments
43    ///
44    /// * `func` - A closure that takes a mutable reference to `T` and returns `R`
45    ///
46    /// # Returns
47    ///
48    /// * `Ok(R)` - The result of the function if the lock was acquired successfully
49    /// * `Err(PoisonError)` - If the mutex was poisoned (a thread panicked while holding the lock)
50    ///
51    /// # Examples
52    ///
53    /// ```rust
54    /// use timsrust_utils::thread::Synced;
55    ///
56    /// let synced = Synced::from(vec![1, 2, 3]);
57    ///
58    /// let sum = synced.with_lock(|data| {
59    ///     data.iter().sum::<i32>()
60    /// }).unwrap();
61    ///
62    /// assert_eq!(sum, 6);
63    /// ```
64    pub fn with_lock<F, R>(
65        &self,
66        func: F,
67    ) -> Result<R, std::sync::PoisonError<std::sync::MutexGuard<'_, T>>>
68    where
69        F: FnOnce(&mut T) -> R,
70    {
71        let mut guard = self.internal.lock()?;
72        Ok(func(&mut guard))
73    }
74
75    /// Attempts to extract the inner value, consuming the `Synced` wrapper.
76    ///
77    /// This method will only succeed if this is the last reference to the
78    /// internal `Arc` and the mutex is not poisoned.
79    ///
80    /// # Returns
81    ///
82    /// * `Some(T)` - The inner value if this was the last `Arc` reference and the mutex was not poisoned
83    /// * `None` - If there are other references to the `Arc` or the mutex was poisoned
84    ///
85    /// # Examples
86    ///
87    /// ```rust
88    /// use timsrust_utils::thread::Synced;
89    ///
90    /// let synced = Synced::from(42);
91    /// let value = synced.try_finalize();
92    ///
93    /// assert_eq!(value, Some(42));
94    /// ```
95    ///
96    /// ```rust
97    /// use timsrust_utils::thread::Synced;
98    ///
99    /// let synced = Synced::from(42);
100    /// let cloned = synced.clone();
101    ///
102    /// // Cannot finalize because there are multiple references
103    /// let value = synced.try_finalize();
104    /// assert_eq!(value, None);
105    /// ```
106    pub fn try_finalize(self) -> Option<T> {
107        Arc::into_inner(self.internal)?.into_inner().ok()
108    }
109}
110
111impl<T: Default> Default for Synced<T> {
112    fn default() -> Self {
113        Synced {
114            internal: Arc::new(Mutex::new(T::default())),
115        }
116    }
117}
118
119impl<T> From<T> for Synced<T> {
120    fn from(value: T) -> Self {
121        Synced {
122            internal: Arc::new(Mutex::new(value)),
123        }
124    }
125}
126
127impl<T> Clone for Synced<T> {
128    fn clone(&self) -> Self {
129        Synced {
130            internal: Arc::clone(&self.internal),
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use crate::vec::arg_max;
138
139    use super::*;
140
141    fn scale_to_u64(vec: &[f32], upper_bound: f32) -> Vec<u64> {
142        vec.iter()
143            .map(|&x| {
144                let t = (upper_bound * x) as u64;
145                t.min(upper_bound as u64)
146            })
147            .collect()
148    }
149
150    #[test]
151    fn test_synced_with_lock_and_finalize() {
152        let synced = Synced::from(10);
153        let res = synced.with_lock(|v| {
154            *v += 5;
155            *v
156        });
157        assert_eq!(res.unwrap(), 15);
158        let finalized = synced.try_finalize();
159        assert_eq!(finalized, Some(15));
160    }
161
162    #[test]
163    fn test_synced_default() {
164        let synced: Synced<i32> = Synced::default();
165        let val = synced.try_finalize().unwrap();
166        assert_eq!(val, 0);
167    }
168
169    #[test]
170    fn test_synced_clone() {
171        let synced = Synced::from(42);
172        {
173            let cloned = synced.clone();
174            let _ = cloned.with_lock(|v| *v += 1);
175        }
176        let orig = synced.try_finalize().unwrap();
177        assert_eq!(orig, 43);
178    }
179
180    #[test]
181    fn test_arg_max() {
182        let v = vec![1, 3, 2, 5, 4];
183        let idx = arg_max(&v);
184        assert_eq!(idx, Some(3));
185        let empty: Vec<i32> = vec![];
186        assert_eq!(arg_max(&empty), None);
187    }
188
189    #[test]
190    fn test_scale_to_u64() {
191        let v = vec![0.0, 0.5, 1.0, 1.5];
192        let scaled = scale_to_u64(&v, 10.0);
193        assert_eq!(scaled, vec![0, 5, 10, 10]);
194    }
195}