viewpoint_core/wait/waiter/
mod.rs

1//! Load state waiter implementation.
2
3use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{Mutex, broadcast};
9use tokio::time::{Instant, sleep, timeout};
10use tracing::{debug, instrument, trace, warn};
11use viewpoint_cdp::CdpEvent;
12use viewpoint_cdp::protocol::network::{
13    LoadingFailedEvent, LoadingFinishedEvent, RequestWillBeSentEvent, ResponseReceivedEvent,
14};
15
16use super::DocumentLoadState;
17use crate::error::WaitError;
18
19/// Captured response data during navigation.
20#[derive(Debug, Clone, Default)]
21pub struct NavigationResponseData {
22    /// HTTP status code.
23    pub status: Option<u16>,
24    /// Response headers.
25    pub headers: HashMap<String, String>,
26    /// Final URL after redirects.
27    pub url: Option<String>,
28}
29
30/// Default timeout for wait operations.
31const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
32
33/// Network idle threshold (no requests for this duration).
34const NETWORK_IDLE_THRESHOLD: Duration = Duration::from_millis(500);
35
36/// Waits for page load states by listening to CDP events.
37#[derive(Debug)]
38pub struct LoadStateWaiter {
39    /// Current load state.
40    current_state: Arc<Mutex<DocumentLoadState>>,
41    /// Event receiver for CDP events.
42    event_rx: broadcast::Receiver<CdpEvent>,
43    /// Session ID to filter events for.
44    session_id: String,
45    /// Frame ID to wait for.
46    frame_id: String,
47    /// Pending network request count.
48    pending_requests: Arc<AtomicUsize>,
49    /// Set of pending request IDs.
50    pending_request_ids: Arc<Mutex<HashSet<String>>>,
51    /// Captured response data from navigation.
52    response_data: Arc<Mutex<NavigationResponseData>>,
53    /// The main document request ID (for tracking the navigation response).
54    main_request_id: Arc<Mutex<Option<String>>>,
55}
56
57impl LoadStateWaiter {
58    /// Create a new load state waiter.
59    pub fn new(
60        event_rx: broadcast::Receiver<CdpEvent>,
61        session_id: String,
62        frame_id: String,
63    ) -> Self {
64        debug!(session_id = %session_id, frame_id = %frame_id, "Created LoadStateWaiter");
65        Self {
66            current_state: Arc::new(Mutex::new(DocumentLoadState::Commit)),
67            event_rx,
68            session_id,
69            frame_id,
70            pending_requests: Arc::new(AtomicUsize::new(0)),
71            pending_request_ids: Arc::new(Mutex::new(HashSet::new())),
72            response_data: Arc::new(Mutex::new(NavigationResponseData::default())),
73            main_request_id: Arc::new(Mutex::new(None)),
74        }
75    }
76
77    /// Wait for the specified load state to be reached.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the wait times out or is cancelled.
82    pub async fn wait_for_load_state(
83        &mut self,
84        target_state: DocumentLoadState,
85    ) -> Result<(), WaitError> {
86        self.wait_for_load_state_with_timeout(target_state, DEFAULT_TIMEOUT)
87            .await
88    }
89
90    /// Wait for the specified load state with a custom timeout.
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if the wait times out or is cancelled.
95    #[instrument(level = "debug", skip(self), fields(target_state = ?target_state, timeout_ms = timeout_duration.as_millis()))]
96    pub async fn wait_for_load_state_with_timeout(
97        &mut self,
98        target_state: DocumentLoadState,
99        timeout_duration: Duration,
100    ) -> Result<(), WaitError> {
101        // Check if already reached
102        {
103            let current = *self.current_state.lock().await;
104            if target_state.is_reached(current) {
105                debug!(current = ?current, "Target state already reached");
106                return Ok(());
107            }
108            trace!(current = ?current, "Starting wait for target state");
109        }
110
111        let result = timeout(timeout_duration, self.wait_for_state_impl(target_state)).await;
112
113        match result {
114            Ok(Ok(())) => {
115                debug!("Wait completed successfully");
116                Ok(())
117            }
118            Ok(Err(e)) => {
119                warn!(error = ?e, "Wait failed with error");
120                Err(e)
121            }
122            Err(_) => {
123                warn!(timeout_ms = timeout_duration.as_millis(), "Wait timed out");
124                Err(WaitError::Timeout(timeout_duration))
125            }
126        }
127    }
128
129    /// Internal implementation of waiting for a load state.
130    async fn wait_for_state_impl(
131        &mut self,
132        target_state: DocumentLoadState,
133    ) -> Result<(), WaitError> {
134        let mut last_network_activity = Instant::now();
135
136        loop {
137            // Check current state
138            {
139                let current = *self.current_state.lock().await;
140                if target_state.is_reached(current) {
141                    // For NetworkIdle, we need additional checking
142                    if target_state == DocumentLoadState::NetworkIdle {
143                        let pending = self.pending_requests.load(Ordering::Relaxed);
144                        if pending == 0 && last_network_activity.elapsed() >= NETWORK_IDLE_THRESHOLD
145                        {
146                            return Ok(());
147                        }
148                    } else {
149                        return Ok(());
150                    }
151                }
152            }
153
154            // Wait for the next event
155            let event = match self.event_rx.recv().await {
156                Ok(event) => event,
157                Err(broadcast::error::RecvError::Closed) => {
158                    return Err(WaitError::PageClosed);
159                }
160                Err(broadcast::error::RecvError::Lagged(_)) => {
161                    // Missed some events, continue
162                    continue;
163                }
164            };
165
166            // Filter for our session
167            if event.session_id.as_deref() != Some(&self.session_id) {
168                continue;
169            }
170
171            // Process the event
172            match event.method.as_str() {
173                "Page.domContentEventFired" => {
174                    let mut current = self.current_state.lock().await;
175                    if *current < DocumentLoadState::DomContentLoaded {
176                        debug!(previous = ?*current, "State transition: DomContentLoaded");
177                        *current = DocumentLoadState::DomContentLoaded;
178                    }
179                }
180                "Page.loadEventFired" => {
181                    let mut current = self.current_state.lock().await;
182                    if *current < DocumentLoadState::Load {
183                        debug!(previous = ?*current, "State transition: Load");
184                        *current = DocumentLoadState::Load;
185                    }
186                }
187                "Network.requestWillBeSent" => {
188                    if let Some(params) = event.params {
189                        if let Ok(req) = serde_json::from_value::<RequestWillBeSentEvent>(params) {
190                            // Only track main frame requests
191                            if req.frame_id.as_deref() == Some(&self.frame_id) {
192                                let mut ids = self.pending_request_ids.lock().await;
193                                if ids.insert(req.request_id.clone()) {
194                                    let count =
195                                        self.pending_requests.fetch_add(1, Ordering::Relaxed) + 1;
196                                    trace!(request_id = %req.request_id, pending_count = count, "Network request started");
197                                    last_network_activity = Instant::now();
198
199                                    // Track the main document request (type "Document")
200                                    if req.resource_type.as_deref() == Some("Document") {
201                                        let mut main_req = self.main_request_id.lock().await;
202                                        if main_req.is_none() {
203                                            *main_req = Some(req.request_id.clone());
204                                            trace!(request_id = %req.request_id, "Tracking main document request");
205                                        }
206                                    }
207                                }
208                            }
209                        }
210                    }
211                }
212                "Network.responseReceived" => {
213                    if let Some(params) = event.params {
214                        if let Ok(resp) = serde_json::from_value::<ResponseReceivedEvent>(params) {
215                            // Check if this is the main document response
216                            let main_req = self.main_request_id.lock().await;
217                            if main_req.as_deref() == Some(&resp.request_id) {
218                                let mut response_data = self.response_data.lock().await;
219                                response_data.status = Some(resp.response.status as u16);
220                                response_data.url = Some(resp.response.url.clone());
221
222                                // Copy headers
223                                response_data.headers = resp.response.headers.clone();
224
225                                trace!(
226                                    status = response_data.status,
227                                    url = ?response_data.url,
228                                    header_count = response_data.headers.len(),
229                                    "Captured main document response"
230                                );
231                            }
232                        }
233                    }
234                }
235                "Network.loadingFinished" => {
236                    if let Some(params) = event.params {
237                        if let Ok(finished) = serde_json::from_value::<LoadingFinishedEvent>(params)
238                        {
239                            let mut ids = self.pending_request_ids.lock().await;
240                            if ids.remove(&finished.request_id) {
241                                let count =
242                                    self.pending_requests.fetch_sub(1, Ordering::Relaxed) - 1;
243                                trace!(request_id = %finished.request_id, pending_count = count, "Network request finished");
244                                last_network_activity = Instant::now();
245                            }
246                        }
247                    }
248                }
249                "Network.loadingFailed" => {
250                    if let Some(params) = event.params {
251                        if let Ok(failed) = serde_json::from_value::<LoadingFailedEvent>(params) {
252                            let mut ids = self.pending_request_ids.lock().await;
253                            if ids.remove(&failed.request_id) {
254                                let count =
255                                    self.pending_requests.fetch_sub(1, Ordering::Relaxed) - 1;
256                                trace!(request_id = %failed.request_id, pending_count = count, "Network request failed");
257                                last_network_activity = Instant::now();
258                            }
259                        }
260                    }
261                }
262                _ => {}
263            }
264
265            // For NetworkIdle, check if we've been idle long enough
266            if target_state == DocumentLoadState::NetworkIdle {
267                let pending = self.pending_requests.load(Ordering::Relaxed);
268                let current = *self.current_state.lock().await;
269                if pending == 0 && current >= DocumentLoadState::Load {
270                    // Wait for the idle threshold
271                    sleep(NETWORK_IDLE_THRESHOLD).await;
272                    // Check again after sleeping
273                    let pending_after = self.pending_requests.load(Ordering::Relaxed);
274                    if pending_after == 0 {
275                        return Ok(());
276                    }
277                }
278            }
279        }
280    }
281
282    /// Set the current load state (used when commit is detected).
283    pub async fn set_commit_received(&self) {
284        let mut current = self.current_state.lock().await;
285        if *current < DocumentLoadState::Commit {
286            debug!("State transition: Commit");
287            *current = DocumentLoadState::Commit;
288        }
289    }
290
291    /// Get the current load state.
292    pub async fn current_state(&self) -> DocumentLoadState {
293        *self.current_state.lock().await
294    }
295
296    /// Get the captured response data from navigation.
297    ///
298    /// This returns the status code, headers, and final URL captured during navigation.
299    pub async fn response_data(&self) -> NavigationResponseData {
300        self.response_data.lock().await.clone()
301    }
302}