viewpoint_core/network/handler/
mod.rs

1//! Route handler registry and dispatch.
2
3mod constructors;
4
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use viewpoint_cdp::CdpConnection;
11use viewpoint_cdp::protocol::fetch::{AuthRequiredEvent, RequestPausedEvent};
12
13use super::auth::{AuthHandler, HttpCredentials};
14use super::handler_fetch::{disable_fetch, enable_fetch};
15use super::handler_request::{continue_request, create_route_from_event};
16use super::route::Route;
17use super::types::{UrlMatcher, UrlPattern};
18use crate::error::NetworkError;
19
20/// A registered route handler.
21struct RegisteredHandler {
22    /// Pattern to match URLs.
23    pattern: Box<dyn UrlMatcher>,
24    /// The handler function.
25    handler: Arc<
26        dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
27            + Send
28            + Sync,
29    >,
30}
31
32impl std::fmt::Debug for RegisteredHandler {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("RegisteredHandler")
35            .field("pattern", &"<pattern>")
36            .field("handler", &"<fn>")
37            .finish()
38    }
39}
40
41/// Route handler registry for a page or context.
42#[derive(Debug)]
43pub struct RouteHandlerRegistry {
44    /// Registered handlers (in reverse order - last registered is first tried).
45    handlers: RwLock<Vec<RegisteredHandler>>,
46    /// CDP connection for sending commands.
47    connection: Arc<CdpConnection>,
48    /// Session ID for CDP commands.
49    session_id: String,
50    /// Whether the Fetch domain is enabled.
51    fetch_enabled: RwLock<bool>,
52    /// HTTP authentication handler.
53    auth_handler: AuthHandler,
54    /// Whether auth handling is enabled.
55    auth_enabled: RwLock<bool>,
56    /// Context-level route registry (for fallback handling).
57    context_routes: Option<Arc<crate::context::routing::ContextRouteRegistry>>,
58}
59
60impl RouteHandlerRegistry {
61    /// Enable Fetch domain if context has routes or auth is enabled.
62    ///
63    /// This should be called after the registry is created to check if there are
64    /// context-level routes that need interception or if HTTP credentials are configured.
65    pub async fn enable_fetch_for_context_routes(&self) -> Result<(), NetworkError> {
66        // Enable if auth is enabled (credentials were provided)
67        let auth_enabled = *self.auth_enabled.read().await;
68        if auth_enabled {
69            self.ensure_fetch_enabled().await?;
70            return Ok(());
71        }
72
73        // Also enable if there are context routes
74        if let Some(ref context_routes) = self.context_routes {
75            if context_routes.has_routes().await {
76                self.ensure_fetch_enabled().await?;
77            }
78        }
79        Ok(())
80    }
81
82    /// Set the context-level route registry.
83    pub fn set_context_routes(
84        &mut self,
85        context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
86    ) {
87        self.context_routes = Some(context_routes);
88    }
89
90    /// Start the background fetch event listener.
91    ///
92    /// This spawns a background task that listens for `Fetch.requestPaused` and
93    /// `Fetch.authRequired` events and dispatches them to the appropriate handlers.
94    ///
95    /// Also listens for context route change notifications to enable Fetch when
96    /// new routes are added to the context after the page was created.
97    ///
98    /// This should be called after creating the registry, passing an Arc reference to self.
99    pub fn start_fetch_listener(self: &Arc<Self>) {
100        let mut events = self.connection.subscribe_events();
101        let session_id = self.session_id.clone();
102        let registry = Arc::clone(self);
103
104        // Subscribe to context route changes if we have context routes
105        let mut route_change_rx = self
106            .context_routes
107            .as_ref()
108            .map(|ctx| ctx.subscribe_route_changes());
109        let registry_for_routes = Arc::clone(self);
110
111        tokio::spawn(async move {
112            loop {
113                tokio::select! {
114                    // Handle CDP events
115                    event_result = events.recv() => {
116                        let Ok(event) = event_result else {
117                            break;
118                        };
119
120                        // Filter for this session
121                        if event.session_id.as_deref() != Some(&session_id) {
122                            continue;
123                        }
124
125                        match event.method.as_str() {
126                            "Fetch.requestPaused" => {
127                                if let Some(params) = &event.params {
128                                    if let Ok(paused_event) = serde_json::from_value::<RequestPausedEvent>(params.clone()) {
129                                        tracing::debug!(
130                                            request_id = %paused_event.request_id,
131                                            url = %paused_event.request.url,
132                                            "Fetch.requestPaused received"
133                                        );
134                                        if let Err(e) = registry.handle_request(&paused_event).await {
135                                            tracing::warn!(
136                                                request_id = %paused_event.request_id,
137                                                error = %e,
138                                                "Failed to handle paused request"
139                                            );
140                                        }
141                                    }
142                                }
143                            }
144                            "Fetch.authRequired" => {
145                                if let Some(params) = &event.params {
146                                    if let Ok(auth_event) = serde_json::from_value::<AuthRequiredEvent>(params.clone()) {
147                                        tracing::debug!(
148                                            request_id = %auth_event.request_id,
149                                            origin = %auth_event.auth_challenge.origin,
150                                            scheme = %auth_event.auth_challenge.scheme,
151                                            "Fetch.authRequired received"
152                                        );
153                                        if let Err(e) = registry.handle_auth_required(&auth_event).await {
154                                            tracing::warn!(
155                                                request_id = %auth_event.request_id,
156                                                error = %e,
157                                                "Failed to handle auth required"
158                                            );
159                                        }
160                                    }
161                                }
162                            }
163                            _ => {}
164                        }
165                    }
166
167                    // Handle context route change notifications
168                    Some(Ok(_notification)) = async {
169                        match route_change_rx.as_mut() {
170                            Some(rx) => Some(rx.recv().await),
171                            None => std::future::pending().await,
172                        }
173                    } => {
174                        // A new route was added to the context - enable Fetch if not already
175                        tracing::debug!("Context route added, ensuring Fetch is enabled");
176                        if let Err(e) = registry_for_routes.ensure_fetch_enabled().await {
177                            tracing::warn!(error = %e, "Failed to enable Fetch after context route added");
178                        }
179                    }
180                }
181            }
182        });
183    }
184
185    /// Set HTTP credentials for authentication.
186    pub async fn set_http_credentials(&self, credentials: HttpCredentials) {
187        self.auth_handler.set_credentials(credentials).await;
188
189        // Enable auth handling if not already enabled
190        let mut auth_enabled = self.auth_enabled.write().await;
191        if !*auth_enabled {
192            *auth_enabled = true;
193            // Re-enable fetch with auth handling if fetch is already enabled
194            drop(auth_enabled);
195            let fetch_enabled = *self.fetch_enabled.read().await;
196            if fetch_enabled {
197                let _ = self.re_enable_fetch_with_auth().await;
198            }
199        }
200    }
201
202    /// Clear HTTP credentials.
203    pub async fn clear_http_credentials(&self) {
204        self.auth_handler.clear_credentials().await;
205        let mut auth_enabled = self.auth_enabled.write().await;
206        *auth_enabled = false;
207    }
208
209    /// Handle an authentication challenge.
210    pub async fn handle_auth_required(
211        &self,
212        event: &AuthRequiredEvent,
213    ) -> Result<(), NetworkError> {
214        self.auth_handler.handle_auth_challenge(event).await?;
215        Ok(())
216    }
217
218    /// Register a route handler for the given pattern.
219    pub async fn route<M, H, Fut>(&self, pattern: M, handler: H) -> Result<(), NetworkError>
220    where
221        M: Into<UrlPattern>,
222        H: Fn(Route) -> Fut + Send + Sync + 'static,
223        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
224    {
225        let pattern = pattern.into();
226
227        // Enable Fetch domain if not already enabled
228        self.ensure_fetch_enabled().await?;
229
230        // Wrap the handler
231        let handler: Arc<
232            dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
233                + Send
234                + Sync,
235        > = Arc::new(move |route| Box::pin(handler(route)));
236
237        // Add to handlers (will be matched in reverse order)
238        let mut handlers = self.handlers.write().await;
239        handlers.push(RegisteredHandler {
240            pattern: Box::new(pattern),
241            handler,
242        });
243
244        Ok(())
245    }
246
247    /// Register a route handler with a predicate function.
248    pub async fn route_predicate<P, H, Fut>(
249        &self,
250        predicate: P,
251        handler: H,
252    ) -> Result<(), NetworkError>
253    where
254        P: Fn(&str) -> bool + Send + Sync + 'static,
255        H: Fn(Route) -> Fut + Send + Sync + 'static,
256        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
257    {
258        // Enable Fetch domain if not already enabled
259        self.ensure_fetch_enabled().await?;
260
261        // Create a matcher from the predicate
262        struct PredicateMatcher<F>(F);
263        impl<F: Fn(&str) -> bool + Send + Sync> UrlMatcher for PredicateMatcher<F> {
264            fn matches(&self, url: &str) -> bool {
265                (self.0)(url)
266            }
267        }
268
269        // Wrap the handler
270        let handler: Arc<
271            dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
272                + Send
273                + Sync,
274        > = Arc::new(move |route| Box::pin(handler(route)));
275
276        // Add to handlers
277        let mut handlers = self.handlers.write().await;
278        handlers.push(RegisteredHandler {
279            pattern: Box::new(PredicateMatcher(predicate)),
280            handler,
281        });
282
283        Ok(())
284    }
285
286    /// Unregister handlers matching the given pattern.
287    pub async fn unroute(&self, pattern: &str) {
288        let mut handlers = self.handlers.write().await;
289
290        // Remove handlers that match this pattern
291        // For simplicity, we match based on glob pattern equality
292        handlers.retain(|h| {
293            // This is a simplification - in a real implementation,
294            // we'd need to compare patterns more thoroughly
295            !h.pattern.matches(pattern)
296        });
297
298        // If no handlers left, disable Fetch domain
299        if handlers.is_empty() {
300            drop(handlers);
301            let _ = self.disable_fetch_domain().await;
302        }
303    }
304
305    /// Unregister all handlers.
306    pub async fn unroute_all(&self) {
307        let mut handlers = self.handlers.write().await;
308        handlers.clear();
309        drop(handlers);
310        let _ = self.disable_fetch_domain().await;
311    }
312
313    /// Handle a paused request by dispatching to matching handlers.
314    ///
315    /// Handlers are tried in reverse order (LIFO). If a handler calls `fallback()`,
316    /// the next matching handler is tried. If no handler handles the request,
317    /// it is continued to the network.
318    pub async fn handle_request(&self, event: &RequestPausedEvent) -> Result<(), NetworkError> {
319        let url = &event.request.url;
320        let handlers = self.handlers.read().await;
321
322        // Collect all matching handlers (in reverse order - LIFO)
323        let matching_handlers: Vec<_> = handlers
324            .iter()
325            .rev()
326            .filter(|h| h.pattern.matches(url))
327            .collect();
328
329        // Try each matching handler in order
330        for handler in &matching_handlers {
331            let route =
332                create_route_from_event(event, self.connection.clone(), self.session_id.clone());
333            let route_check = route.clone();
334
335            // Call the handler (handler takes ownership of route)
336            (handler.handler)(route).await?;
337
338            // Check if the route was actually handled (made a CDP call)
339            if route_check.is_handled().await {
340                return Ok(());
341            }
342            tracing::debug!(
343                request_id = %event.request_id,
344                url = %url,
345                "Handler called fallback, trying next handler"
346            );
347        }
348
349        // Drop page handlers lock before checking context routes
350        drop(handlers);
351
352        // Check context routes as fallback
353        if let Some(ref context_routes) = self.context_routes {
354            let context_handlers = context_routes.find_all_handlers(url).await;
355
356            for handler in context_handlers {
357                let route = create_route_from_event(
358                    event,
359                    self.connection.clone(),
360                    self.session_id.clone(),
361                );
362                let route_check = route.clone();
363
364                handler(route).await?;
365
366                if route_check.is_handled().await {
367                    return Ok(());
368                }
369                tracing::debug!(
370                    request_id = %event.request_id,
371                    url = %url,
372                    "Context handler called fallback, trying next handler"
373                );
374            }
375        }
376
377        // No handler handled the request - continue to the network
378        continue_request(&self.connection, &self.session_id, &event.request_id).await
379    }
380
381    /// Enable the Fetch domain if not already enabled.
382    ///
383    /// This is a public version for use by `ContextRouteRegistry` when
384    /// synchronously enabling Fetch on all pages after a context route is added.
385    pub async fn ensure_fetch_enabled_public(&self) -> Result<(), NetworkError> {
386        self.ensure_fetch_enabled().await
387    }
388
389    /// Enable the Fetch domain if not already enabled.
390    async fn ensure_fetch_enabled(&self) -> Result<(), NetworkError> {
391        let mut enabled = self.fetch_enabled.write().await;
392        if *enabled {
393            return Ok(());
394        }
395
396        let auth_enabled = *self.auth_enabled.read().await;
397        enable_fetch(&self.connection, &self.session_id, auth_enabled).await?;
398        *enabled = true;
399        Ok(())
400    }
401
402    /// Re-enable Fetch domain with auth handling.
403    async fn re_enable_fetch_with_auth(&self) -> Result<(), NetworkError> {
404        // First disable, then re-enable with auth
405        disable_fetch(&self.connection, &self.session_id).await?;
406        enable_fetch(&self.connection, &self.session_id, true).await
407    }
408
409    /// Disable the Fetch domain.
410    async fn disable_fetch_domain(&self) -> Result<(), NetworkError> {
411        let mut enabled = self.fetch_enabled.write().await;
412        if !*enabled {
413            return Ok(());
414        }
415
416        disable_fetch(&self.connection, &self.session_id).await?;
417        *enabled = false;
418        Ok(())
419    }
420}
421
422#[cfg(test)]
423mod tests;