wsio_core/event/
registry.rs

1use std::{
2    any::{
3        Any,
4        TypeId,
5    },
6    collections::hash_map::Entry,
7    marker::PhantomData,
8    pin::Pin,
9    sync::{
10        Arc,
11        LazyLock,
12        atomic::{
13            AtomicU32,
14            Ordering,
15        },
16    },
17};
18
19use anyhow::Result;
20use parking_lot::RwLock;
21use serde::de::DeserializeOwned;
22
23use crate::{
24    packet::codecs::WsIoPacketCodec,
25    traits::task::spawner::TaskSpawner,
26    types::hashers::FxHashMap,
27};
28
29// Types
30type DataDecoder = fn(&[u8], &WsIoPacketCodec) -> Result<Arc<dyn Any + Send + Sync>>;
31type Handler<C> = Arc<
32    dyn Fn(Arc<C>, Arc<dyn Any + Send + Sync>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>
33        + Send
34        + Sync
35        + 'static,
36>;
37
38// Structs
39struct EventEntry<C> {
40    data_decoder: DataDecoder,
41    data_type_id: TypeId,
42    handlers: RwLock<FxHashMap<u32, Handler<C>>>,
43}
44
45pub struct WsIoEventRegistry<C: Send + Sync + 'static, S: TaskSpawner> {
46    _task_spawner: PhantomData<S>,
47    event_entries: RwLock<FxHashMap<String, Arc<EventEntry<C>>>>,
48    next_handler_id: AtomicU32,
49}
50
51impl<C: Send + Sync + 'static, S: TaskSpawner> Default for WsIoEventRegistry<C, S> {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl<C: Send + Sync + 'static, S: TaskSpawner> WsIoEventRegistry<C, S> {
58    #[inline]
59    pub fn new() -> Self {
60        Self {
61            _task_spawner: PhantomData,
62            event_entries: RwLock::new(FxHashMap::default()),
63            next_handler_id: AtomicU32::new(0),
64        }
65    }
66
67    // Public methods
68    #[inline]
69    pub fn dispatch_event_packet(
70        &self,
71        ctx: Arc<C>,
72        event: &str,
73        packet_codec: &WsIoPacketCodec,
74        packet_data: Option<Vec<u8>>,
75        task_spawner: &Arc<S>,
76    ) {
77        let Some(event_entry) = self.event_entries.read().get(event).cloned() else {
78            return;
79        };
80
81        let packet_codec = *packet_codec;
82        let task_spawner_clone = task_spawner.clone();
83        task_spawner.spawn_task(async move {
84            let data = match packet_data {
85                Some(bytes) => match (event_entry.data_decoder)(&bytes, &packet_codec) {
86                    Ok(data) => data,
87                    Err(_) => return Ok(()),
88                },
89                None => EMPTY_EVENT_DATA_ANY_ARC.clone(),
90            };
91
92            let handlers = event_entry.handlers.read().values().cloned().collect::<Vec<_>>();
93            for handler in handlers {
94                let ctx = ctx.clone();
95                let data = data.clone();
96                task_spawner_clone.spawn_task(handler(ctx, data));
97            }
98
99            Ok(())
100        });
101    }
102
103    #[inline]
104    pub fn off(&self, event: &str) {
105        self.event_entries.write().remove(event);
106    }
107
108    #[inline]
109    pub fn off_by_handler_id(&self, event: &str, handler_id: u32) {
110        if let Some(event_entry) = self.event_entries.write().get(event).cloned() {
111            event_entry.handlers.write().remove(&handler_id);
112            if event_entry.handlers.read().is_empty() {
113                self.event_entries.write().remove(event);
114            }
115        }
116    }
117
118    #[inline]
119    pub fn on<H, Fut, D>(&self, event: &str, handler: H) -> u32
120    where
121        H: Fn(Arc<C>, Arc<D>) -> Fut + Send + Sync + 'static,
122        Fut: Future<Output = Result<()>> + Send + 'static,
123        D: DeserializeOwned + Send + Sync + 'static,
124    {
125        let data_type_id = TypeId::of::<D>();
126
127        let mut event_entries = self.event_entries.write();
128        let event_entry = match event_entries.entry(event.into()) {
129            Entry::Occupied(occupied) => {
130                let event_entry = occupied.into_mut();
131                assert_eq!(
132                    event_entry.data_type_id, data_type_id,
133                    "Event '{}' already registered with a different data type — each event name must correspond to exactly one payload type.",
134                    event
135                );
136
137                event_entry
138            }
139            Entry::Vacant(vacant) => vacant.insert(Arc::new(EventEntry {
140                data_decoder: decode_data_as_any_arc::<D>,
141                data_type_id,
142                handlers: RwLock::new(FxHashMap::default()),
143            })),
144        };
145
146        let handler_id = self.next_handler_id.fetch_add(1, Ordering::Relaxed);
147        event_entry.handlers.write().insert(
148            handler_id,
149            Arc::new(move |connection, data| {
150                if (*data).type_id() != data_type_id {
151                    return Box::pin(async { Ok(()) });
152                }
153
154                Box::pin(handler(connection, data.downcast().unwrap()))
155            }),
156        );
157
158        handler_id
159    }
160}
161
162// Constants/Statics
163static EMPTY_EVENT_DATA_ANY_ARC: LazyLock<Arc<dyn Any + Send + Sync>> = LazyLock::new(|| Arc::new(()));
164
165// Functions
166#[inline]
167fn decode_data_as_any_arc<D: DeserializeOwned + Send + Sync + 'static>(
168    bytes: &[u8],
169    packet_codec: &WsIoPacketCodec,
170) -> Result<Arc<dyn Any + Send + Sync>> {
171    Ok(Arc::new(packet_codec.decode_data::<D>(bytes)?))
172}