viewpoint_core/context/storage/
mod.rs

1//! Storage state collection and restoration.
2//!
3//! This module provides functionality for collecting and restoring browser
4//! storage state including cookies, localStorage, and `IndexedDB`.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use tracing::debug;
11use viewpoint_cdp::CdpConnection;
12use viewpoint_js::js;
13
14use tracing::instrument;
15
16use super::BrowserContext;
17use super::types::{Cookie, IndexedDbDatabase, LocalStorageEntry, StorageOrigin, StorageState};
18use crate::error::ContextError;
19
20// Re-export restore functions for external use
21pub use super::storage_restore::{restore_indexed_db, restore_local_storage};
22
23impl BrowserContext {
24    /// Get the storage state (cookies and localStorage).
25    ///
26    /// This method collects cookies and localStorage for all pages in the context.
27    /// For more advanced options including `IndexedDB`, use `storage_state_builder()`.
28    ///
29    /// # Example
30    ///
31    /// ```no_run
32    /// use viewpoint_core::BrowserContext;
33    ///
34    /// # async fn example(context: &BrowserContext) -> Result<(), Box<dyn std::error::Error>> {
35    /// let state = context.storage_state().await?;
36    /// state.save("auth.json").await?;
37    /// # Ok(())
38    /// # }
39    /// ```
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if getting storage state fails.
44    #[instrument(level = "debug", skip(self))]
45    pub async fn storage_state(&self) -> Result<StorageState, ContextError> {
46        self.storage_state_builder().collect().await
47    }
48
49    /// Create a builder for collecting storage state with options.
50    ///
51    /// Use this method when you need to include `IndexedDB` data or configure
52    /// other collection options.
53    ///
54    /// # Example
55    ///
56    /// ```no_run
57    /// use viewpoint_core::BrowserContext;
58    ///
59    /// # async fn example(context: &BrowserContext) -> Result<(), Box<dyn std::error::Error>> {
60    /// // Include IndexedDB data
61    /// let state = context.storage_state_builder()
62    ///     .indexed_db(true)
63    ///     .collect()
64    ///     .await?;
65    ///
66    /// state.save("full-state.json").await?;
67    /// # Ok(())
68    /// # }
69    /// ```
70    pub fn storage_state_builder(&self) -> StorageStateBuilder<'_> {
71        StorageStateBuilder::new(self.connection(), self.context_id(), &self.pages)
72    }
73}
74
75/// Options for collecting storage state.
76#[derive(Debug, Clone, Default)]
77pub struct StorageStateOptions {
78    /// Include `IndexedDB` data in the snapshot.
79    pub indexed_db: bool,
80    /// Maximum entries per `IndexedDB` object store.
81    /// Set to 0 for unlimited (default: 1000).
82    pub indexed_db_max_entries: usize,
83}
84
85impl StorageStateOptions {
86    /// Create new default options.
87    pub fn new() -> Self {
88        Self {
89            indexed_db: false,
90            indexed_db_max_entries: 1000,
91        }
92    }
93
94    /// Include `IndexedDB` data in the snapshot.
95    #[must_use]
96    pub fn indexed_db(mut self, include: bool) -> Self {
97        self.indexed_db = include;
98        self
99    }
100
101    /// Set maximum entries per `IndexedDB` object store.
102    #[must_use]
103    pub fn indexed_db_max_entries(mut self, max: usize) -> Self {
104        self.indexed_db_max_entries = max;
105        self
106    }
107}
108
109/// Builder for collecting storage state with options.
110pub struct StorageStateBuilder<'a> {
111    connection: &'a Arc<CdpConnection>,
112    context_id: &'a str,
113    pages: &'a Arc<RwLock<Vec<super::PageInfo>>>,
114    options: StorageStateOptions,
115}
116
117impl<'a> StorageStateBuilder<'a> {
118    pub(crate) fn new(
119        connection: &'a Arc<CdpConnection>,
120        context_id: &'a str,
121        pages: &'a Arc<RwLock<Vec<super::PageInfo>>>,
122    ) -> Self {
123        Self {
124            connection,
125            context_id,
126            pages,
127            options: StorageStateOptions::default(),
128        }
129    }
130
131    /// Include `IndexedDB` data in the storage state.
132    #[must_use]
133    pub fn indexed_db(mut self, include: bool) -> Self {
134        self.options.indexed_db = include;
135        self
136    }
137
138    /// Set maximum entries per `IndexedDB` object store.
139    #[must_use]
140    pub fn indexed_db_max_entries(mut self, max: usize) -> Self {
141        self.options.indexed_db_max_entries = max;
142        self
143    }
144
145    /// Collect the storage state.
146    ///
147    /// # Errors
148    ///
149    /// Returns an error if collecting storage state fails.
150    pub async fn collect(self) -> Result<StorageState, ContextError> {
151        // Collect cookies using the Storage domain
152        let cookies = self.collect_cookies().await?;
153
154        let mut origins: HashMap<String, StorageOrigin> = HashMap::new();
155
156        // Get all page sessions for evaluation
157        let pages = self.pages.read().await;
158
159        for page in pages.iter() {
160            if page.session_id.is_empty() {
161                continue;
162            }
163
164            // Get the current page URL/origin
165            let origin = self.get_page_origin(&page.session_id).await?;
166            if origin.is_empty() || origin == "null" {
167                continue;
168            }
169
170            // Get localStorage for this page
171            let local_storage = self.collect_local_storage(&page.session_id).await?;
172
173            // Get IndexedDB if requested
174            let indexed_db = if self.options.indexed_db {
175                self.collect_indexed_db(&page.session_id).await?
176            } else {
177                Vec::new()
178            };
179
180            // Merge into origins map
181            let storage_origin = origins
182                .entry(origin.clone())
183                .or_insert_with(|| StorageOrigin::new(origin));
184            storage_origin.local_storage.extend(local_storage);
185            storage_origin.indexed_db.extend(indexed_db);
186        }
187
188        Ok(StorageState {
189            cookies,
190            origins: origins.into_values().collect(),
191        })
192    }
193
194    /// Collect cookies from the browser context.
195    async fn collect_cookies(&self) -> Result<Vec<Cookie>, ContextError> {
196        use super::types::SameSite;
197        use viewpoint_cdp::protocol::storage::{GetCookiesParams, GetCookiesResult};
198
199        let result: GetCookiesResult = self
200            .connection
201            .send_command(
202                "Storage.getCookies",
203                Some(GetCookiesParams::new().browser_context_id(self.context_id.to_string())),
204                None,
205            )
206            .await?;
207
208        let cookies = result
209            .cookies
210            .into_iter()
211            .map(|c| Cookie {
212                name: c.name,
213                value: c.value,
214                domain: Some(c.domain),
215                path: Some(c.path),
216                url: None,
217                expires: if c.expires > 0.0 {
218                    Some(c.expires)
219                } else {
220                    None
221                },
222                http_only: Some(c.http_only),
223                secure: Some(c.secure),
224                same_site: c.same_site.map(|s| match s {
225                    viewpoint_cdp::protocol::CookieSameSite::Strict => SameSite::Strict,
226                    viewpoint_cdp::protocol::CookieSameSite::Lax => SameSite::Lax,
227                    viewpoint_cdp::protocol::CookieSameSite::None => SameSite::None,
228                }),
229            })
230            .collect();
231
232        Ok(cookies)
233    }
234
235    /// Get the origin URL for a page.
236    async fn get_page_origin(&self, session_id: &str) -> Result<String, ContextError> {
237        let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
238            .connection
239            .send_command(
240                "Runtime.evaluate",
241                Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
242                    expression: js! { window.location.origin }.to_string(),
243                    object_group: None,
244                    include_command_line_api: None,
245                    silent: Some(true),
246                    context_id: None,
247                    return_by_value: Some(true),
248                    await_promise: Some(false),
249                }),
250                Some(session_id),
251            )
252            .await?;
253
254        Ok(result
255            .result
256            .value
257            .and_then(|v| v.as_str().map(String::from))
258            .unwrap_or_default())
259    }
260
261    /// Collect localStorage entries from a page.
262    async fn collect_local_storage(
263        &self,
264        session_id: &str,
265    ) -> Result<Vec<LocalStorageEntry>, ContextError> {
266        let js = r"
267            (function() {
268                const entries = [];
269                for (let i = 0; i < localStorage.length; i++) {
270                    const key = localStorage.key(i);
271                    if (key !== null) {
272                        entries.push({ name: key, value: localStorage.getItem(key) || '' });
273                    }
274                }
275                return entries;
276            })()
277        ";
278
279        let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
280            .connection
281            .send_command(
282                "Runtime.evaluate",
283                Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
284                    expression: js.to_string(),
285                    object_group: None,
286                    include_command_line_api: None,
287                    silent: Some(true),
288                    context_id: None,
289                    return_by_value: Some(true),
290                    await_promise: Some(false),
291                }),
292                Some(session_id),
293            )
294            .await?;
295
296        if let Some(value) = result.result.value {
297            let entries: Vec<LocalStorageEntry> = serde_json::from_value(value).unwrap_or_default();
298            debug!("Collected {} localStorage entries", entries.len());
299            Ok(entries)
300        } else {
301            Ok(Vec::new())
302        }
303    }
304
305    /// Collect `IndexedDB` databases from a page.
306    async fn collect_indexed_db(
307        &self,
308        session_id: &str,
309    ) -> Result<Vec<IndexedDbDatabase>, ContextError> {
310        let max_entries = self.options.indexed_db_max_entries;
311
312        // JavaScript to collect IndexedDB data
313        let js = format!(
314            r"
315            (async function() {{
316                const maxEntries = {max_entries};
317                const databases = [];
318                
319                if (!window.indexedDB || !window.indexedDB.databases) {{
320                    return databases;
321                }}
322                
323                const dbList = await window.indexedDB.databases();
324                
325                for (const dbInfo of dbList) {{
326                    if (!dbInfo.name) continue;
327                    
328                    try {{
329                        const db = await new Promise((resolve, reject) => {{
330                            const request = indexedDB.open(dbInfo.name, dbInfo.version);
331                            request.onerror = () => reject(request.error);
332                            request.onsuccess = () => resolve(request.result);
333                        }});
334                        
335                        const dbData = {{
336                            name: dbInfo.name,
337                            version: db.version,
338                            stores: []
339                        }};
340                        
341                        for (const storeName of db.objectStoreNames) {{
342                            const tx = db.transaction(storeName, 'readonly');
343                            const store = tx.objectStore(storeName);
344                            
345                            const storeData = {{
346                                name: storeName,
347                                keyPath: store.keyPath ? (typeof store.keyPath === 'string' ? store.keyPath : store.keyPath.join(',')) : null,
348                                autoIncrement: store.autoIncrement,
349                                entries: [],
350                                indexes: []
351                            }};
352                            
353                            // Collect index definitions
354                            for (const indexName of store.indexNames) {{
355                                const index = store.index(indexName);
356                                storeData.indexes.push({{
357                                    name: index.name,
358                                    keyPath: typeof index.keyPath === 'string' ? index.keyPath : index.keyPath.join(','),
359                                    unique: index.unique,
360                                    multiEntry: index.multiEntry
361                                }});
362                            }}
363                            
364                            // Collect entries (limited)
365                            const entries = await new Promise((resolve, reject) => {{
366                                const entries = [];
367                                const request = store.openCursor();
368                                request.onerror = () => reject(request.error);
369                                request.onsuccess = (event) => {{
370                                    const cursor = event.target.result;
371                                    if (cursor && (maxEntries === 0 || entries.length < maxEntries)) {{
372                                        entries.push({{ key: cursor.key, value: cursor.value }});
373                                        cursor.continue();
374                                    }} else {{
375                                        resolve(entries);
376                                    }}
377                                }};
378                            }});
379                            
380                            storeData.entries = entries;
381                            dbData.stores.push(storeData);
382                        }}
383                        
384                        db.close();
385                        databases.push(dbData);
386                    }} catch (e) {{
387                        console.warn('Failed to read IndexedDB:', dbInfo.name, e);
388                    }}
389                }}
390                
391                return databases;
392            }})()
393        "
394        );
395
396        let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
397            .connection
398            .send_command(
399                "Runtime.evaluate",
400                Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
401                    expression: js,
402                    object_group: None,
403                    include_command_line_api: None,
404                    silent: Some(true),
405                    context_id: None,
406                    return_by_value: Some(true),
407                    await_promise: Some(true),
408                }),
409                Some(session_id),
410            )
411            .await?;
412
413        if let Some(value) = result.result.value {
414            let databases: Vec<IndexedDbDatabase> =
415                serde_json::from_value(value).unwrap_or_default();
416            debug!("Collected {} IndexedDB databases", databases.len());
417            Ok(databases)
418        } else {
419            Ok(Vec::new())
420        }
421    }
422}
423
424#[cfg(test)]
425mod tests;