viewpoint_core/network/handler/
mod.rs

1//! Route handler registry and dispatch.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use tokio::sync::RwLock;
8use viewpoint_cdp::CdpConnection;
9use viewpoint_cdp::protocol::fetch::{AuthRequiredEvent, RequestPausedEvent};
10
11use super::auth::{AuthHandler, HttpCredentials};
12use super::handler_fetch::{disable_fetch, enable_fetch};
13use super::handler_request::{continue_request, create_route_from_event};
14use super::route::Route;
15use super::types::{UrlMatcher, UrlPattern};
16use crate::error::NetworkError;
17
18/// A registered route handler.
19struct RegisteredHandler {
20    /// Pattern to match URLs.
21    pattern: Box<dyn UrlMatcher>,
22    /// The handler function.
23    handler: Arc<
24        dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
25            + Send
26            + Sync,
27    >,
28}
29
30impl std::fmt::Debug for RegisteredHandler {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("RegisteredHandler")
33            .field("pattern", &"<pattern>")
34            .field("handler", &"<fn>")
35            .finish()
36    }
37}
38
39/// Route handler registry for a page or context.
40#[derive(Debug)]
41pub struct RouteHandlerRegistry {
42    /// Registered handlers (in reverse order - last registered is first tried).
43    handlers: RwLock<Vec<RegisteredHandler>>,
44    /// CDP connection for sending commands.
45    connection: Arc<CdpConnection>,
46    /// Session ID for CDP commands.
47    session_id: String,
48    /// Whether the Fetch domain is enabled.
49    fetch_enabled: RwLock<bool>,
50    /// HTTP authentication handler.
51    auth_handler: AuthHandler,
52    /// Whether auth handling is enabled.
53    auth_enabled: RwLock<bool>,
54    /// Context-level route registry (for fallback handling).
55    context_routes: Option<Arc<crate::context::routing::ContextRouteRegistry>>,
56}
57
58impl RouteHandlerRegistry {
59    /// Create a new route handler registry.
60    pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
61        let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
62        Self {
63            handlers: RwLock::new(Vec::new()),
64            connection,
65            session_id,
66            fetch_enabled: RwLock::new(false),
67            auth_handler,
68            auth_enabled: RwLock::new(false),
69            context_routes: None,
70        }
71    }
72
73    /// Create a new route handler registry with HTTP credentials.
74    pub fn with_credentials(
75        connection: Arc<CdpConnection>,
76        session_id: String,
77        credentials: HttpCredentials,
78    ) -> Self {
79        let auth_handler =
80            AuthHandler::with_credentials(connection.clone(), session_id.clone(), credentials);
81        Self {
82            handlers: RwLock::new(Vec::new()),
83            connection,
84            session_id,
85            fetch_enabled: RwLock::new(false),
86            auth_handler,
87            auth_enabled: RwLock::new(true),
88            context_routes: None,
89        }
90    }
91
92    /// Create a new route handler registry with context-level routes.
93    ///
94    /// If `http_credentials` is provided, they will be set on the auth handler
95    /// for handling HTTP authentication challenges.
96    pub fn with_context_routes(
97        connection: Arc<CdpConnection>,
98        session_id: String,
99        context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
100        http_credentials: Option<HttpCredentials>,
101    ) -> Self {
102        let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
103
104        // If credentials are provided, set them on the auth handler
105        if let Some(ref creds) = http_credentials {
106            // We need to set credentials synchronously during construction.
107            // AuthHandler stores credentials in an Arc<RwLock>, so we can set them directly.
108            tracing::debug!(
109                username = %creds.username,
110                has_origin = creds.origin.is_some(),
111                "Setting HTTP credentials on auth handler"
112            );
113            auth_handler.set_credentials_sync(creds.clone());
114        }
115
116        Self {
117            handlers: RwLock::new(Vec::new()),
118            connection,
119            session_id,
120            fetch_enabled: RwLock::new(false),
121            auth_handler,
122            // Enable auth if credentials are provided
123            auth_enabled: RwLock::new(http_credentials.is_some()),
124            context_routes: Some(context_routes),
125        }
126    }
127
128    /// Enable Fetch domain if context has routes or auth is enabled.
129    ///
130    /// This should be called after the registry is created to check if there are
131    /// context-level routes that need interception or if HTTP credentials are configured.
132    pub async fn enable_fetch_for_context_routes(&self) -> Result<(), NetworkError> {
133        // Enable if auth is enabled (credentials were provided)
134        let auth_enabled = *self.auth_enabled.read().await;
135        if auth_enabled {
136            self.ensure_fetch_enabled().await?;
137            return Ok(());
138        }
139
140        // Also enable if there are context routes
141        if let Some(ref context_routes) = self.context_routes {
142            if context_routes.has_routes().await {
143                self.ensure_fetch_enabled().await?;
144            }
145        }
146        Ok(())
147    }
148
149    /// Set the context-level route registry.
150    pub fn set_context_routes(
151        &mut self,
152        context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
153    ) {
154        self.context_routes = Some(context_routes);
155    }
156
157    /// Start the background fetch event listener.
158    ///
159    /// This spawns a background task that listens for `Fetch.requestPaused` and
160    /// `Fetch.authRequired` events and dispatches them to the appropriate handlers.
161    ///
162    /// Also listens for context route change notifications to enable Fetch when
163    /// new routes are added to the context after the page was created.
164    ///
165    /// This should be called after creating the registry, passing an Arc reference to self.
166    pub fn start_fetch_listener(self: &Arc<Self>) {
167        let mut events = self.connection.subscribe_events();
168        let session_id = self.session_id.clone();
169        let registry = Arc::clone(self);
170
171        // Subscribe to context route changes if we have context routes
172        let mut route_change_rx = self
173            .context_routes
174            .as_ref()
175            .map(|ctx| ctx.subscribe_route_changes());
176        let registry_for_routes = Arc::clone(self);
177
178        tokio::spawn(async move {
179            loop {
180                tokio::select! {
181                    // Handle CDP events
182                    event_result = events.recv() => {
183                        let Ok(event) = event_result else {
184                            break;
185                        };
186
187                        // Filter for this session
188                        if event.session_id.as_deref() != Some(&session_id) {
189                            continue;
190                        }
191
192                        match event.method.as_str() {
193                            "Fetch.requestPaused" => {
194                                if let Some(params) = &event.params {
195                                    if let Ok(paused_event) = serde_json::from_value::<RequestPausedEvent>(params.clone()) {
196                                        tracing::debug!(
197                                            request_id = %paused_event.request_id,
198                                            url = %paused_event.request.url,
199                                            "Fetch.requestPaused received"
200                                        );
201                                        if let Err(e) = registry.handle_request(&paused_event).await {
202                                            tracing::warn!(
203                                                request_id = %paused_event.request_id,
204                                                error = %e,
205                                                "Failed to handle paused request"
206                                            );
207                                        }
208                                    }
209                                }
210                            }
211                            "Fetch.authRequired" => {
212                                if let Some(params) = &event.params {
213                                    if let Ok(auth_event) = serde_json::from_value::<AuthRequiredEvent>(params.clone()) {
214                                        tracing::debug!(
215                                            request_id = %auth_event.request_id,
216                                            origin = %auth_event.auth_challenge.origin,
217                                            scheme = %auth_event.auth_challenge.scheme,
218                                            "Fetch.authRequired received"
219                                        );
220                                        if let Err(e) = registry.handle_auth_required(&auth_event).await {
221                                            tracing::warn!(
222                                                request_id = %auth_event.request_id,
223                                                error = %e,
224                                                "Failed to handle auth required"
225                                            );
226                                        }
227                                    }
228                                }
229                            }
230                            _ => {}
231                        }
232                    }
233
234                    // Handle context route change notifications
235                    Some(Ok(_notification)) = async {
236                        match route_change_rx.as_mut() {
237                            Some(rx) => Some(rx.recv().await),
238                            None => std::future::pending().await,
239                        }
240                    } => {
241                        // A new route was added to the context - enable Fetch if not already
242                        tracing::debug!("Context route added, ensuring Fetch is enabled");
243                        if let Err(e) = registry_for_routes.ensure_fetch_enabled().await {
244                            tracing::warn!(error = %e, "Failed to enable Fetch after context route added");
245                        }
246                    }
247                }
248            }
249        });
250    }
251
252    /// Set HTTP credentials for authentication.
253    pub async fn set_http_credentials(&self, credentials: HttpCredentials) {
254        self.auth_handler.set_credentials(credentials).await;
255
256        // Enable auth handling if not already enabled
257        let mut auth_enabled = self.auth_enabled.write().await;
258        if !*auth_enabled {
259            *auth_enabled = true;
260            // Re-enable fetch with auth handling if fetch is already enabled
261            drop(auth_enabled);
262            let fetch_enabled = *self.fetch_enabled.read().await;
263            if fetch_enabled {
264                let _ = self.re_enable_fetch_with_auth().await;
265            }
266        }
267    }
268
269    /// Clear HTTP credentials.
270    pub async fn clear_http_credentials(&self) {
271        self.auth_handler.clear_credentials().await;
272        let mut auth_enabled = self.auth_enabled.write().await;
273        *auth_enabled = false;
274    }
275
276    /// Handle an authentication challenge.
277    pub async fn handle_auth_required(
278        &self,
279        event: &AuthRequiredEvent,
280    ) -> Result<(), NetworkError> {
281        self.auth_handler.handle_auth_challenge(event).await?;
282        Ok(())
283    }
284
285    /// Register a route handler for the given pattern.
286    pub async fn route<M, H, Fut>(&self, pattern: M, handler: H) -> Result<(), NetworkError>
287    where
288        M: Into<UrlPattern>,
289        H: Fn(Route) -> Fut + Send + Sync + 'static,
290        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
291    {
292        let pattern = pattern.into();
293
294        // Enable Fetch domain if not already enabled
295        self.ensure_fetch_enabled().await?;
296
297        // Wrap the handler
298        let handler: Arc<
299            dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
300                + Send
301                + Sync,
302        > = Arc::new(move |route| Box::pin(handler(route)));
303
304        // Add to handlers (will be matched in reverse order)
305        let mut handlers = self.handlers.write().await;
306        handlers.push(RegisteredHandler {
307            pattern: Box::new(pattern),
308            handler,
309        });
310
311        Ok(())
312    }
313
314    /// Register a route handler with a predicate function.
315    pub async fn route_predicate<P, H, Fut>(
316        &self,
317        predicate: P,
318        handler: H,
319    ) -> Result<(), NetworkError>
320    where
321        P: Fn(&str) -> bool + Send + Sync + 'static,
322        H: Fn(Route) -> Fut + Send + Sync + 'static,
323        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
324    {
325        // Enable Fetch domain if not already enabled
326        self.ensure_fetch_enabled().await?;
327
328        // Create a matcher from the predicate
329        struct PredicateMatcher<F>(F);
330        impl<F: Fn(&str) -> bool + Send + Sync> UrlMatcher for PredicateMatcher<F> {
331            fn matches(&self, url: &str) -> bool {
332                (self.0)(url)
333            }
334        }
335
336        // Wrap the handler
337        let handler: Arc<
338            dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
339                + Send
340                + Sync,
341        > = Arc::new(move |route| Box::pin(handler(route)));
342
343        // Add to handlers
344        let mut handlers = self.handlers.write().await;
345        handlers.push(RegisteredHandler {
346            pattern: Box::new(PredicateMatcher(predicate)),
347            handler,
348        });
349
350        Ok(())
351    }
352
353    /// Unregister handlers matching the given pattern.
354    pub async fn unroute(&self, pattern: &str) {
355        let mut handlers = self.handlers.write().await;
356
357        // Remove handlers that match this pattern
358        // For simplicity, we match based on glob pattern equality
359        handlers.retain(|h| {
360            // This is a simplification - in a real implementation,
361            // we'd need to compare patterns more thoroughly
362            !h.pattern.matches(pattern)
363        });
364
365        // If no handlers left, disable Fetch domain
366        if handlers.is_empty() {
367            drop(handlers);
368            let _ = self.disable_fetch_domain().await;
369        }
370    }
371
372    /// Unregister all handlers.
373    pub async fn unroute_all(&self) {
374        let mut handlers = self.handlers.write().await;
375        handlers.clear();
376        drop(handlers);
377        let _ = self.disable_fetch_domain().await;
378    }
379
380    /// Handle a paused request by dispatching to matching handlers.
381    ///
382    /// Handlers are tried in reverse order (LIFO). If a handler calls `fallback()`,
383    /// the next matching handler is tried. If no handler handles the request,
384    /// it is continued to the network.
385    pub async fn handle_request(&self, event: &RequestPausedEvent) -> Result<(), NetworkError> {
386        let url = &event.request.url;
387        let handlers = self.handlers.read().await;
388
389        // Collect all matching handlers (in reverse order - LIFO)
390        let matching_handlers: Vec<_> = handlers
391            .iter()
392            .rev()
393            .filter(|h| h.pattern.matches(url))
394            .collect();
395
396        // Try each matching handler in order
397        for handler in &matching_handlers {
398            let route =
399                create_route_from_event(event, self.connection.clone(), self.session_id.clone());
400            let route_check = route.clone();
401
402            // Call the handler (handler takes ownership of route)
403            (handler.handler)(route).await?;
404
405            // Check if the route was actually handled (made a CDP call)
406            if route_check.is_handled().await {
407                return Ok(());
408            }
409            tracing::debug!(
410                request_id = %event.request_id,
411                url = %url,
412                "Handler called fallback, trying next handler"
413            );
414        }
415
416        // Drop page handlers lock before checking context routes
417        drop(handlers);
418
419        // Check context routes as fallback
420        if let Some(ref context_routes) = self.context_routes {
421            let context_handlers = context_routes.find_all_handlers(url).await;
422
423            for handler in context_handlers {
424                let route = create_route_from_event(
425                    event,
426                    self.connection.clone(),
427                    self.session_id.clone(),
428                );
429                let route_check = route.clone();
430
431                handler(route).await?;
432
433                if route_check.is_handled().await {
434                    return Ok(());
435                }
436                tracing::debug!(
437                    request_id = %event.request_id,
438                    url = %url,
439                    "Context handler called fallback, trying next handler"
440                );
441            }
442        }
443
444        // No handler handled the request - continue to the network
445        continue_request(&self.connection, &self.session_id, &event.request_id).await
446    }
447
448    /// Enable the Fetch domain if not already enabled.
449    ///
450    /// This is a public version for use by `ContextRouteRegistry` when
451    /// synchronously enabling Fetch on all pages after a context route is added.
452    pub async fn ensure_fetch_enabled_public(&self) -> Result<(), NetworkError> {
453        self.ensure_fetch_enabled().await
454    }
455
456    /// Enable the Fetch domain if not already enabled.
457    async fn ensure_fetch_enabled(&self) -> Result<(), NetworkError> {
458        let mut enabled = self.fetch_enabled.write().await;
459        if *enabled {
460            return Ok(());
461        }
462
463        let auth_enabled = *self.auth_enabled.read().await;
464        enable_fetch(&self.connection, &self.session_id, auth_enabled).await?;
465        *enabled = true;
466        Ok(())
467    }
468
469    /// Re-enable Fetch domain with auth handling.
470    async fn re_enable_fetch_with_auth(&self) -> Result<(), NetworkError> {
471        // First disable, then re-enable with auth
472        disable_fetch(&self.connection, &self.session_id).await?;
473        enable_fetch(&self.connection, &self.session_id, true).await
474    }
475
476    /// Disable the Fetch domain.
477    async fn disable_fetch_domain(&self) -> Result<(), NetworkError> {
478        let mut enabled = self.fetch_enabled.write().await;
479        if !*enabled {
480            return Ok(());
481        }
482
483        disable_fetch(&self.connection, &self.session_id).await?;
484        *enabled = false;
485        Ok(())
486    }
487}
488
489#[cfg(test)]
490mod tests;