tokio_evacuate/
lib.rs

1// Copyright (c) 2018 Nuclear Furnace
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19// SOFTWARE.
20//! `tokio-evacuate` provides a way to safely "evacuate" users of a resource before forcefully
21//! removing them.
22//!
23//! In many networked applications, there comes a time when the server must shutdown or reload, and
24//! may still be actively serving traffic.  Listeners or publishers can be shut down, and remaining
25//! work can be processed while no new work is allowed.. but this may take longer than the operator
26//! is comfortable with.
27//!
28//! `Evacuate` is a middleware future, that works in conjuction with a classic "shutdown signal."
29//! By combining a way to track the number of current users, as well as a way to fire a global
30//! timeout, we allow applications to provide soft shutdown capabilities, giving work a chance to
31//! complete, before forcefully stopping computation.
32//!
33//! `tokio-evacuate` depends on Tokio facilities, and so will not work on other futures executors.
34#[macro_use]
35extern crate futures;
36extern crate parking_lot;
37extern crate slab;
38extern crate tokio_executor;
39extern crate tokio_sync;
40extern crate tokio_timer;
41
42use futures::{future::Fuse, prelude::*};
43use parking_lot::Mutex;
44use slab::Slab;
45use std::{
46    sync::{
47        atomic::{AtomicBool, AtomicUsize, Ordering},
48        Arc,
49    },
50    time::Duration,
51};
52use tokio_executor::{DefaultExecutor, Executor, SpawnError};
53use tokio_sync::task::AtomicTask;
54use tokio_timer::{clock::now as clock_now, Delay};
55
56#[derive(Default)]
57struct Inner {
58    count: AtomicUsize,
59    finished: AtomicBool,
60    notifier: AtomicTask,
61    waiters: Mutex<Slab<Arc<AtomicTask>>>,
62}
63
64/// Dispatcher for user count updates.
65///
66/// [`Warden`] is cloneable.
67#[derive(Clone)]
68pub struct Warden {
69    state: Arc<Inner>,
70}
71
72/// A future for safely "evacuating" a resource that is used by multiple parties.
73///
74/// [`Evacuate`] tracks a tripwire, the count of concurrent users, and an evacuation timeout, and
75/// functions in a two-step process: we must be tripped, and then we race to the timeout.
76///
77/// Until the tripwire completes, [`Evacuate`] will always return `Async::NotReady`.  Once we detect
78/// that the tripwire has completed, however, we immediately spawn a timeout, based on the
79/// configured value, and race between the user count dropping to zero and the timeout firing.
80///
81/// The user count is updated by calls to [`Warden::increment`] and [`Warden::decrement`].
82///
83/// [`Evacuate`] can be cloned, and all clones will become ready at the same time.
84pub struct Evacuate {
85    state: Arc<Inner>,
86    task: Arc<AtomicTask>,
87    waiter_id: usize,
88}
89
90pub struct Runner<F: Future> {
91    state: Arc<Inner>,
92    tripwire: Fuse<F>,
93    timeout_ms: u64,
94    timeout: Delay,
95}
96
97impl Inner {
98    pub fn new() -> Arc<Inner> { Arc::new(Default::default()) }
99
100    pub fn increment(&self) { self.count.fetch_add(1, Ordering::SeqCst); }
101
102    pub fn decrement(&self) {
103        if self.count.fetch_sub(1, Ordering::SeqCst) == 1 {
104            self.notifier.notify();
105        }
106    }
107
108    pub fn register(&self, waiter: Arc<AtomicTask>) -> usize {
109        let mut waiters = self.waiters.lock();
110        waiters.insert(waiter)
111    }
112
113    pub fn unregister(&self, waiter_id: usize) {
114        let mut waiters = self.waiters.lock();
115        let _ = waiters.remove(waiter_id);
116    }
117
118    pub fn notify(&self) {
119        self.finished.store(true, Ordering::SeqCst);
120
121        let waiters = self.waiters.lock();
122        for waiter in waiters.iter() {
123            waiter.1.notify();
124        }
125    }
126}
127
128impl Warden {
129    /// Increments the user count.
130    pub fn increment(&self) { self.state.increment(); }
131
132    /// Decrements the user count.
133    pub fn decrement(&self) { self.state.decrement(); }
134}
135
136impl<F: Future> Runner<F> {
137    pub(crate) fn new(tripwire: F, timeout_ms: u64, state: Arc<Inner>) -> Runner<F> {
138        Runner {
139            state,
140            tripwire: tripwire.fuse(),
141            timeout_ms,
142            timeout: Delay::new(clock_now()),
143        }
144    }
145}
146
147impl<F: Future> Future for Runner<F> {
148    type Error = ();
149    type Item = ();
150
151    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
152        self.state.notifier.register();
153
154        // We have to wait for our tripwire.
155        if !self.tripwire.is_done() {
156            let _ = try_ready!(self.tripwire.poll().map_err(|_| ()));
157
158            // If we're here, reset our delay based on the timeout.
159            self.timeout.reset(clock_now() + Duration::from_millis(self.timeout_ms));
160        }
161
162        // We've tripped, so let's see what we're at for count.  If we're at zero, then we're done,
163        // otherwise, fall through and see if we've hit our delay yet.
164        if self.state.count.load(Ordering::SeqCst) == 0 {
165            // We've tripped and we're at count 0, so we're done.  Notify waiters.
166            self.state.notify();
167            return Ok(Async::Ready(()));
168        }
169
170        // Our count isn't at zero, but let's see if we've timed out yet.
171        try_ready!(self.timeout.poll().map_err(|_| ()));
172
173        // We timed out, so mark ourselves finished and notify.
174        self.state.notify();
175
176        Ok(Async::Ready(()))
177    }
178}
179
180impl Evacuate {
181    /// Creates a new [`Evacuate`].
182    ///
183    /// The given `tripwire` is used, and the internal timeout is set to the value of `timeout_ms`.
184    ///
185    /// Returns a [`Warden`] handle, used for incrementing and decrementing the user count, an
186    /// [`Evacuate`] future, which callers can select/poll against directly, and a [`Runner`]
187    /// future which must be spawned manually to drive the inner behavior of [`Evacuate`].
188    ///
189    /// If you're using Tokio, you can call [`default_executor`] to spawn the runner on the default
190    /// executor.
191    pub fn new<F>(tripwire: F, timeout_ms: u64) -> (Warden, Evacuate, Runner<F>)
192    where
193        F: Future + Send + 'static,
194    {
195        let state = Inner::new();
196        let warden = Warden { state: state.clone() };
197
198        let task = Arc::new(AtomicTask::new());
199        let waiter_id = state.register(task.clone());
200
201        let evacuate = Evacuate {
202            state: state.clone(),
203            task,
204            waiter_id,
205        };
206
207        let runner = Runner::new(tripwire, timeout_ms, state);
208
209        (warden, evacuate, runner)
210    }
211
212    /// Creates a new [`Evacuate`], based on the default executor.
213    ///
214    /// The given `tripwire` is used, and the internal timeout is set to the value of `timeout_ms`.
215    ///
216    /// Returns a [`Warden`] handle, used for incrementing and decrementing the user count, and an
217    /// [`Evacuate`] future, which callers can select/poll against directly.
218    ///
219    /// This functions spawns a background task on the default executor which drives the state
220    /// machine powering [`Evacuate`].  This function must be called from within a running task.
221    pub fn default_executor<F>(tripwire: F, timeout_ms: u64) -> Result<(Warden, Evacuate), SpawnError>
222    where
223        F: Future + Send + 'static,
224    {
225        let (warden, evacuate, runner) = Self::new(tripwire, timeout_ms);
226
227        DefaultExecutor::current()
228            .spawn(Box::new(runner))
229            .map(move |_| (warden, evacuate))
230    }
231}
232
233impl Drop for Evacuate {
234    fn drop(&mut self) { self.state.unregister(self.waiter_id); }
235}
236
237impl Clone for Evacuate {
238    fn clone(&self) -> Self {
239        let state = self.state.clone();
240        let task = Arc::new(AtomicTask::new());
241        let waiter_id = state.register(task.clone());
242
243        Evacuate { state, task, waiter_id }
244    }
245}
246
247impl Future for Evacuate {
248    type Error = ();
249    type Item = ();
250
251    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
252        self.task.register();
253
254        if !self.state.finished.load(Ordering::SeqCst) {
255            Ok(Async::NotReady)
256        } else {
257            Ok(Async::Ready(()))
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    #[macro_use]
265    mod support;
266    use self::support::*;
267
268    use super::Evacuate;
269
270    use futures::{
271        future::{empty, ok},
272        Future,
273    };
274
275    #[test]
276    fn test_evacuate_stops_at_tripwire() {
277        mocked(|_, _| {
278            let tripwire = empty::<(), ()>();
279            let (_warden, mut evacuate, _runner) = Evacuate::new(tripwire, 10000);
280            assert_not_ready!(evacuate);
281        });
282    }
283
284    #[test]
285    fn test_evacuate_falls_through_on_tripwire() {
286        mocked(|_, _| {
287            let tripwire = ok::<(), ()>(());
288            let (_warden, mut evacuate, mut runner) = Evacuate::new(tripwire, 10000);
289            assert_not_ready!(evacuate);
290            assert_ready!(runner);
291            assert_ready!(evacuate);
292        });
293    }
294
295    #[test]
296    fn test_evacuate_stops_after_tripping_with_clients() {
297        mocked(|_, _| {
298            let tripwire = ok::<(), ()>(());
299            let (warden, mut evacuate, mut runner) = Evacuate::new(tripwire, 10000);
300            assert_not_ready!(evacuate);
301            warden.increment();
302
303            assert_not_ready!(runner);
304            assert_not_ready!(evacuate);
305        });
306    }
307
308    #[test]
309    fn test_evacuate_completes_after_client_count_ping_pong() {
310        mocked(|_, _| {
311            let tripwire = ok::<(), ()>(());
312            let (warden, mut evacuate, mut runner) = Evacuate::new(tripwire, 10000);
313            warden.increment();
314            assert_not_ready!(runner);
315            assert_not_ready!(evacuate);
316            warden.increment();
317            assert_not_ready!(runner);
318            assert_not_ready!(evacuate);
319            warden.decrement();
320            warden.decrement();
321            assert_ready!(runner);
322            assert_ready!(evacuate);
323        });
324    }
325
326    #[test]
327    fn test_evacuate_delay_before_clients_hit_zero() {
328        mocked(|timer, _| {
329            let tripwire = ok::<(), ()>(());
330            let (warden, mut evacuate, mut runner) = Evacuate::new(tripwire, 10000);
331            warden.increment();
332            assert_not_ready!(runner);
333            assert_not_ready!(evacuate);
334            warden.increment();
335            assert_not_ready!(runner);
336            assert_not_ready!(evacuate);
337            warden.decrement();
338            assert_not_ready!(runner);
339            assert_not_ready!(evacuate);
340            advance(timer, ms(10001));
341            assert_ready!(runner);
342            assert_ready!(evacuate);
343        });
344    }
345}