viewpoint_core/context/storage/
mod.rs

1//! Storage state collection and restoration.
2//!
3//! This module provides functionality for collecting and restoring browser
4//! storage state including cookies, localStorage, and `IndexedDB`.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use tracing::debug;
11use viewpoint_cdp::CdpConnection;
12use viewpoint_js::js;
13
14use tracing::instrument;
15
16use super::BrowserContext;
17use super::types::{Cookie, IndexedDbDatabase, LocalStorageEntry, StorageOrigin, StorageState};
18use crate::error::ContextError;
19
20// Re-export restore functions for external use
21pub use super::storage_restore::{restore_indexed_db, restore_local_storage};
22
23impl BrowserContext {
24    /// Get the storage state (cookies and localStorage).
25    ///
26    /// This method collects cookies and localStorage for all pages in the context.
27    /// For more advanced options including `IndexedDB`, use `storage_state_builder()`.
28    ///
29    /// # Example
30    ///
31    /// ```no_run
32    /// use viewpoint_core::BrowserContext;
33    ///
34    /// # async fn example(context: &BrowserContext) -> Result<(), Box<dyn std::error::Error>> {
35    /// let state = context.storage_state().await?;
36    /// state.save("auth.json").await?;
37    /// # Ok(())
38    /// # }
39    /// ```
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if getting storage state fails.
44    #[instrument(level = "debug", skip(self))]
45    pub async fn storage_state(&self) -> Result<StorageState, ContextError> {
46        self.storage_state_builder().collect().await
47    }
48
49    /// Create a builder for collecting storage state with options.
50    ///
51    /// Use this method when you need to include `IndexedDB` data or configure
52    /// other collection options.
53    ///
54    /// # Example
55    ///
56    /// ```no_run
57    /// use viewpoint_core::BrowserContext;
58    ///
59    /// # async fn example(context: &BrowserContext) -> Result<(), Box<dyn std::error::Error>> {
60    /// // Include IndexedDB data
61    /// let state = context.storage_state_builder()
62    ///     .indexed_db(true)
63    ///     .collect()
64    ///     .await?;
65    ///
66    /// state.save("full-state.json").await?;
67    /// # Ok(())
68    /// # }
69    /// ```
70    pub fn storage_state_builder(&self) -> StorageStateBuilder<'_> {
71        StorageStateBuilder::new(self.connection(), self.context_id(), &self.pages)
72    }
73}
74
75/// Options for collecting storage state.
76#[derive(Debug, Clone, Default)]
77pub struct StorageStateOptions {
78    /// Include `IndexedDB` data in the snapshot.
79    pub indexed_db: bool,
80    /// Maximum entries per `IndexedDB` object store.
81    /// Set to 0 for unlimited (default: 1000).
82    pub indexed_db_max_entries: usize,
83}
84
85impl StorageStateOptions {
86    /// Create new default options.
87    pub fn new() -> Self {
88        Self {
89            indexed_db: false,
90            indexed_db_max_entries: 1000,
91        }
92    }
93
94    /// Include `IndexedDB` data in the snapshot.
95    #[must_use]
96    pub fn indexed_db(mut self, include: bool) -> Self {
97        self.indexed_db = include;
98        self
99    }
100
101    /// Set maximum entries per `IndexedDB` object store.
102    #[must_use]
103    pub fn indexed_db_max_entries(mut self, max: usize) -> Self {
104        self.indexed_db_max_entries = max;
105        self
106    }
107}
108
109/// Builder for collecting storage state with options.
110pub struct StorageStateBuilder<'a> {
111    connection: &'a Arc<CdpConnection>,
112    context_id: &'a str,
113    pages: &'a Arc<RwLock<Vec<super::PageInfo>>>,
114    options: StorageStateOptions,
115}
116
117impl<'a> StorageStateBuilder<'a> {
118    pub(crate) fn new(
119        connection: &'a Arc<CdpConnection>,
120        context_id: &'a str,
121        pages: &'a Arc<RwLock<Vec<super::PageInfo>>>,
122    ) -> Self {
123        Self {
124            connection,
125            context_id,
126            pages,
127            options: StorageStateOptions::default(),
128        }
129    }
130
131    /// Include `IndexedDB` data in the storage state.
132    #[must_use]
133    pub fn indexed_db(mut self, include: bool) -> Self {
134        self.options.indexed_db = include;
135        self
136    }
137
138    /// Set maximum entries per `IndexedDB` object store.
139    #[must_use]
140    pub fn indexed_db_max_entries(mut self, max: usize) -> Self {
141        self.options.indexed_db_max_entries = max;
142        self
143    }
144
145    /// Collect the storage state.
146    ///
147    /// # Errors
148    ///
149    /// Returns an error if collecting storage state fails.
150    pub async fn collect(self) -> Result<StorageState, ContextError> {
151        // Collect cookies using the Storage domain
152        let cookies = self.collect_cookies().await?;
153
154        let mut origins: HashMap<String, StorageOrigin> = HashMap::new();
155
156        // Get all page sessions for evaluation
157        let pages = self.pages.read().await;
158
159        for page in pages.iter() {
160            if page.session_id.is_empty() {
161                continue;
162            }
163
164            // Get the current page URL/origin
165            let origin = self.get_page_origin(&page.session_id).await?;
166            if origin.is_empty() || origin == "null" {
167                continue;
168            }
169
170            // Get localStorage for this page
171            let local_storage = self.collect_local_storage(&page.session_id).await?;
172
173            // Get IndexedDB if requested
174            let indexed_db = if self.options.indexed_db {
175                self.collect_indexed_db(&page.session_id).await?
176            } else {
177                Vec::new()
178            };
179
180            // Merge into origins map
181            let storage_origin = origins
182                .entry(origin.clone())
183                .or_insert_with(|| StorageOrigin::new(origin));
184            storage_origin.local_storage.extend(local_storage);
185            storage_origin.indexed_db.extend(indexed_db);
186        }
187
188        Ok(StorageState {
189            cookies,
190            origins: origins.into_values().collect(),
191        })
192    }
193
194    /// Collect cookies from the browser context.
195    async fn collect_cookies(&self) -> Result<Vec<Cookie>, ContextError> {
196        use super::types::SameSite;
197        use viewpoint_cdp::protocol::storage::{GetCookiesParams, GetCookiesResult};
198
199        let result: GetCookiesResult = self
200            .connection
201            .send_command(
202                "Storage.getCookies",
203                Some(GetCookiesParams::new().browser_context_id(self.context_id.to_string())),
204                None,
205            )
206            .await?;
207
208        let cookies = result
209            .cookies
210            .into_iter()
211            .map(|c| Cookie {
212                name: c.name,
213                value: c.value,
214                domain: Some(c.domain),
215                path: Some(c.path),
216                url: None,
217                expires: if c.expires > 0.0 {
218                    Some(c.expires)
219                } else {
220                    None
221                },
222                http_only: Some(c.http_only),
223                secure: Some(c.secure),
224                same_site: c.same_site.map(|s| match s {
225                    viewpoint_cdp::protocol::CookieSameSite::Strict => SameSite::Strict,
226                    viewpoint_cdp::protocol::CookieSameSite::Lax => SameSite::Lax,
227                    viewpoint_cdp::protocol::CookieSameSite::None => SameSite::None,
228                }),
229            })
230            .collect();
231
232        Ok(cookies)
233    }
234
235    /// Get the origin URL for a page.
236    async fn get_page_origin(&self, session_id: &str) -> Result<String, ContextError> {
237        let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
238            .connection
239            .send_command(
240                "Runtime.evaluate",
241                Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
242                    expression: js! { window.location.origin }.to_string(),
243                    object_group: None,
244                    include_command_line_api: None,
245                    silent: Some(true),
246                    context_id: None,
247                    return_by_value: Some(true),
248                    await_promise: Some(false),
249                }),
250                Some(session_id),
251            )
252            .await?;
253
254        Ok(result
255            .result
256            .value
257            .and_then(|v| v.as_str().map(String::from))
258            .unwrap_or_default())
259    }
260
261    /// Collect localStorage entries from a page.
262    async fn collect_local_storage(
263        &self,
264        session_id: &str,
265    ) -> Result<Vec<LocalStorageEntry>, ContextError> {
266        let js = r"
267            (function() {
268                const entries = [];
269                for (let i = 0; i < localStorage.length; i++) {
270                    const key = localStorage.key(i);
271                    if (key !== null) {
272                        entries.push({ name: key, value: localStorage.getItem(key) || '' });
273                    }
274                }
275                return entries;
276            })()
277        ";
278
279        let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
280            .connection
281            .send_command(
282                "Runtime.evaluate",
283                Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
284                    expression: js.to_string(),
285                    object_group: None,
286                    include_command_line_api: None,
287                    silent: Some(true),
288                    context_id: None,
289                    return_by_value: Some(true),
290                    await_promise: Some(false),
291                }),
292                Some(session_id),
293            )
294            .await?;
295
296        if let Some(value) = result.result.value {
297            let entries: Vec<LocalStorageEntry> = serde_json::from_value(value).unwrap_or_default();
298            debug!("Collected {} localStorage entries", entries.len());
299            Ok(entries)
300        } else {
301            Ok(Vec::new())
302        }
303    }
304
305    /// Collect `IndexedDB` databases from a page.
306    async fn collect_indexed_db(
307        &self,
308        session_id: &str,
309    ) -> Result<Vec<IndexedDbDatabase>, ContextError> {
310        let max_entries = self.options.indexed_db_max_entries.to_string();
311
312        // JavaScript to collect IndexedDB data
313        let js_code = js! {
314            (async function() {
315                const maxEntries = @{max_entries};
316                const databases = [];
317
318                if (!window.indexedDB || !window.indexedDB.databases) {
319                    return databases;
320                }
321
322                const dbList = await window.indexedDB.databases();
323
324                for (const dbInfo of dbList) {
325                    if (!dbInfo.name) continue;
326
327                    try {
328                        const db = await new Promise((resolve, reject) => {
329                            const request = indexedDB.open(dbInfo.name, dbInfo.version);
330                            request.onerror = () => reject(request.error);
331                            request.onsuccess = () => resolve(request.result);
332                        });
333
334                        const dbData = {
335                            name: dbInfo.name,
336                            version: db.version,
337                            stores: []
338                        };
339
340                        for (const storeName of db.objectStoreNames) {
341                            const tx = db.transaction(storeName, "readonly");
342                            const store = tx.objectStore(storeName);
343
344                            const storeData = {
345                                name: storeName,
346                                keyPath: store.keyPath ? (typeof store.keyPath === "string" ? store.keyPath : store.keyPath.join(",")) : null,
347                                autoIncrement: store.autoIncrement,
348                                entries: [],
349                                indexes: []
350                            };
351
352                            // Collect index definitions
353                            for (const indexName of store.indexNames) {
354                                const index = store.index(indexName);
355                                storeData.indexes.push({
356                                    name: index.name,
357                                    keyPath: typeof index.keyPath === "string" ? index.keyPath : index.keyPath.join(","),
358                                    unique: index.unique,
359                                    multiEntry: index.multiEntry
360                                });
361                            }
362
363                            // Collect entries (limited)
364                            const entries = await new Promise((resolve, reject) => {
365                                const entries = [];
366                                const request = store.openCursor();
367                                request.onerror = () => reject(request.error);
368                                request.onsuccess = (event) => {
369                                    const cursor = event.target.result;
370                                    if (cursor && (maxEntries === 0 || entries.length < maxEntries)) {
371                                        entries.push({ key: cursor.key, value: cursor.value });
372                                        cursor.continue();
373                                    } else {
374                                        resolve(entries);
375                                    }
376                                };
377                            });
378
379                            storeData.entries = entries;
380                            dbData.stores.push(storeData);
381                        }
382
383                        db.close();
384                        databases.push(dbData);
385                    } catch (e) {
386                        console.warn("Failed to read IndexedDB:", dbInfo.name, e);
387                    }
388                }
389
390                return databases;
391            })()
392        };
393
394        let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
395            .connection
396            .send_command(
397                "Runtime.evaluate",
398                Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
399                    expression: js_code,
400                    object_group: None,
401                    include_command_line_api: None,
402                    silent: Some(true),
403                    context_id: None,
404                    return_by_value: Some(true),
405                    await_promise: Some(true),
406                }),
407                Some(session_id),
408            )
409            .await?;
410
411        if let Some(value) = result.result.value {
412            let databases: Vec<IndexedDbDatabase> =
413                serde_json::from_value(value).unwrap_or_default();
414            debug!("Collected {} IndexedDB databases", databases.len());
415            Ok(databases)
416        } else {
417            Ok(Vec::new())
418        }
419    }
420}
421
422#[cfg(test)]
423mod tests;