viewpoint_core/wait/waiter/
mod.rs

1//! Load state waiter implementation.
2
3use std::collections::{HashMap, HashSet};
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use viewpoint_cdp::protocol::network::{LoadingFailedEvent, LoadingFinishedEvent, RequestWillBeSentEvent, ResponseReceivedEvent};
9use viewpoint_cdp::CdpEvent;
10use tokio::sync::{broadcast, Mutex};
11use tokio::time::{sleep, timeout, Instant};
12use tracing::{debug, instrument, trace, warn};
13
14use super::DocumentLoadState;
15use crate::error::WaitError;
16
17/// Captured response data during navigation.
18#[derive(Debug, Clone, Default)]
19pub struct NavigationResponseData {
20    /// HTTP status code.
21    pub status: Option<u16>,
22    /// Response headers.
23    pub headers: HashMap<String, String>,
24    /// Final URL after redirects.
25    pub url: Option<String>,
26}
27
28/// Default timeout for wait operations.
29const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
30
31/// Network idle threshold (no requests for this duration).
32const NETWORK_IDLE_THRESHOLD: Duration = Duration::from_millis(500);
33
34/// Waits for page load states by listening to CDP events.
35#[derive(Debug)]
36pub struct LoadStateWaiter {
37    /// Current load state.
38    current_state: Arc<Mutex<DocumentLoadState>>,
39    /// Event receiver for CDP events.
40    event_rx: broadcast::Receiver<CdpEvent>,
41    /// Session ID to filter events for.
42    session_id: String,
43    /// Frame ID to wait for.
44    frame_id: String,
45    /// Pending network request count.
46    pending_requests: Arc<AtomicUsize>,
47    /// Set of pending request IDs.
48    pending_request_ids: Arc<Mutex<HashSet<String>>>,
49    /// Captured response data from navigation.
50    response_data: Arc<Mutex<NavigationResponseData>>,
51    /// The main document request ID (for tracking the navigation response).
52    main_request_id: Arc<Mutex<Option<String>>>,
53}
54
55impl LoadStateWaiter {
56    /// Create a new load state waiter.
57    pub fn new(
58        event_rx: broadcast::Receiver<CdpEvent>,
59        session_id: String,
60        frame_id: String,
61    ) -> Self {
62        debug!(session_id = %session_id, frame_id = %frame_id, "Created LoadStateWaiter");
63        Self {
64            current_state: Arc::new(Mutex::new(DocumentLoadState::Commit)),
65            event_rx,
66            session_id,
67            frame_id,
68            pending_requests: Arc::new(AtomicUsize::new(0)),
69            pending_request_ids: Arc::new(Mutex::new(HashSet::new())),
70            response_data: Arc::new(Mutex::new(NavigationResponseData::default())),
71            main_request_id: Arc::new(Mutex::new(None)),
72        }
73    }
74
75    /// Wait for the specified load state to be reached.
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if the wait times out or is cancelled.
80    pub async fn wait_for_load_state(
81        &mut self,
82        target_state: DocumentLoadState,
83    ) -> Result<(), WaitError> {
84        self.wait_for_load_state_with_timeout(target_state, DEFAULT_TIMEOUT)
85            .await
86    }
87
88    /// Wait for the specified load state with a custom timeout.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if the wait times out or is cancelled.
93    #[instrument(level = "debug", skip(self), fields(target_state = ?target_state, timeout_ms = timeout_duration.as_millis()))]
94    pub async fn wait_for_load_state_with_timeout(
95        &mut self,
96        target_state: DocumentLoadState,
97        timeout_duration: Duration,
98    ) -> Result<(), WaitError> {
99        // Check if already reached
100        {
101            let current = *self.current_state.lock().await;
102            if target_state.is_reached(current) {
103                debug!(current = ?current, "Target state already reached");
104                return Ok(());
105            }
106            trace!(current = ?current, "Starting wait for target state");
107        }
108
109        let result = timeout(timeout_duration, self.wait_for_state_impl(target_state)).await;
110
111        match result {
112            Ok(Ok(())) => {
113                debug!("Wait completed successfully");
114                Ok(())
115            }
116            Ok(Err(e)) => {
117                warn!(error = ?e, "Wait failed with error");
118                Err(e)
119            }
120            Err(_) => {
121                warn!(timeout_ms = timeout_duration.as_millis(), "Wait timed out");
122                Err(WaitError::Timeout(timeout_duration))
123            }
124        }
125    }
126
127    /// Internal implementation of waiting for a load state.
128    async fn wait_for_state_impl(&mut self, target_state: DocumentLoadState) -> Result<(), WaitError> {
129        let mut last_network_activity = Instant::now();
130
131        loop {
132            // Check current state
133            {
134                let current = *self.current_state.lock().await;
135                if target_state.is_reached(current) {
136                    // For NetworkIdle, we need additional checking
137                    if target_state == DocumentLoadState::NetworkIdle {
138                        let pending = self.pending_requests.load(Ordering::Relaxed);
139                        if pending == 0 && last_network_activity.elapsed() >= NETWORK_IDLE_THRESHOLD {
140                            return Ok(());
141                        }
142                    } else {
143                        return Ok(());
144                    }
145                }
146            }
147
148            // Wait for the next event
149            let event = match self.event_rx.recv().await {
150                Ok(event) => event,
151                Err(broadcast::error::RecvError::Closed) => {
152                    return Err(WaitError::PageClosed);
153                }
154                Err(broadcast::error::RecvError::Lagged(_)) => {
155                    // Missed some events, continue
156                    continue;
157                }
158            };
159
160            // Filter for our session
161            if event.session_id.as_deref() != Some(&self.session_id) {
162                continue;
163            }
164
165            // Process the event
166            match event.method.as_str() {
167                "Page.domContentEventFired" => {
168                    let mut current = self.current_state.lock().await;
169                    if *current < DocumentLoadState::DomContentLoaded {
170                        debug!(previous = ?*current, "State transition: DomContentLoaded");
171                        *current = DocumentLoadState::DomContentLoaded;
172                    }
173                }
174                "Page.loadEventFired" => {
175                    let mut current = self.current_state.lock().await;
176                    if *current < DocumentLoadState::Load {
177                        debug!(previous = ?*current, "State transition: Load");
178                        *current = DocumentLoadState::Load;
179                    }
180                }
181                "Network.requestWillBeSent" => {
182                    if let Some(params) = event.params {
183                        if let Ok(req) = serde_json::from_value::<RequestWillBeSentEvent>(params) {
184                            // Only track main frame requests
185                            if req.frame_id.as_deref() == Some(&self.frame_id) {
186                                let mut ids = self.pending_request_ids.lock().await;
187                                if ids.insert(req.request_id.clone()) {
188                                    let count = self.pending_requests.fetch_add(1, Ordering::Relaxed) + 1;
189                                    trace!(request_id = %req.request_id, pending_count = count, "Network request started");
190                                    last_network_activity = Instant::now();
191                                    
192                                    // Track the main document request (type "Document")
193                                    if req.resource_type.as_deref() == Some("Document") {
194                                        let mut main_req = self.main_request_id.lock().await;
195                                        if main_req.is_none() {
196                                            *main_req = Some(req.request_id.clone());
197                                            trace!(request_id = %req.request_id, "Tracking main document request");
198                                        }
199                                    }
200                                }
201                            }
202                        }
203                    }
204                }
205                "Network.responseReceived" => {
206                    if let Some(params) = event.params {
207                        if let Ok(resp) = serde_json::from_value::<ResponseReceivedEvent>(params) {
208                            // Check if this is the main document response
209                            let main_req = self.main_request_id.lock().await;
210                            if main_req.as_deref() == Some(&resp.request_id) {
211                                let mut response_data = self.response_data.lock().await;
212                                response_data.status = Some(resp.response.status as u16);
213                                response_data.url = Some(resp.response.url.clone());
214                                
215                                // Copy headers
216                                response_data.headers = resp.response.headers.clone();
217                                
218                                trace!(
219                                    status = response_data.status,
220                                    url = ?response_data.url,
221                                    header_count = response_data.headers.len(),
222                                    "Captured main document response"
223                                );
224                            }
225                        }
226                    }
227                }
228                "Network.loadingFinished" => {
229                    if let Some(params) = event.params {
230                        if let Ok(finished) = serde_json::from_value::<LoadingFinishedEvent>(params) {
231                            let mut ids = self.pending_request_ids.lock().await;
232                            if ids.remove(&finished.request_id) {
233                                let count = self.pending_requests.fetch_sub(1, Ordering::Relaxed) - 1;
234                                trace!(request_id = %finished.request_id, pending_count = count, "Network request finished");
235                                last_network_activity = Instant::now();
236                            }
237                        }
238                    }
239                }
240                "Network.loadingFailed" => {
241                    if let Some(params) = event.params {
242                        if let Ok(failed) = serde_json::from_value::<LoadingFailedEvent>(params) {
243                            let mut ids = self.pending_request_ids.lock().await;
244                            if ids.remove(&failed.request_id) {
245                                let count = self.pending_requests.fetch_sub(1, Ordering::Relaxed) - 1;
246                                trace!(request_id = %failed.request_id, pending_count = count, "Network request failed");
247                                last_network_activity = Instant::now();
248                            }
249                        }
250                    }
251                }
252                _ => {}
253            }
254
255            // For NetworkIdle, check if we've been idle long enough
256            if target_state == DocumentLoadState::NetworkIdle {
257                let pending = self.pending_requests.load(Ordering::Relaxed);
258                let current = *self.current_state.lock().await;
259                if pending == 0 && current >= DocumentLoadState::Load {
260                    // Wait for the idle threshold
261                    sleep(NETWORK_IDLE_THRESHOLD).await;
262                    // Check again after sleeping
263                    let pending_after = self.pending_requests.load(Ordering::Relaxed);
264                    if pending_after == 0 {
265                        return Ok(());
266                    }
267                }
268            }
269        }
270    }
271
272    /// Set the current load state (used when commit is detected).
273    pub async fn set_commit_received(&self) {
274        let mut current = self.current_state.lock().await;
275        if *current < DocumentLoadState::Commit {
276            debug!("State transition: Commit");
277            *current = DocumentLoadState::Commit;
278        }
279    }
280
281    /// Get the current load state.
282    pub async fn current_state(&self) -> DocumentLoadState {
283        *self.current_state.lock().await
284    }
285
286    /// Get the captured response data from navigation.
287    ///
288    /// This returns the status code, headers, and final URL captured during navigation.
289    pub async fn response_data(&self) -> NavigationResponseData {
290        self.response_data.lock().await.clone()
291    }
292}