viewpoint_core/network/handler/
mod.rs

1//! Route handler registry and dispatch.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use tokio::sync::RwLock;
8use viewpoint_cdp::protocol::fetch::{AuthRequiredEvent, RequestPausedEvent};
9use viewpoint_cdp::CdpConnection;
10
11use super::auth::{AuthHandler, HttpCredentials};
12use super::handler_fetch::{disable_fetch, enable_fetch};
13use super::handler_request::{continue_request, create_route_from_event};
14use super::route::Route;
15use super::types::{UrlMatcher, UrlPattern};
16use crate::error::NetworkError;
17
18/// A registered route handler.
19struct RegisteredHandler {
20    /// Pattern to match URLs.
21    pattern: Box<dyn UrlMatcher>,
22    /// The handler function.
23    handler: Arc<
24        dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
25            + Send
26            + Sync,
27    >,
28}
29
30impl std::fmt::Debug for RegisteredHandler {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("RegisteredHandler")
33            .field("pattern", &"<pattern>")
34            .field("handler", &"<fn>")
35            .finish()
36    }
37}
38
39/// Route handler registry for a page or context.
40#[derive(Debug)]
41pub struct RouteHandlerRegistry {
42    /// Registered handlers (in reverse order - last registered is first tried).
43    handlers: RwLock<Vec<RegisteredHandler>>,
44    /// CDP connection for sending commands.
45    connection: Arc<CdpConnection>,
46    /// Session ID for CDP commands.
47    session_id: String,
48    /// Whether the Fetch domain is enabled.
49    fetch_enabled: RwLock<bool>,
50    /// HTTP authentication handler.
51    auth_handler: AuthHandler,
52    /// Whether auth handling is enabled.
53    auth_enabled: RwLock<bool>,
54    /// Context-level route registry (for fallback handling).
55    context_routes: Option<Arc<crate::context::routing::ContextRouteRegistry>>,
56}
57
58impl RouteHandlerRegistry {
59    /// Create a new route handler registry.
60    pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
61        let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
62        Self {
63            handlers: RwLock::new(Vec::new()),
64            connection,
65            session_id,
66            fetch_enabled: RwLock::new(false),
67            auth_handler,
68            auth_enabled: RwLock::new(false),
69            context_routes: None,
70        }
71    }
72
73    /// Create a new route handler registry with HTTP credentials.
74    pub fn with_credentials(
75        connection: Arc<CdpConnection>,
76        session_id: String,
77        credentials: HttpCredentials,
78    ) -> Self {
79        let auth_handler = AuthHandler::with_credentials(
80            connection.clone(),
81            session_id.clone(),
82            credentials,
83        );
84        Self {
85            handlers: RwLock::new(Vec::new()),
86            connection,
87            session_id,
88            fetch_enabled: RwLock::new(false),
89            auth_handler,
90            auth_enabled: RwLock::new(true),
91            context_routes: None,
92        }
93    }
94
95    /// Create a new route handler registry with context-level routes.
96    /// 
97    /// If `http_credentials` is provided, they will be set on the auth handler
98    /// for handling HTTP authentication challenges.
99    pub fn with_context_routes(
100        connection: Arc<CdpConnection>,
101        session_id: String,
102        context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
103        http_credentials: Option<HttpCredentials>,
104    ) -> Self {
105        let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
106        
107        // If credentials are provided, set them on the auth handler
108        if let Some(ref creds) = http_credentials {
109            // We need to set credentials synchronously during construction.
110            // AuthHandler stores credentials in an Arc<RwLock>, so we can set them directly.
111            tracing::debug!(
112                username = %creds.username,
113                has_origin = creds.origin.is_some(),
114                "Setting HTTP credentials on auth handler"
115            );
116            auth_handler.set_credentials_sync(creds.clone());
117        }
118        
119        Self {
120            handlers: RwLock::new(Vec::new()),
121            connection,
122            session_id,
123            fetch_enabled: RwLock::new(false),
124            auth_handler,
125            // Enable auth if credentials are provided
126            auth_enabled: RwLock::new(http_credentials.is_some()),
127            context_routes: Some(context_routes),
128        }
129    }
130    
131    /// Enable Fetch domain if context has routes or auth is enabled.
132    /// 
133    /// This should be called after the registry is created to check if there are
134    /// context-level routes that need interception or if HTTP credentials are configured.
135    pub async fn enable_fetch_for_context_routes(&self) -> Result<(), NetworkError> {
136        // Enable if auth is enabled (credentials were provided)
137        let auth_enabled = *self.auth_enabled.read().await;
138        if auth_enabled {
139            self.ensure_fetch_enabled().await?;
140            return Ok(());
141        }
142        
143        // Also enable if there are context routes
144        if let Some(ref context_routes) = self.context_routes {
145            if context_routes.has_routes().await {
146                self.ensure_fetch_enabled().await?;
147            }
148        }
149        Ok(())
150    }
151
152    /// Set the context-level route registry.
153    pub fn set_context_routes(&mut self, context_routes: Arc<crate::context::routing::ContextRouteRegistry>) {
154        self.context_routes = Some(context_routes);
155    }
156
157    /// Start the background fetch event listener.
158    ///
159    /// This spawns a background task that listens for `Fetch.requestPaused` and
160    /// `Fetch.authRequired` events and dispatches them to the appropriate handlers.
161    ///
162    /// Also listens for context route change notifications to enable Fetch when
163    /// new routes are added to the context after the page was created.
164    ///
165    /// This should be called after creating the registry, passing an Arc reference to self.
166    pub fn start_fetch_listener(self: &Arc<Self>) {
167        let mut events = self.connection.subscribe_events();
168        let session_id = self.session_id.clone();
169        let registry = Arc::clone(self);
170
171        // Subscribe to context route changes if we have context routes
172        let mut route_change_rx = self.context_routes.as_ref().map(|ctx| ctx.subscribe_route_changes());
173        let registry_for_routes = Arc::clone(self);
174
175        tokio::spawn(async move {
176            loop {
177                tokio::select! {
178                    // Handle CDP events
179                    event_result = events.recv() => {
180                        let Ok(event) = event_result else {
181                            break;
182                        };
183                        
184                        // Filter for this session
185                        if event.session_id.as_deref() != Some(&session_id) {
186                            continue;
187                        }
188
189                        match event.method.as_str() {
190                            "Fetch.requestPaused" => {
191                                if let Some(params) = &event.params {
192                                    if let Ok(paused_event) = serde_json::from_value::<RequestPausedEvent>(params.clone()) {
193                                        tracing::debug!(
194                                            request_id = %paused_event.request_id,
195                                            url = %paused_event.request.url,
196                                            "Fetch.requestPaused received"
197                                        );
198                                        if let Err(e) = registry.handle_request(&paused_event).await {
199                                            tracing::warn!(
200                                                request_id = %paused_event.request_id,
201                                                error = %e,
202                                                "Failed to handle paused request"
203                                            );
204                                        }
205                                    }
206                                }
207                            }
208                            "Fetch.authRequired" => {
209                                if let Some(params) = &event.params {
210                                    if let Ok(auth_event) = serde_json::from_value::<AuthRequiredEvent>(params.clone()) {
211                                        tracing::debug!(
212                                            request_id = %auth_event.request_id,
213                                            origin = %auth_event.auth_challenge.origin,
214                                            scheme = %auth_event.auth_challenge.scheme,
215                                            "Fetch.authRequired received"
216                                        );
217                                        if let Err(e) = registry.handle_auth_required(&auth_event).await {
218                                            tracing::warn!(
219                                                request_id = %auth_event.request_id,
220                                                error = %e,
221                                                "Failed to handle auth required"
222                                            );
223                                        }
224                                    }
225                                }
226                            }
227                            _ => {}
228                        }
229                    }
230                    
231                    // Handle context route change notifications
232                    Some(Ok(_notification)) = async {
233                        match route_change_rx.as_mut() {
234                            Some(rx) => Some(rx.recv().await),
235                            None => std::future::pending().await,
236                        }
237                    } => {
238                        // A new route was added to the context - enable Fetch if not already
239                        tracing::debug!("Context route added, ensuring Fetch is enabled");
240                        if let Err(e) = registry_for_routes.ensure_fetch_enabled().await {
241                            tracing::warn!(error = %e, "Failed to enable Fetch after context route added");
242                        }
243                    }
244                }
245            }
246        });
247    }
248
249    /// Set HTTP credentials for authentication.
250    pub async fn set_http_credentials(&self, credentials: HttpCredentials) {
251        self.auth_handler.set_credentials(credentials).await;
252        
253        // Enable auth handling if not already enabled
254        let mut auth_enabled = self.auth_enabled.write().await;
255        if !*auth_enabled {
256            *auth_enabled = true;
257            // Re-enable fetch with auth handling if fetch is already enabled
258            drop(auth_enabled);
259            let fetch_enabled = *self.fetch_enabled.read().await;
260            if fetch_enabled {
261                let _ = self.re_enable_fetch_with_auth().await;
262            }
263        }
264    }
265
266    /// Clear HTTP credentials.
267    pub async fn clear_http_credentials(&self) {
268        self.auth_handler.clear_credentials().await;
269        let mut auth_enabled = self.auth_enabled.write().await;
270        *auth_enabled = false;
271    }
272
273    /// Handle an authentication challenge.
274    pub async fn handle_auth_required(&self, event: &AuthRequiredEvent) -> Result<(), NetworkError> {
275        self.auth_handler.handle_auth_challenge(event).await?;
276        Ok(())
277    }
278
279    /// Register a route handler for the given pattern.
280    pub async fn route<M, H, Fut>(&self, pattern: M, handler: H) -> Result<(), NetworkError>
281    where
282        M: Into<UrlPattern>,
283        H: Fn(Route) -> Fut + Send + Sync + 'static,
284        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
285    {
286        let pattern = pattern.into();
287
288        // Enable Fetch domain if not already enabled
289        self.ensure_fetch_enabled().await?;
290
291        // Wrap the handler
292        let handler: Arc<
293            dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
294                + Send
295                + Sync,
296        > = Arc::new(move |route| Box::pin(handler(route)));
297
298        // Add to handlers (will be matched in reverse order)
299        let mut handlers = self.handlers.write().await;
300        handlers.push(RegisteredHandler {
301            pattern: Box::new(pattern),
302            handler,
303        });
304
305        Ok(())
306    }
307
308    /// Register a route handler with a predicate function.
309    pub async fn route_predicate<P, H, Fut>(&self, predicate: P, handler: H) -> Result<(), NetworkError>
310    where
311        P: Fn(&str) -> bool + Send + Sync + 'static,
312        H: Fn(Route) -> Fut + Send + Sync + 'static,
313        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
314    {
315        // Enable Fetch domain if not already enabled
316        self.ensure_fetch_enabled().await?;
317
318        // Create a matcher from the predicate
319        struct PredicateMatcher<F>(F);
320        impl<F: Fn(&str) -> bool + Send + Sync> UrlMatcher for PredicateMatcher<F> {
321            fn matches(&self, url: &str) -> bool {
322                (self.0)(url)
323            }
324        }
325
326        // Wrap the handler
327        let handler: Arc<
328            dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
329                + Send
330                + Sync,
331        > = Arc::new(move |route| Box::pin(handler(route)));
332
333        // Add to handlers
334        let mut handlers = self.handlers.write().await;
335        handlers.push(RegisteredHandler {
336            pattern: Box::new(PredicateMatcher(predicate)),
337            handler,
338        });
339
340        Ok(())
341    }
342
343    /// Unregister handlers matching the given pattern.
344    pub async fn unroute(&self, pattern: &str) {
345        let mut handlers = self.handlers.write().await;
346        
347        // Remove handlers that match this pattern
348        // For simplicity, we match based on glob pattern equality
349        handlers.retain(|h| {
350            // This is a simplification - in a real implementation,
351            // we'd need to compare patterns more thoroughly
352            !h.pattern.matches(pattern)
353        });
354
355        // If no handlers left, disable Fetch domain
356        if handlers.is_empty() {
357            drop(handlers);
358            let _ = self.disable_fetch_domain().await;
359        }
360    }
361
362    /// Unregister all handlers.
363    pub async fn unroute_all(&self) {
364        let mut handlers = self.handlers.write().await;
365        handlers.clear();
366        drop(handlers);
367        let _ = self.disable_fetch_domain().await;
368    }
369
370    /// Handle a paused request by dispatching to matching handlers.
371    /// 
372    /// Handlers are tried in reverse order (LIFO). If a handler calls `fallback()`,
373    /// the next matching handler is tried. If no handler handles the request,
374    /// it is continued to the network.
375    pub async fn handle_request(&self, event: &RequestPausedEvent) -> Result<(), NetworkError> {
376        let url = &event.request.url;
377        let handlers = self.handlers.read().await;
378
379        // Collect all matching handlers (in reverse order - LIFO)
380        let matching_handlers: Vec<_> = handlers
381            .iter()
382            .rev()
383            .filter(|h| h.pattern.matches(url))
384            .collect();
385
386        // Try each matching handler in order
387        for handler in &matching_handlers {
388            let route = create_route_from_event(event, self.connection.clone(), self.session_id.clone());
389            let route_check = route.clone();
390            
391            // Call the handler (handler takes ownership of route)
392            (handler.handler)(route).await?;
393            
394            // Check if the route was actually handled (made a CDP call)
395            if route_check.is_handled().await {
396                return Ok(());
397            }
398            tracing::debug!(
399                request_id = %event.request_id,
400                url = %url,
401                "Handler called fallback, trying next handler"
402            );
403        }
404
405        // Drop page handlers lock before checking context routes
406        drop(handlers);
407
408        // Check context routes as fallback
409        if let Some(ref context_routes) = self.context_routes {
410            let context_handlers = context_routes.find_all_handlers(url).await;
411            
412            for handler in context_handlers {
413                let route = create_route_from_event(event, self.connection.clone(), self.session_id.clone());
414                let route_check = route.clone();
415                
416                handler(route).await?;
417                
418                if route_check.is_handled().await {
419                    return Ok(());
420                }
421                tracing::debug!(
422                    request_id = %event.request_id,
423                    url = %url,
424                    "Context handler called fallback, trying next handler"
425                );
426            }
427        }
428
429        // No handler handled the request - continue to the network
430        continue_request(&self.connection, &self.session_id, &event.request_id).await
431    }
432
433    /// Enable the Fetch domain if not already enabled.
434    /// 
435    /// This is a public version for use by `ContextRouteRegistry` when
436    /// synchronously enabling Fetch on all pages after a context route is added.
437    pub async fn ensure_fetch_enabled_public(&self) -> Result<(), NetworkError> {
438        self.ensure_fetch_enabled().await
439    }
440
441    /// Enable the Fetch domain if not already enabled.
442    async fn ensure_fetch_enabled(&self) -> Result<(), NetworkError> {
443        let mut enabled = self.fetch_enabled.write().await;
444        if *enabled {
445            return Ok(());
446        }
447
448        let auth_enabled = *self.auth_enabled.read().await;
449        enable_fetch(&self.connection, &self.session_id, auth_enabled).await?;
450        *enabled = true;
451        Ok(())
452    }
453
454    /// Re-enable Fetch domain with auth handling.
455    async fn re_enable_fetch_with_auth(&self) -> Result<(), NetworkError> {
456        // First disable, then re-enable with auth
457        disable_fetch(&self.connection, &self.session_id).await?;
458        enable_fetch(&self.connection, &self.session_id, true).await
459    }
460
461    /// Disable the Fetch domain.
462    async fn disable_fetch_domain(&self) -> Result<(), NetworkError> {
463        let mut enabled = self.fetch_enabled.write().await;
464        if !*enabled {
465            return Ok(());
466        }
467
468        disable_fetch(&self.connection, &self.session_id).await?;
469        *enabled = false;
470        Ok(())
471    }
472}
473
474#[cfg(test)]
475mod tests;