wasmtime_internal_core/undo.rs
1//! Helpers for undoing partial side effects when their larger operation fails.
2
3use core::{fmt, mem, ops};
4
5/// An RAII guard to rollback and undo something on (early) drop.
6///
7/// Dereferences to its inner `T` and its undo function is given the `T` on
8/// drop.
9///
10/// When all of the changes that need to happen together have happened, you can
11/// call `Undo::commit` to disable the guard and commit the associated side
12/// effects.
13///
14/// # Example
15///
16/// ```
17/// use std::cell::Cell;
18/// use wasmtime_internal_core::{error::Result, undo::Undo};
19///
20/// /// Some big ball of state that must always be coherent.
21/// pub struct Context {
22/// // ...
23/// }
24///
25/// impl Context {
26/// /// Perform some incremental mutation to `self`, which might not leave
27/// /// it in a valid state unless its whole batch of work is completed.
28/// fn do_thing(&mut self, arg: u32) -> Result<()> {
29/// # let _ = arg;
30/// # todo!()
31/// // ...
32/// }
33///
34/// /// Undo the side effects of `self.do_thing(arg)` for when we need to
35/// /// roll back mutations.
36/// fn undo_thing(&mut self, arg: u32) {
37/// # let _ = arg;
38/// // ...
39/// }
40///
41/// /// Call `self.do_thing(arg)` for each `arg` in `args`.
42/// ///
43/// /// However, if any `self.do_thing(arg)` call fails, make sure that
44/// /// we roll back to the original state by calling `self.undo_thing(arg)`
45/// /// for all the `self.do_thing(arg)` calls that already succeeded. This
46/// /// way we never leave `self` in a state where things got half-done.
47/// pub fn do_all_or_nothing(&mut self, args: &[u32]) -> Result<()> {
48/// // Counter for our progress, so that we know how much to work undo upon
49/// // failure.
50/// let num_things_done = Cell::new(0);
51///
52/// // Wrap the `Context` in an `Undo` that rolls back our side effects if
53/// // we early-exit this function via `?`-propagation or panic unwinding.
54/// let mut ctx = Undo::new(self, |ctx| {
55/// for arg in args.iter().take(num_things_done.get()) {
56/// ctx.undo_thing(*arg);
57/// }
58/// });
59///
60/// // Do each piece of work!
61/// for arg in args {
62/// // Note: if this call returns an error that is `?`-propagated or
63/// // triggers unwinding by panicking, then the work performed thus
64/// // far will be rolled back when `ctx` is dropped.
65/// ctx.do_thing(*arg)?;
66///
67/// // Update how much work has been completed.
68/// num_things_done.set(num_things_done.get() + 1);
69/// }
70///
71/// // We completed all of the work, so commit the `Undo` guard and
72/// // disable its cleanup function.
73/// Undo::commit(ctx);
74///
75/// Ok(())
76/// }
77/// }
78/// ```
79#[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \
80 to disable"]
81pub struct Undo<T, F>
82where
83 F: FnOnce(T),
84{
85 inner: mem::ManuallyDrop<T>,
86 undo: mem::ManuallyDrop<F>,
87}
88
89impl<T, F> Drop for Undo<T, F>
90where
91 F: FnOnce(T),
92{
93 fn drop(&mut self) {
94 // Safety: These `ManuallyDrop` fields will not be used again.
95 let inner = unsafe { mem::ManuallyDrop::take(&mut self.inner) };
96 let undo = unsafe { mem::ManuallyDrop::take(&mut self.undo) };
97 undo(inner);
98 }
99}
100
101impl<T, F> fmt::Debug for Undo<T, F>
102where
103 F: FnOnce(T),
104 T: fmt::Debug,
105{
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 f.debug_struct("Undo")
108 .field("inner", &self.inner)
109 .field("undo", &"..")
110 .finish()
111 }
112}
113
114impl<T, F> ops::Deref for Undo<T, F>
115where
116 F: FnOnce(T),
117{
118 type Target = T;
119
120 fn deref(&self) -> &Self::Target {
121 &self.inner
122 }
123}
124
125impl<T, F> ops::DerefMut for Undo<T, F>
126where
127 F: FnOnce(T),
128{
129 fn deref_mut(&mut self) -> &mut Self::Target {
130 &mut self.inner
131 }
132}
133
134impl<T, F> Undo<T, F>
135where
136 F: FnOnce(T),
137{
138 /// Create a new `Undo` guard.
139 ///
140 /// This guard will wrap the given `inner` object and call `undo(inner)`
141 /// when dropped, unless the guard is disabled via `Undo::commit`.
142 pub fn new(inner: T, undo: F) -> Self {
143 Self {
144 inner: mem::ManuallyDrop::new(inner),
145 undo: mem::ManuallyDrop::new(undo),
146 }
147 }
148
149 /// Disable this `Undo` and return its inner value.
150 ///
151 /// This `Undo`'s cleanup function will never be called.
152 pub fn commit(guard: Self) -> T {
153 let mut guard = mem::ManuallyDrop::new(guard);
154
155 // Safety: These `ManuallyDrop` fields will not be used again.
156 unsafe {
157 // Make sure to drop `undo`, even though we aren't calling it, to
158 // avoid leaking closed-over `Arc`s, for example.
159 mem::ManuallyDrop::drop(&mut guard.undo);
160
161 mem::ManuallyDrop::take(&mut guard.inner)
162 }
163 }
164}
165
166#[cfg(all(test, feature = "std"))]
167mod tests {
168 use super::*;
169 use crate::error::{Result, ensure};
170 use core::{cell::Cell, cmp};
171 use std::{panic, string::ToString};
172
173 #[derive(Default)]
174 struct Counter {
175 value: u32,
176 max_value_seen: u32,
177 }
178
179 impl Counter {
180 fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
181 f(self)?;
182 self.value += 1;
183 self.max_value_seen = cmp::max(self.max_value_seen, self.value);
184 Ok(())
185 }
186
187 fn dec(&mut self) {
188 self.value -= 1;
189 }
190
191 fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
192 let i = Cell::new(0);
193
194 let mut counter = Undo::new(self, |counter| {
195 for _ in 0..i.get() {
196 counter.dec();
197 }
198 });
199
200 for _ in 0..n {
201 counter.inc(&mut f)?;
202 i.set(i.get() + 1);
203 }
204
205 Undo::commit(counter);
206 Ok(())
207 }
208 }
209
210 #[test]
211 fn error_propagation() {
212 let mut counter = Counter::default();
213 let result = counter.inc_n(10, |c| {
214 ensure!(c.value < 5, "uh oh");
215 Ok(())
216 });
217 assert_eq!(result.unwrap_err().to_string(), "uh oh");
218 assert_eq!(counter.value, 0);
219 assert_eq!(counter.max_value_seen, 5);
220 }
221
222 #[test]
223 fn panic_unwind() {
224 let mut counter = Counter::default();
225 let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
226 counter.inc_n(10, |c| {
227 assert!(c.value < 5);
228 Ok(())
229 })
230 }));
231 assert!(result.is_err());
232 assert_eq!(counter.value, 0);
233 assert_eq!(counter.max_value_seen, 5);
234 }
235
236 #[test]
237 fn commit() {
238 let mut counter = Counter::default();
239 let result = counter.inc_n(10, |_| Ok(()));
240 assert!(result.is_ok());
241 assert_eq!(counter.value, 10);
242 assert_eq!(counter.max_value_seen, 10);
243 }
244}