rustenium_core/
events.rs

1use crate::error::{CommandResultError, SessionSendError};
2use crate::{impl_has_method, impl_has_method_getter};
3use rustenium_bidi_commands::session::commands::{
4    SessionSubscribeMethod, SessionUnsubscribeMethod, Subscribe, SubscribeResult,
5    SubscriptionRequest, Unsubscribe, UnsubscribeParameters, UnsubscribeResult,
6};
7use rustenium_bidi_commands::session::types::{
8    Subscription, UnsubscribeByAttributesRequest, UnsubscribeByIDRequest,
9};
10use rustenium_bidi_commands::{
11    BrowsingContextEvent, CommandData, Event, EventData, InputEvent, LogEvent, NetworkEvent,
12    ResultData, ScriptEvent, SessionCommand, SessionResult,
13};
14use std::collections::HashSet;
15use std::future::{Future};
16use std::pin::Pin;
17use std::sync::Arc;
18use std::sync::Mutex as StdMutex;
19use std::{fmt, vec};
20use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
21use tokio::sync::Mutex;
22use tokio::task::JoinHandle;
23
24trait HasMethod {
25    fn get_method(&self) -> String;
26}
27
28trait HasMethodGetter {
29    fn get_method(&self) -> String;
30}
31
32// Some events do not have context, need to specify to macro which has and which do not
33impl_has_method_getter!(
34    EventData,
35    [
36        BrowsingContextEvent,
37        ScriptEvent,
38        NetworkEvent,
39        LogEvent,
40        InputEvent
41    ]
42);
43
44impl_has_method!(
45    BrowsingContextEvent,
46    [
47        ContextCreated,
48        ContextDestroyed,
49        DomContentLoaded,
50        DownloadEnd,
51        DownloadWillBegin,
52        FragmentNavigated,
53        HistoryUpdated,
54        Load,
55        NavigationAborted,
56        NavigationCommitted,
57        NavigationFailed,
58        NavigationStarted,
59        UserPromptClosed,
60        UserPromptOpened
61    ]
62);
63
64impl_has_method!(InputEvent, [FileDialogOpened]);
65
66impl_has_method!(LogEvent, [EntryAdded]);
67
68impl_has_method!(
69    NetworkEvent,
70    [
71        AuthRequired,
72        BeforeRequestSent,
73        FetchError,
74        ResponseCompleted,
75        ResponseStarted
76    ]
77);
78impl_has_method!(ScriptEvent, [Message, RealmCreated, RealmDestroyed]);
79
80type BidiEventHandler = Arc<
81    Mutex<dyn FnMut(Event) -> Pin<Box<dyn Future<Output=()> + Send>> + Send + Sync + 'static>,
82>;
83pub struct BidiEvent {
84    pub id: String,
85    pub events: Vec<String>,
86    pub handler: BidiEventHandler,
87}
88
89impl fmt::Debug for BidiEvent {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        f.debug_struct("BidiEvent")
92            .field("id", &self.id)
93            .field("events", &self.events)
94            .field("handler", &"<BidiEventHandler>")
95            .finish()
96    }
97}
98
99pub trait EventManagement {
100    fn send_event(
101        &mut self,
102        command_data: CommandData,
103    ) -> impl Future<Output = Result<ResultData, SessionSendError>>;
104
105    fn get_bidi_events(&mut self) -> &mut Arc<StdMutex<Vec<BidiEvent>>>;
106
107    fn push_event(&mut self, event: BidiEvent) -> ();
108
109    // I don't know what to do with UserContexts yet
110    fn subscribe_events<F, R>(
111        &mut self,
112        events: HashSet<&str>,
113        mut handler: F,
114        browsing_contexts: Option<Vec<String>>,
115        _user_contexts: Option<Vec<&str>>,
116    ) -> impl Future<Output = Result<Option<SubscribeResult>, CommandResultError>>
117    where
118        F: FnMut(Event) -> R + Send + Sync + 'static,
119        R: Future<Output=()> + Send + 'static,
120    {
121        async move {
122        let browsing_context_strings = match &browsing_contexts {
123            Some(browsing_contexts) => browsing_contexts.clone(),
124            None => vec![],
125        };
126
127        // Optimistically push event before sending to avoid race condition
128        let temp_id = format!("temp_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
129        let bidi_event = BidiEvent {
130            id: temp_id.clone(),
131            events: events
132                .clone()
133                .into_iter()
134                .map(|event| event.to_string())
135                .collect(),
136            handler: Arc::new(Mutex::new(move |event| Box::pin(handler(event)) as Pin<Box<dyn Future<Output=()> + Send>>)),
137        };
138        self.push_event(bidi_event);
139
140        let subscribe_event_command =
141            CommandData::SessionCommand(SessionCommand::Subscribe(Subscribe {
142                method: SessionSubscribeMethod::SessionSubscribe,
143                params: SubscriptionRequest {
144                    events: events
145                        .clone()
146                        .into_iter()
147                        .map(|event| event.to_string())
148                        .collect(),
149                    contexts: if browsing_contexts.is_none() {
150                        None
151                    } else {
152                        Some(browsing_context_strings.clone())
153                    },
154                    user_contexts: None,
155                },
156            }));
157        let event_result = self.send_event(subscribe_event_command).await;
158        match event_result {
159            Ok(ResultData::SessionResult(session_result)) => match session_result {
160                SessionResult::SubscribeResult(subscribe_result) => {
161                    // Update temp ID with actual subscription ID
162                    let mut bidi_events = self.get_bidi_events().lock().unwrap();
163                    if let Some(event) = bidi_events.iter_mut().find(|e| e.id == temp_id) {
164                        event.id = subscribe_result.subscription.clone();
165                    }
166                    Ok(Some(subscribe_result))
167                }
168                _ => {
169                    // Remove on failure
170                    let mut bidi_events = self.get_bidi_events().lock().unwrap();
171                    bidi_events.retain(|e| e.id != temp_id);
172                    Err(CommandResultError::InvalidResultTypeError(
173                        ResultData::SessionResult(session_result),
174                    ))
175                }
176            },
177            Ok(result) => {
178                // Remove on failure
179                let mut bidi_events = self.get_bidi_events().lock().unwrap();
180                bidi_events.retain(|e| e.id != temp_id);
181                Err(CommandResultError::InvalidResultTypeError(result))
182            }
183            Err(e) => {
184                // Remove on failure
185                let mut bidi_events = self.get_bidi_events().lock().unwrap();
186                bidi_events.retain(|e| e.id != temp_id);
187                Err(CommandResultError::SessionSendError(e))
188            }
189        }
190        }
191    }
192
193    /// Add an event handler without sending a subscription command
194    /// Returns the handler ID (either provided or generated)
195    fn add_event_handler<F, R>(
196        &mut self,
197        events: HashSet<&str>,
198        mut handler: F,
199        handler_id: Option<String>,
200    ) -> String
201    where
202        F: FnMut(Event) -> R + Send + Sync + 'static,
203        R: Future<Output=()> + Send + 'static,
204    {
205        let id = handler_id.unwrap_or_else(|| {
206            format!("handler_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos())
207        });
208
209        let bidi_event = BidiEvent {
210            id: id.clone(),
211            events: events
212                .into_iter()
213                .map(|event| event.to_string())
214                .collect(),
215            handler: Arc::new(Mutex::new(move |event| Box::pin(handler(event)) as Pin<Box<dyn Future<Output=()> + Send>>)),
216        };
217        self.push_event(bidi_event);
218
219        id
220    }
221
222    /// Unsubscribe from events by event names
223    fn unsubscribe_events_by_names(
224        &mut self,
225        events: HashSet<&str>,
226    ) -> impl Future<Output = Result<Option<UnsubscribeResult>, CommandResultError>> {
227        async move {
228        let unsubscribe_command =
229            CommandData::SessionCommand(SessionCommand::Unsubscribe(Unsubscribe {
230                method: SessionUnsubscribeMethod::SessionUnsubscribe,
231                params: UnsubscribeParameters::UnsubscribeByAttributesRequest(
232                    UnsubscribeByAttributesRequest {
233                        events: events
234                            .clone()
235                            .into_iter()
236                            .map(|event| event.to_string())
237                            .collect(),
238                    },
239                ),
240            }));
241
242        let event_result = self.send_event(unsubscribe_command).await;
243        match event_result {
244            Ok(ResultData::SessionResult(session_result)) => match session_result {
245                SessionResult::UnsubscribeResult(unsubscribe_result) => {
246                    // Remove the event names from BidiEvents and clean up empty ones
247                    let mut bidi_events = self.get_bidi_events().lock().unwrap();
248
249                    // First, remove matching event names from each BidiEvent
250                    for bidi_event in bidi_events.iter_mut() {
251                        bidi_event.events.retain(|e| !events.contains(e.as_str()));
252                    }
253
254                    // Then remove any BidiEvents that have no events left
255                    bidi_events.retain(|bidi_event| !bidi_event.events.is_empty());
256
257                    Ok(Some(unsubscribe_result))
258                }
259                _ => Err(CommandResultError::InvalidResultTypeError(
260                    ResultData::SessionResult(session_result),
261                )),
262            },
263            Ok(result) => Err(CommandResultError::InvalidResultTypeError(result)),
264            Err(e) => Err(CommandResultError::SessionSendError(e)),
265        }
266        }
267    }
268
269    /// Unsubscribe from events by subscription IDs
270    fn unsubscribe_events_by_ids(
271        &mut self,
272        subscription_ids: Vec<Subscription>,
273    ) -> impl Future<Output = Result<Option<UnsubscribeResult>, CommandResultError>> {
274        async move {
275        let unsubscribe_command =
276            CommandData::SessionCommand(SessionCommand::Unsubscribe(Unsubscribe {
277                method: SessionUnsubscribeMethod::SessionUnsubscribe,
278                params: UnsubscribeParameters::UnsubscribeByIDRequest(UnsubscribeByIDRequest {
279                    subscriptions: subscription_ids.clone(),
280                }),
281            }));
282
283        let event_result = self.send_event(unsubscribe_command).await;
284        match event_result {
285            Ok(ResultData::EmptyResult(empty_result)) => {
286                // Remove the subscriptions from our local tracking
287                let mut bidi_events = self.get_bidi_events().lock().unwrap();
288                bidi_events.retain(| bidi_event | !subscription_ids.contains(&bidi_event.id));
289                Ok(Some(empty_result))
290            },
291            Ok(result) => Err(CommandResultError::InvalidResultTypeError(result)),
292            Err(e) => Err(CommandResultError::SessionSendError(e)),
293        }
294        }
295    }
296
297    fn event_dispatch(&mut self) -> impl Future<Output = (JoinHandle<()>, UnboundedSender<Event>)> {
298        async move {
299        let (tx, mut rx) = unbounded_channel::<Event>();
300        let bidi_events = self.get_bidi_events().clone();
301        (
302            tokio::spawn(async move {
303                while let Some(event) = rx.recv().await {
304                    let event_method = event.event_data.get_method().trim_matches('"').to_string();
305                    // Manually handling context check was abandoned, too much variation/nesting of context
306                    for bidi_event in bidi_events.lock().unwrap().iter() {
307                        if bidi_event.events.contains(&event_method) {
308                            let ch = Arc::clone(&bidi_event.handler);
309                            let ce = event.clone();
310                            tokio::spawn(async move {
311                                (ch.lock().await)(ce).await;
312                            });
313                        }
314                    }
315                }
316            }),
317            tx,
318        )
319        }
320    }
321}