sycamore_reactive/
root.rs

1//! [`Root`] and [`Scope`].
2
3use std::cell::{Cell, RefCell};
4
5use slotmap::{Key, SlotMap};
6use smallvec::SmallVec;
7
8use crate::*;
9
10/// The struct managing the state of the reactive system. Only one should be created per running
11/// app.
12///
13/// Often times, this is intended to be leaked to be able to get a `&'static Root`. However, the
14/// `Root` is also `dispose`-able, meaning that any resources allocated in this `Root` will get
15/// deallocated. Therefore in practice, there should be no memory leak at all except for the `Root`
16/// itself. Finally, the `Root` is expected to live for the whole duration of the app so this is
17/// not a problem.
18pub(crate) struct Root {
19    /// If this is `Some`, that means we are tracking signal accesses.
20    pub tracker: RefCell<Option<DependencyTracker>>,
21    /// A temporary buffer used in `propagate_updates` to prevent allocating a new Vec every time
22    /// it is called.
23    pub rev_sorted_buf: RefCell<Vec<NodeId>>,
24    /// The current node that owns everything created in its scope.
25    /// If we are at the top-level, then this is the "null" key.
26    pub current_node: Cell<NodeId>,
27    /// The root node of the reactive graph.
28    pub root_node: Cell<NodeId>,
29    /// All the nodes created in this `Root`.
30    pub nodes: RefCell<SlotMap<NodeId, ReactiveNode>>,
31    /// A list of signals who need their values to be propagated after the batch is over.
32    pub node_update_queue: RefCell<Vec<NodeId>>,
33    /// Whether we are currently batching signal updates. If this is true, we do not run
34    /// `effect_queue` and instead wait until the end of the batch.
35    pub batching: Cell<bool>,
36}
37
38thread_local! {
39    /// The current reactive root.
40    static GLOBAL_ROOT: Cell<Option<&'static Root>> = const { Cell::new(None) };
41}
42
43impl Root {
44    /// Get the current reactive root. Panics if no root is found.
45    #[cfg_attr(debug_assertions, track_caller)]
46    pub fn global() -> &'static Root {
47        GLOBAL_ROOT.with(|root| root.get()).expect("no root found")
48    }
49
50    /// Sets the current reactive root. Returns the previous root.
51    pub fn set_global(root: Option<&'static Root>) -> Option<&'static Root> {
52        GLOBAL_ROOT.with(|r| r.replace(root))
53    }
54
55    /// Create a new reactive root. This root is leaked and so lives until the end of the program.
56    pub fn new_static() -> &'static Self {
57        let this = Self {
58            tracker: RefCell::new(None),
59            rev_sorted_buf: RefCell::new(Vec::new()),
60            current_node: Cell::new(NodeId::null()),
61            root_node: Cell::new(NodeId::null()),
62            nodes: RefCell::new(SlotMap::default()),
63            node_update_queue: RefCell::new(Vec::new()),
64            batching: Cell::new(false),
65        };
66        let _ref = Box::leak(Box::new(this));
67        _ref.reinit();
68        _ref
69    }
70
71    /// Disposes of all the resources held on by this root and resets the state.
72    pub fn reinit(&'static self) {
73        // Dispose the root node.
74        NodeHandle(self.root_node.get(), self).dispose();
75
76        let _ = self.tracker.take();
77        let _ = self.rev_sorted_buf.take();
78        let _ = self.node_update_queue.take();
79        let _ = self.current_node.take();
80        let _ = self.root_node.take();
81        let _ = self.nodes.take();
82        self.batching.set(false);
83
84        // Create a new root node.
85        Root::set_global(Some(self));
86        let root_node = create_child_scope(|| {});
87        Root::set_global(None);
88        self.root_node.set(root_node.0);
89        self.current_node.set(root_node.0);
90    }
91
92    /// Create a new child scope. Implementation detail for [`create_child_scope`].
93    pub fn create_child_scope(&'static self, f: impl FnOnce()) -> NodeHandle {
94        let node = create_signal(()).id;
95        let prev = self.current_node.replace(node);
96        f();
97        self.current_node.set(prev);
98        NodeHandle(node, self)
99    }
100
101    /// Run the provided closure in a tracked scope. This will detect all the signals that are
102    /// accessed and track them in a dependency list.
103    pub fn tracked_scope<T>(&self, f: impl FnOnce() -> T) -> (T, DependencyTracker) {
104        let prev = self.tracker.replace(Some(DependencyTracker::default()));
105        let ret = f();
106        (ret, self.tracker.replace(prev).unwrap())
107    }
108
109    /// Run the update callback of the signal, also recreating any dependencies found by
110    /// tracking signal accesses inside the function.
111    ///
112    /// Also marks all the dependencies as dirty and marks the current node as clean.
113    ///
114    /// # Params
115    /// * `root` - The reactive root.
116    /// * `id` - The id associated with the reactive node. `SignalId` inside the state itself.
117    fn run_node_update(&'static self, current: NodeId) {
118        debug_assert_eq!(
119            self.nodes.borrow()[current].state,
120            NodeState::Dirty,
121            "should only update when dirty"
122        );
123        // Remove old dependency links.
124        let dependencies = std::mem::take(&mut self.nodes.borrow_mut()[current].dependencies);
125        for dependency in dependencies {
126            self.nodes.borrow_mut()[dependency]
127                .dependents
128                .retain(|&id| id != current);
129        }
130        // We take the callback out because that requires a mut ref and we cannot hold that while
131        // running update itself.
132        let mut nodes_mut = self.nodes.borrow_mut();
133        let mut callback = nodes_mut[current].callback.take().unwrap();
134        let mut value = nodes_mut[current].value.take().unwrap();
135        drop(nodes_mut); // End RefMut borrow.
136
137        NodeHandle(current, self).dispose_children(); // Destroy anything created in a previous update.
138
139        let prev = self.current_node.replace(current);
140        let (changed, tracker) = self.tracked_scope(|| callback(&mut value));
141        self.current_node.set(prev);
142
143        tracker.create_dependency_link(self, current);
144
145        let mut nodes_mut = self.nodes.borrow_mut();
146        nodes_mut[current].callback = Some(callback); // Put the callback back in.
147        nodes_mut[current].value = Some(value);
148
149        // Mark this node as clean.
150        nodes_mut[current].state = NodeState::Clean;
151        drop(nodes_mut);
152
153        if changed {
154            self.mark_dependents_dirty(current);
155        }
156    }
157
158    // Mark any dependent node of the current node as dirty.
159    fn mark_dependents_dirty(&self, current: NodeId) {
160        let mut nodes_mut = self.nodes.borrow_mut();
161        let dependents = std::mem::take(&mut nodes_mut[current].dependents);
162        for &dependent in &dependents {
163            if let Some(dependent) = nodes_mut.get_mut(dependent) {
164                dependent.state = NodeState::Dirty;
165            }
166        }
167        nodes_mut[current].dependents = dependents;
168    }
169
170    /// If there are no cyclic dependencies, then the reactive graph is a DAG (Directed Acyclic
171    /// Graph). We can therefore use DFS to get a topological sorting of all the reactive nodes.
172    ///
173    /// We then go through every node in this topological sorting and update only those nodes which
174    /// have dependencies that were updated.
175    fn propagate_node_updates(&'static self, start_nodes: &[NodeId]) {
176        // Try to reuse the shared buffer if possible.
177        let mut rev_sorted = Vec::new();
178        let mut rev_sorted_buf = self.rev_sorted_buf.try_borrow_mut();
179        let rev_sorted = if let Ok(rev_sorted_buf) = rev_sorted_buf.as_mut() {
180            rev_sorted_buf.clear();
181            rev_sorted_buf
182        } else {
183            &mut rev_sorted
184        };
185
186        // Traverse reactive graph.
187        for &node in start_nodes {
188            Self::dfs(node, &mut self.nodes.borrow_mut(), rev_sorted);
189            self.mark_dependents_dirty(node);
190        }
191
192        for &node in rev_sorted.iter().rev() {
193            let mut nodes_mut = self.nodes.borrow_mut();
194            // Only run if node is still alive.
195            if nodes_mut.get(node).is_none() {
196                continue;
197            }
198            let node_state = &mut nodes_mut[node];
199            node_state.mark = Mark::None; // Reset value.
200
201            // Check if this node needs to be updated.
202            if nodes_mut[node].state == NodeState::Dirty {
203                drop(nodes_mut); // End RefMut borrow.
204                self.run_node_update(node)
205            };
206        }
207    }
208
209    /// Call this if `start_node` has been updated manually. This will automatically update all
210    /// signals that depend on `start_node`.
211    ///
212    /// If we are currently batching, defers updating the signal until the end of the batch.
213    pub fn propagate_updates(&'static self, start_node: NodeId) {
214        if self.batching.get() {
215            self.node_update_queue.borrow_mut().push(start_node);
216        } else {
217            // Set the global root.
218            let prev = Root::set_global(Some(self));
219            // Propagate any signal updates.
220            self.propagate_node_updates(&[start_node]);
221            Root::set_global(prev);
222        }
223    }
224
225    /// Run depth-first-search on the reactive graph starting at `current`.
226    fn dfs(current_id: NodeId, nodes: &mut SlotMap<NodeId, ReactiveNode>, buf: &mut Vec<NodeId>) {
227        let Some(current) = nodes.get_mut(current_id) else {
228            // If signal is dead, don't even visit it.
229            return;
230        };
231
232        match current.mark {
233            Mark::Temp => panic!("cyclic reactive dependency"),
234            Mark::Permanent => return,
235            Mark::None => {}
236        }
237        current.mark = Mark::Temp;
238
239        // Take the `dependents` field out temporarily to avoid borrow checker.
240        let children = std::mem::take(&mut current.dependents);
241        for child in &children {
242            Self::dfs(*child, nodes, buf);
243        }
244        nodes[current_id].dependents = children;
245
246        nodes[current_id].mark = Mark::Permanent;
247        buf.push(current_id);
248    }
249
250    /// Sets the batch flag to `true`.
251    fn start_batch(&self) {
252        self.batching.set(true);
253    }
254
255    /// Sets the batch flag to `false` and run all the queued effects.
256    fn end_batch(&'static self) {
257        self.batching.set(false);
258        let nodes = self.node_update_queue.take();
259        self.propagate_node_updates(&nodes);
260    }
261}
262
263/// A handle to a root. This lets you reinitialize or dispose the root for resource cleanup.
264///
265/// This is generally obtained from [`create_root`].
266#[derive(Clone, Copy)]
267pub struct RootHandle {
268    _ref: &'static Root,
269}
270
271impl RootHandle {
272    /// Destroy everything that was created in this scope.
273    pub fn dispose(&self) {
274        self._ref.reinit();
275    }
276
277    /// Runs the closure in the current scope of the root.
278    pub fn run_in<T>(&self, f: impl FnOnce() -> T) -> T {
279        let prev = Root::set_global(Some(self._ref));
280        let ret = f();
281        Root::set_global(prev);
282        ret
283    }
284}
285
286/// Tracks nodes that are accessed inside a reactive scope.
287#[derive(Default)]
288pub(crate) struct DependencyTracker {
289    /// A list of reactive nodes that were accessed.
290    pub dependencies: SmallVec<[NodeId; 1]>,
291}
292
293impl DependencyTracker {
294    /// Sets the `dependents` field for all the nodes that have been tracked and updates
295    /// `dependencies` of the `dependent`.
296    pub fn create_dependency_link(self, root: &Root, dependent: NodeId) {
297        for node in &self.dependencies {
298            root.nodes.borrow_mut()[*node].dependents.push(dependent);
299        }
300        // Set the signal dependencies so that it is updated automatically.
301        root.nodes.borrow_mut()[dependent].dependencies = self.dependencies;
302    }
303}
304
305/// Creates a new reactive root with a top-level reactive node. The returned [`RootHandle`] can be
306/// used to [`dispose`](RootHandle::dispose) the root.
307///
308/// # Example
309/// ```rust
310/// # use sycamore_reactive::*;
311///
312/// create_root(|| {
313///     let signal = create_signal(123);
314///
315///     let child_scope = create_child_scope(move || {
316///         // ...
317///     });
318/// });
319/// ```
320#[must_use = "root should be disposed"]
321pub fn create_root(f: impl FnOnce()) -> RootHandle {
322    let _ref = Root::new_static();
323    #[cfg(not(target_arch = "wasm32"))]
324    {
325        /// An unsafe wrapper around a raw pointer which we promise to never touch, effectively
326        /// making it thread-safe.
327        #[allow(dead_code)]
328        struct UnsafeSendPtr<T>(*const T);
329        /// We never ever touch the pointer inside so surely this is safe!
330        unsafe impl<T> Send for UnsafeSendPtr<T> {}
331
332        /// A static variable to keep on holding to the allocated `Root`s to prevent Miri and
333        /// Valgrind from complaining.
334        static KEEP_ALIVE: std::sync::Mutex<Vec<UnsafeSendPtr<Root>>> =
335            std::sync::Mutex::new(Vec::new());
336        KEEP_ALIVE
337            .lock()
338            .unwrap()
339            .push(UnsafeSendPtr(_ref as *const Root));
340    }
341
342    Root::set_global(Some(_ref));
343    NodeHandle(_ref.root_node.get(), _ref).run_in(f);
344    Root::set_global(None);
345    RootHandle { _ref }
346}
347
348/// Create a child scope.
349///
350/// Returns the created [`NodeHandle`] which can be used to dispose it.
351#[cfg_attr(debug_assertions, track_caller)]
352pub fn create_child_scope(f: impl FnOnce()) -> NodeHandle {
353    Root::global().create_child_scope(f)
354}
355
356/// Adds a callback that is called when the scope is destroyed.
357///
358/// # Example
359/// ```rust
360/// # use sycamore_reactive::*;
361/// # create_root(|| {
362/// let child_scope = create_child_scope(|| {
363///     on_cleanup(|| {
364///         println!("Child scope is being dropped");
365///     });
366/// });
367/// child_scope.dispose(); // Executes the on_cleanup callback.
368/// # });
369/// ```
370#[cfg_attr(debug_assertions, track_caller)]
371pub fn on_cleanup(f: impl FnOnce() + 'static) {
372    let root = Root::global();
373    if !root.current_node.get().is_null() {
374        root.nodes.borrow_mut()[root.current_node.get()]
375            .cleanups
376            .push(Box::new(f));
377    }
378}
379
380/// Batch updates from related signals together and only run memos and effects at the end of the
381/// scope.
382///
383/// # Example
384///
385/// ```
386/// # use sycamore_reactive::*;
387/// # let _ = create_root(|| {
388/// let state = create_signal(1);
389/// let double = create_memo(move || state.get() * 2);
390/// batch(move || {
391///     state.set(2);
392///     assert_eq!(double.get(), 2);
393/// });
394/// assert_eq!(double.get(), 4);
395/// # });
396/// ```
397pub fn batch<T>(f: impl FnOnce() -> T) -> T {
398    let root = Root::global();
399    root.start_batch();
400    let ret = f();
401    root.end_batch();
402    ret
403}
404
405/// Run the passed closure inside an untracked dependency scope.
406///
407/// See also [`ReadSignal::get_untracked`].
408///
409/// # Example
410///
411/// ```
412/// # use sycamore_reactive::*;
413/// # create_root(|| {
414/// let state = create_signal(1);
415/// let double = create_memo(move || untrack(|| state.get() * 2));
416/// assert_eq!(double.get(), 2);
417///
418/// state.set(2);
419/// // double value should still be old value because state was untracked
420/// assert_eq!(double.get(), 2);
421/// # });
422/// ```
423pub fn untrack<T>(f: impl FnOnce() -> T) -> T {
424    untrack_in_scope(f, Root::global())
425}
426
427/// Same as [`untrack`] but for a specific [`Root`].
428pub(crate) fn untrack_in_scope<T>(f: impl FnOnce() -> T, root: &'static Root) -> T {
429    let prev = root.tracker.replace(None);
430    let ret = f();
431    root.tracker.replace(prev);
432    ret
433}
434
435/// Get a handle to the current reactive scope.
436pub fn use_current_scope() -> NodeHandle {
437    let root = Root::global();
438    NodeHandle(root.current_node.get(), root)
439}
440
441/// Get a handle to the root reactive scope.
442pub fn use_global_scope() -> NodeHandle {
443    let root = Root::global();
444    NodeHandle(root.root_node.get(), root)
445}
446
447#[cfg(test)]
448mod tests {
449    use crate::*;
450
451    #[test]
452    fn cleanup() {
453        let _ = create_root(|| {
454            let cleanup_called = create_signal(false);
455            let scope = create_child_scope(|| {
456                on_cleanup(move || {
457                    cleanup_called.set(true);
458                });
459            });
460            assert!(!cleanup_called.get());
461            scope.dispose();
462            assert!(cleanup_called.get());
463        });
464    }
465
466    #[test]
467    fn cleanup_in_effect() {
468        let _ = create_root(|| {
469            let trigger = create_signal(());
470
471            let counter = create_signal(0);
472
473            create_effect(move || {
474                trigger.track();
475
476                on_cleanup(move || {
477                    counter.set(counter.get() + 1);
478                });
479            });
480
481            assert_eq!(counter.get(), 0);
482
483            trigger.set(());
484            assert_eq!(counter.get(), 1);
485
486            trigger.set(());
487            assert_eq!(counter.get(), 2);
488        });
489    }
490
491    #[test]
492    fn cleanup_is_untracked() {
493        let _ = create_root(|| {
494            let trigger = create_signal(());
495
496            let counter = create_signal(0);
497
498            create_effect(move || {
499                counter.set(counter.get_untracked() + 1);
500
501                on_cleanup(move || {
502                    trigger.track(); // trigger should not be tracked
503                });
504            });
505
506            assert_eq!(counter.get(), 1);
507
508            trigger.set(());
509            assert_eq!(counter.get(), 1);
510        });
511    }
512
513    #[test]
514    fn batch_memo() {
515        let _ = create_root(|| {
516            let state = create_signal(1);
517            let double = create_memo(move || state.get() * 2);
518            batch(move || {
519                state.set(2);
520                assert_eq!(double.get(), 2);
521            });
522            assert_eq!(double.get(), 4);
523        });
524    }
525
526    #[test]
527    fn batch_updates_effects_at_end() {
528        let _ = create_root(|| {
529            let state1 = create_signal(1);
530            let state2 = create_signal(2);
531            let counter = create_signal(0);
532            create_effect(move || {
533                counter.set(counter.get_untracked() + 1);
534                let _ = state1.get() + state2.get();
535            });
536            assert_eq!(counter.get(), 1);
537            state1.set(2);
538            state2.set(3);
539            assert_eq!(counter.get(), 3);
540            batch(move || {
541                state1.set(3);
542                assert_eq!(counter.get(), 3);
543                state2.set(4);
544                assert_eq!(counter.get(), 3);
545            });
546            assert_eq!(counter.get(), 4);
547        });
548    }
549}