Skip to main content

wasmtime_internal_core/
undo.rs

1//! Helpers for undoing partial side effects when their larger operation fails.
2
3use core::{fmt, mem, ops};
4
5/// An RAII guard to rollback and undo something on (early) drop.
6///
7/// Dereferences to its inner `T` and its undo function is given the `T` on
8/// drop.
9///
10/// When all of the changes that need to happen together have happened, you can
11/// call `Undo::commit` to disable the guard and commit the associated side
12/// effects.
13///
14/// # Example
15///
16/// ```
17/// use std::cell::Cell;
18/// use wasmtime_internal_core::{error::Result, undo::Undo};
19///
20/// /// Some big ball of state that must always be coherent.
21/// pub struct Context {
22///     // ...
23/// }
24///
25/// impl Context {
26///     /// Perform some incremental mutation to `self`, which might not leave
27///     /// it in a valid state unless its whole batch of work is completed.
28///     fn do_thing(&mut self, arg: u32) -> Result<()> {
29/// #       let _ = arg;
30/// #       todo!()
31///         // ...
32///     }
33///
34///     /// Undo the side effects of `self.do_thing(arg)` for when we need to
35///     /// roll back mutations.
36///     fn undo_thing(&mut self, arg: u32) {
37/// #       let _ = arg;
38///         // ...
39///     }
40///
41///     /// Call `self.do_thing(arg)` for each `arg` in `args`.
42///     ///
43///     /// However, if any `self.do_thing(arg)` call fails, make sure that
44///     /// we roll back to the original state by calling `self.undo_thing(arg)`
45///     /// for all the `self.do_thing(arg)` calls that already succeeded. This
46///     /// way we never leave `self` in a state where things got half-done.
47///     pub fn do_all_or_nothing(&mut self, args: &[u32]) -> Result<()> {
48///         // Counter for our progress, so that we know how much to work undo upon
49///         // failure.
50///         let num_things_done = Cell::new(0);
51///
52///         // Wrap the `Context` in an `Undo` that rolls back our side effects if
53///         // we early-exit this function via `?`-propagation or panic unwinding.
54///         let mut ctx = Undo::new(self, |ctx| {
55///             for arg in args.iter().take(num_things_done.get()) {
56///                 ctx.undo_thing(*arg);
57///             }
58///         });
59///
60///         // Do each piece of work!
61///         for arg in args {
62///             // Note: if this call returns an error that is `?`-propagated or
63///             // triggers unwinding by panicking, then the work performed thus
64///             // far will be rolled back when `ctx` is dropped.
65///             ctx.do_thing(*arg)?;
66///
67///             // Update how much work has been completed.
68///             num_things_done.set(num_things_done.get() + 1);
69///         }
70///
71///         // We completed all of the work, so commit the `Undo` guard and
72///         // disable its cleanup function.
73///         Undo::commit(ctx);
74///
75///         Ok(())
76///     }
77/// }
78/// ```
79#[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \
80              to disable"]
81pub struct Undo<T, F>
82where
83    F: FnOnce(T),
84{
85    inner: mem::ManuallyDrop<T>,
86    undo: mem::ManuallyDrop<F>,
87}
88
89impl<T, F> Drop for Undo<T, F>
90where
91    F: FnOnce(T),
92{
93    fn drop(&mut self) {
94        // Safety: These `ManuallyDrop` fields will not be used again.
95        let inner = unsafe { mem::ManuallyDrop::take(&mut self.inner) };
96        let undo = unsafe { mem::ManuallyDrop::take(&mut self.undo) };
97        undo(inner);
98    }
99}
100
101impl<T, F> fmt::Debug for Undo<T, F>
102where
103    F: FnOnce(T),
104    T: fmt::Debug,
105{
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        f.debug_struct("Undo")
108            .field("inner", &self.inner)
109            .field("undo", &"..")
110            .finish()
111    }
112}
113
114impl<T, F> ops::Deref for Undo<T, F>
115where
116    F: FnOnce(T),
117{
118    type Target = T;
119
120    fn deref(&self) -> &Self::Target {
121        &self.inner
122    }
123}
124
125impl<T, F> ops::DerefMut for Undo<T, F>
126where
127    F: FnOnce(T),
128{
129    fn deref_mut(&mut self) -> &mut Self::Target {
130        &mut self.inner
131    }
132}
133
134impl<T, F> Undo<T, F>
135where
136    F: FnOnce(T),
137{
138    /// Create a new `Undo` guard.
139    ///
140    /// This guard will wrap the given `inner` object and call `undo(inner)`
141    /// when dropped, unless the guard is disabled via `Undo::commit`.
142    pub fn new(inner: T, undo: F) -> Self {
143        Self {
144            inner: mem::ManuallyDrop::new(inner),
145            undo: mem::ManuallyDrop::new(undo),
146        }
147    }
148
149    /// Disable this `Undo` and return its inner value.
150    ///
151    /// This `Undo`'s cleanup function will never be called.
152    pub fn commit(guard: Self) -> T {
153        let mut guard = mem::ManuallyDrop::new(guard);
154
155        // Safety: These `ManuallyDrop` fields will not be used again.
156        unsafe {
157            // Make sure to drop `undo`, even though we aren't calling it, to
158            // avoid leaking closed-over `Arc`s, for example.
159            mem::ManuallyDrop::drop(&mut guard.undo);
160
161            mem::ManuallyDrop::take(&mut guard.inner)
162        }
163    }
164}
165
166#[cfg(all(test, feature = "std"))]
167mod tests {
168    use super::*;
169    use crate::error::{Result, ensure};
170    use core::{cell::Cell, cmp};
171    use std::{panic, string::ToString};
172
173    #[derive(Default)]
174    struct Counter {
175        value: u32,
176        max_value_seen: u32,
177    }
178
179    impl Counter {
180        fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
181            f(self)?;
182            self.value += 1;
183            self.max_value_seen = cmp::max(self.max_value_seen, self.value);
184            Ok(())
185        }
186
187        fn dec(&mut self) {
188            self.value -= 1;
189        }
190
191        fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
192            let i = Cell::new(0);
193
194            let mut counter = Undo::new(self, |counter| {
195                for _ in 0..i.get() {
196                    counter.dec();
197                }
198            });
199
200            for _ in 0..n {
201                counter.inc(&mut f)?;
202                i.set(i.get() + 1);
203            }
204
205            Undo::commit(counter);
206            Ok(())
207        }
208    }
209
210    #[test]
211    fn error_propagation() {
212        let mut counter = Counter::default();
213        let result = counter.inc_n(10, |c| {
214            ensure!(c.value < 5, "uh oh");
215            Ok(())
216        });
217        assert_eq!(result.unwrap_err().to_string(), "uh oh");
218        assert_eq!(counter.value, 0);
219        assert_eq!(counter.max_value_seen, 5);
220    }
221
222    #[test]
223    fn panic_unwind() {
224        let mut counter = Counter::default();
225        let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
226            counter.inc_n(10, |c| {
227                assert!(c.value < 5);
228                Ok(())
229            })
230        }));
231        assert!(result.is_err());
232        assert_eq!(counter.value, 0);
233        assert_eq!(counter.max_value_seen, 5);
234    }
235
236    #[test]
237    fn commit() {
238        let mut counter = Counter::default();
239        let result = counter.inc_n(10, |_| Ok(()));
240        assert!(result.is_ok());
241        assert_eq!(counter.value, 10);
242        assert_eq!(counter.max_value_seen, 10);
243    }
244}