viewpoint_core/context/storage/
mod.rs

1//! Storage state collection and restoration.
2//!
3//! This module provides functionality for collecting and restoring browser
4//! storage state including cookies, localStorage, and `IndexedDB`.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use tracing::debug;
11use viewpoint_cdp::CdpConnection;
12use viewpoint_js::js;
13
14use tracing::instrument;
15
16use super::types::{
17    Cookie, IndexedDbDatabase, LocalStorageEntry, StorageOrigin, StorageState,
18};
19use super::BrowserContext;
20use crate::error::ContextError;
21
22// Re-export restore functions for external use
23pub use super::storage_restore::{restore_indexed_db, restore_local_storage};
24
25impl BrowserContext {
26    /// Get the storage state (cookies and localStorage).
27    ///
28    /// This method collects cookies and localStorage for all pages in the context.
29    /// For more advanced options including `IndexedDB`, use `storage_state_builder()`.
30    ///
31    /// # Example
32    ///
33    /// ```no_run
34    /// use viewpoint_core::BrowserContext;
35    ///
36    /// # async fn example(context: &BrowserContext) -> Result<(), Box<dyn std::error::Error>> {
37    /// let state = context.storage_state().await?;
38    /// state.save("auth.json").await?;
39    /// # Ok(())
40    /// # }
41    /// ```
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if getting storage state fails.
46    #[instrument(level = "debug", skip(self))]
47    pub async fn storage_state(&self) -> Result<StorageState, ContextError> {
48        self.storage_state_builder().collect().await
49    }
50
51    /// Create a builder for collecting storage state with options.
52    ///
53    /// Use this method when you need to include `IndexedDB` data or configure
54    /// other collection options.
55    ///
56    /// # Example
57    ///
58    /// ```no_run
59    /// use viewpoint_core::BrowserContext;
60    ///
61    /// # async fn example(context: &BrowserContext) -> Result<(), Box<dyn std::error::Error>> {
62    /// // Include IndexedDB data
63    /// let state = context.storage_state_builder()
64    ///     .indexed_db(true)
65    ///     .collect()
66    ///     .await?;
67    ///
68    /// state.save("full-state.json").await?;
69    /// # Ok(())
70    /// # }
71    /// ```
72    pub fn storage_state_builder(&self) -> StorageStateBuilder<'_> {
73        StorageStateBuilder::new(
74            self.connection(),
75            self.context_id(),
76            &self.pages,
77        )
78    }
79}
80
81/// Options for collecting storage state.
82#[derive(Debug, Clone, Default)]
83pub struct StorageStateOptions {
84    /// Include `IndexedDB` data in the snapshot.
85    pub indexed_db: bool,
86    /// Maximum entries per `IndexedDB` object store.
87    /// Set to 0 for unlimited (default: 1000).
88    pub indexed_db_max_entries: usize,
89}
90
91impl StorageStateOptions {
92    /// Create new default options.
93    pub fn new() -> Self {
94        Self {
95            indexed_db: false,
96            indexed_db_max_entries: 1000,
97        }
98    }
99
100    /// Include `IndexedDB` data in the snapshot.
101    #[must_use]
102    pub fn indexed_db(mut self, include: bool) -> Self {
103        self.indexed_db = include;
104        self
105    }
106
107    /// Set maximum entries per `IndexedDB` object store.
108    #[must_use]
109    pub fn indexed_db_max_entries(mut self, max: usize) -> Self {
110        self.indexed_db_max_entries = max;
111        self
112    }
113}
114
115/// Builder for collecting storage state with options.
116pub struct StorageStateBuilder<'a> {
117    connection: &'a Arc<CdpConnection>,
118    context_id: &'a str,
119    pages: &'a Arc<RwLock<Vec<super::PageInfo>>>,
120    options: StorageStateOptions,
121}
122
123impl<'a> StorageStateBuilder<'a> {
124    pub(crate) fn new(
125        connection: &'a Arc<CdpConnection>,
126        context_id: &'a str,
127        pages: &'a Arc<RwLock<Vec<super::PageInfo>>>,
128    ) -> Self {
129        Self {
130            connection,
131            context_id,
132            pages,
133            options: StorageStateOptions::default(),
134        }
135    }
136
137    /// Include `IndexedDB` data in the storage state.
138    #[must_use]
139    pub fn indexed_db(mut self, include: bool) -> Self {
140        self.options.indexed_db = include;
141        self
142    }
143
144    /// Set maximum entries per `IndexedDB` object store.
145    #[must_use]
146    pub fn indexed_db_max_entries(mut self, max: usize) -> Self {
147        self.options.indexed_db_max_entries = max;
148        self
149    }
150
151    /// Collect the storage state.
152    ///
153    /// # Errors
154    ///
155    /// Returns an error if collecting storage state fails.
156    pub async fn collect(self) -> Result<StorageState, ContextError> {
157        // Collect cookies using the Storage domain
158        let cookies = self.collect_cookies().await?;
159
160        let mut origins: HashMap<String, StorageOrigin> = HashMap::new();
161
162        // Get all page sessions for evaluation
163        let pages = self.pages.read().await;
164        
165        for page in pages.iter() {
166            if page.session_id.is_empty() {
167                continue;
168            }
169
170            // Get the current page URL/origin
171            let origin = self.get_page_origin(&page.session_id).await?;
172            if origin.is_empty() || origin == "null" {
173                continue;
174            }
175
176            // Get localStorage for this page
177            let local_storage = self.collect_local_storage(&page.session_id).await?;
178
179            // Get IndexedDB if requested
180            let indexed_db = if self.options.indexed_db {
181                self.collect_indexed_db(&page.session_id).await?
182            } else {
183                Vec::new()
184            };
185
186            // Merge into origins map
187            let storage_origin = origins.entry(origin.clone()).or_insert_with(|| {
188                StorageOrigin::new(origin)
189            });
190            storage_origin.local_storage.extend(local_storage);
191            storage_origin.indexed_db.extend(indexed_db);
192        }
193
194        Ok(StorageState {
195            cookies,
196            origins: origins.into_values().collect(),
197        })
198    }
199
200    /// Collect cookies from the browser context.
201    async fn collect_cookies(&self) -> Result<Vec<Cookie>, ContextError> {
202        use viewpoint_cdp::protocol::storage::{GetCookiesParams, GetCookiesResult};
203        use super::types::SameSite;
204
205        let result: GetCookiesResult = self
206            .connection
207            .send_command(
208                "Storage.getCookies",
209                Some(GetCookiesParams::new().browser_context_id(self.context_id.to_string())),
210                None,
211            )
212            .await?;
213
214        let cookies = result
215            .cookies
216            .into_iter()
217            .map(|c| Cookie {
218                name: c.name,
219                value: c.value,
220                domain: Some(c.domain),
221                path: Some(c.path),
222                url: None,
223                expires: if c.expires > 0.0 {
224                    Some(c.expires)
225                } else {
226                    None
227                },
228                http_only: Some(c.http_only),
229                secure: Some(c.secure),
230                same_site: c.same_site.map(|s| match s {
231                    viewpoint_cdp::protocol::CookieSameSite::Strict => SameSite::Strict,
232                    viewpoint_cdp::protocol::CookieSameSite::Lax => SameSite::Lax,
233                    viewpoint_cdp::protocol::CookieSameSite::None => SameSite::None,
234                }),
235            })
236            .collect();
237
238        Ok(cookies)
239    }
240
241    /// Get the origin URL for a page.
242    async fn get_page_origin(&self, session_id: &str) -> Result<String, ContextError> {
243        let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
244            .connection
245            .send_command(
246                "Runtime.evaluate",
247                Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
248                    expression: js!{ window.location.origin }.to_string(),
249                    object_group: None,
250                    include_command_line_api: None,
251                    silent: Some(true),
252                    context_id: None,
253                    return_by_value: Some(true),
254                    await_promise: Some(false),
255                }),
256                Some(session_id),
257            )
258            .await?;
259
260        Ok(result
261            .result
262            .value
263            .and_then(|v| v.as_str().map(String::from))
264            .unwrap_or_default())
265    }
266
267    /// Collect localStorage entries from a page.
268    async fn collect_local_storage(
269        &self,
270        session_id: &str,
271    ) -> Result<Vec<LocalStorageEntry>, ContextError> {
272        let js = r"
273            (function() {
274                const entries = [];
275                for (let i = 0; i < localStorage.length; i++) {
276                    const key = localStorage.key(i);
277                    if (key !== null) {
278                        entries.push({ name: key, value: localStorage.getItem(key) || '' });
279                    }
280                }
281                return entries;
282            })()
283        ";
284
285        let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
286            .connection
287            .send_command(
288                "Runtime.evaluate",
289                Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
290                    expression: js.to_string(),
291                    object_group: None,
292                    include_command_line_api: None,
293                    silent: Some(true),
294                    context_id: None,
295                    return_by_value: Some(true),
296                    await_promise: Some(false),
297                }),
298                Some(session_id),
299            )
300            .await?;
301
302        if let Some(value) = result.result.value {
303            let entries: Vec<LocalStorageEntry> = serde_json::from_value(value).unwrap_or_default();
304            debug!("Collected {} localStorage entries", entries.len());
305            Ok(entries)
306        } else {
307            Ok(Vec::new())
308        }
309    }
310
311    /// Collect `IndexedDB` databases from a page.
312    async fn collect_indexed_db(
313        &self,
314        session_id: &str,
315    ) -> Result<Vec<IndexedDbDatabase>, ContextError> {
316        let max_entries = self.options.indexed_db_max_entries;
317        
318        // JavaScript to collect IndexedDB data
319        let js = format!(r"
320            (async function() {{
321                const maxEntries = {max_entries};
322                const databases = [];
323                
324                if (!window.indexedDB || !window.indexedDB.databases) {{
325                    return databases;
326                }}
327                
328                const dbList = await window.indexedDB.databases();
329                
330                for (const dbInfo of dbList) {{
331                    if (!dbInfo.name) continue;
332                    
333                    try {{
334                        const db = await new Promise((resolve, reject) => {{
335                            const request = indexedDB.open(dbInfo.name, dbInfo.version);
336                            request.onerror = () => reject(request.error);
337                            request.onsuccess = () => resolve(request.result);
338                        }});
339                        
340                        const dbData = {{
341                            name: dbInfo.name,
342                            version: db.version,
343                            stores: []
344                        }};
345                        
346                        for (const storeName of db.objectStoreNames) {{
347                            const tx = db.transaction(storeName, 'readonly');
348                            const store = tx.objectStore(storeName);
349                            
350                            const storeData = {{
351                                name: storeName,
352                                keyPath: store.keyPath ? (typeof store.keyPath === 'string' ? store.keyPath : store.keyPath.join(',')) : null,
353                                autoIncrement: store.autoIncrement,
354                                entries: [],
355                                indexes: []
356                            }};
357                            
358                            // Collect index definitions
359                            for (const indexName of store.indexNames) {{
360                                const index = store.index(indexName);
361                                storeData.indexes.push({{
362                                    name: index.name,
363                                    keyPath: typeof index.keyPath === 'string' ? index.keyPath : index.keyPath.join(','),
364                                    unique: index.unique,
365                                    multiEntry: index.multiEntry
366                                }});
367                            }}
368                            
369                            // Collect entries (limited)
370                            const entries = await new Promise((resolve, reject) => {{
371                                const entries = [];
372                                const request = store.openCursor();
373                                request.onerror = () => reject(request.error);
374                                request.onsuccess = (event) => {{
375                                    const cursor = event.target.result;
376                                    if (cursor && (maxEntries === 0 || entries.length < maxEntries)) {{
377                                        entries.push({{ key: cursor.key, value: cursor.value }});
378                                        cursor.continue();
379                                    }} else {{
380                                        resolve(entries);
381                                    }}
382                                }};
383                            }});
384                            
385                            storeData.entries = entries;
386                            dbData.stores.push(storeData);
387                        }}
388                        
389                        db.close();
390                        databases.push(dbData);
391                    }} catch (e) {{
392                        console.warn('Failed to read IndexedDB:', dbInfo.name, e);
393                    }}
394                }}
395                
396                return databases;
397            }})()
398        ");
399
400        let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
401            .connection
402            .send_command(
403                "Runtime.evaluate",
404                Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
405                    expression: js,
406                    object_group: None,
407                    include_command_line_api: None,
408                    silent: Some(true),
409                    context_id: None,
410                    return_by_value: Some(true),
411                    await_promise: Some(true),
412                }),
413                Some(session_id),
414            )
415            .await?;
416
417        if let Some(value) = result.result.value {
418            let databases: Vec<IndexedDbDatabase> = serde_json::from_value(value).unwrap_or_default();
419            debug!("Collected {} IndexedDB databases", databases.len());
420            Ok(databases)
421        } else {
422            Ok(Vec::new())
423        }
424    }
425}
426
427#[cfg(test)]
428mod tests;