viewpoint_core/network/route_fetch/
mod.rs

1//! Fetch builder for intercepting and modifying responses.
2//!
3//! This module provides the `FetchBuilder` and `FetchedResponse` types for
4//! intercepting network responses and modifying them before they reach the page.
5
6use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::time::Duration;
10
11use viewpoint_cdp::protocol::fetch::{
12    ContinueRequestParams, FulfillRequestParams, HeaderEntry, RequestPausedEvent,
13};
14
15use super::route::Route;
16use crate::error::NetworkError;
17
18/// Builder for fetching the actual response with optional request modifications.
19#[derive(Debug)]
20pub struct FetchBuilder<'a> {
21    pub(super) route: &'a Route,
22    pub(super) url: Option<String>,
23    pub(super) method: Option<String>,
24    pub(super) headers: Vec<HeaderEntry>,
25    pub(super) post_data: Option<Vec<u8>>,
26    pub(super) timeout: Duration,
27}
28
29impl<'a> FetchBuilder<'a> {
30    pub(super) fn new(route: &'a Route) -> Self {
31        Self {
32            route,
33            url: None,
34            method: None,
35            headers: Vec::new(),
36            post_data: None,
37            timeout: Duration::from_secs(30),
38        }
39    }
40
41    /// Override the request URL before fetching.
42    #[must_use]
43    pub fn url(mut self, url: impl Into<String>) -> Self {
44        self.url = Some(url.into());
45        self
46    }
47
48    /// Override the request method before fetching.
49    #[must_use]
50    pub fn method(mut self, method: impl Into<String>) -> Self {
51        self.method = Some(method.into());
52        self
53    }
54
55    /// Add or override a request header before fetching.
56    #[must_use]
57    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
58        self.headers.push(HeaderEntry {
59            name: name.into(),
60            value: value.into(),
61        });
62        self
63    }
64
65    /// Set multiple request headers before fetching.
66    #[must_use]
67    pub fn headers(
68        mut self,
69        headers: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
70    ) -> Self {
71        for (name, value) in headers {
72            self.headers.push(HeaderEntry {
73                name: name.into(),
74                value: value.into(),
75            });
76        }
77        self
78    }
79
80    /// Override the request POST data before fetching.
81    #[must_use]
82    pub fn post_data(mut self, data: impl Into<Vec<u8>>) -> Self {
83        self.post_data = Some(data.into());
84        self
85    }
86
87    /// Set the timeout for waiting for the response.
88    #[must_use]
89    pub fn timeout(mut self, timeout: Duration) -> Self {
90        self.timeout = timeout;
91        self
92    }
93
94    /// Fetch the response.
95    pub async fn send(self) -> Result<FetchedResponse<'a>, NetworkError> {
96        use base64::Engine;
97
98        // Subscribe to CDP events before sending the continue command
99        let mut events = self.route.connection().subscribe_events();
100        let request_id = self.route.request_id().to_string();
101        let session_id = self.route.session_id().to_string();
102
103        // Build continue params with modifications
104        let post_data = self
105            .post_data
106            .map(|d| base64::engine::general_purpose::STANDARD.encode(&d));
107
108        let params = ContinueRequestParams {
109            request_id: self.route.request_id().to_string(),
110            url: self.url,
111            method: self.method,
112            post_data,
113            headers: if self.headers.is_empty() {
114                None
115            } else {
116                Some(self.headers)
117            },
118            intercept_response: Some(true),
119        };
120
121        // Continue the request but intercept the response
122        self.route
123            .connection()
124            .send_command::<_, serde_json::Value>(
125                "Fetch.continueRequest",
126                Some(params),
127                Some(&session_id),
128            )
129            .await
130            .map_err(NetworkError::from)?;
131
132        // Wait for the response-stage Fetch.requestPaused event
133        let timeout = self.timeout;
134        let response_event = tokio::time::timeout(timeout, async {
135            while let Ok(event) = events.recv().await {
136                // Filter for our session
137                if event.session_id.as_deref() != Some(&session_id) {
138                    continue;
139                }
140
141                // Look for Fetch.requestPaused at response stage
142                if event.method == "Fetch.requestPaused" {
143                    if let Some(params) = &event.params {
144                        if let Ok(paused) =
145                            serde_json::from_value::<RequestPausedEvent>(params.clone())
146                        {
147                            // Check if this is for our request and at response stage
148                            if paused.request_id == request_id && paused.is_response_stage() {
149                                return Ok(paused);
150                            }
151                        }
152                    }
153                }
154            }
155            Err(NetworkError::Aborted)
156        })
157        .await
158        .map_err(|_| NetworkError::Timeout(timeout))??;
159
160        // Extract response data
161        let status = response_event.response_status_code.unwrap_or(200) as u16;
162        let headers: HashMap<String, String> = response_event
163            .response_headers
164            .as_ref()
165            .map(|h| {
166                h.iter()
167                    .map(|e| (e.name.clone(), e.value.clone()))
168                    .collect()
169            })
170            .unwrap_or_default();
171
172        // Get the response body
173        let body = self
174            .route
175            .get_response_body(&response_event.request_id)
176            .await?;
177
178        Ok(FetchedResponse {
179            route: self.route,
180            request_id: response_event.request_id,
181            status,
182            headers,
183            body,
184        })
185    }
186}
187
188// Allow `route.fetch().await` without calling `.send()`
189impl<'a> std::future::IntoFuture for FetchBuilder<'a> {
190    type Output = Result<FetchedResponse<'a>, NetworkError>;
191    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
192
193    fn into_future(self) -> Self::IntoFuture {
194        Box::pin(self.send())
195    }
196}
197
198/// A response fetched via `route.fetch()`.
199#[derive(Debug)]
200pub struct FetchedResponse<'a> {
201    route: &'a Route,
202    /// Request ID for the response-stage paused request.
203    request_id: String,
204    /// HTTP status code.
205    pub status: u16,
206    /// Response headers.
207    pub headers: HashMap<String, String>,
208    /// Response body (already fetched).
209    pub(super) body: Option<Vec<u8>>,
210}
211
212impl FetchedResponse<'_> {
213    /// Get the response body.
214    ///
215    /// The body is fetched when `route.fetch()` is called, so this method
216    /// returns immediately.
217    pub fn body(&self) -> Result<Vec<u8>, NetworkError> {
218        self.body
219            .clone()
220            .ok_or_else(|| NetworkError::InvalidResponse("Response body not available".to_string()))
221    }
222
223    /// Get the response body as text.
224    pub fn text(&self) -> Result<String, NetworkError> {
225        let body = self.body()?;
226        String::from_utf8(body)
227            .map_err(|e| NetworkError::InvalidResponse(format!("Response is not valid UTF-8: {e}")))
228    }
229
230    /// Parse the response body as JSON.
231    pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, NetworkError> {
232        let text = self.text()?;
233        serde_json::from_str(&text)
234            .map_err(|e| NetworkError::InvalidResponse(format!("Failed to parse JSON: {e}")))
235    }
236
237    /// Continue the response to the page.
238    ///
239    /// This must be called after inspecting/modifying the response to let
240    /// the browser receive it.
241    pub async fn fulfill(self) -> Result<(), NetworkError> {
242        use base64::Engine;
243
244        // Build response headers
245        let response_headers: Vec<HeaderEntry> = self
246            .headers
247            .iter()
248            .map(|(k, v)| HeaderEntry {
249                name: k.clone(),
250                value: v.clone(),
251            })
252            .collect();
253
254        // Encode body
255        let body = self
256            .body
257            .map(|b| base64::engine::general_purpose::STANDARD.encode(&b));
258
259        let params = FulfillRequestParams {
260            request_id: self.request_id.clone(),
261            response_code: i32::from(self.status),
262            response_headers: if response_headers.is_empty() {
263                None
264            } else {
265                Some(response_headers)
266            },
267            binary_response_headers: None,
268            body,
269            response_phrase: None,
270        };
271
272        self.route
273            .connection()
274            .send_command::<_, serde_json::Value>(
275                "Fetch.fulfillRequest",
276                Some(params),
277                Some(self.route.session_id()),
278            )
279            .await
280            .map_err(NetworkError::from)?;
281
282        Ok(())
283    }
284}