viewpoint_core/network/handler/
mod.rs1mod constructors;
4
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use viewpoint_cdp::CdpConnection;
11use viewpoint_cdp::protocol::fetch::{AuthRequiredEvent, RequestPausedEvent};
12
13use super::auth::{AuthHandler, HttpCredentials};
14use super::handler_fetch::{disable_fetch, enable_fetch};
15use super::handler_request::{continue_request, create_route_from_event};
16use super::route::Route;
17use super::types::{UrlMatcher, UrlPattern};
18use crate::error::NetworkError;
19
20struct RegisteredHandler {
22 pattern: Box<dyn UrlMatcher>,
24 handler: Arc<
26 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
27 + Send
28 + Sync,
29 >,
30}
31
32impl std::fmt::Debug for RegisteredHandler {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("RegisteredHandler")
35 .field("pattern", &"<pattern>")
36 .field("handler", &"<fn>")
37 .finish()
38 }
39}
40
41#[derive(Debug)]
43pub struct RouteHandlerRegistry {
44 handlers: RwLock<Vec<RegisteredHandler>>,
46 connection: Arc<CdpConnection>,
48 session_id: String,
50 fetch_enabled: RwLock<bool>,
52 auth_handler: AuthHandler,
54 auth_enabled: RwLock<bool>,
56 context_routes: Option<Arc<crate::context::routing::ContextRouteRegistry>>,
58}
59
60impl RouteHandlerRegistry {
61 pub async fn enable_fetch_for_context_routes(&self) -> Result<(), NetworkError> {
66 let auth_enabled = *self.auth_enabled.read().await;
68 if auth_enabled {
69 self.ensure_fetch_enabled().await?;
70 return Ok(());
71 }
72
73 if let Some(ref context_routes) = self.context_routes {
75 if context_routes.has_routes().await {
76 self.ensure_fetch_enabled().await?;
77 }
78 }
79 Ok(())
80 }
81
82 pub fn set_context_routes(
84 &mut self,
85 context_routes: Arc<crate::context::routing::ContextRouteRegistry>,
86 ) {
87 self.context_routes = Some(context_routes);
88 }
89
90 pub fn start_fetch_listener(self: &Arc<Self>) {
100 let mut events = self.connection.subscribe_events();
101 let session_id = self.session_id.clone();
102 let registry = Arc::clone(self);
103
104 let mut route_change_rx = self
106 .context_routes
107 .as_ref()
108 .map(|ctx| ctx.subscribe_route_changes());
109 let registry_for_routes = Arc::clone(self);
110
111 tokio::spawn(async move {
112 loop {
113 tokio::select! {
114 event_result = events.recv() => {
116 let Ok(event) = event_result else {
117 break;
118 };
119
120 if event.session_id.as_deref() != Some(&session_id) {
122 continue;
123 }
124
125 match event.method.as_str() {
126 "Fetch.requestPaused" => {
127 if let Some(params) = &event.params {
128 if let Ok(paused_event) = serde_json::from_value::<RequestPausedEvent>(params.clone()) {
129 tracing::debug!(
130 request_id = %paused_event.request_id,
131 url = %paused_event.request.url,
132 "Fetch.requestPaused received"
133 );
134 if let Err(e) = registry.handle_request(&paused_event).await {
135 tracing::warn!(
136 request_id = %paused_event.request_id,
137 error = %e,
138 "Failed to handle paused request"
139 );
140 }
141 }
142 }
143 }
144 "Fetch.authRequired" => {
145 if let Some(params) = &event.params {
146 if let Ok(auth_event) = serde_json::from_value::<AuthRequiredEvent>(params.clone()) {
147 tracing::debug!(
148 request_id = %auth_event.request_id,
149 origin = %auth_event.auth_challenge.origin,
150 scheme = %auth_event.auth_challenge.scheme,
151 "Fetch.authRequired received"
152 );
153 if let Err(e) = registry.handle_auth_required(&auth_event).await {
154 tracing::warn!(
155 request_id = %auth_event.request_id,
156 error = %e,
157 "Failed to handle auth required"
158 );
159 }
160 }
161 }
162 }
163 _ => {}
164 }
165 }
166
167 Some(Ok(_notification)) = async {
169 match route_change_rx.as_mut() {
170 Some(rx) => Some(rx.recv().await),
171 None => std::future::pending().await,
172 }
173 } => {
174 tracing::debug!("Context route added, ensuring Fetch is enabled");
176 if let Err(e) = registry_for_routes.ensure_fetch_enabled().await {
177 tracing::warn!(error = %e, "Failed to enable Fetch after context route added");
178 }
179 }
180 }
181 }
182 });
183 }
184
185 pub async fn set_http_credentials(&self, credentials: HttpCredentials) {
187 self.auth_handler.set_credentials(credentials).await;
188
189 let mut auth_enabled = self.auth_enabled.write().await;
191 if !*auth_enabled {
192 *auth_enabled = true;
193 drop(auth_enabled);
195 let fetch_enabled = *self.fetch_enabled.read().await;
196 if fetch_enabled {
197 let _ = self.re_enable_fetch_with_auth().await;
198 }
199 }
200 }
201
202 pub async fn clear_http_credentials(&self) {
204 self.auth_handler.clear_credentials().await;
205 let mut auth_enabled = self.auth_enabled.write().await;
206 *auth_enabled = false;
207 }
208
209 pub async fn handle_auth_required(
211 &self,
212 event: &AuthRequiredEvent,
213 ) -> Result<(), NetworkError> {
214 self.auth_handler.handle_auth_challenge(event).await?;
215 Ok(())
216 }
217
218 pub async fn route<M, H, Fut>(&self, pattern: M, handler: H) -> Result<(), NetworkError>
220 where
221 M: Into<UrlPattern>,
222 H: Fn(Route) -> Fut + Send + Sync + 'static,
223 Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
224 {
225 let pattern = pattern.into();
226
227 self.ensure_fetch_enabled().await?;
229
230 let handler: Arc<
232 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
233 + Send
234 + Sync,
235 > = Arc::new(move |route| Box::pin(handler(route)));
236
237 let mut handlers = self.handlers.write().await;
239 handlers.push(RegisteredHandler {
240 pattern: Box::new(pattern),
241 handler,
242 });
243
244 Ok(())
245 }
246
247 pub async fn route_predicate<P, H, Fut>(
249 &self,
250 predicate: P,
251 handler: H,
252 ) -> Result<(), NetworkError>
253 where
254 P: Fn(&str) -> bool + Send + Sync + 'static,
255 H: Fn(Route) -> Fut + Send + Sync + 'static,
256 Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
257 {
258 self.ensure_fetch_enabled().await?;
260
261 struct PredicateMatcher<F>(F);
263 impl<F: Fn(&str) -> bool + Send + Sync> UrlMatcher for PredicateMatcher<F> {
264 fn matches(&self, url: &str) -> bool {
265 (self.0)(url)
266 }
267 }
268
269 let handler: Arc<
271 dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
272 + Send
273 + Sync,
274 > = Arc::new(move |route| Box::pin(handler(route)));
275
276 let mut handlers = self.handlers.write().await;
278 handlers.push(RegisteredHandler {
279 pattern: Box::new(PredicateMatcher(predicate)),
280 handler,
281 });
282
283 Ok(())
284 }
285
286 pub async fn unroute(&self, pattern: &str) {
288 let mut handlers = self.handlers.write().await;
289
290 handlers.retain(|h| {
293 !h.pattern.matches(pattern)
296 });
297
298 if handlers.is_empty() {
300 drop(handlers);
301 let _ = self.disable_fetch_domain().await;
302 }
303 }
304
305 pub async fn unroute_all(&self) {
307 let mut handlers = self.handlers.write().await;
308 handlers.clear();
309 drop(handlers);
310 let _ = self.disable_fetch_domain().await;
311 }
312
313 pub async fn handle_request(&self, event: &RequestPausedEvent) -> Result<(), NetworkError> {
319 let url = &event.request.url;
320 let handlers = self.handlers.read().await;
321
322 let matching_handlers: Vec<_> = handlers
324 .iter()
325 .rev()
326 .filter(|h| h.pattern.matches(url))
327 .collect();
328
329 for handler in &matching_handlers {
331 let route =
332 create_route_from_event(event, self.connection.clone(), self.session_id.clone());
333 let route_check = route.clone();
334
335 (handler.handler)(route).await?;
337
338 if route_check.is_handled().await {
340 return Ok(());
341 }
342 tracing::debug!(
343 request_id = %event.request_id,
344 url = %url,
345 "Handler called fallback, trying next handler"
346 );
347 }
348
349 drop(handlers);
351
352 if let Some(ref context_routes) = self.context_routes {
354 let context_handlers = context_routes.find_all_handlers(url).await;
355
356 for handler in context_handlers {
357 let route = create_route_from_event(
358 event,
359 self.connection.clone(),
360 self.session_id.clone(),
361 );
362 let route_check = route.clone();
363
364 handler(route).await?;
365
366 if route_check.is_handled().await {
367 return Ok(());
368 }
369 tracing::debug!(
370 request_id = %event.request_id,
371 url = %url,
372 "Context handler called fallback, trying next handler"
373 );
374 }
375 }
376
377 continue_request(&self.connection, &self.session_id, &event.request_id).await
379 }
380
381 pub async fn ensure_fetch_enabled_public(&self) -> Result<(), NetworkError> {
386 self.ensure_fetch_enabled().await
387 }
388
389 async fn ensure_fetch_enabled(&self) -> Result<(), NetworkError> {
391 let mut enabled = self.fetch_enabled.write().await;
392 if *enabled {
393 return Ok(());
394 }
395
396 let auth_enabled = *self.auth_enabled.read().await;
397 enable_fetch(&self.connection, &self.session_id, auth_enabled).await?;
398 *enabled = true;
399 Ok(())
400 }
401
402 async fn re_enable_fetch_with_auth(&self) -> Result<(), NetworkError> {
404 disable_fetch(&self.connection, &self.session_id).await?;
406 enable_fetch(&self.connection, &self.session_id, true).await
407 }
408
409 async fn disable_fetch_domain(&self) -> Result<(), NetworkError> {
411 let mut enabled = self.fetch_enabled.write().await;
412 if !*enabled {
413 return Ok(());
414 }
415
416 disable_fetch(&self.connection, &self.session_id).await?;
417 *enabled = false;
418 Ok(())
419 }
420}
421
422#[cfg(test)]
423mod tests;