tessera_shard/
lib.rs

1pub mod router;
2pub mod task_handles;
3mod tokio_runtime;
4
5use std::{
6    any::Any,
7    sync::{Arc, OnceLock},
8};
9
10use dashmap::DashMap;
11
12static REGISTRY: OnceLock<ShardRegistry> = OnceLock::new();
13
14/// Trait for shard state that can be auto-injected into `shard component`.
15pub trait ShardState: Any + Send + Sync {}
16
17/// Describes the lifecycle of this ShardState.
18///
19/// The lifecycle of ShardState can be divided into two types:
20///
21/// 1. Application: ShardState exists for the lifetime of the application and
22///    will not be destroyed.
23/// 2. Shard: ShardState's lifecycle matches the navigation target, meaning it
24///    will be destroyed when the page is popped.
25#[derive(Debug, PartialEq, Eq)]
26pub enum ShardStateLifeCycle {
27    /// ShardState exists for the lifetime of the application and will not be
28    /// destroyed.
29    Application,
30    /// ShardState's lifecycle matches the navigation target, meaning it will be
31    /// destroyed when the page is popped.
32    Shard,
33}
34
35impl<T> ShardState for T where T: 'static + Send + Sync + Default {}
36
37pub struct ShardRegistry {
38    shards: DashMap<String, Arc<dyn ShardState>>,
39}
40
41impl ShardRegistry {
42    /// Get the singleton instance of the shard registry.
43    ///
44    /// Should only be called by macro, not manually.
45    pub fn get() -> &'static Self {
46        REGISTRY.get_or_init(|| ShardRegistry {
47            shards: DashMap::new(),
48        })
49    }
50
51    /// Get or initialize and get a shard state, and provide it to the closure
52    /// `f` as `Arc<T>`. The state type must implement `ShardState`.
53    ///
54    /// This function should never be called manually; it should be
55    /// automatically generated by the `#[shard]` macro.
56    ///
57    /// # Safety
58    ///
59    /// This function is unsafe because it uses an evil method to cast `Arc<dyn
60    /// ShardState>` to `Arc<T>`.
61    pub unsafe fn init_or_get<T, F, R>(&self, id: &str, f: F) -> R
62    where
63        T: ShardState + Default + 'static,
64        F: FnOnce(Arc<T>) -> R,
65    {
66        let shard_ref = self
67            .shards
68            .entry(id.to_string())
69            .or_insert_with(|| Arc::new(T::default()));
70
71        // Clone to increase the reference count, ensuring the raw pointer is valid
72        let arc_clone = shard_ref.value().clone();
73
74        // Unsafe cast Arc<dyn ShardState> -> Arc<T>
75        let arc_t = unsafe {
76            let raw_dyn: *const dyn ShardState = Arc::as_ptr(&arc_clone);
77            let raw_t = raw_dyn as *const T;
78            Arc::from_raw(raw_t)
79        };
80
81        // Clone again to return, keeping the reference count
82        let ret = arc_t.clone();
83
84        // Forget arc_t to avoid decreasing the reference count on drop
85        std::mem::forget(arc_t);
86
87        f(ret)
88    }
89}