viewpoint_core/network/handler/
mod.rs1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use tokio::sync::RwLock;
8use viewpoint_cdp::CdpConnection;
9use viewpoint_cdp::protocol::fetch::{AuthRequiredEvent, RequestPausedEvent};
10
11use super::auth::{AuthHandler, HttpCredentials, ProxyCredentials};
12use super::handler_fetch::{disable_fetch, enable_fetch};
13use super::handler_request::{continue_request, create_route_from_event};
14use super::route::Route;
15use super::types::{UrlMatcher, UrlPattern};
16use crate::error::NetworkError;
17
18struct RegisteredHandler {
20 pattern: Box<dyn UrlMatcher>,
22 handler: Arc<
24 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
25 + Send
26 + Sync,
27 >,
28}
29
30impl std::fmt::Debug for RegisteredHandler {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("RegisteredHandler")
33 .field("pattern", &"<pattern>")
34 .field("handler", &"<fn>")
35 .finish()
36 }
37}
38
39#[derive(Debug)]
41pub struct RouteHandlerRegistry {
42 handlers: RwLock<Vec<RegisteredHandler>>,
44 connection: Arc<CdpConnection>,
46 session_id: String,
48 fetch_enabled: RwLock<bool>,
50 auth_handler: AuthHandler,
52 auth_enabled: RwLock<bool>,
54 context_routes: Option<Arc<crate::context::routing::ContextRouteRegistry>>,
56}
57
58impl RouteHandlerRegistry {
59 pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
61 let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
62 Self {
63 handlers: RwLock::new(Vec::new()),
64 connection,
65 session_id,
66 fetch_enabled: RwLock::new(false),
67 auth_handler,
68 auth_enabled: RwLock::new(false),
69 context_routes: None,
70 }
71 }
72
73 pub fn with_credentials(
75 connection: Arc<CdpConnection>,
76 session_id: String,
77 credentials: HttpCredentials,
78 ) -> Self {
79 let auth_handler =
80 AuthHandler::with_credentials(connection.clone(), session_id.clone(), credentials);
81 Self {
82 handlers: RwLock::new(Vec::new()),
83 connection,
84 session_id,
85 fetch_enabled: RwLock::new(false),
86 auth_handler,
87 auth_enabled: RwLock::new(true),
88 context_routes: None,
89 }
90 }
91
92 pub fn with_context_routes(
97 connection: Arc<CdpConnection>,
98 session_id: String,
99 context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
100 http_credentials: Option<HttpCredentials>,
101 ) -> Self {
102 Self::with_context_routes_and_proxy(
103 connection,
104 session_id,
105 context_routes,
106 http_credentials,
107 None,
108 )
109 }
110
111 pub fn with_context_routes_and_proxy(
117 connection: Arc<CdpConnection>,
118 session_id: String,
119 context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
120 http_credentials: Option<HttpCredentials>,
121 proxy_credentials: Option<ProxyCredentials>,
122 ) -> Self {
123 let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
124
125 if let Some(ref creds) = http_credentials {
127 tracing::debug!(
128 username = %creds.username,
129 has_origin = creds.origin.is_some(),
130 "Setting HTTP credentials on auth handler"
131 );
132 auth_handler.set_credentials_sync(creds.clone());
133 }
134
135 if let Some(ref proxy_creds) = proxy_credentials {
137 tracing::debug!(
138 username = %proxy_creds.username,
139 "Setting proxy credentials on auth handler"
140 );
141 auth_handler.set_proxy_credentials_sync(proxy_creds.clone());
142 }
143
144 let auth_enabled = http_credentials.is_some() || proxy_credentials.is_some();
146
147 Self {
148 handlers: RwLock::new(Vec::new()),
149 connection,
150 session_id,
151 fetch_enabled: RwLock::new(false),
152 auth_handler,
153 auth_enabled: RwLock::new(auth_enabled),
154 context_routes: Some(context_routes),
155 }
156 }
157
158 pub async fn enable_fetch_for_context_routes(&self) -> Result<(), NetworkError> {
163 let auth_enabled = *self.auth_enabled.read().await;
165 if auth_enabled {
166 self.ensure_fetch_enabled().await?;
167 return Ok(());
168 }
169
170 if let Some(ref context_routes) = self.context_routes {
172 if context_routes.has_routes().await {
173 self.ensure_fetch_enabled().await?;
174 }
175 }
176 Ok(())
177 }
178
179 pub fn set_context_routes(
181 &mut self,
182 context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
183 ) {
184 self.context_routes = Some(context_routes);
185 }
186
187 pub fn start_fetch_listener(self: &Arc<Self>) {
197 let mut events = self.connection.subscribe_events();
198 let session_id = self.session_id.clone();
199 let registry = Arc::clone(self);
200
201 let mut route_change_rx = self
203 .context_routes
204 .as_ref()
205 .map(|ctx| ctx.subscribe_route_changes());
206 let registry_for_routes = Arc::clone(self);
207
208 tokio::spawn(async move {
209 loop {
210 tokio::select! {
211 event_result = events.recv() => {
213 let Ok(event) = event_result else {
214 break;
215 };
216
217 if event.session_id.as_deref() != Some(&session_id) {
219 continue;
220 }
221
222 match event.method.as_str() {
223 "Fetch.requestPaused" => {
224 if let Some(params) = &event.params {
225 if let Ok(paused_event) = serde_json::from_value::<RequestPausedEvent>(params.clone()) {
226 tracing::debug!(
227 request_id = %paused_event.request_id,
228 url = %paused_event.request.url,
229 "Fetch.requestPaused received"
230 );
231 if let Err(e) = registry.handle_request(&paused_event).await {
232 tracing::warn!(
233 request_id = %paused_event.request_id,
234 error = %e,
235 "Failed to handle paused request"
236 );
237 }
238 }
239 }
240 }
241 "Fetch.authRequired" => {
242 if let Some(params) = &event.params {
243 if let Ok(auth_event) = serde_json::from_value::<AuthRequiredEvent>(params.clone()) {
244 tracing::debug!(
245 request_id = %auth_event.request_id,
246 origin = %auth_event.auth_challenge.origin,
247 scheme = %auth_event.auth_challenge.scheme,
248 "Fetch.authRequired received"
249 );
250 if let Err(e) = registry.handle_auth_required(&auth_event).await {
251 tracing::warn!(
252 request_id = %auth_event.request_id,
253 error = %e,
254 "Failed to handle auth required"
255 );
256 }
257 }
258 }
259 }
260 _ => {}
261 }
262 }
263
264 Some(Ok(_notification)) = async {
266 match route_change_rx.as_mut() {
267 Some(rx) => Some(rx.recv().await),
268 None => std::future::pending().await,
269 }
270 } => {
271 tracing::debug!("Context route added, ensuring Fetch is enabled");
273 if let Err(e) = registry_for_routes.ensure_fetch_enabled().await {
274 tracing::warn!(error = %e, "Failed to enable Fetch after context route added");
275 }
276 }
277 }
278 }
279 });
280 }
281
282 pub async fn set_http_credentials(&self, credentials: HttpCredentials) {
284 self.auth_handler.set_credentials(credentials).await;
285
286 let mut auth_enabled = self.auth_enabled.write().await;
288 if !*auth_enabled {
289 *auth_enabled = true;
290 drop(auth_enabled);
292 let fetch_enabled = *self.fetch_enabled.read().await;
293 if fetch_enabled {
294 let _ = self.re_enable_fetch_with_auth().await;
295 }
296 }
297 }
298
299 pub async fn clear_http_credentials(&self) {
301 self.auth_handler.clear_credentials().await;
302 let mut auth_enabled = self.auth_enabled.write().await;
303 *auth_enabled = false;
304 }
305
306 pub async fn handle_auth_required(
308 &self,
309 event: &AuthRequiredEvent,
310 ) -> Result<(), NetworkError> {
311 self.auth_handler.handle_auth_challenge(event).await?;
312 Ok(())
313 }
314
315 pub async fn route<M, H, Fut>(&self, pattern: M, handler: H) -> Result<(), NetworkError>
317 where
318 M: Into<UrlPattern>,
319 H: Fn(Route) -> Fut + Send + Sync + 'static,
320 Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
321 {
322 let pattern = pattern.into();
323
324 self.ensure_fetch_enabled().await?;
326
327 let handler: Arc<
329 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
330 + Send
331 + Sync,
332 > = Arc::new(move |route| Box::pin(handler(route)));
333
334 let mut handlers = self.handlers.write().await;
336 handlers.push(RegisteredHandler {
337 pattern: Box::new(pattern),
338 handler,
339 });
340
341 Ok(())
342 }
343
344 pub async fn route_predicate<P, H, Fut>(
346 &self,
347 predicate: P,
348 handler: H,
349 ) -> Result<(), NetworkError>
350 where
351 P: Fn(&str) -> bool + Send + Sync + 'static,
352 H: Fn(Route) -> Fut + Send + Sync + 'static,
353 Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
354 {
355 self.ensure_fetch_enabled().await?;
357
358 struct PredicateMatcher<F>(F);
360 impl<F: Fn(&str) -> bool + Send + Sync> UrlMatcher for PredicateMatcher<F> {
361 fn matches(&self, url: &str) -> bool {
362 (self.0)(url)
363 }
364 }
365
366 let handler: Arc<
368 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
369 + Send
370 + Sync,
371 > = Arc::new(move |route| Box::pin(handler(route)));
372
373 let mut handlers = self.handlers.write().await;
375 handlers.push(RegisteredHandler {
376 pattern: Box::new(PredicateMatcher(predicate)),
377 handler,
378 });
379
380 Ok(())
381 }
382
383 pub async fn unroute(&self, pattern: &str) {
385 let mut handlers = self.handlers.write().await;
386
387 handlers.retain(|h| {
390 !h.pattern.matches(pattern)
393 });
394
395 if handlers.is_empty() {
397 drop(handlers);
398 let _ = self.disable_fetch_domain().await;
399 }
400 }
401
402 pub async fn unroute_all(&self) {
404 let mut handlers = self.handlers.write().await;
405 handlers.clear();
406 drop(handlers);
407 let _ = self.disable_fetch_domain().await;
408 }
409
410 pub async fn handle_request(&self, event: &RequestPausedEvent) -> Result<(), NetworkError> {
416 let url = &event.request.url;
417 let handlers = self.handlers.read().await;
418
419 let matching_handlers: Vec<_> = handlers
421 .iter()
422 .rev()
423 .filter(|h| h.pattern.matches(url))
424 .collect();
425
426 for handler in &matching_handlers {
428 let route =
429 create_route_from_event(event, self.connection.clone(), self.session_id.clone());
430 let route_check = route.clone();
431
432 (handler.handler)(route).await?;
434
435 if route_check.is_handled().await {
437 return Ok(());
438 }
439 tracing::debug!(
440 request_id = %event.request_id,
441 url = %url,
442 "Handler called fallback, trying next handler"
443 );
444 }
445
446 drop(handlers);
448
449 if let Some(ref context_routes) = self.context_routes {
451 let context_handlers = context_routes.find_all_handlers(url).await;
452
453 for handler in context_handlers {
454 let route = create_route_from_event(
455 event,
456 self.connection.clone(),
457 self.session_id.clone(),
458 );
459 let route_check = route.clone();
460
461 handler(route).await?;
462
463 if route_check.is_handled().await {
464 return Ok(());
465 }
466 tracing::debug!(
467 request_id = %event.request_id,
468 url = %url,
469 "Context handler called fallback, trying next handler"
470 );
471 }
472 }
473
474 continue_request(&self.connection, &self.session_id, &event.request_id).await
476 }
477
478 pub async fn ensure_fetch_enabled_public(&self) -> Result<(), NetworkError> {
483 self.ensure_fetch_enabled().await
484 }
485
486 async fn ensure_fetch_enabled(&self) -> Result<(), NetworkError> {
488 let mut enabled = self.fetch_enabled.write().await;
489 if *enabled {
490 return Ok(());
491 }
492
493 let auth_enabled = *self.auth_enabled.read().await;
494 enable_fetch(&self.connection, &self.session_id, auth_enabled).await?;
495 *enabled = true;
496 Ok(())
497 }
498
499 async fn re_enable_fetch_with_auth(&self) -> Result<(), NetworkError> {
501 disable_fetch(&self.connection, &self.session_id).await?;
503 enable_fetch(&self.connection, &self.session_id, true).await
504 }
505
506 async fn disable_fetch_domain(&self) -> Result<(), NetworkError> {
508 let mut enabled = self.fetch_enabled.write().await;
509 if !*enabled {
510 return Ok(());
511 }
512
513 disable_fetch(&self.connection, &self.session_id).await?;
514 *enabled = false;
515 Ok(())
516 }
517}
518
519#[cfg(test)]
520mod tests;