spider_core/
checkpoint.rs1use spider_util::error::SpiderError;
9use spider_util::item::ScrapedItem;
10use spider_pipeline::pipeline::Pipeline;
11use spider_util::request::Request;
12use crate::spider::Spider;
13use dashmap::DashSet;
14use rmp_serde;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::{HashMap, VecDeque};
18use std::fs;
19use std::path::Path;
20use std::sync::Arc;
21use tracing::{info, warn};
22
23use tokio::sync::RwLock;
24
25#[derive(Serialize, Deserialize, Default, Clone, Debug)]
27pub struct SchedulerCheckpoint {
28 pub request_queue: VecDeque<Request>,
30 pub salvaged_requests: VecDeque<Request>,
32 pub visited_urls: DashSet<String>,
34}
35
36#[derive(Debug, Serialize, Deserialize, Default)]
38pub struct Checkpoint {
39 pub scheduler: SchedulerCheckpoint,
41 pub pipelines: HashMap<String, Value>,
43 #[serde(default)]
45 pub cookie_store: cookie_store::CookieStore,
46}
47
48pub async fn save_checkpoint<S: Spider>(
49 path: &Path,
50 scheduler_checkpoint: SchedulerCheckpoint,
51 pipelines: &Arc<Vec<Box<dyn Pipeline<S::Item>>>>,
52 cookie_store: &Arc<RwLock<cookie_store::CookieStore>>,
53) -> Result<(), SpiderError>
54where
55 S::Item: ScrapedItem,
56{
57 info!("Saving checkpoint to {:?}", path);
58
59 let mut pipelines_checkpoint_map = HashMap::new();
60 for pipeline in pipelines.iter() {
61 if let Some(state) = pipeline.get_state().await? {
62 pipelines_checkpoint_map.insert(pipeline.name().to_string(), state);
63 }
64 }
65
66 if !scheduler_checkpoint.salvaged_requests.is_empty() {
67 warn!(
68 "Found {} salvaged requests during checkpoint. These have been added to the request queue.",
69 scheduler_checkpoint.salvaged_requests.len()
70 );
71 }
72
73 let checkpoint = {
74 let cookie_store_read = cookie_store.read().await;
75 let cookie_store_clone = (*cookie_store_read).clone();
76 drop(cookie_store_read);
77 Checkpoint {
78 scheduler: scheduler_checkpoint,
79 pipelines: pipelines_checkpoint_map,
80 cookie_store: cookie_store_clone,
81 }
82 };
83
84 let tmp_path = path.with_extension("tmp");
85 let encoded = rmp_serde::to_vec(&checkpoint)
86 .map_err(|e| SpiderError::GeneralError(format!("Failed to serialize checkpoint: {}", e)))?;
87 fs::write(&tmp_path, encoded).map_err(|e| {
88 SpiderError::GeneralError(format!(
89 "Failed to write checkpoint to temporary file: {}",
90 e
91 ))
92 })?;
93 fs::rename(&tmp_path, path).map_err(|e| {
94 SpiderError::GeneralError(format!("Failed to rename temporary checkpoint file: {}", e))
95 })?;
96
97 info!("Checkpoint saved successfully.");
98 Ok(())
99}