viewpoint_core/network/events/
mod.rs

1//! Network event handling.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::sync::broadcast;
8use viewpoint_cdp::CdpConnection;
9use viewpoint_cdp::protocol::network::{
10    LoadingFailedEvent, LoadingFinishedEvent, RequestWillBeSentEvent, ResponseReceivedEvent,
11};
12
13use super::request::Request;
14use super::response::Response;
15use super::types::{ResourceType, UrlMatcher};
16use crate::error::NetworkError;
17
18/// Event emitted when a request is made.
19#[derive(Debug, Clone)]
20pub struct RequestEvent {
21    /// The request.
22    pub request: Request,
23}
24
25/// Event emitted when a response is received.
26#[derive(Debug, Clone)]
27pub struct ResponseEvent {
28    /// The response.
29    pub response: Response,
30}
31
32/// Event emitted when a request finishes.
33#[derive(Debug, Clone)]
34pub struct RequestFinishedEvent {
35    /// The request that finished.
36    pub request: Request,
37}
38
39/// Event emitted when a request fails.
40#[derive(Debug, Clone)]
41pub struct RequestFailedEvent {
42    /// The failed request.
43    pub request: Request,
44    /// The error message.
45    pub error: String,
46}
47
48/// Network event types.
49#[derive(Debug, Clone)]
50pub enum NetworkEvent {
51    /// Request made.
52    Request(RequestEvent),
53    /// Response received.
54    Response(ResponseEvent),
55    /// Request finished.
56    RequestFinished(RequestFinishedEvent),
57    /// Request failed.
58    RequestFailed(RequestFailedEvent),
59}
60
61/// Network event listener for a page.
62#[derive(Debug)]
63pub struct NetworkEventListener {
64    /// CDP connection.
65    connection: Arc<CdpConnection>,
66    /// Session ID.
67    session_id: String,
68    /// Event sender.
69    event_tx: broadcast::Sender<NetworkEvent>,
70}
71
72impl NetworkEventListener {
73    /// Create a new network event listener.
74    pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
75        let (event_tx, _) = broadcast::channel(256);
76        Self {
77            connection,
78            session_id,
79            event_tx,
80        }
81    }
82
83    /// Subscribe to network events.
84    pub fn subscribe(&self) -> broadcast::Receiver<NetworkEvent> {
85        self.event_tx.subscribe()
86    }
87
88    /// Start listening for network events.
89    ///
90    /// This spawns a background task that processes CDP events.
91    pub fn start(&self) {
92        let mut cdp_events = self.connection.subscribe_events();
93        let session_id = self.session_id.clone();
94        let event_tx = self.event_tx.clone();
95        let connection = self.connection.clone();
96
97        tokio::spawn(async move {
98            // Track pending requests for building responses
99            let mut pending_requests: HashMap<String, Request> = HashMap::new();
100
101            while let Ok(event) = cdp_events.recv().await {
102                // Filter events for this session
103                if event.session_id.as_deref() != Some(&session_id) {
104                    continue;
105                }
106
107                // Process network events
108                match event.method.as_str() {
109                    "Network.requestWillBeSent" => {
110                        if let Some(params) = &event.params {
111                            if let Ok(req_event) =
112                                serde_json::from_value::<RequestWillBeSentEvent>(params.clone())
113                            {
114                                // Check if this is a redirect (redirect_response is present)
115                                let previous_request = if req_event.redirect_response.is_some() {
116                                    // This is a redirect - get the previous request with the same ID
117                                    pending_requests.remove(&req_event.request_id)
118                                } else {
119                                    None
120                                };
121
122                                let request =
123                                    parse_request_will_be_sent(&req_event, previous_request);
124                                pending_requests
125                                    .insert(req_event.request_id.clone(), request.clone());
126                                let _ =
127                                    event_tx.send(NetworkEvent::Request(RequestEvent { request }));
128                            }
129                        }
130                    }
131                    "Network.responseReceived" => {
132                        if let Some(params) = &event.params {
133                            if let Ok(resp_event) =
134                                serde_json::from_value::<ResponseReceivedEvent>(params.clone())
135                            {
136                                // Get the associated request
137                                if let Some(request) =
138                                    pending_requests.get(&resp_event.request_id).cloned()
139                                {
140                                    let response = Response::new(
141                                        resp_event.response,
142                                        request,
143                                        connection.clone(),
144                                        session_id.clone(),
145                                        resp_event.request_id.clone(),
146                                    );
147                                    let _ = event_tx
148                                        .send(NetworkEvent::Response(ResponseEvent { response }));
149                                }
150                            }
151                        }
152                    }
153                    "Network.loadingFinished" => {
154                        if let Some(params) = &event.params {
155                            if let Ok(finished_event) =
156                                serde_json::from_value::<LoadingFinishedEvent>(params.clone())
157                            {
158                                if let Some(request) =
159                                    pending_requests.remove(&finished_event.request_id)
160                                {
161                                    let _ = event_tx.send(NetworkEvent::RequestFinished(
162                                        RequestFinishedEvent { request },
163                                    ));
164                                }
165                            }
166                        }
167                    }
168                    "Network.loadingFailed" => {
169                        if let Some(params) = &event.params {
170                            if let Ok(failed_event) =
171                                serde_json::from_value::<LoadingFailedEvent>(params.clone())
172                            {
173                                if let Some(request) =
174                                    pending_requests.remove(&failed_event.request_id)
175                                {
176                                    let _ = event_tx.send(NetworkEvent::RequestFailed(
177                                        RequestFailedEvent {
178                                            request,
179                                            error: failed_event.error_text,
180                                        },
181                                    ));
182                                }
183                            }
184                        }
185                    }
186                    _ => {}
187                }
188            }
189        });
190    }
191}
192
193/// Parse a `RequestWillBeSentEvent` into a Request.
194/// Parse a `RequestWillBeSentEvent` into a Request.
195///
196/// If `previous_request` is provided, it will be set as the `redirected_from` source.
197fn parse_request_will_be_sent(
198    event: &RequestWillBeSentEvent,
199    previous_request: Option<Request>,
200) -> Request {
201    let resource_type = event
202        .resource_type
203        .as_ref()
204        .map_or(ResourceType::Other, |t| parse_resource_type(t));
205
206    Request {
207        url: event.request.url.clone(),
208        method: event.request.method.clone(),
209        headers: event.request.headers.clone(),
210        post_data: event.request.post_data.clone(),
211        resource_type,
212        frame_id: event.frame_id.clone().unwrap_or_default(),
213        is_navigation: event.initiator.initiator_type == "navigation",
214        connection: None,
215        session_id: None,
216        request_id: Some(event.request_id.clone()),
217        redirected_from: previous_request.map(Box::new),
218        redirected_to: None,
219        timing: None,
220        failure_text: None,
221    }
222}
223
224/// Parse a resource type string into `ResourceType` enum.
225fn parse_resource_type(s: &str) -> ResourceType {
226    match s.to_lowercase().as_str() {
227        "document" => ResourceType::Document,
228        "stylesheet" => ResourceType::Stylesheet,
229        "image" => ResourceType::Image,
230        "media" => ResourceType::Media,
231        "font" => ResourceType::Font,
232        "script" => ResourceType::Script,
233        "texttrack" => ResourceType::TextTrack,
234        "xhr" => ResourceType::Xhr,
235        "fetch" => ResourceType::Fetch,
236        "eventsource" => ResourceType::EventSource,
237        "websocket" => ResourceType::WebSocket,
238        "manifest" => ResourceType::Manifest,
239        "ping" => ResourceType::Ping,
240        "other" | _ => ResourceType::Other,
241    }
242}
243
244/// Builder for waiting for a request.
245#[derive(Debug)]
246pub struct WaitForRequestBuilder<'a, M> {
247    /// Connection.
248    connection: &'a Arc<CdpConnection>,
249    /// Session ID.
250    session_id: &'a str,
251    /// Pattern to match.
252    pattern: M,
253    /// Timeout duration.
254    timeout: Duration,
255}
256
257impl<'a, M: UrlMatcher + Clone + 'static> WaitForRequestBuilder<'a, M> {
258    /// Create a new wait for request builder.
259    pub fn new(connection: &'a Arc<CdpConnection>, session_id: &'a str, pattern: M) -> Self {
260        Self {
261            connection,
262            session_id,
263            pattern,
264            timeout: Duration::from_secs(30),
265        }
266    }
267
268    /// Set the timeout duration.
269    #[must_use]
270    pub fn timeout(mut self, timeout: Duration) -> Self {
271        self.timeout = timeout;
272        self
273    }
274
275    /// Wait for a matching request.
276    ///
277    /// # Errors
278    ///
279    /// Returns an error if the wait times out before a matching request is received,
280    /// or if the event stream is aborted.
281    pub async fn wait(self) -> Result<Request, NetworkError> {
282        let mut events = self.connection.subscribe_events();
283        let session_id = self.session_id.to_string();
284        let pattern = self.pattern;
285        let timeout = self.timeout;
286
287        tokio::time::timeout(timeout, async move {
288            while let Ok(event) = events.recv().await {
289                // Filter for this session
290                if event.session_id.as_deref() != Some(&session_id) {
291                    continue;
292                }
293
294                if event.method == "Network.requestWillBeSent" {
295                    if let Some(params) = &event.params {
296                        if let Ok(req_event) =
297                            serde_json::from_value::<RequestWillBeSentEvent>(params.clone())
298                        {
299                            if pattern.matches(&req_event.request.url) {
300                                return Ok(parse_request_will_be_sent(&req_event, None));
301                            }
302                        }
303                    }
304                }
305            }
306            Err(NetworkError::Aborted)
307        })
308        .await
309        .map_err(|_| NetworkError::Timeout(timeout))?
310    }
311}
312
313/// Builder for waiting for a response.
314#[derive(Debug)]
315pub struct WaitForResponseBuilder<'a, M> {
316    /// Connection.
317    connection: &'a Arc<CdpConnection>,
318    /// Session ID.
319    session_id: &'a str,
320    /// Pattern to match.
321    pattern: M,
322    /// Timeout duration.
323    timeout: Duration,
324}
325
326impl<'a, M: UrlMatcher + Clone + 'static> WaitForResponseBuilder<'a, M> {
327    /// Create a new wait for response builder.
328    pub fn new(connection: &'a Arc<CdpConnection>, session_id: &'a str, pattern: M) -> Self {
329        Self {
330            connection,
331            session_id,
332            pattern,
333            timeout: Duration::from_secs(30),
334        }
335    }
336
337    /// Set the timeout duration.
338    #[must_use]
339    pub fn timeout(mut self, timeout: Duration) -> Self {
340        self.timeout = timeout;
341        self
342    }
343
344    /// Wait for a matching response.
345    ///
346    /// # Errors
347    ///
348    /// Returns an error if the wait times out before a matching response is received,
349    /// or if the event stream is aborted.
350    pub async fn wait(self) -> Result<Response, NetworkError> {
351        let mut events = self.connection.subscribe_events();
352        let session_id = self.session_id.to_string();
353        let pattern = self.pattern;
354        let timeout = self.timeout;
355        let connection = self.connection.clone();
356
357        tokio::time::timeout(timeout, async move {
358            let mut pending_requests: HashMap<String, Request> = HashMap::new();
359
360            while let Ok(event) = events.recv().await {
361                // Filter for this session
362                if event.session_id.as_deref() != Some(&session_id) {
363                    continue;
364                }
365
366                match event.method.as_str() {
367                    "Network.requestWillBeSent" => {
368                        // Track requests so we can associate them with responses
369                        if let Some(params) = &event.params {
370                            if let Ok(req_event) =
371                                serde_json::from_value::<RequestWillBeSentEvent>(params.clone())
372                            {
373                                let request = parse_request_will_be_sent(&req_event, None);
374                                pending_requests.insert(req_event.request_id.clone(), request);
375                            }
376                        }
377                    }
378                    "Network.responseReceived" => {
379                        if let Some(params) = &event.params {
380                            if let Ok(resp_event) =
381                                serde_json::from_value::<ResponseReceivedEvent>(params.clone())
382                            {
383                                if pattern.matches(&resp_event.response.url) {
384                                    // Get the associated request or create a minimal one
385                                    let request = pending_requests
386                                        .get(&resp_event.request_id)
387                                        .cloned()
388                                        .unwrap_or_else(|| Request {
389                                            url: resp_event.response.url.clone(),
390                                            method: "GET".to_string(),
391                                            headers: HashMap::new(),
392                                            post_data: None,
393                                            resource_type: ResourceType::Other,
394                                            frame_id: resp_event
395                                                .frame_id
396                                                .clone()
397                                                .unwrap_or_default(),
398                                            is_navigation: false,
399                                            connection: None,
400                                            session_id: None,
401                                            request_id: Some(resp_event.request_id.clone()),
402                                            redirected_from: None,
403                                            redirected_to: None,
404                                            timing: None,
405                                            failure_text: None,
406                                        });
407
408                                    return Ok(Response::new(
409                                        resp_event.response,
410                                        request,
411                                        connection.clone(),
412                                        session_id.clone(),
413                                        resp_event.request_id.clone(),
414                                    ));
415                                }
416                            }
417                        }
418                    }
419                    _ => {}
420                }
421            }
422            Err(NetworkError::Aborted)
423        })
424        .await
425        .map_err(|_| NetworkError::Timeout(timeout))?
426    }
427}
428
429#[cfg(test)]
430mod tests;