1#[macro_use]
35extern crate futures;
36extern crate parking_lot;
37extern crate slab;
38extern crate tokio_executor;
39extern crate tokio_sync;
40extern crate tokio_timer;
41
42use futures::{future::Fuse, prelude::*};
43use parking_lot::Mutex;
44use slab::Slab;
45use std::{
46 sync::{
47 atomic::{AtomicBool, AtomicUsize, Ordering},
48 Arc,
49 },
50 time::Duration,
51};
52use tokio_executor::{DefaultExecutor, Executor, SpawnError};
53use tokio_sync::task::AtomicTask;
54use tokio_timer::{clock::now as clock_now, Delay};
55
56#[derive(Default)]
57struct Inner {
58 count: AtomicUsize,
59 finished: AtomicBool,
60 notifier: AtomicTask,
61 waiters: Mutex<Slab<Arc<AtomicTask>>>,
62}
63
64#[derive(Clone)]
68pub struct Warden {
69 state: Arc<Inner>,
70}
71
72pub struct Evacuate {
85 state: Arc<Inner>,
86 task: Arc<AtomicTask>,
87 waiter_id: usize,
88}
89
90pub struct Runner<F: Future> {
91 state: Arc<Inner>,
92 tripwire: Fuse<F>,
93 timeout_ms: u64,
94 timeout: Delay,
95}
96
97impl Inner {
98 pub fn new() -> Arc<Inner> { Arc::new(Default::default()) }
99
100 pub fn increment(&self) { self.count.fetch_add(1, Ordering::SeqCst); }
101
102 pub fn decrement(&self) {
103 if self.count.fetch_sub(1, Ordering::SeqCst) == 1 {
104 self.notifier.notify();
105 }
106 }
107
108 pub fn register(&self, waiter: Arc<AtomicTask>) -> usize {
109 let mut waiters = self.waiters.lock();
110 waiters.insert(waiter)
111 }
112
113 pub fn unregister(&self, waiter_id: usize) {
114 let mut waiters = self.waiters.lock();
115 let _ = waiters.remove(waiter_id);
116 }
117
118 pub fn notify(&self) {
119 self.finished.store(true, Ordering::SeqCst);
120
121 let waiters = self.waiters.lock();
122 for waiter in waiters.iter() {
123 waiter.1.notify();
124 }
125 }
126}
127
128impl Warden {
129 pub fn increment(&self) { self.state.increment(); }
131
132 pub fn decrement(&self) { self.state.decrement(); }
134}
135
136impl<F: Future> Runner<F> {
137 pub(crate) fn new(tripwire: F, timeout_ms: u64, state: Arc<Inner>) -> Runner<F> {
138 Runner {
139 state,
140 tripwire: tripwire.fuse(),
141 timeout_ms,
142 timeout: Delay::new(clock_now()),
143 }
144 }
145}
146
147impl<F: Future> Future for Runner<F> {
148 type Error = ();
149 type Item = ();
150
151 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
152 self.state.notifier.register();
153
154 if !self.tripwire.is_done() {
156 let _ = try_ready!(self.tripwire.poll().map_err(|_| ()));
157
158 self.timeout.reset(clock_now() + Duration::from_millis(self.timeout_ms));
160 }
161
162 if self.state.count.load(Ordering::SeqCst) == 0 {
165 self.state.notify();
167 return Ok(Async::Ready(()));
168 }
169
170 try_ready!(self.timeout.poll().map_err(|_| ()));
172
173 self.state.notify();
175
176 Ok(Async::Ready(()))
177 }
178}
179
180impl Evacuate {
181 pub fn new<F>(tripwire: F, timeout_ms: u64) -> (Warden, Evacuate, Runner<F>)
192 where
193 F: Future + Send + 'static,
194 {
195 let state = Inner::new();
196 let warden = Warden { state: state.clone() };
197
198 let task = Arc::new(AtomicTask::new());
199 let waiter_id = state.register(task.clone());
200
201 let evacuate = Evacuate {
202 state: state.clone(),
203 task,
204 waiter_id,
205 };
206
207 let runner = Runner::new(tripwire, timeout_ms, state);
208
209 (warden, evacuate, runner)
210 }
211
212 pub fn default_executor<F>(tripwire: F, timeout_ms: u64) -> Result<(Warden, Evacuate), SpawnError>
222 where
223 F: Future + Send + 'static,
224 {
225 let (warden, evacuate, runner) = Self::new(tripwire, timeout_ms);
226
227 DefaultExecutor::current()
228 .spawn(Box::new(runner))
229 .map(move |_| (warden, evacuate))
230 }
231}
232
233impl Drop for Evacuate {
234 fn drop(&mut self) { self.state.unregister(self.waiter_id); }
235}
236
237impl Clone for Evacuate {
238 fn clone(&self) -> Self {
239 let state = self.state.clone();
240 let task = Arc::new(AtomicTask::new());
241 let waiter_id = state.register(task.clone());
242
243 Evacuate { state, task, waiter_id }
244 }
245}
246
247impl Future for Evacuate {
248 type Error = ();
249 type Item = ();
250
251 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
252 self.task.register();
253
254 if !self.state.finished.load(Ordering::SeqCst) {
255 Ok(Async::NotReady)
256 } else {
257 Ok(Async::Ready(()))
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 #[macro_use]
265 mod support;
266 use self::support::*;
267
268 use super::Evacuate;
269
270 use futures::{
271 future::{empty, ok},
272 Future,
273 };
274
275 #[test]
276 fn test_evacuate_stops_at_tripwire() {
277 mocked(|_, _| {
278 let tripwire = empty::<(), ()>();
279 let (_warden, mut evacuate, _runner) = Evacuate::new(tripwire, 10000);
280 assert_not_ready!(evacuate);
281 });
282 }
283
284 #[test]
285 fn test_evacuate_falls_through_on_tripwire() {
286 mocked(|_, _| {
287 let tripwire = ok::<(), ()>(());
288 let (_warden, mut evacuate, mut runner) = Evacuate::new(tripwire, 10000);
289 assert_not_ready!(evacuate);
290 assert_ready!(runner);
291 assert_ready!(evacuate);
292 });
293 }
294
295 #[test]
296 fn test_evacuate_stops_after_tripping_with_clients() {
297 mocked(|_, _| {
298 let tripwire = ok::<(), ()>(());
299 let (warden, mut evacuate, mut runner) = Evacuate::new(tripwire, 10000);
300 assert_not_ready!(evacuate);
301 warden.increment();
302
303 assert_not_ready!(runner);
304 assert_not_ready!(evacuate);
305 });
306 }
307
308 #[test]
309 fn test_evacuate_completes_after_client_count_ping_pong() {
310 mocked(|_, _| {
311 let tripwire = ok::<(), ()>(());
312 let (warden, mut evacuate, mut runner) = Evacuate::new(tripwire, 10000);
313 warden.increment();
314 assert_not_ready!(runner);
315 assert_not_ready!(evacuate);
316 warden.increment();
317 assert_not_ready!(runner);
318 assert_not_ready!(evacuate);
319 warden.decrement();
320 warden.decrement();
321 assert_ready!(runner);
322 assert_ready!(evacuate);
323 });
324 }
325
326 #[test]
327 fn test_evacuate_delay_before_clients_hit_zero() {
328 mocked(|timer, _| {
329 let tripwire = ok::<(), ()>(());
330 let (warden, mut evacuate, mut runner) = Evacuate::new(tripwire, 10000);
331 warden.increment();
332 assert_not_ready!(runner);
333 assert_not_ready!(evacuate);
334 warden.increment();
335 assert_not_ready!(runner);
336 assert_not_ready!(evacuate);
337 warden.decrement();
338 assert_not_ready!(runner);
339 assert_not_ready!(evacuate);
340 advance(timer, ms(10001));
341 assert_ready!(runner);
342 assert_ready!(evacuate);
343 });
344 }
345}