viewpoint_core/network/handler/
mod.rs1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use tokio::sync::RwLock;
8use viewpoint_cdp::CdpConnection;
9use viewpoint_cdp::protocol::fetch::{AuthRequiredEvent, RequestPausedEvent};
10
11use super::auth::{AuthHandler, HttpCredentials};
12use super::handler_fetch::{disable_fetch, enable_fetch};
13use super::handler_request::{continue_request, create_route_from_event};
14use super::route::Route;
15use super::types::{UrlMatcher, UrlPattern};
16use crate::error::NetworkError;
17
18struct RegisteredHandler {
20 pattern: Box<dyn UrlMatcher>,
22 handler: Arc<
24 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
25 + Send
26 + Sync,
27 >,
28}
29
30impl std::fmt::Debug for RegisteredHandler {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("RegisteredHandler")
33 .field("pattern", &"<pattern>")
34 .field("handler", &"<fn>")
35 .finish()
36 }
37}
38
39#[derive(Debug)]
41pub struct RouteHandlerRegistry {
42 handlers: RwLock<Vec<RegisteredHandler>>,
44 connection: Arc<CdpConnection>,
46 session_id: String,
48 fetch_enabled: RwLock<bool>,
50 auth_handler: AuthHandler,
52 auth_enabled: RwLock<bool>,
54 context_routes: Option<Arc<crate::context::routing::ContextRouteRegistry>>,
56}
57
58impl RouteHandlerRegistry {
59 pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
61 let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
62 Self {
63 handlers: RwLock::new(Vec::new()),
64 connection,
65 session_id,
66 fetch_enabled: RwLock::new(false),
67 auth_handler,
68 auth_enabled: RwLock::new(false),
69 context_routes: None,
70 }
71 }
72
73 pub fn with_credentials(
75 connection: Arc<CdpConnection>,
76 session_id: String,
77 credentials: HttpCredentials,
78 ) -> Self {
79 let auth_handler =
80 AuthHandler::with_credentials(connection.clone(), session_id.clone(), credentials);
81 Self {
82 handlers: RwLock::new(Vec::new()),
83 connection,
84 session_id,
85 fetch_enabled: RwLock::new(false),
86 auth_handler,
87 auth_enabled: RwLock::new(true),
88 context_routes: None,
89 }
90 }
91
92 pub fn with_context_routes(
97 connection: Arc<CdpConnection>,
98 session_id: String,
99 context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
100 http_credentials: Option<HttpCredentials>,
101 ) -> Self {
102 let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
103
104 if let Some(ref creds) = http_credentials {
106 tracing::debug!(
109 username = %creds.username,
110 has_origin = creds.origin.is_some(),
111 "Setting HTTP credentials on auth handler"
112 );
113 auth_handler.set_credentials_sync(creds.clone());
114 }
115
116 Self {
117 handlers: RwLock::new(Vec::new()),
118 connection,
119 session_id,
120 fetch_enabled: RwLock::new(false),
121 auth_handler,
122 auth_enabled: RwLock::new(http_credentials.is_some()),
124 context_routes: Some(context_routes),
125 }
126 }
127
128 pub async fn enable_fetch_for_context_routes(&self) -> Result<(), NetworkError> {
133 let auth_enabled = *self.auth_enabled.read().await;
135 if auth_enabled {
136 self.ensure_fetch_enabled().await?;
137 return Ok(());
138 }
139
140 if let Some(ref context_routes) = self.context_routes {
142 if context_routes.has_routes().await {
143 self.ensure_fetch_enabled().await?;
144 }
145 }
146 Ok(())
147 }
148
149 pub fn set_context_routes(
151 &mut self,
152 context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
153 ) {
154 self.context_routes = Some(context_routes);
155 }
156
157 pub fn start_fetch_listener(self: &Arc<Self>) {
167 let mut events = self.connection.subscribe_events();
168 let session_id = self.session_id.clone();
169 let registry = Arc::clone(self);
170
171 let mut route_change_rx = self
173 .context_routes
174 .as_ref()
175 .map(|ctx| ctx.subscribe_route_changes());
176 let registry_for_routes = Arc::clone(self);
177
178 tokio::spawn(async move {
179 loop {
180 tokio::select! {
181 event_result = events.recv() => {
183 let Ok(event) = event_result else {
184 break;
185 };
186
187 if event.session_id.as_deref() != Some(&session_id) {
189 continue;
190 }
191
192 match event.method.as_str() {
193 "Fetch.requestPaused" => {
194 if let Some(params) = &event.params {
195 if let Ok(paused_event) = serde_json::from_value::<RequestPausedEvent>(params.clone()) {
196 tracing::debug!(
197 request_id = %paused_event.request_id,
198 url = %paused_event.request.url,
199 "Fetch.requestPaused received"
200 );
201 if let Err(e) = registry.handle_request(&paused_event).await {
202 tracing::warn!(
203 request_id = %paused_event.request_id,
204 error = %e,
205 "Failed to handle paused request"
206 );
207 }
208 }
209 }
210 }
211 "Fetch.authRequired" => {
212 if let Some(params) = &event.params {
213 if let Ok(auth_event) = serde_json::from_value::<AuthRequiredEvent>(params.clone()) {
214 tracing::debug!(
215 request_id = %auth_event.request_id,
216 origin = %auth_event.auth_challenge.origin,
217 scheme = %auth_event.auth_challenge.scheme,
218 "Fetch.authRequired received"
219 );
220 if let Err(e) = registry.handle_auth_required(&auth_event).await {
221 tracing::warn!(
222 request_id = %auth_event.request_id,
223 error = %e,
224 "Failed to handle auth required"
225 );
226 }
227 }
228 }
229 }
230 _ => {}
231 }
232 }
233
234 Some(Ok(_notification)) = async {
236 match route_change_rx.as_mut() {
237 Some(rx) => Some(rx.recv().await),
238 None => std::future::pending().await,
239 }
240 } => {
241 tracing::debug!("Context route added, ensuring Fetch is enabled");
243 if let Err(e) = registry_for_routes.ensure_fetch_enabled().await {
244 tracing::warn!(error = %e, "Failed to enable Fetch after context route added");
245 }
246 }
247 }
248 }
249 });
250 }
251
252 pub async fn set_http_credentials(&self, credentials: HttpCredentials) {
254 self.auth_handler.set_credentials(credentials).await;
255
256 let mut auth_enabled = self.auth_enabled.write().await;
258 if !*auth_enabled {
259 *auth_enabled = true;
260 drop(auth_enabled);
262 let fetch_enabled = *self.fetch_enabled.read().await;
263 if fetch_enabled {
264 let _ = self.re_enable_fetch_with_auth().await;
265 }
266 }
267 }
268
269 pub async fn clear_http_credentials(&self) {
271 self.auth_handler.clear_credentials().await;
272 let mut auth_enabled = self.auth_enabled.write().await;
273 *auth_enabled = false;
274 }
275
276 pub async fn handle_auth_required(
278 &self,
279 event: &AuthRequiredEvent,
280 ) -> Result<(), NetworkError> {
281 self.auth_handler.handle_auth_challenge(event).await?;
282 Ok(())
283 }
284
285 pub async fn route<M, H, Fut>(&self, pattern: M, handler: H) -> Result<(), NetworkError>
287 where
288 M: Into<UrlPattern>,
289 H: Fn(Route) -> Fut + Send + Sync + 'static,
290 Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
291 {
292 let pattern = pattern.into();
293
294 self.ensure_fetch_enabled().await?;
296
297 let handler: Arc<
299 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
300 + Send
301 + Sync,
302 > = Arc::new(move |route| Box::pin(handler(route)));
303
304 let mut handlers = self.handlers.write().await;
306 handlers.push(RegisteredHandler {
307 pattern: Box::new(pattern),
308 handler,
309 });
310
311 Ok(())
312 }
313
314 pub async fn route_predicate<P, H, Fut>(
316 &self,
317 predicate: P,
318 handler: H,
319 ) -> Result<(), NetworkError>
320 where
321 P: Fn(&str) -> bool + Send + Sync + 'static,
322 H: Fn(Route) -> Fut + Send + Sync + 'static,
323 Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
324 {
325 self.ensure_fetch_enabled().await?;
327
328 struct PredicateMatcher<F>(F);
330 impl<F: Fn(&str) -> bool + Send + Sync> UrlMatcher for PredicateMatcher<F> {
331 fn matches(&self, url: &str) -> bool {
332 (self.0)(url)
333 }
334 }
335
336 let handler: Arc<
338 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
339 + Send
340 + Sync,
341 > = Arc::new(move |route| Box::pin(handler(route)));
342
343 let mut handlers = self.handlers.write().await;
345 handlers.push(RegisteredHandler {
346 pattern: Box::new(PredicateMatcher(predicate)),
347 handler,
348 });
349
350 Ok(())
351 }
352
353 pub async fn unroute(&self, pattern: &str) {
355 let mut handlers = self.handlers.write().await;
356
357 handlers.retain(|h| {
360 !h.pattern.matches(pattern)
363 });
364
365 if handlers.is_empty() {
367 drop(handlers);
368 let _ = self.disable_fetch_domain().await;
369 }
370 }
371
372 pub async fn unroute_all(&self) {
374 let mut handlers = self.handlers.write().await;
375 handlers.clear();
376 drop(handlers);
377 let _ = self.disable_fetch_domain().await;
378 }
379
380 pub async fn handle_request(&self, event: &RequestPausedEvent) -> Result<(), NetworkError> {
386 let url = &event.request.url;
387 let handlers = self.handlers.read().await;
388
389 let matching_handlers: Vec<_> = handlers
391 .iter()
392 .rev()
393 .filter(|h| h.pattern.matches(url))
394 .collect();
395
396 for handler in &matching_handlers {
398 let route =
399 create_route_from_event(event, self.connection.clone(), self.session_id.clone());
400 let route_check = route.clone();
401
402 (handler.handler)(route).await?;
404
405 if route_check.is_handled().await {
407 return Ok(());
408 }
409 tracing::debug!(
410 request_id = %event.request_id,
411 url = %url,
412 "Handler called fallback, trying next handler"
413 );
414 }
415
416 drop(handlers);
418
419 if let Some(ref context_routes) = self.context_routes {
421 let context_handlers = context_routes.find_all_handlers(url).await;
422
423 for handler in context_handlers {
424 let route = create_route_from_event(
425 event,
426 self.connection.clone(),
427 self.session_id.clone(),
428 );
429 let route_check = route.clone();
430
431 handler(route).await?;
432
433 if route_check.is_handled().await {
434 return Ok(());
435 }
436 tracing::debug!(
437 request_id = %event.request_id,
438 url = %url,
439 "Context handler called fallback, trying next handler"
440 );
441 }
442 }
443
444 continue_request(&self.connection, &self.session_id, &event.request_id).await
446 }
447
448 pub async fn ensure_fetch_enabled_public(&self) -> Result<(), NetworkError> {
453 self.ensure_fetch_enabled().await
454 }
455
456 async fn ensure_fetch_enabled(&self) -> Result<(), NetworkError> {
458 let mut enabled = self.fetch_enabled.write().await;
459 if *enabled {
460 return Ok(());
461 }
462
463 let auth_enabled = *self.auth_enabled.read().await;
464 enable_fetch(&self.connection, &self.session_id, auth_enabled).await?;
465 *enabled = true;
466 Ok(())
467 }
468
469 async fn re_enable_fetch_with_auth(&self) -> Result<(), NetworkError> {
471 disable_fetch(&self.connection, &self.session_id).await?;
473 enable_fetch(&self.connection, &self.session_id, true).await
474 }
475
476 async fn disable_fetch_domain(&self) -> Result<(), NetworkError> {
478 let mut enabled = self.fetch_enabled.write().await;
479 if !*enabled {
480 return Ok(());
481 }
482
483 disable_fetch(&self.connection, &self.session_id).await?;
484 *enabled = false;
485 Ok(())
486 }
487}
488
489#[cfg(test)]
490mod tests;