1use crate::event::{ComponentId, Event, EventContext, EventKind, EventType};
4use crate::Action;
5use crossterm::event::{self, KeyModifiers, MouseEventKind};
6use std::collections::{HashMap, HashSet};
7use std::time::Duration;
8use tokio::sync::mpsc;
9use tokio_util::sync::CancellationToken;
10use tracing::{debug, info};
11
12#[derive(Debug)]
14pub enum RawEvent {
15 Key(crossterm::event::KeyEvent),
16 Mouse(crossterm::event::MouseEvent),
17 Resize(u16, u16),
18}
19
20pub struct EventBus<A: Action, C: ComponentId> {
26 subscriptions: HashMap<EventType, HashSet<C>>,
28 context: EventContext<C>,
30 action_tx: mpsc::UnboundedSender<A>,
32}
33
34impl<A: Action, C: ComponentId> EventBus<A, C> {
35 pub fn new(action_tx: mpsc::UnboundedSender<A>) -> Self {
37 Self {
38 subscriptions: HashMap::new(),
39 context: EventContext::default(),
40 action_tx,
41 }
42 }
43
44 pub fn subscribe(&mut self, component: C, event_type: EventType) {
46 self.subscriptions
47 .entry(event_type)
48 .or_default()
49 .insert(component);
50 }
51
52 pub fn subscribe_many(&mut self, component: C, event_types: &[EventType]) {
54 for &event_type in event_types {
55 self.subscribe(component, event_type);
56 }
57 }
58
59 pub fn unsubscribe(&mut self, component: C, event_type: EventType) {
61 if let Some(subscribers) = self.subscriptions.get_mut(&event_type) {
62 subscribers.remove(&component);
63 }
64 }
65
66 pub fn unsubscribe_all(&mut self, component: C) {
68 for subscribers in self.subscriptions.values_mut() {
69 subscribers.remove(&component);
70 }
71 }
72
73 pub fn get_subscribers(&self, event_type: EventType) -> Vec<C> {
75 self.subscriptions
76 .get(&event_type)
77 .map(|s| s.iter().copied().collect())
78 .unwrap_or_default()
79 }
80
81 pub fn get_event_subscribers(&self, event: &Event<C>) -> Vec<C> {
83 let mut subscribers = HashSet::new();
84
85 if event.is_global() {
87 if let Some(global_subs) = self.subscriptions.get(&EventType::Global) {
88 subscribers.extend(global_subs.iter().copied());
89 }
90 }
91
92 if let Some(type_subs) = self.subscriptions.get(&event.event_type()) {
94 subscribers.extend(type_subs.iter().copied());
95 }
96
97 subscribers.into_iter().collect()
98 }
99
100 pub fn context_mut(&mut self) -> &mut EventContext<C> {
102 &mut self.context
103 }
104
105 pub fn context(&self) -> &EventContext<C> {
107 &self.context
108 }
109
110 pub fn create_event(&self, kind: EventKind) -> Event<C> {
112 Event::new(kind, self.context.clone())
113 }
114
115 pub fn action_tx(&self) -> &mpsc::UnboundedSender<A> {
117 &self.action_tx
118 }
119
120 pub fn send(&self, action: A) -> Result<(), mpsc::error::SendError<A>> {
122 self.action_tx.send(action)
123 }
124
125 pub fn update_mouse_position(&mut self, x: u16, y: u16) {
127 self.context.mouse_position = Some((x, y));
128 }
129
130 pub fn update_modifiers(&mut self, modifiers: KeyModifiers) {
132 self.context.modifiers = modifiers;
133 }
134}
135
136pub fn spawn_event_poller(
147 tx: mpsc::UnboundedSender<RawEvent>,
148 poll_timeout: Duration,
149 loop_sleep: Duration,
150 cancel_token: CancellationToken,
151) -> tokio::task::JoinHandle<()> {
152 tokio::spawn(async move {
153 const MAX_EVENTS_PER_BATCH: usize = 20;
154
155 loop {
156 tokio::select! {
157 _ = cancel_token.cancelled() => {
158 info!("Event poller cancelled, draining buffer");
159 while event::poll(Duration::ZERO).unwrap_or(false) {
161 let _ = event::read();
162 }
163 break;
164 }
165 _ = tokio::time::sleep(loop_sleep) => {
166 let mut events_processed = 0;
168 while events_processed < MAX_EVENTS_PER_BATCH
169 && event::poll(poll_timeout).unwrap_or(false)
170 {
171 events_processed += 1;
172 if let Ok(evt) = event::read() {
173 let raw = match evt {
174 event::Event::Key(key) => Some(RawEvent::Key(key)),
175 event::Event::Mouse(mouse) => Some(RawEvent::Mouse(mouse)),
176 event::Event::Resize(w, h) => Some(RawEvent::Resize(w, h)),
177 _ => None,
178 };
179 if let Some(raw) = raw {
180 if tx.send(raw).is_err() {
181 debug!("Event channel closed, stopping poller");
182 return;
183 }
184 }
185 }
186 }
187 }
188 }
189 }
190 })
191}
192
193pub fn process_raw_event(raw: RawEvent) -> EventKind {
195 match raw {
196 RawEvent::Key(key) => EventKind::Key(key),
197 RawEvent::Mouse(mouse) => match mouse.kind {
198 MouseEventKind::ScrollDown => EventKind::Scroll {
199 column: mouse.column,
200 row: mouse.row,
201 delta: 1,
202 },
203 MouseEventKind::ScrollUp => EventKind::Scroll {
204 column: mouse.column,
205 row: mouse.row,
206 delta: -1,
207 },
208 _ => EventKind::Mouse(mouse),
209 },
210 RawEvent::Resize(w, h) => EventKind::Resize(w, h),
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::event::NumericComponentId;
218
219 #[derive(Clone, Debug)]
220 #[allow(dead_code)]
221 enum TestAction {
222 Test,
223 }
224
225 impl Action for TestAction {
226 fn name(&self) -> &'static str {
227 "Test"
228 }
229 }
230
231 #[test]
232 fn test_subscribe_unsubscribe() {
233 let (tx, _rx) = mpsc::unbounded_channel();
234 let mut bus: EventBus<TestAction, NumericComponentId> = EventBus::new(tx);
235
236 let component = NumericComponentId(1);
237 bus.subscribe(component, EventType::Key);
238
239 assert_eq!(bus.get_subscribers(EventType::Key), vec![component]);
240
241 bus.unsubscribe(component, EventType::Key);
242 assert!(bus.get_subscribers(EventType::Key).is_empty());
243 }
244
245 #[test]
246 fn test_subscribe_many() {
247 let (tx, _rx) = mpsc::unbounded_channel();
248 let mut bus: EventBus<TestAction, NumericComponentId> = EventBus::new(tx);
249
250 let component = NumericComponentId(1);
251 bus.subscribe_many(component, &[EventType::Key, EventType::Mouse]);
252
253 assert_eq!(bus.get_subscribers(EventType::Key), vec![component]);
254 assert_eq!(bus.get_subscribers(EventType::Mouse), vec![component]);
255 }
256
257 #[test]
258 fn test_unsubscribe_all() {
259 let (tx, _rx) = mpsc::unbounded_channel();
260 let mut bus: EventBus<TestAction, NumericComponentId> = EventBus::new(tx);
261
262 let component = NumericComponentId(1);
263 bus.subscribe_many(
264 component,
265 &[EventType::Key, EventType::Mouse, EventType::Scroll],
266 );
267
268 bus.unsubscribe_all(component);
269
270 assert!(bus.get_subscribers(EventType::Key).is_empty());
271 assert!(bus.get_subscribers(EventType::Mouse).is_empty());
272 assert!(bus.get_subscribers(EventType::Scroll).is_empty());
273 }
274
275 #[test]
276 fn test_process_raw_event_key() {
277 use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, KeyEventState, KeyModifiers};
278
279 let key_event = KeyEvent {
280 code: KeyCode::Char('a'),
281 modifiers: KeyModifiers::NONE,
282 kind: KeyEventKind::Press,
283 state: KeyEventState::empty(),
284 };
285
286 let kind = process_raw_event(RawEvent::Key(key_event));
287 assert!(matches!(kind, EventKind::Key(_)));
288 }
289
290 #[test]
291 fn test_process_raw_event_scroll() {
292 use crossterm::event::{MouseEvent, MouseEventKind};
293
294 let scroll_down = MouseEvent {
295 kind: MouseEventKind::ScrollDown,
296 column: 10,
297 row: 20,
298 modifiers: KeyModifiers::NONE,
299 };
300
301 let kind = process_raw_event(RawEvent::Mouse(scroll_down));
302 match kind {
303 EventKind::Scroll { column, row, delta } => {
304 assert_eq!(column, 10);
305 assert_eq!(row, 20);
306 assert_eq!(delta, 1);
307 }
308 _ => panic!("Expected Scroll event"),
309 }
310 }
311
312 #[test]
313 fn test_process_raw_event_resize() {
314 let kind = process_raw_event(RawEvent::Resize(80, 24));
315 assert!(matches!(kind, EventKind::Resize(80, 24)));
316 }
317}