viewpoint_core/context/storage/
mod.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use tracing::debug;
11use viewpoint_cdp::CdpConnection;
12use viewpoint_js::js;
13
14use tracing::instrument;
15
16use super::BrowserContext;
17use super::types::{Cookie, IndexedDbDatabase, LocalStorageEntry, StorageOrigin, StorageState};
18use crate::error::ContextError;
19
20pub use super::storage_restore::{restore_indexed_db, restore_local_storage};
22
23impl BrowserContext {
24 #[instrument(level = "debug", skip(self))]
45 pub async fn storage_state(&self) -> Result<StorageState, ContextError> {
46 self.storage_state_builder().collect().await
47 }
48
49 pub fn storage_state_builder(&self) -> StorageStateBuilder<'_> {
71 StorageStateBuilder::new(self.connection(), self.context_id(), &self.pages)
72 }
73}
74
75#[derive(Debug, Clone, Default)]
77pub struct StorageStateOptions {
78 pub indexed_db: bool,
80 pub indexed_db_max_entries: usize,
83}
84
85impl StorageStateOptions {
86 pub fn new() -> Self {
88 Self {
89 indexed_db: false,
90 indexed_db_max_entries: 1000,
91 }
92 }
93
94 #[must_use]
96 pub fn indexed_db(mut self, include: bool) -> Self {
97 self.indexed_db = include;
98 self
99 }
100
101 #[must_use]
103 pub fn indexed_db_max_entries(mut self, max: usize) -> Self {
104 self.indexed_db_max_entries = max;
105 self
106 }
107}
108
109pub struct StorageStateBuilder<'a> {
111 connection: &'a Arc<CdpConnection>,
112 context_id: &'a str,
113 pages: &'a Arc<RwLock<Vec<super::PageInfo>>>,
114 options: StorageStateOptions,
115}
116
117impl<'a> StorageStateBuilder<'a> {
118 pub(crate) fn new(
119 connection: &'a Arc<CdpConnection>,
120 context_id: &'a str,
121 pages: &'a Arc<RwLock<Vec<super::PageInfo>>>,
122 ) -> Self {
123 Self {
124 connection,
125 context_id,
126 pages,
127 options: StorageStateOptions::default(),
128 }
129 }
130
131 #[must_use]
133 pub fn indexed_db(mut self, include: bool) -> Self {
134 self.options.indexed_db = include;
135 self
136 }
137
138 #[must_use]
140 pub fn indexed_db_max_entries(mut self, max: usize) -> Self {
141 self.options.indexed_db_max_entries = max;
142 self
143 }
144
145 pub async fn collect(self) -> Result<StorageState, ContextError> {
151 let cookies = self.collect_cookies().await?;
153
154 let mut origins: HashMap<String, StorageOrigin> = HashMap::new();
155
156 let pages = self.pages.read().await;
158
159 for page in pages.iter() {
160 if page.session_id.is_empty() {
161 continue;
162 }
163
164 let origin = self.get_page_origin(&page.session_id).await?;
166 if origin.is_empty() || origin == "null" {
167 continue;
168 }
169
170 let local_storage = self.collect_local_storage(&page.session_id).await?;
172
173 let indexed_db = if self.options.indexed_db {
175 self.collect_indexed_db(&page.session_id).await?
176 } else {
177 Vec::new()
178 };
179
180 let storage_origin = origins
182 .entry(origin.clone())
183 .or_insert_with(|| StorageOrigin::new(origin));
184 storage_origin.local_storage.extend(local_storage);
185 storage_origin.indexed_db.extend(indexed_db);
186 }
187
188 Ok(StorageState {
189 cookies,
190 origins: origins.into_values().collect(),
191 })
192 }
193
194 async fn collect_cookies(&self) -> Result<Vec<Cookie>, ContextError> {
196 use super::types::SameSite;
197 use viewpoint_cdp::protocol::storage::{GetCookiesParams, GetCookiesResult};
198
199 let result: GetCookiesResult = self
200 .connection
201 .send_command(
202 "Storage.getCookies",
203 Some(GetCookiesParams::new().browser_context_id(self.context_id.to_string())),
204 None,
205 )
206 .await?;
207
208 let cookies = result
209 .cookies
210 .into_iter()
211 .map(|c| Cookie {
212 name: c.name,
213 value: c.value,
214 domain: Some(c.domain),
215 path: Some(c.path),
216 url: None,
217 expires: if c.expires > 0.0 {
218 Some(c.expires)
219 } else {
220 None
221 },
222 http_only: Some(c.http_only),
223 secure: Some(c.secure),
224 same_site: c.same_site.map(|s| match s {
225 viewpoint_cdp::protocol::CookieSameSite::Strict => SameSite::Strict,
226 viewpoint_cdp::protocol::CookieSameSite::Lax => SameSite::Lax,
227 viewpoint_cdp::protocol::CookieSameSite::None => SameSite::None,
228 }),
229 })
230 .collect();
231
232 Ok(cookies)
233 }
234
235 async fn get_page_origin(&self, session_id: &str) -> Result<String, ContextError> {
237 let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
238 .connection
239 .send_command(
240 "Runtime.evaluate",
241 Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
242 expression: js! { window.location.origin }.to_string(),
243 object_group: None,
244 include_command_line_api: None,
245 silent: Some(true),
246 context_id: None,
247 return_by_value: Some(true),
248 await_promise: Some(false),
249 }),
250 Some(session_id),
251 )
252 .await?;
253
254 Ok(result
255 .result
256 .value
257 .and_then(|v| v.as_str().map(String::from))
258 .unwrap_or_default())
259 }
260
261 async fn collect_local_storage(
263 &self,
264 session_id: &str,
265 ) -> Result<Vec<LocalStorageEntry>, ContextError> {
266 let js = r"
267 (function() {
268 const entries = [];
269 for (let i = 0; i < localStorage.length; i++) {
270 const key = localStorage.key(i);
271 if (key !== null) {
272 entries.push({ name: key, value: localStorage.getItem(key) || '' });
273 }
274 }
275 return entries;
276 })()
277 ";
278
279 let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
280 .connection
281 .send_command(
282 "Runtime.evaluate",
283 Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
284 expression: js.to_string(),
285 object_group: None,
286 include_command_line_api: None,
287 silent: Some(true),
288 context_id: None,
289 return_by_value: Some(true),
290 await_promise: Some(false),
291 }),
292 Some(session_id),
293 )
294 .await?;
295
296 if let Some(value) = result.result.value {
297 let entries: Vec<LocalStorageEntry> = serde_json::from_value(value).unwrap_or_default();
298 debug!("Collected {} localStorage entries", entries.len());
299 Ok(entries)
300 } else {
301 Ok(Vec::new())
302 }
303 }
304
305 async fn collect_indexed_db(
307 &self,
308 session_id: &str,
309 ) -> Result<Vec<IndexedDbDatabase>, ContextError> {
310 let max_entries = self.options.indexed_db_max_entries;
311
312 let js = format!(
314 r"
315 (async function() {{
316 const maxEntries = {max_entries};
317 const databases = [];
318
319 if (!window.indexedDB || !window.indexedDB.databases) {{
320 return databases;
321 }}
322
323 const dbList = await window.indexedDB.databases();
324
325 for (const dbInfo of dbList) {{
326 if (!dbInfo.name) continue;
327
328 try {{
329 const db = await new Promise((resolve, reject) => {{
330 const request = indexedDB.open(dbInfo.name, dbInfo.version);
331 request.onerror = () => reject(request.error);
332 request.onsuccess = () => resolve(request.result);
333 }});
334
335 const dbData = {{
336 name: dbInfo.name,
337 version: db.version,
338 stores: []
339 }};
340
341 for (const storeName of db.objectStoreNames) {{
342 const tx = db.transaction(storeName, 'readonly');
343 const store = tx.objectStore(storeName);
344
345 const storeData = {{
346 name: storeName,
347 keyPath: store.keyPath ? (typeof store.keyPath === 'string' ? store.keyPath : store.keyPath.join(',')) : null,
348 autoIncrement: store.autoIncrement,
349 entries: [],
350 indexes: []
351 }};
352
353 // Collect index definitions
354 for (const indexName of store.indexNames) {{
355 const index = store.index(indexName);
356 storeData.indexes.push({{
357 name: index.name,
358 keyPath: typeof index.keyPath === 'string' ? index.keyPath : index.keyPath.join(','),
359 unique: index.unique,
360 multiEntry: index.multiEntry
361 }});
362 }}
363
364 // Collect entries (limited)
365 const entries = await new Promise((resolve, reject) => {{
366 const entries = [];
367 const request = store.openCursor();
368 request.onerror = () => reject(request.error);
369 request.onsuccess = (event) => {{
370 const cursor = event.target.result;
371 if (cursor && (maxEntries === 0 || entries.length < maxEntries)) {{
372 entries.push({{ key: cursor.key, value: cursor.value }});
373 cursor.continue();
374 }} else {{
375 resolve(entries);
376 }}
377 }};
378 }});
379
380 storeData.entries = entries;
381 dbData.stores.push(storeData);
382 }}
383
384 db.close();
385 databases.push(dbData);
386 }} catch (e) {{
387 console.warn('Failed to read IndexedDB:', dbInfo.name, e);
388 }}
389 }}
390
391 return databases;
392 }})()
393 "
394 );
395
396 let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
397 .connection
398 .send_command(
399 "Runtime.evaluate",
400 Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
401 expression: js,
402 object_group: None,
403 include_command_line_api: None,
404 silent: Some(true),
405 context_id: None,
406 return_by_value: Some(true),
407 await_promise: Some(true),
408 }),
409 Some(session_id),
410 )
411 .await?;
412
413 if let Some(value) = result.result.value {
414 let databases: Vec<IndexedDbDatabase> =
415 serde_json::from_value(value).unwrap_or_default();
416 debug!("Collected {} IndexedDB databases", databases.len());
417 Ok(databases)
418 } else {
419 Ok(Vec::new())
420 }
421 }
422}
423
424#[cfg(test)]
425mod tests;