Skip to main content

salsa/
update.rs

1#![allow(clippy::undocumented_unsafe_blocks)] // TODO(#697) document safety
2
3use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
4use std::hash::{BuildHasher, Hash};
5use std::marker::PhantomData;
6use std::path::PathBuf;
7
8#[cfg(feature = "rayon")]
9use rayon::iter::Either;
10
11use crate::sync::Arc;
12
13/// This is used by the macro generated code.
14/// If possible, uses `Update` trait, but else requires `'static`.
15///
16/// To use:
17///
18/// ```rust,ignore
19/// use crate::update::helper::Fallback;
20/// update::helper::Dispatch::<$ty>::maybe_update(pointer, new_value);
21/// ```
22///
23/// It is important that you specify the `$ty` explicitly.
24///
25/// This uses the ["method dispatch hack"](https://github.com/nvzqz/impls#how-it-works)
26/// to use the `Update` trait if it is available and else fallback to `'static`.
27pub mod helper {
28    use std::marker::PhantomData;
29
30    use super::{Update, update_fallback};
31
32    pub struct Dispatch<D>(PhantomData<D>);
33
34    #[allow(clippy::new_without_default)]
35    impl<D> Dispatch<D> {
36        pub fn new() -> Self {
37            Dispatch(PhantomData)
38        }
39    }
40
41    impl<D> Dispatch<D>
42    where
43        D: Update,
44    {
45        /// # Safety
46        ///
47        /// See the `maybe_update` method in the [`Update`][] trait.
48        pub unsafe fn maybe_update(old_pointer: *mut D, new_value: D) -> bool {
49            // SAFETY: Same safety conditions as `Update::maybe_update`
50            unsafe { D::maybe_update(old_pointer, new_value) }
51        }
52    }
53
54    /// # Safety
55    ///
56    /// Impl will fulfill the postconditions of `maybe_update`
57    pub unsafe trait Fallback<T> {
58        /// # Safety
59        ///
60        /// Same safety conditions as `Update::maybe_update`
61        unsafe fn maybe_update(old_pointer: *mut T, new_value: T) -> bool;
62    }
63
64    // SAFETY: Same safety conditions as `Update::maybe_update`
65    unsafe impl<T: 'static + PartialEq> Fallback<T> for Dispatch<T> {
66        unsafe fn maybe_update(old_pointer: *mut T, new_value: T) -> bool {
67            // SAFETY: Same safety conditions as `Update::maybe_update`
68            unsafe { update_fallback(old_pointer, new_value) }
69        }
70    }
71}
72
73/// "Fallback" for maybe-update that is suitable for fully owned T
74/// that implement `Eq`. In this version, we update only if the new value
75/// is not `Eq` to the old one. Note that given `Eq` impls that are not just
76/// structurally comparing fields, this may cause us not to update even if
77/// the value has changed (presumably because this change is not semantically
78/// significant).
79///
80/// # Safety
81///
82/// See `Update::maybe_update`
83pub unsafe fn update_fallback<T>(old_pointer: *mut T, new_value: T) -> bool
84where
85    T: 'static + PartialEq,
86{
87    // SAFETY: Because everything is owned, this ref is simply a valid `&mut`
88    let old_ref: &mut T = unsafe { &mut *old_pointer };
89
90    if *old_ref != new_value {
91        *old_ref = new_value;
92        true
93    } else {
94        // Subtle but important: Eq impls can be buggy or define equality
95        // in surprising ways. If it says that the value has not changed,
96        // we do not modify the existing value, and thus do not have to
97        // update the revision, as downstream code will not see the new value.
98        false
99    }
100}
101
102/// Helper for generated code. Updates `*old_pointer` with `new_value`.
103/// Used for fields tagged with `#[no_eq]`
104///
105/// # Safety
106///
107/// See `Update::maybe_update`
108pub unsafe fn always_update<T>(old_pointer: *mut T, new_value: T) -> bool {
109    unsafe { *old_pointer = new_value };
110
111    true
112}
113
114/// # Safety
115///
116/// Implementing this trait requires the implementor to verify:
117///
118/// * `maybe_update` ensures the properties it is intended to ensure.
119/// * If the value implements `Eq`, it is safe to compare an instance
120///   of the value from an older revision with one from the newer
121///   revision. If the value compares as equal, no update is needed to
122///   bring it into the newer revision.
123///
124/// NB: The second point implies that `Update` cannot be implemented for any
125/// `&'db T` -- (i.e., any Rust reference tied to the database).
126/// Such a value could refer to memory that was freed in some
127/// earlier revision. Even if the memory is still valid, it could also
128/// have been part of a tracked struct whose values were mutated,
129/// thus invalidating the `'db` lifetime (from a stacked borrows perspective).
130/// Either way, the `Eq` implementation would be invalid.
131pub unsafe trait Update {
132    /// # Returns
133    ///
134    /// True if the value should be considered to have changed in the new revision.
135    ///
136    /// # Safety
137    ///
138    /// ## Requires
139    ///
140    /// Informally, requires that `old_value` points to a value in the
141    /// database that is potentially from a previous revision and `new_value`
142    /// points to a value produced in this revision.
143    ///
144    /// More formally, requires that
145    ///
146    /// * all parameters meet the [validity and safety invariants][i] for their type
147    /// * `old_value` further points to allocated memory that meets the [validity invariant][i] for `Self`
148    /// * all data *owned* by `old_value` further meets its safety invariant
149    ///     * not that borrowed data in `old_value` only meets its validity invariant
150    ///       and hence cannot be dereferenced; essentially, a `&T` may point to memory
151    ///       in the database which has been modified or even freed in the newer revision.
152    ///
153    /// [i]: https://www.ralfj.de/blog/2018/08/22/two-kinds-of-invariants.html
154    ///
155    /// ## Ensures
156    ///
157    /// That `old_value` is updated with
158    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool;
159}
160
161unsafe impl Update for std::convert::Infallible {
162    unsafe fn maybe_update(_old_pointer: *mut Self, new_value: Self) -> bool {
163        match new_value {}
164    }
165}
166
167macro_rules! maybe_update_vec {
168    ($old_pointer: expr, $new_vec: expr, $elem_ty: ty) => {{
169        let old_pointer = $old_pointer;
170        let new_vec = $new_vec;
171
172        let old_vec: &mut Self = unsafe { &mut *old_pointer };
173
174        if old_vec.len() != new_vec.len() {
175            old_vec.clear();
176            old_vec.extend(new_vec);
177            return true;
178        }
179
180        let mut changed = false;
181        for (old_element, new_element) in old_vec.iter_mut().zip(new_vec) {
182            changed |= unsafe { <$elem_ty>::maybe_update(old_element, new_element) };
183        }
184
185        changed
186    }};
187}
188
189unsafe impl<T> Update for Vec<T>
190where
191    T: Update,
192{
193    unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
194        maybe_update_vec!(old_pointer, new_vec, T)
195    }
196}
197
198unsafe impl<T> Update for thin_vec::ThinVec<T>
199where
200    T: Update,
201{
202    unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
203        maybe_update_vec!(old_pointer, new_vec, T)
204    }
205}
206
207unsafe impl<A> Update for smallvec::SmallVec<A>
208where
209    A: smallvec::Array,
210    A::Item: Update,
211{
212    unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
213        maybe_update_vec!(old_pointer, new_vec, A::Item)
214    }
215}
216
217macro_rules! maybe_update_set {
218    ($old_pointer: expr, $new_set: expr) => {{
219        let old_pointer = $old_pointer;
220        let new_set = $new_set;
221
222        let old_set: &mut Self = unsafe { &mut *old_pointer };
223
224        if *old_set == new_set {
225            false
226        } else {
227            old_set.clear();
228            old_set.extend(new_set);
229            return true;
230        }
231    }};
232}
233
234unsafe impl<K, S> Update for HashSet<K, S>
235where
236    K: Update + Eq + Hash,
237    S: BuildHasher,
238{
239    unsafe fn maybe_update(old_pointer: *mut Self, new_set: Self) -> bool {
240        maybe_update_set!(old_pointer, new_set)
241    }
242}
243
244unsafe impl<K> Update for BTreeSet<K>
245where
246    K: Update + Eq + Ord,
247{
248    unsafe fn maybe_update(old_pointer: *mut Self, new_set: Self) -> bool {
249        maybe_update_set!(old_pointer, new_set)
250    }
251}
252
253// Duck typing FTW, it was too annoying to make a proper function out of this.
254macro_rules! maybe_update_map {
255    ($old_pointer: expr, $new_map: expr) => {
256        'function: {
257            let old_pointer = $old_pointer;
258            let new_map = $new_map;
259            let old_map: &mut Self = unsafe { &mut *old_pointer };
260
261            // To be considered "equal", the set of keys
262            // must be the same between the two maps.
263            let same_keys =
264                old_map.len() == new_map.len() && old_map.keys().all(|k| new_map.contains_key(k));
265
266            // If the set of keys has changed, then just pull in the new values
267            // from new_map and discard the old ones.
268            if !same_keys {
269                old_map.clear();
270                old_map.extend(new_map);
271                break 'function true;
272            }
273
274            // Otherwise, recursively descend to the values.
275            // We do not invoke `K::update` because we assume
276            // that if the values are `Eq` they must not need
277            // updating (see the trait criteria).
278            let mut changed = false;
279            for (key, new_value) in new_map.into_iter() {
280                let old_value = old_map.get_mut(&key).unwrap();
281                changed |= unsafe { V::maybe_update(old_value, new_value) };
282            }
283            changed
284        }
285    };
286}
287
288unsafe impl<K, V, S> Update for HashMap<K, V, S>
289where
290    K: Update + Eq + Hash,
291    V: Update,
292    S: BuildHasher,
293{
294    unsafe fn maybe_update(old_pointer: *mut Self, new_map: Self) -> bool {
295        maybe_update_map!(old_pointer, new_map)
296    }
297}
298
299unsafe impl<K, V, S> Update for hashbrown::HashMap<K, V, S>
300where
301    K: Update + Eq + Hash,
302    V: Update,
303    S: BuildHasher,
304{
305    unsafe fn maybe_update(old_pointer: *mut Self, new_map: Self) -> bool {
306        maybe_update_map!(old_pointer, new_map)
307    }
308}
309
310unsafe impl<K, S> Update for hashbrown::HashSet<K, S>
311where
312    K: Update + Eq + Hash,
313    S: BuildHasher,
314{
315    unsafe fn maybe_update(old_pointer: *mut Self, new_set: Self) -> bool {
316        maybe_update_set!(old_pointer, new_set)
317    }
318}
319
320unsafe impl<K, V, S> Update for indexmap::IndexMap<K, V, S>
321where
322    K: Update + Eq + Hash,
323    V: Update,
324    S: BuildHasher,
325{
326    unsafe fn maybe_update(old_pointer: *mut Self, new_map: Self) -> bool {
327        maybe_update_map!(old_pointer, new_map)
328    }
329}
330
331unsafe impl<K, S> Update for indexmap::IndexSet<K, S>
332where
333    K: Update + Eq + Hash,
334    S: BuildHasher,
335{
336    unsafe fn maybe_update(old_pointer: *mut Self, new_set: Self) -> bool {
337        maybe_update_set!(old_pointer, new_set)
338    }
339}
340
341#[cfg(feature = "ordermap")]
342unsafe impl<K, V, S> Update for ordermap::OrderMap<K, V, S>
343where
344    K: Update + Eq + Hash,
345    V: Update,
346    S: BuildHasher,
347{
348    unsafe fn maybe_update(old_pointer: *mut Self, new_map: Self) -> bool {
349        let old_map = unsafe { &mut *old_pointer };
350
351        if old_map.keys().ne(new_map.keys()) {
352            *old_map = new_map;
353            true
354        } else {
355            maybe_update_map!(old_pointer, new_map)
356        }
357    }
358}
359
360#[cfg(feature = "ordermap")]
361unsafe impl<K, S> Update for ordermap::OrderSet<K, S>
362where
363    K: Update + Eq + Hash,
364    S: BuildHasher,
365{
366    unsafe fn maybe_update(old_pointer: *mut Self, new_set: Self) -> bool {
367        maybe_update_set!(old_pointer, new_set)
368    }
369}
370
371unsafe impl<K, V> Update for BTreeMap<K, V>
372where
373    K: Update + Eq + Ord,
374    V: Update,
375{
376    unsafe fn maybe_update(old_pointer: *mut Self, new_map: Self) -> bool {
377        maybe_update_map!(old_pointer, new_map)
378    }
379}
380
381unsafe impl<T> Update for Box<T>
382where
383    T: Update,
384{
385    unsafe fn maybe_update(old_pointer: *mut Self, new_box: Self) -> bool {
386        let old_box: &mut Box<T> = unsafe { &mut *old_pointer };
387
388        unsafe { T::maybe_update(&mut **old_box, *new_box) }
389    }
390}
391
392unsafe impl<T> Update for Box<[T]>
393where
394    T: Update,
395{
396    unsafe fn maybe_update(old_pointer: *mut Self, new_box: Self) -> bool {
397        let old_box: &mut Box<[T]> = unsafe { &mut *old_pointer };
398
399        if old_box.len() == new_box.len() {
400            let mut changed = false;
401            for (old_element, new_element) in old_box.iter_mut().zip(new_box) {
402                changed |= unsafe { T::maybe_update(old_element, new_element) };
403            }
404            changed
405        } else {
406            *old_box = new_box;
407            true
408        }
409    }
410}
411
412unsafe impl<T> Update for Arc<T>
413where
414    T: Update,
415{
416    unsafe fn maybe_update(old_pointer: *mut Self, new_arc: Self) -> bool {
417        let old_arc: &mut Arc<T> = unsafe { &mut *old_pointer };
418
419        if Arc::ptr_eq(old_arc, &new_arc) {
420            return false;
421        }
422
423        if let Some(inner) = Arc::get_mut(old_arc) {
424            match Arc::try_unwrap(new_arc) {
425                Ok(new_inner) => unsafe { T::maybe_update(inner, new_inner) },
426                Err(new_arc) => {
427                    // We can't unwrap the new arc, so we have to update the old one in place.
428                    *old_arc = new_arc;
429                    true
430                }
431            }
432        } else {
433            unsafe { *old_pointer = new_arc };
434            true
435        }
436    }
437}
438
439#[cfg(feature = "triomphe")]
440unsafe impl<T> Update for triomphe::Arc<T>
441where
442    T: Update,
443{
444    unsafe fn maybe_update(old_pointer: *mut Self, new_arc: Self) -> bool {
445        let old_arc: &mut triomphe::Arc<T> = unsafe { &mut *old_pointer };
446
447        if triomphe::Arc::ptr_eq(old_arc, &new_arc) {
448            return false;
449        }
450
451        if let Some(inner) = triomphe::Arc::get_mut(old_arc) {
452            match triomphe::Arc::try_unwrap(new_arc) {
453                Ok(new_inner) => unsafe { T::maybe_update(inner, new_inner) },
454                Err(new_arc) => {
455                    // We can't unwrap the new arc, so we have to update the old one in place.
456                    *old_arc = new_arc;
457                    true
458                }
459            }
460        } else {
461            unsafe { *old_pointer = new_arc };
462            true
463        }
464    }
465}
466
467unsafe impl<T, const N: usize> Update for [T; N]
468where
469    T: Update,
470{
471    unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
472        let old_pointer = old_pointer.cast::<T>();
473        let mut changed = false;
474        for (new_element, i) in new_vec.into_iter().zip(0..) {
475            changed |= unsafe { T::maybe_update(old_pointer.add(i), new_element) };
476        }
477        changed
478    }
479}
480
481unsafe impl<T, E> Update for Result<T, E>
482where
483    T: Update,
484    E: Update,
485{
486    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
487        let old_value = unsafe { &mut *old_pointer };
488        match (old_value, new_value) {
489            (Ok(old), Ok(new)) => unsafe { T::maybe_update(old, new) },
490            (Err(old), Err(new)) => unsafe { E::maybe_update(old, new) },
491            (old_value, new_value) => {
492                *old_value = new_value;
493                true
494            }
495        }
496    }
497}
498
499#[cfg(feature = "rayon")]
500unsafe impl<L, R> Update for Either<L, R>
501where
502    L: Update,
503    R: Update,
504{
505    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
506        let old_value = unsafe { &mut *old_pointer };
507        match (old_value, new_value) {
508            (Either::Left(old), Either::Left(new)) => unsafe { L::maybe_update(old, new) },
509            (Either::Right(old), Either::Right(new)) => unsafe { R::maybe_update(old, new) },
510            (old_value, new_value) => {
511                *old_value = new_value;
512                true
513            }
514        }
515    }
516}
517
518macro_rules! fallback_impl {
519    ($($t:ty,)*) => {
520        $(
521            unsafe impl Update for $t {
522                unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
523                    unsafe { update_fallback(old_pointer, new_value) }
524                }
525            }
526        )*
527    }
528}
529
530fallback_impl! {
531    String,
532    i128,
533    u128,
534    i64,
535    u64,
536    i32,
537    u32,
538    i16,
539    u16,
540    i8,
541    u8,
542    bool,
543    f32,
544    f64,
545    usize,
546    isize,
547    PathBuf,
548}
549
550#[cfg(feature = "compact_str")]
551fallback_impl! { compact_str::CompactString, }
552
553macro_rules! tuple_impl {
554    ($($t:ident),*; $($u:ident),*) => {
555        unsafe impl<$($t),*> Update for ($($t,)*)
556        where
557            $($t: Update,)*
558        {
559            #[allow(non_snake_case)]
560            unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
561                let ($($t,)*) = new_value;
562                let ($($u,)*) = unsafe { &mut *old_pointer };
563
564                #[allow(unused_mut)]
565                let mut changed = false;
566                $(
567                    unsafe { changed |= Update::maybe_update($u, $t); }
568                )*
569                changed
570            }
571        }
572    }
573}
574
575// Create implementations for tuples up to arity 12
576tuple_impl!(;);
577tuple_impl!(A; a);
578tuple_impl!(A, B; a, b);
579tuple_impl!(A, B, C; a, b, c);
580tuple_impl!(A, B, C, D; a, b, c, d);
581tuple_impl!(A, B, C, D, E; a, b, c, d, e);
582tuple_impl!(A, B, C, D, E, F; a, b, c, d, e, f);
583tuple_impl!(A, B, C, D, E, F, G; a, b, c, d, e, f, g);
584tuple_impl!(A, B, C, D, E, F, G, H; a, b, c, d, e, f, g, h);
585tuple_impl!(A, B, C, D, E, F, G, H, I; a, b, c, d, e, f, g, h, i);
586tuple_impl!(A, B, C, D, E, F, G, H, I, J; a, b, c, d, e, f, g, h, i, j);
587tuple_impl!(A, B, C, D, E, F, G, H, I, J, K; a, b, c, d, e, f, g, h, i, j, k);
588tuple_impl!(A, B, C, D, E, F, G, H, I, J, K, L; a, b, c, d, e, f, g, h, i, j, k, l);
589
590unsafe impl<T> Update for Option<T>
591where
592    T: Update,
593{
594    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
595        let old_value = unsafe { &mut *old_pointer };
596        match (old_value, new_value) {
597            (Some(old), Some(new)) => unsafe { T::maybe_update(old, new) },
598            (None, None) => false,
599            (old_value, new_value) => {
600                *old_value = new_value;
601                true
602            }
603        }
604    }
605}
606
607unsafe impl<T> Update for PhantomData<T> {
608    unsafe fn maybe_update(_old_pointer: *mut Self, _new_value: Self) -> bool {
609        false
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::Update;
616
617    #[test]
618    #[cfg(feature = "ordermap")]
619    fn update_order_map_reorders_entries() {
620        let mut old = ordermap::OrderMap::from([(1_u32, 10_u32), (2, 20)]);
621        let new = ordermap::OrderMap::from([(2_u32, 20_u32), (1, 10)]);
622
623        // SAFETY: `old` is valid for reads and writes for the duration of this call.
624        let changed = unsafe { Update::maybe_update(&mut old, new.clone()) };
625
626        assert!(changed);
627        assert_eq!(old, new);
628    }
629
630    #[test]
631    fn update_empty_array() {
632        let mut old: [u32; 0] = [];
633
634        // SAFETY: `old` is valid for reads and writes for the duration of this call.
635        let changed = unsafe { Update::maybe_update(&mut old, []) };
636
637        assert!(!changed);
638    }
639}