sycamore_reactive/
root.rs1use std::cell::{Cell, RefCell};
4
5use slotmap::{Key, SlotMap};
6use smallvec::SmallVec;
7
8use crate::*;
9
10pub(crate) struct Root {
19 pub tracker: RefCell<Option<DependencyTracker>>,
21 pub rev_sorted_buf: RefCell<Vec<NodeId>>,
24 pub current_node: Cell<NodeId>,
27 pub root_node: Cell<NodeId>,
29 pub nodes: RefCell<SlotMap<NodeId, ReactiveNode>>,
31 pub node_update_queue: RefCell<Vec<NodeId>>,
33 pub batching: Cell<bool>,
36}
37
38thread_local! {
39 static GLOBAL_ROOT: Cell<Option<&'static Root>> = const { Cell::new(None) };
41}
42
43impl Root {
44 #[cfg_attr(debug_assertions, track_caller)]
46 pub fn global() -> &'static Root {
47 GLOBAL_ROOT.with(|root| root.get()).expect("no root found")
48 }
49
50 pub fn set_global(root: Option<&'static Root>) -> Option<&'static Root> {
52 GLOBAL_ROOT.with(|r| r.replace(root))
53 }
54
55 pub fn new_static() -> &'static Self {
57 let this = Self {
58 tracker: RefCell::new(None),
59 rev_sorted_buf: RefCell::new(Vec::new()),
60 current_node: Cell::new(NodeId::null()),
61 root_node: Cell::new(NodeId::null()),
62 nodes: RefCell::new(SlotMap::default()),
63 node_update_queue: RefCell::new(Vec::new()),
64 batching: Cell::new(false),
65 };
66 let _ref = Box::leak(Box::new(this));
67 _ref.reinit();
68 _ref
69 }
70
71 pub fn reinit(&'static self) {
73 NodeHandle(self.root_node.get(), self).dispose();
75
76 let _ = self.tracker.take();
77 let _ = self.rev_sorted_buf.take();
78 let _ = self.node_update_queue.take();
79 let _ = self.current_node.take();
80 let _ = self.root_node.take();
81 let _ = self.nodes.take();
82 self.batching.set(false);
83
84 Root::set_global(Some(self));
86 let root_node = create_child_scope(|| {});
87 Root::set_global(None);
88 self.root_node.set(root_node.0);
89 self.current_node.set(root_node.0);
90 }
91
92 pub fn create_child_scope(&'static self, f: impl FnOnce()) -> NodeHandle {
94 let node = create_signal(()).id;
95 let prev = self.current_node.replace(node);
96 f();
97 self.current_node.set(prev);
98 NodeHandle(node, self)
99 }
100
101 pub fn tracked_scope<T>(&self, f: impl FnOnce() -> T) -> (T, DependencyTracker) {
104 let prev = self.tracker.replace(Some(DependencyTracker::default()));
105 let ret = f();
106 (ret, self.tracker.replace(prev).unwrap())
107 }
108
109 fn run_node_update(&'static self, current: NodeId) {
118 debug_assert_eq!(
119 self.nodes.borrow()[current].state,
120 NodeState::Dirty,
121 "should only update when dirty"
122 );
123 let dependencies = std::mem::take(&mut self.nodes.borrow_mut()[current].dependencies);
125 for dependency in dependencies {
126 self.nodes.borrow_mut()[dependency]
127 .dependents
128 .retain(|&id| id != current);
129 }
130 let mut nodes_mut = self.nodes.borrow_mut();
133 let mut callback = nodes_mut[current].callback.take().unwrap();
134 let mut value = nodes_mut[current].value.take().unwrap();
135 drop(nodes_mut); NodeHandle(current, self).dispose_children(); let prev = self.current_node.replace(current);
140 let (changed, tracker) = self.tracked_scope(|| callback(&mut value));
141 self.current_node.set(prev);
142
143 tracker.create_dependency_link(self, current);
144
145 let mut nodes_mut = self.nodes.borrow_mut();
146 nodes_mut[current].callback = Some(callback); nodes_mut[current].value = Some(value);
148
149 nodes_mut[current].state = NodeState::Clean;
151 drop(nodes_mut);
152
153 if changed {
154 self.mark_dependents_dirty(current);
155 }
156 }
157
158 fn mark_dependents_dirty(&self, current: NodeId) {
160 let mut nodes_mut = self.nodes.borrow_mut();
161 let dependents = std::mem::take(&mut nodes_mut[current].dependents);
162 for &dependent in &dependents {
163 if let Some(dependent) = nodes_mut.get_mut(dependent) {
164 dependent.state = NodeState::Dirty;
165 }
166 }
167 nodes_mut[current].dependents = dependents;
168 }
169
170 fn propagate_node_updates(&'static self, start_nodes: &[NodeId]) {
176 let mut rev_sorted = Vec::new();
178 let mut rev_sorted_buf = self.rev_sorted_buf.try_borrow_mut();
179 let rev_sorted = if let Ok(rev_sorted_buf) = rev_sorted_buf.as_mut() {
180 rev_sorted_buf.clear();
181 rev_sorted_buf
182 } else {
183 &mut rev_sorted
184 };
185
186 for &node in start_nodes {
188 Self::dfs(node, &mut self.nodes.borrow_mut(), rev_sorted);
189 self.mark_dependents_dirty(node);
190 }
191
192 for &node in rev_sorted.iter().rev() {
193 let mut nodes_mut = self.nodes.borrow_mut();
194 if nodes_mut.get(node).is_none() {
196 continue;
197 }
198 let node_state = &mut nodes_mut[node];
199 node_state.mark = Mark::None; if nodes_mut[node].state == NodeState::Dirty {
203 drop(nodes_mut); self.run_node_update(node)
205 };
206 }
207 }
208
209 pub fn propagate_updates(&'static self, start_node: NodeId) {
214 if self.batching.get() {
215 self.node_update_queue.borrow_mut().push(start_node);
216 } else {
217 let prev = Root::set_global(Some(self));
219 self.propagate_node_updates(&[start_node]);
221 Root::set_global(prev);
222 }
223 }
224
225 fn dfs(current_id: NodeId, nodes: &mut SlotMap<NodeId, ReactiveNode>, buf: &mut Vec<NodeId>) {
227 let Some(current) = nodes.get_mut(current_id) else {
228 return;
230 };
231
232 match current.mark {
233 Mark::Temp => panic!("cyclic reactive dependency"),
234 Mark::Permanent => return,
235 Mark::None => {}
236 }
237 current.mark = Mark::Temp;
238
239 let children = std::mem::take(&mut current.dependents);
241 for child in &children {
242 Self::dfs(*child, nodes, buf);
243 }
244 nodes[current_id].dependents = children;
245
246 nodes[current_id].mark = Mark::Permanent;
247 buf.push(current_id);
248 }
249
250 fn start_batch(&self) {
252 self.batching.set(true);
253 }
254
255 fn end_batch(&'static self) {
257 self.batching.set(false);
258 let nodes = self.node_update_queue.take();
259 self.propagate_node_updates(&nodes);
260 }
261}
262
263#[derive(Clone, Copy)]
267pub struct RootHandle {
268 _ref: &'static Root,
269}
270
271impl RootHandle {
272 pub fn dispose(&self) {
274 self._ref.reinit();
275 }
276
277 pub fn run_in<T>(&self, f: impl FnOnce() -> T) -> T {
279 let prev = Root::set_global(Some(self._ref));
280 let ret = f();
281 Root::set_global(prev);
282 ret
283 }
284}
285
286#[derive(Default)]
288pub(crate) struct DependencyTracker {
289 pub dependencies: SmallVec<[NodeId; 1]>,
291}
292
293impl DependencyTracker {
294 pub fn create_dependency_link(self, root: &Root, dependent: NodeId) {
297 for node in &self.dependencies {
298 root.nodes.borrow_mut()[*node].dependents.push(dependent);
299 }
300 root.nodes.borrow_mut()[dependent].dependencies = self.dependencies;
302 }
303}
304
305#[must_use = "root should be disposed"]
321pub fn create_root(f: impl FnOnce()) -> RootHandle {
322 let _ref = Root::new_static();
323 #[cfg(not(target_arch = "wasm32"))]
324 {
325 #[allow(dead_code)]
328 struct UnsafeSendPtr<T>(*const T);
329 unsafe impl<T> Send for UnsafeSendPtr<T> {}
331
332 static KEEP_ALIVE: std::sync::Mutex<Vec<UnsafeSendPtr<Root>>> =
335 std::sync::Mutex::new(Vec::new());
336 KEEP_ALIVE
337 .lock()
338 .unwrap()
339 .push(UnsafeSendPtr(_ref as *const Root));
340 }
341
342 Root::set_global(Some(_ref));
343 NodeHandle(_ref.root_node.get(), _ref).run_in(f);
344 Root::set_global(None);
345 RootHandle { _ref }
346}
347
348#[cfg_attr(debug_assertions, track_caller)]
352pub fn create_child_scope(f: impl FnOnce()) -> NodeHandle {
353 Root::global().create_child_scope(f)
354}
355
356#[cfg_attr(debug_assertions, track_caller)]
371pub fn on_cleanup(f: impl FnOnce() + 'static) {
372 let root = Root::global();
373 if !root.current_node.get().is_null() {
374 root.nodes.borrow_mut()[root.current_node.get()]
375 .cleanups
376 .push(Box::new(f));
377 }
378}
379
380pub fn batch<T>(f: impl FnOnce() -> T) -> T {
398 let root = Root::global();
399 root.start_batch();
400 let ret = f();
401 root.end_batch();
402 ret
403}
404
405pub fn untrack<T>(f: impl FnOnce() -> T) -> T {
424 untrack_in_scope(f, Root::global())
425}
426
427pub(crate) fn untrack_in_scope<T>(f: impl FnOnce() -> T, root: &'static Root) -> T {
429 let prev = root.tracker.replace(None);
430 let ret = f();
431 root.tracker.replace(prev);
432 ret
433}
434
435pub fn use_current_scope() -> NodeHandle {
437 let root = Root::global();
438 NodeHandle(root.current_node.get(), root)
439}
440
441pub fn use_global_scope() -> NodeHandle {
443 let root = Root::global();
444 NodeHandle(root.root_node.get(), root)
445}
446
447#[cfg(test)]
448mod tests {
449 use crate::*;
450
451 #[test]
452 fn cleanup() {
453 let _ = create_root(|| {
454 let cleanup_called = create_signal(false);
455 let scope = create_child_scope(|| {
456 on_cleanup(move || {
457 cleanup_called.set(true);
458 });
459 });
460 assert!(!cleanup_called.get());
461 scope.dispose();
462 assert!(cleanup_called.get());
463 });
464 }
465
466 #[test]
467 fn cleanup_in_effect() {
468 let _ = create_root(|| {
469 let trigger = create_signal(());
470
471 let counter = create_signal(0);
472
473 create_effect(move || {
474 trigger.track();
475
476 on_cleanup(move || {
477 counter.set(counter.get() + 1);
478 });
479 });
480
481 assert_eq!(counter.get(), 0);
482
483 trigger.set(());
484 assert_eq!(counter.get(), 1);
485
486 trigger.set(());
487 assert_eq!(counter.get(), 2);
488 });
489 }
490
491 #[test]
492 fn cleanup_is_untracked() {
493 let _ = create_root(|| {
494 let trigger = create_signal(());
495
496 let counter = create_signal(0);
497
498 create_effect(move || {
499 counter.set(counter.get_untracked() + 1);
500
501 on_cleanup(move || {
502 trigger.track(); });
504 });
505
506 assert_eq!(counter.get(), 1);
507
508 trigger.set(());
509 assert_eq!(counter.get(), 1);
510 });
511 }
512
513 #[test]
514 fn batch_memo() {
515 let _ = create_root(|| {
516 let state = create_signal(1);
517 let double = create_memo(move || state.get() * 2);
518 batch(move || {
519 state.set(2);
520 assert_eq!(double.get(), 2);
521 });
522 assert_eq!(double.get(), 4);
523 });
524 }
525
526 #[test]
527 fn batch_updates_effects_at_end() {
528 let _ = create_root(|| {
529 let state1 = create_signal(1);
530 let state2 = create_signal(2);
531 let counter = create_signal(0);
532 create_effect(move || {
533 counter.set(counter.get_untracked() + 1);
534 let _ = state1.get() + state2.get();
535 });
536 assert_eq!(counter.get(), 1);
537 state1.set(2);
538 state2.set(3);
539 assert_eq!(counter.get(), 3);
540 batch(move || {
541 state1.set(3);
542 assert_eq!(counter.get(), 3);
543 state2.set(4);
544 assert_eq!(counter.get(), 3);
545 });
546 assert_eq!(counter.get(), 4);
547 });
548 }
549}