thread_local_collect/tlm/restr/
probed.rs

1//! Variant of module [`crate::tlm::probed`] with a `send` API similar to that of [`crate::tlcr::probed`].
2//!
3//! This module supports the collection and aggregation of values across threads (see package
4//! [overview and core concepts](crate)), including the ability to inspect
5//! the accumulated value before participating threads have terminated. The following capabilities and constraints apply ...
6//! - Values may be collected from the thread responsible for collection/aggregation.
7//! - The participating threads update thread-local data via the clonable `control` object which is also
8//! used to aggregate the values.
9//! - The [`Control::probe_tls`] function can be called at any time to return a clone of the current aggregated value.
10//! - The [`Control::drain_tls`] function can be called to return the accumulated value after all participating
11//! threads (other than the thread responsible for collection) have terminated (joins are not necessary).
12//!
13//! ## Usage pattern
14
15//! ```rust
16#![doc = include_str!("../../../examples/tlmrestr_probed_i32_accumulator.rs")]
17//! ````
18
19//!
20//! ## Other examples
21//!
22//! See another example at [`examples/tlmrestr_probed_map_accumulator`](https://github.com/pvillela/rust-thread-local-collect/blob/main/examples/tlmrestr_probed_map_accumulator.rs).
23
24pub use super::control_restr::ControlRestrG;
25
26use super::control_restr::WithTakeTls;
27use crate::tlm::probed::{Control as ControlInner, Holder as HolderInner, Probed};
28
29/// Specialization of [`ControlRestrG`] for this module.
30/// Controls the collection and accumulation of thread-local values linked to this object.
31///
32/// `U` is the type of the accumulated value.
33/// Partially accumulated values are held in thread-locals of type [`Holder<U>`].
34pub type Control<U> = ControlRestrG<Probed<U, Option<U>>, U>;
35
36impl<U> WithTakeTls<Probed<U, Option<U>>, U> for Control<U>
37where
38    U: 'static,
39{
40    fn take_tls(control: &ControlInner<U, Option<U>>) {
41        control.take_tls();
42    }
43}
44
45impl<U> Control<U>
46where
47    U: Clone,
48{
49    pub fn probe_tls(&self) -> U {
50        self.control
51            .probe_tls()
52            .expect("accumulator guaranteed to never be None")
53    }
54}
55
56/// Specialization of [`crate::tlm::probed::Holder`] for this module.
57/// Holds thread-local partially accumulated data of type `U` and a smart pointer to a [`Control<U>`],
58/// enabling the linkage of the held data with the control object.
59pub type Holder<U> = HolderInner<U, Option<U>>;
60
61#[cfg(test)]
62#[allow(clippy::unwrap_used)]
63mod tests {
64    use super::{Control, Holder};
65    use crate::dev_support::{assert_eq_and_println, ThreadGater};
66    use std::{
67        collections::HashMap,
68        fmt::Debug,
69        sync::Mutex,
70        thread::{self, ThreadId},
71    };
72
73    #[derive(Debug, Clone, PartialEq)]
74    struct Foo(String);
75
76    type Data = (i32, Foo);
77
78    type AccValue = HashMap<ThreadId, HashMap<i32, Foo>>;
79
80    fn op(data: Data, acc: &mut AccValue, tid: ThreadId) {
81        println!(
82            "`op` called from {:?} with data {:?}",
83            thread::current().id(),
84            data
85        );
86
87        acc.entry(tid).or_default();
88        let (k, v) = data;
89        acc.get_mut(&tid).unwrap().insert(k, v.clone());
90    }
91
92    fn op_r(acc1: AccValue, acc2: AccValue) -> AccValue {
93        println!(
94            "`op_r` called from {:?} with acc1={:?} and acc2={:?}",
95            thread::current().id(),
96            acc1,
97            acc2
98        );
99
100        let mut acc = acc1;
101        acc2.into_iter().for_each(|(k, v)| {
102            acc.insert(k, v);
103        });
104        acc
105    }
106
107    thread_local! {static MY_TL: Holder<AccValue> = Holder::new();}
108
109    const NTHREADS: usize = 5;
110
111    #[test]
112    fn own_thread_and_explicit_joins_no_probe() {
113        let mut control = Control::new(&MY_TL, HashMap::new, op_r);
114
115        let tid_own = thread::current().id();
116
117        let map_own = {
118            let value1 = Foo("a".to_owned());
119            let value2 = Foo("b".to_owned());
120            let map_own = HashMap::from([(1, value1.clone()), (2, value2.clone())]);
121
122            control.aggregate_data((1, value1), op);
123            control.aggregate_data((2, value2), op);
124
125            map_own
126        };
127
128        let tid_map_pairs = thread::scope(|s| {
129            let hs = (0..NTHREADS)
130                .map(|i| {
131                    let value1 = Foo("a".to_owned() + &i.to_string());
132                    let value2 = Foo("a".to_owned() + &i.to_string());
133                    let map_i = HashMap::from([(1, value1.clone()), (2, value2.clone())]);
134
135                    s.spawn(|| {
136                        control.aggregate_data((1, value1), op);
137                        control.aggregate_data((2, value2), op);
138
139                        let tid_spawned = thread::current().id();
140                        (tid_spawned, map_i)
141                    })
142                })
143                .collect::<Vec<_>>();
144
145            hs.into_iter()
146                .map(|h| h.join().unwrap())
147                .collect::<Vec<_>>()
148        });
149
150        {
151            let map = std::iter::once((tid_own, map_own))
152                .chain(tid_map_pairs)
153                .collect::<HashMap<_, _>>();
154
155            {
156                let acc = control.drain_tls();
157                assert_eq_and_println(&acc, &map, "Accumulator check");
158            }
159
160            // drain_tls again
161            {
162                let acc = control.drain_tls();
163                assert_eq_and_println(&acc, &HashMap::new(), "empty accumulatore expected");
164            }
165        }
166
167        // Control reused.
168        {
169            let map_own = {
170                let value1 = Foo("c".to_owned());
171                let value2 = Foo("d".to_owned());
172                let map_own = HashMap::from([(11, value1.clone()), (22, value2.clone())]);
173
174                control.aggregate_data((11, value1), op);
175                control.aggregate_data((22, value2), op);
176
177                map_own
178            };
179
180            let (tid_spawned, map_spawned) = thread::scope(|s| {
181                let control = &control;
182
183                let value1 = Foo("x".to_owned());
184                let value2 = Foo("y".to_owned());
185                let map_spawned = HashMap::from([(11, value1.clone()), (22, value2.clone())]);
186
187                let tid = s
188                    .spawn(move || {
189                        control.aggregate_data((11, value1), op);
190                        control.aggregate_data((22, value2), op);
191                        thread::current().id()
192                    })
193                    .join()
194                    .unwrap();
195
196                (tid, map_spawned)
197            });
198
199            let map = HashMap::from([(tid_own, map_own), (tid_spawned, map_spawned)]);
200            let acc = control.drain_tls();
201            assert_eq_and_println(&acc, &map, "take_acc - control reused");
202        }
203    }
204
205    #[test]
206    fn own_thread_only_no_probe() {
207        let mut control = Control::new(&MY_TL, HashMap::new, op_r);
208
209        control.aggregate_data((1, Foo("a".to_owned())), op);
210        control.aggregate_data((2, Foo("b".to_owned())), op);
211
212        let tid_own = thread::current().id();
213        let map_own = HashMap::from([(1, Foo("a".to_owned())), (2, Foo("b".to_owned()))]);
214
215        let map = HashMap::from([(tid_own, map_own)]);
216
217        let acc = control.drain_tls();
218        assert_eq_and_println(&acc, &map, "Accumulator check");
219    }
220
221    #[test]
222    fn own_thread_and_explicit_join_with_probe() {
223        let mut control = Control::new(&MY_TL, HashMap::new, op_r);
224
225        let main_tid = thread::current().id();
226        println!("main_tid={:?}", main_tid);
227
228        let main_thread_gater = ThreadGater::new("main");
229        let spawned_thread_gater = ThreadGater::new("spawned");
230
231        let expected_acc_mutex = Mutex::new(HashMap::new());
232
233        let assert_acc = |acc: &AccValue, msg: &str| {
234            let exp = expected_acc_mutex.try_lock().unwrap().clone();
235            assert_eq_and_println(acc, &exp, msg);
236        };
237
238        thread::scope(|s| {
239            let h = s.spawn(|| {
240                let spawned_tid = thread::current().id();
241                println!("spawned tid={:?}", spawned_tid);
242
243                let mut my_map = HashMap::<i32, Foo>::new();
244
245                let mut process_value = |gate: u8, k: i32, v: Foo| {
246                    main_thread_gater.wait_for(gate);
247                    control.aggregate_data((k, v.clone()), op);
248                    my_map.insert(k, v);
249                    expected_acc_mutex
250                        .try_lock()
251                        .unwrap()
252                        .insert(spawned_tid, my_map.clone());
253                    spawned_thread_gater.open(gate);
254                };
255
256                process_value(0, 1, Foo("aa".to_owned()));
257                process_value(1, 2, Foo("bb".to_owned()));
258                process_value(2, 3, Foo("cc".to_owned()));
259                process_value(3, 4, Foo("dd".to_owned()));
260            });
261
262            {
263                control.aggregate_data((1, Foo("a".to_owned())), op);
264                control.aggregate_data((2, Foo("b".to_owned())), op);
265                let my_map = HashMap::from([(1, Foo("a".to_owned())), (2, Foo("b".to_owned()))]);
266
267                let mut map = expected_acc_mutex.try_lock().unwrap();
268                map.insert(main_tid, my_map);
269                let map = map.clone(); // Mutex guard dropped here
270                let acc = control.probe_tls();
271                assert_eq_and_println(
272                    &acc,
273                    &map,
274                    "Accumulator after main thread inserts and probe_tls",
275                );
276                main_thread_gater.open(0);
277            }
278
279            {
280                spawned_thread_gater.wait_for(0);
281                let acc = control.probe_tls();
282                assert_acc(
283                    &acc,
284                    "Accumulator after 1st spawned thread insert and probe_tls",
285                );
286                main_thread_gater.open(1);
287            }
288
289            {
290                spawned_thread_gater.wait_for(1);
291                let acc = control.probe_tls();
292                assert_acc(
293                    &acc,
294                    "Accumulator after 2nd spawned thread insert and take_tls",
295                );
296                main_thread_gater.open(2);
297            }
298
299            {
300                spawned_thread_gater.wait_for(2);
301                let acc = control.probe_tls();
302                assert_acc(
303                    &acc,
304                    "Accumulator after 3rd spawned thread insert and probe_tls",
305                );
306                main_thread_gater.open(3);
307            }
308
309            {
310                // done with thread gaters
311                h.join().unwrap();
312            }
313        });
314
315        {
316            let acc = control.drain_tls();
317            assert_acc(
318                &acc,
319                "Accumulator after 4th spawned thread insert and drain_tls",
320            );
321        }
322
323        // drain_tls again
324        {
325            let acc = control.drain_tls();
326            assert_eq_and_println(&acc, &HashMap::new(), "empty accumulatore expected");
327        }
328    }
329
330    #[test]
331    fn no_thread() {
332        let mut control = Control::new(&MY_TL, HashMap::new, op_r);
333        let acc = control.drain_tls();
334        assert_eq!(acc, HashMap::new(), "empty accumulator expected");
335    }
336}