viewpoint_core/network/handler/
mod.rs1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use tokio::sync::RwLock;
8use viewpoint_cdp::protocol::fetch::{AuthRequiredEvent, RequestPausedEvent};
9use viewpoint_cdp::CdpConnection;
10
11use super::auth::{AuthHandler, HttpCredentials};
12use super::handler_fetch::{disable_fetch, enable_fetch};
13use super::handler_request::{continue_request, create_route_from_event};
14use super::route::Route;
15use super::types::{UrlMatcher, UrlPattern};
16use crate::error::NetworkError;
17
18struct RegisteredHandler {
20 pattern: Box<dyn UrlMatcher>,
22 handler: Arc<
24 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
25 + Send
26 + Sync,
27 >,
28}
29
30impl std::fmt::Debug for RegisteredHandler {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("RegisteredHandler")
33 .field("pattern", &"<pattern>")
34 .field("handler", &"<fn>")
35 .finish()
36 }
37}
38
39#[derive(Debug)]
41pub struct RouteHandlerRegistry {
42 handlers: RwLock<Vec<RegisteredHandler>>,
44 connection: Arc<CdpConnection>,
46 session_id: String,
48 fetch_enabled: RwLock<bool>,
50 auth_handler: AuthHandler,
52 auth_enabled: RwLock<bool>,
54 context_routes: Option<Arc<crate::context::routing::ContextRouteRegistry>>,
56}
57
58impl RouteHandlerRegistry {
59 pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
61 let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
62 Self {
63 handlers: RwLock::new(Vec::new()),
64 connection,
65 session_id,
66 fetch_enabled: RwLock::new(false),
67 auth_handler,
68 auth_enabled: RwLock::new(false),
69 context_routes: None,
70 }
71 }
72
73 pub fn with_credentials(
75 connection: Arc<CdpConnection>,
76 session_id: String,
77 credentials: HttpCredentials,
78 ) -> Self {
79 let auth_handler = AuthHandler::with_credentials(
80 connection.clone(),
81 session_id.clone(),
82 credentials,
83 );
84 Self {
85 handlers: RwLock::new(Vec::new()),
86 connection,
87 session_id,
88 fetch_enabled: RwLock::new(false),
89 auth_handler,
90 auth_enabled: RwLock::new(true),
91 context_routes: None,
92 }
93 }
94
95 pub fn with_context_routes(
100 connection: Arc<CdpConnection>,
101 session_id: String,
102 context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
103 http_credentials: Option<HttpCredentials>,
104 ) -> Self {
105 let auth_handler = AuthHandler::new(connection.clone(), session_id.clone());
106
107 if let Some(ref creds) = http_credentials {
109 tracing::debug!(
112 username = %creds.username,
113 has_origin = creds.origin.is_some(),
114 "Setting HTTP credentials on auth handler"
115 );
116 auth_handler.set_credentials_sync(creds.clone());
117 }
118
119 Self {
120 handlers: RwLock::new(Vec::new()),
121 connection,
122 session_id,
123 fetch_enabled: RwLock::new(false),
124 auth_handler,
125 auth_enabled: RwLock::new(http_credentials.is_some()),
127 context_routes: Some(context_routes),
128 }
129 }
130
131 pub async fn enable_fetch_for_context_routes(&self) -> Result<(), NetworkError> {
136 let auth_enabled = *self.auth_enabled.read().await;
138 if auth_enabled {
139 self.ensure_fetch_enabled().await?;
140 return Ok(());
141 }
142
143 if let Some(ref context_routes) = self.context_routes {
145 if context_routes.has_routes().await {
146 self.ensure_fetch_enabled().await?;
147 }
148 }
149 Ok(())
150 }
151
152 pub fn set_context_routes(&mut self, context_routes: Arc<crate::context::routing::ContextRouteRegistry>) {
154 self.context_routes = Some(context_routes);
155 }
156
157 pub fn start_fetch_listener(self: &Arc<Self>) {
167 let mut events = self.connection.subscribe_events();
168 let session_id = self.session_id.clone();
169 let registry = Arc::clone(self);
170
171 let mut route_change_rx = self.context_routes.as_ref().map(|ctx| ctx.subscribe_route_changes());
173 let registry_for_routes = Arc::clone(self);
174
175 tokio::spawn(async move {
176 loop {
177 tokio::select! {
178 event_result = events.recv() => {
180 let Ok(event) = event_result else {
181 break;
182 };
183
184 if event.session_id.as_deref() != Some(&session_id) {
186 continue;
187 }
188
189 match event.method.as_str() {
190 "Fetch.requestPaused" => {
191 if let Some(params) = &event.params {
192 if let Ok(paused_event) = serde_json::from_value::<RequestPausedEvent>(params.clone()) {
193 tracing::debug!(
194 request_id = %paused_event.request_id,
195 url = %paused_event.request.url,
196 "Fetch.requestPaused received"
197 );
198 if let Err(e) = registry.handle_request(&paused_event).await {
199 tracing::warn!(
200 request_id = %paused_event.request_id,
201 error = %e,
202 "Failed to handle paused request"
203 );
204 }
205 }
206 }
207 }
208 "Fetch.authRequired" => {
209 if let Some(params) = &event.params {
210 if let Ok(auth_event) = serde_json::from_value::<AuthRequiredEvent>(params.clone()) {
211 tracing::debug!(
212 request_id = %auth_event.request_id,
213 origin = %auth_event.auth_challenge.origin,
214 scheme = %auth_event.auth_challenge.scheme,
215 "Fetch.authRequired received"
216 );
217 if let Err(e) = registry.handle_auth_required(&auth_event).await {
218 tracing::warn!(
219 request_id = %auth_event.request_id,
220 error = %e,
221 "Failed to handle auth required"
222 );
223 }
224 }
225 }
226 }
227 _ => {}
228 }
229 }
230
231 Some(Ok(_notification)) = async {
233 match route_change_rx.as_mut() {
234 Some(rx) => Some(rx.recv().await),
235 None => std::future::pending().await,
236 }
237 } => {
238 tracing::debug!("Context route added, ensuring Fetch is enabled");
240 if let Err(e) = registry_for_routes.ensure_fetch_enabled().await {
241 tracing::warn!(error = %e, "Failed to enable Fetch after context route added");
242 }
243 }
244 }
245 }
246 });
247 }
248
249 pub async fn set_http_credentials(&self, credentials: HttpCredentials) {
251 self.auth_handler.set_credentials(credentials).await;
252
253 let mut auth_enabled = self.auth_enabled.write().await;
255 if !*auth_enabled {
256 *auth_enabled = true;
257 drop(auth_enabled);
259 let fetch_enabled = *self.fetch_enabled.read().await;
260 if fetch_enabled {
261 let _ = self.re_enable_fetch_with_auth().await;
262 }
263 }
264 }
265
266 pub async fn clear_http_credentials(&self) {
268 self.auth_handler.clear_credentials().await;
269 let mut auth_enabled = self.auth_enabled.write().await;
270 *auth_enabled = false;
271 }
272
273 pub async fn handle_auth_required(&self, event: &AuthRequiredEvent) -> Result<(), NetworkError> {
275 self.auth_handler.handle_auth_challenge(event).await?;
276 Ok(())
277 }
278
279 pub async fn route<M, H, Fut>(&self, pattern: M, handler: H) -> Result<(), NetworkError>
281 where
282 M: Into<UrlPattern>,
283 H: Fn(Route) -> Fut + Send + Sync + 'static,
284 Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
285 {
286 let pattern = pattern.into();
287
288 self.ensure_fetch_enabled().await?;
290
291 let handler: Arc<
293 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
294 + Send
295 + Sync,
296 > = Arc::new(move |route| Box::pin(handler(route)));
297
298 let mut handlers = self.handlers.write().await;
300 handlers.push(RegisteredHandler {
301 pattern: Box::new(pattern),
302 handler,
303 });
304
305 Ok(())
306 }
307
308 pub async fn route_predicate<P, H, Fut>(&self, predicate: P, handler: H) -> Result<(), NetworkError>
310 where
311 P: Fn(&str) -> bool + Send + Sync + 'static,
312 H: Fn(Route) -> Fut + Send + Sync + 'static,
313 Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
314 {
315 self.ensure_fetch_enabled().await?;
317
318 struct PredicateMatcher<F>(F);
320 impl<F: Fn(&str) -> bool + Send + Sync> UrlMatcher for PredicateMatcher<F> {
321 fn matches(&self, url: &str) -> bool {
322 (self.0)(url)
323 }
324 }
325
326 let handler: Arc<
328 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
329 + Send
330 + Sync,
331 > = Arc::new(move |route| Box::pin(handler(route)));
332
333 let mut handlers = self.handlers.write().await;
335 handlers.push(RegisteredHandler {
336 pattern: Box::new(PredicateMatcher(predicate)),
337 handler,
338 });
339
340 Ok(())
341 }
342
343 pub async fn unroute(&self, pattern: &str) {
345 let mut handlers = self.handlers.write().await;
346
347 handlers.retain(|h| {
350 !h.pattern.matches(pattern)
353 });
354
355 if handlers.is_empty() {
357 drop(handlers);
358 let _ = self.disable_fetch_domain().await;
359 }
360 }
361
362 pub async fn unroute_all(&self) {
364 let mut handlers = self.handlers.write().await;
365 handlers.clear();
366 drop(handlers);
367 let _ = self.disable_fetch_domain().await;
368 }
369
370 pub async fn handle_request(&self, event: &RequestPausedEvent) -> Result<(), NetworkError> {
376 let url = &event.request.url;
377 let handlers = self.handlers.read().await;
378
379 let matching_handlers: Vec<_> = handlers
381 .iter()
382 .rev()
383 .filter(|h| h.pattern.matches(url))
384 .collect();
385
386 for handler in &matching_handlers {
388 let route = create_route_from_event(event, self.connection.clone(), self.session_id.clone());
389 let route_check = route.clone();
390
391 (handler.handler)(route).await?;
393
394 if route_check.is_handled().await {
396 return Ok(());
397 }
398 tracing::debug!(
399 request_id = %event.request_id,
400 url = %url,
401 "Handler called fallback, trying next handler"
402 );
403 }
404
405 drop(handlers);
407
408 if let Some(ref context_routes) = self.context_routes {
410 let context_handlers = context_routes.find_all_handlers(url).await;
411
412 for handler in context_handlers {
413 let route = create_route_from_event(event, self.connection.clone(), self.session_id.clone());
414 let route_check = route.clone();
415
416 handler(route).await?;
417
418 if route_check.is_handled().await {
419 return Ok(());
420 }
421 tracing::debug!(
422 request_id = %event.request_id,
423 url = %url,
424 "Context handler called fallback, trying next handler"
425 );
426 }
427 }
428
429 continue_request(&self.connection, &self.session_id, &event.request_id).await
431 }
432
433 pub async fn ensure_fetch_enabled_public(&self) -> Result<(), NetworkError> {
438 self.ensure_fetch_enabled().await
439 }
440
441 async fn ensure_fetch_enabled(&self) -> Result<(), NetworkError> {
443 let mut enabled = self.fetch_enabled.write().await;
444 if *enabled {
445 return Ok(());
446 }
447
448 let auth_enabled = *self.auth_enabled.read().await;
449 enable_fetch(&self.connection, &self.session_id, auth_enabled).await?;
450 *enabled = true;
451 Ok(())
452 }
453
454 async fn re_enable_fetch_with_auth(&self) -> Result<(), NetworkError> {
456 disable_fetch(&self.connection, &self.session_id).await?;
458 enable_fetch(&self.connection, &self.session_id, true).await
459 }
460
461 async fn disable_fetch_domain(&self) -> Result<(), NetworkError> {
463 let mut enabled = self.fetch_enabled.write().await;
464 if !*enabled {
465 return Ok(());
466 }
467
468 disable_fetch(&self.connection, &self.session_id).await?;
469 *enabled = false;
470 Ok(())
471 }
472}
473
474#[cfg(test)]
475mod tests;