1#![deny(missing_docs)]
2
3use std::marker::PhantomData;
11use std::thread::{spawn, JoinHandle, Thread};
12use std::mem::{transmute, forget};
13
14#[must_use = "thread will be immediately joined if `JoinGuard` is not used"]
20pub struct JoinGuard<'a, T: Send + 'a> {
21 inner: Option<JoinHandle<BoxedThing>>,
22 _marker: PhantomData<&'a T>,
23}
24
25unsafe impl<'a, T: Send + 'a> Sync for JoinGuard<'a, T> {}
26
27impl<'a, T: Send + 'a> JoinGuard<'a, T> {
28 pub fn thread(&self) -> &Thread {
30 &self.inner.as_ref().unwrap().thread()
31 }
32
33 pub fn join(mut self) -> T {
39 match self.inner.take().unwrap().join() {
40 Ok(res) => unsafe { *res.into_inner() },
41 Err(_) => panic!("child thread {:?} panicked", self.thread()),
42 }
43 }
44}
45
46pub trait ScopedDetach {
48 fn detach(self);
52}
53
54impl<T: Send + 'static> ScopedDetach for JoinGuard<'static, T> {
55 fn detach(mut self) {
56 let _ = self.inner.take();
57 }
58}
59
60impl<'a, T: Send + 'a> Drop for JoinGuard<'a, T> {
61 fn drop(&mut self) {
62 self.inner.take().map(|v| if v.join().is_err() {
63 panic!("child thread {:?} panicked", self.thread());
64 });
65 }
66}
67
68pub unsafe fn scoped<'a, T, F>(f: F) -> JoinGuard<'a, T> where
70 T: Send + 'a, F: FnOnce() -> T, F: Send + 'a
71{
72 let f = BoxedThing::new(f);
73
74 JoinGuard {
75 inner: Some(spawn(move ||
76 BoxedThing::new(f.into_inner::<F>()())
77 )),
78 _marker: PhantomData,
79 }
80}
81
82struct BoxedThing(usize);
83impl BoxedThing {
84 fn new<T>(v: T) -> Self {
85 let mut b = Box::new(v);
86 let b_ptr = &mut *b as *mut _ as usize;
87 forget(b);
88 BoxedThing(b_ptr)
89 }
90
91 unsafe fn into_inner<T>(self) -> Box<T> {
92 transmute(self.0 as *mut T)
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use std::thread::sleep;
99 use std::time::Duration;
100 use super::scoped;
101
102 #[test]
103 fn test_scoped_stack() {
104 unsafe {
105 let mut a = 5;
106 scoped(|| {
107 sleep(Duration::from_millis(100));
108 a = 2;
109 }).join();
110 assert_eq!(a, 2);
111 }
112 }
113
114 #[test]
115 fn test_join_success() {
116 unsafe {
117 assert!(scoped(move|| -> String {
118 "Success!".to_string()
119 }).join() == "Success!");
120 }
121 }
122
123 #[test]
124 fn test_scoped_success() {
125 unsafe {
126 let res = scoped(move|| -> String {
127 "Success!".to_string()
128 }).join();
129 assert!(res == "Success!");
130 }
131 }
132
133 #[test]
134 #[should_panic]
135 fn test_scoped_panic() {
136 unsafe {
137 scoped(|| panic!()).join();
138 }
139 }
140
141 #[test]
142 #[should_panic]
143 fn test_scoped_implicit_panic() {
144 unsafe {
145 let _ = scoped(|| panic!());
146 }
147 }
148}