viewpoint_core/network/handler/
mod.rs

1//! Route handler registry and dispatch.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use tokio::sync::RwLock;
8use viewpoint_cdp::CdpConnection;
9use viewpoint_cdp::protocol::fetch::{AuthRequiredEvent, RequestPausedEvent};
10
11use super::auth::{AuthHandler, HttpCredentials, ProxyCredentials};
12use super::handler_fetch::{disable_fetch, enable_fetch};
13use super::handler_request::{continue_request, create_route_from_event};
14use super::route::Route;
15use super::types::{UrlMatcher, UrlPattern};
16use crate::error::NetworkError;
17
18/// A registered route handler.
19struct RegisteredHandler {
20    /// Pattern to match URLs.
21    pattern: Box<dyn UrlMatcher>,
22    /// The handler function.
23    handler: Arc<
24        dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
25            + Send
26            + Sync,
27    >,
28}
29
30impl std::fmt::Debug for RegisteredHandler {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("RegisteredHandler")
33            .field("pattern", &"<pattern>")
34            .field("handler", &"<fn>")
35            .finish()
36    }
37}
38
39/// Route handler registry for a page or context.
40#[derive(Debug)]
41pub struct RouteHandlerRegistry {
42    /// Registered handlers (in reverse order - last registered is first tried).
43    handlers: RwLock<Vec<RegisteredHandler>>,
44    /// CDP connection for sending commands.
45    connection: Arc<CdpConnection>,
46    /// Session ID for CDP commands.
47    session_id: String,
48    /// Whether the Fetch domain is enabled.
49    fetch_enabled: RwLock<bool>,
50    /// HTTP authentication handler.
51    auth_handler: AuthHandler,
52    /// Whether auth handling is enabled.
53    auth_enabled: RwLock<bool>,
54    /// Context-level route registry (for fallback handling).
55    context_routes: Option<Arc<crate::context::routing::ContextRouteRegistry>>,
56}
57
58impl RouteHandlerRegistry {
59    /// Create a new route handler registry.
60    pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
61        let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
62        Self {
63            handlers: RwLock::new(Vec::new()),
64            connection,
65            session_id,
66            fetch_enabled: RwLock::new(false),
67            auth_handler,
68            auth_enabled: RwLock::new(false),
69            context_routes: None,
70        }
71    }
72
73    /// Create a new route handler registry with HTTP credentials.
74    pub fn with_credentials(
75        connection: Arc<CdpConnection>,
76        session_id: String,
77        credentials: HttpCredentials,
78    ) -> Self {
79        let auth_handler =
80            AuthHandler::with_credentials(connection.clone(), session_id.clone(), credentials);
81        Self {
82            handlers: RwLock::new(Vec::new()),
83            connection,
84            session_id,
85            fetch_enabled: RwLock::new(false),
86            auth_handler,
87            auth_enabled: RwLock::new(true),
88            context_routes: None,
89        }
90    }
91
92    /// Create a new route handler registry with context-level routes.
93    ///
94    /// If `http_credentials` is provided, they will be set on the auth handler
95    /// for handling HTTP authentication challenges.
96    pub fn with_context_routes(
97        connection: Arc<CdpConnection>,
98        session_id: String,
99        context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
100        http_credentials: Option<HttpCredentials>,
101    ) -> Self {
102        Self::with_context_routes_and_proxy(
103            connection,
104            session_id,
105            context_routes,
106            http_credentials,
107            None,
108        )
109    }
110
111    /// Create a new route handler registry with context-level routes and optional proxy credentials.
112    ///
113    /// If `http_credentials` is provided, they will be set on the auth handler
114    /// for handling HTTP authentication challenges.
115    /// If `proxy_credentials` is provided, they will be used for proxy authentication.
116    pub fn with_context_routes_and_proxy(
117        connection: Arc<CdpConnection>,
118        session_id: String,
119        context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
120        http_credentials: Option<HttpCredentials>,
121        proxy_credentials: Option<ProxyCredentials>,
122    ) -> Self {
123        let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
124
125        // Set HTTP credentials if provided
126        if let Some(ref creds) = http_credentials {
127            tracing::debug!(
128                username = %creds.username,
129                has_origin = creds.origin.is_some(),
130                "Setting HTTP credentials on auth handler"
131            );
132            auth_handler.set_credentials_sync(creds.clone());
133        }
134
135        // Set proxy credentials if provided
136        if let Some(ref proxy_creds) = proxy_credentials {
137            tracing::debug!(
138                username = %proxy_creds.username,
139                "Setting proxy credentials on auth handler"
140            );
141            auth_handler.set_proxy_credentials_sync(proxy_creds.clone());
142        }
143
144        // Enable auth if any credentials are provided
145        let auth_enabled = http_credentials.is_some() || proxy_credentials.is_some();
146
147        Self {
148            handlers: RwLock::new(Vec::new()),
149            connection,
150            session_id,
151            fetch_enabled: RwLock::new(false),
152            auth_handler,
153            auth_enabled: RwLock::new(auth_enabled),
154            context_routes: Some(context_routes),
155        }
156    }
157
158    /// Enable Fetch domain if context has routes or auth is enabled.
159    ///
160    /// This should be called after the registry is created to check if there are
161    /// context-level routes that need interception or if HTTP credentials are configured.
162    pub async fn enable_fetch_for_context_routes(&self) -> Result<(), NetworkError> {
163        // Enable if auth is enabled (credentials were provided)
164        let auth_enabled = *self.auth_enabled.read().await;
165        if auth_enabled {
166            self.ensure_fetch_enabled().await?;
167            return Ok(());
168        }
169
170        // Also enable if there are context routes
171        if let Some(ref context_routes) = self.context_routes {
172            if context_routes.has_routes().await {
173                self.ensure_fetch_enabled().await?;
174            }
175        }
176        Ok(())
177    }
178
179    /// Set the context-level route registry.
180    pub fn set_context_routes(
181        &mut self,
182        context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
183    ) {
184        self.context_routes = Some(context_routes);
185    }
186
187    /// Start the background fetch event listener.
188    ///
189    /// This spawns a background task that listens for `Fetch.requestPaused` and
190    /// `Fetch.authRequired` events and dispatches them to the appropriate handlers.
191    ///
192    /// Also listens for context route change notifications to enable Fetch when
193    /// new routes are added to the context after the page was created.
194    ///
195    /// This should be called after creating the registry, passing an Arc reference to self.
196    pub fn start_fetch_listener(self: &Arc<Self>) {
197        let mut events = self.connection.subscribe_events();
198        let session_id = self.session_id.clone();
199        let registry = Arc::clone(self);
200
201        // Subscribe to context route changes if we have context routes
202        let mut route_change_rx = self
203            .context_routes
204            .as_ref()
205            .map(|ctx| ctx.subscribe_route_changes());
206        let registry_for_routes = Arc::clone(self);
207
208        tokio::spawn(async move {
209            loop {
210                tokio::select! {
211                    // Handle CDP events
212                    event_result = events.recv() => {
213                        let Ok(event) = event_result else {
214                            break;
215                        };
216
217                        // Filter for this session
218                        if event.session_id.as_deref() != Some(&session_id) {
219                            continue;
220                        }
221
222                        match event.method.as_str() {
223                            "Fetch.requestPaused" => {
224                                if let Some(params) = &event.params {
225                                    if let Ok(paused_event) = serde_json::from_value::<RequestPausedEvent>(params.clone()) {
226                                        tracing::debug!(
227                                            request_id = %paused_event.request_id,
228                                            url = %paused_event.request.url,
229                                            "Fetch.requestPaused received"
230                                        );
231                                        if let Err(e) = registry.handle_request(&paused_event).await {
232                                            tracing::warn!(
233                                                request_id = %paused_event.request_id,
234                                                error = %e,
235                                                "Failed to handle paused request"
236                                            );
237                                        }
238                                    }
239                                }
240                            }
241                            "Fetch.authRequired" => {
242                                if let Some(params) = &event.params {
243                                    if let Ok(auth_event) = serde_json::from_value::<AuthRequiredEvent>(params.clone()) {
244                                        tracing::debug!(
245                                            request_id = %auth_event.request_id,
246                                            origin = %auth_event.auth_challenge.origin,
247                                            scheme = %auth_event.auth_challenge.scheme,
248                                            "Fetch.authRequired received"
249                                        );
250                                        if let Err(e) = registry.handle_auth_required(&auth_event).await {
251                                            tracing::warn!(
252                                                request_id = %auth_event.request_id,
253                                                error = %e,
254                                                "Failed to handle auth required"
255                                            );
256                                        }
257                                    }
258                                }
259                            }
260                            _ => {}
261                        }
262                    }
263
264                    // Handle context route change notifications
265                    Some(Ok(_notification)) = async {
266                        match route_change_rx.as_mut() {
267                            Some(rx) => Some(rx.recv().await),
268                            None => std::future::pending().await,
269                        }
270                    } => {
271                        // A new route was added to the context - enable Fetch if not already
272                        tracing::debug!("Context route added, ensuring Fetch is enabled");
273                        if let Err(e) = registry_for_routes.ensure_fetch_enabled().await {
274                            tracing::warn!(error = %e, "Failed to enable Fetch after context route added");
275                        }
276                    }
277                }
278            }
279        });
280    }
281
282    /// Set HTTP credentials for authentication.
283    pub async fn set_http_credentials(&self, credentials: HttpCredentials) {
284        self.auth_handler.set_credentials(credentials).await;
285
286        // Enable auth handling if not already enabled
287        let mut auth_enabled = self.auth_enabled.write().await;
288        if !*auth_enabled {
289            *auth_enabled = true;
290            // Re-enable fetch with auth handling if fetch is already enabled
291            drop(auth_enabled);
292            let fetch_enabled = *self.fetch_enabled.read().await;
293            if fetch_enabled {
294                let _ = self.re_enable_fetch_with_auth().await;
295            }
296        }
297    }
298
299    /// Clear HTTP credentials.
300    pub async fn clear_http_credentials(&self) {
301        self.auth_handler.clear_credentials().await;
302        let mut auth_enabled = self.auth_enabled.write().await;
303        *auth_enabled = false;
304    }
305
306    /// Handle an authentication challenge.
307    pub async fn handle_auth_required(
308        &self,
309        event: &AuthRequiredEvent,
310    ) -> Result<(), NetworkError> {
311        self.auth_handler.handle_auth_challenge(event).await?;
312        Ok(())
313    }
314
315    /// Register a route handler for the given pattern.
316    pub async fn route<M, H, Fut>(&self, pattern: M, handler: H) -> Result<(), NetworkError>
317    where
318        M: Into<UrlPattern>,
319        H: Fn(Route) -> Fut + Send + Sync + 'static,
320        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
321    {
322        let pattern = pattern.into();
323
324        // Enable Fetch domain if not already enabled
325        self.ensure_fetch_enabled().await?;
326
327        // Wrap the handler
328        let handler: Arc<
329            dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
330                + Send
331                + Sync,
332        > = Arc::new(move |route| Box::pin(handler(route)));
333
334        // Add to handlers (will be matched in reverse order)
335        let mut handlers = self.handlers.write().await;
336        handlers.push(RegisteredHandler {
337            pattern: Box::new(pattern),
338            handler,
339        });
340
341        Ok(())
342    }
343
344    /// Register a route handler with a predicate function.
345    pub async fn route_predicate<P, H, Fut>(
346        &self,
347        predicate: P,
348        handler: H,
349    ) -> Result<(), NetworkError>
350    where
351        P: Fn(&str) -> bool + Send + Sync + 'static,
352        H: Fn(Route) -> Fut + Send + Sync + 'static,
353        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
354    {
355        // Enable Fetch domain if not already enabled
356        self.ensure_fetch_enabled().await?;
357
358        // Create a matcher from the predicate
359        struct PredicateMatcher<F>(F);
360        impl<F: Fn(&str) -> bool + Send + Sync> UrlMatcher for PredicateMatcher<F> {
361            fn matches(&self, url: &str) -> bool {
362                (self.0)(url)
363            }
364        }
365
366        // Wrap the handler
367        let handler: Arc<
368            dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
369                + Send
370                + Sync,
371        > = Arc::new(move |route| Box::pin(handler(route)));
372
373        // Add to handlers
374        let mut handlers = self.handlers.write().await;
375        handlers.push(RegisteredHandler {
376            pattern: Box::new(PredicateMatcher(predicate)),
377            handler,
378        });
379
380        Ok(())
381    }
382
383    /// Unregister handlers matching the given pattern.
384    pub async fn unroute(&self, pattern: &str) {
385        let mut handlers = self.handlers.write().await;
386
387        // Remove handlers that match this pattern
388        // For simplicity, we match based on glob pattern equality
389        handlers.retain(|h| {
390            // This is a simplification - in a real implementation,
391            // we'd need to compare patterns more thoroughly
392            !h.pattern.matches(pattern)
393        });
394
395        // If no handlers left, disable Fetch domain
396        if handlers.is_empty() {
397            drop(handlers);
398            let _ = self.disable_fetch_domain().await;
399        }
400    }
401
402    /// Unregister all handlers.
403    pub async fn unroute_all(&self) {
404        let mut handlers = self.handlers.write().await;
405        handlers.clear();
406        drop(handlers);
407        let _ = self.disable_fetch_domain().await;
408    }
409
410    /// Handle a paused request by dispatching to matching handlers.
411    ///
412    /// Handlers are tried in reverse order (LIFO). If a handler calls `fallback()`,
413    /// the next matching handler is tried. If no handler handles the request,
414    /// it is continued to the network.
415    pub async fn handle_request(&self, event: &RequestPausedEvent) -> Result<(), NetworkError> {
416        let url = &event.request.url;
417        let handlers = self.handlers.read().await;
418
419        // Collect all matching handlers (in reverse order - LIFO)
420        let matching_handlers: Vec<_> = handlers
421            .iter()
422            .rev()
423            .filter(|h| h.pattern.matches(url))
424            .collect();
425
426        // Try each matching handler in order
427        for handler in &matching_handlers {
428            let route =
429                create_route_from_event(event, self.connection.clone(), self.session_id.clone());
430            let route_check = route.clone();
431
432            // Call the handler (handler takes ownership of route)
433            (handler.handler)(route).await?;
434
435            // Check if the route was actually handled (made a CDP call)
436            if route_check.is_handled().await {
437                return Ok(());
438            }
439            tracing::debug!(
440                request_id = %event.request_id,
441                url = %url,
442                "Handler called fallback, trying next handler"
443            );
444        }
445
446        // Drop page handlers lock before checking context routes
447        drop(handlers);
448
449        // Check context routes as fallback
450        if let Some(ref context_routes) = self.context_routes {
451            let context_handlers = context_routes.find_all_handlers(url).await;
452
453            for handler in context_handlers {
454                let route = create_route_from_event(
455                    event,
456                    self.connection.clone(),
457                    self.session_id.clone(),
458                );
459                let route_check = route.clone();
460
461                handler(route).await?;
462
463                if route_check.is_handled().await {
464                    return Ok(());
465                }
466                tracing::debug!(
467                    request_id = %event.request_id,
468                    url = %url,
469                    "Context handler called fallback, trying next handler"
470                );
471            }
472        }
473
474        // No handler handled the request - continue to the network
475        continue_request(&self.connection, &self.session_id, &event.request_id).await
476    }
477
478    /// Enable the Fetch domain if not already enabled.
479    ///
480    /// This is a public version for use by `ContextRouteRegistry` when
481    /// synchronously enabling Fetch on all pages after a context route is added.
482    pub async fn ensure_fetch_enabled_public(&self) -> Result<(), NetworkError> {
483        self.ensure_fetch_enabled().await
484    }
485
486    /// Enable the Fetch domain if not already enabled.
487    async fn ensure_fetch_enabled(&self) -> Result<(), NetworkError> {
488        let mut enabled = self.fetch_enabled.write().await;
489        if *enabled {
490            return Ok(());
491        }
492
493        let auth_enabled = *self.auth_enabled.read().await;
494        enable_fetch(&self.connection, &self.session_id, auth_enabled).await?;
495        *enabled = true;
496        Ok(())
497    }
498
499    /// Re-enable Fetch domain with auth handling.
500    async fn re_enable_fetch_with_auth(&self) -> Result<(), NetworkError> {
501        // First disable, then re-enable with auth
502        disable_fetch(&self.connection, &self.session_id).await?;
503        enable_fetch(&self.connection, &self.session_id, true).await
504    }
505
506    /// Disable the Fetch domain.
507    async fn disable_fetch_domain(&self) -> Result<(), NetworkError> {
508        let mut enabled = self.fetch_enabled.write().await;
509        if !*enabled {
510            return Ok(());
511        }
512
513        disable_fetch(&self.connection, &self.session_id).await?;
514        *enabled = false;
515        Ok(())
516    }
517}
518
519#[cfg(test)]
520mod tests;