Skip to main content

spider/features/
disk.rs

1#[cfg(feature = "disk")]
2use case_insensitive_string::CaseInsensitiveString;
3#[cfg(feature = "disk")]
4use hashbrown::HashSet;
5#[cfg(feature = "disk")]
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8#[cfg(feature = "disk")]
9use crate::utils::emit_log;
10#[cfg(feature = "disk")]
11use sqlx::{sqlite::SqlitePool, Sqlite, Transaction};
12
13#[cfg(feature = "disk")]
14lazy_static! {
15    static ref AC: aho_corasick::AhoCorasick = {
16        let patterns = vec![".", "/", ":", "\\", "?", "*", "\"", "<", ">", "|"];
17        aho_corasick::AhoCorasick::new(&patterns).expect("valid replacer")
18    };
19    static ref AC_REPLACE: [&'static str; 10] = ["_", "_", "_", "_", "_", "_", "_", "_", "_", "_"];
20}
21
22#[derive(Default, Debug, Clone)]
23#[cfg(feature = "disk")]
24/// Manage Sqlite database operations
25pub struct DatabaseHandler {
26    /// Persist after drop.
27    pub persist: bool,
28    /// The crawl ID.
29    pub crawl_id: Option<String>,
30    /// The connection pool.
31    pool: tokio::sync::OnceCell<SqlitePool>,
32    /// Initial seed ran.
33    pub seeded: bool,
34}
35
36#[derive(Default, Debug, Clone)]
37#[cfg(not(feature = "disk"))]
38/// Manage Sqlite database operations
39pub struct DatabaseHandler {
40    /// Persist after drop.
41    pub persist: bool,
42}
43
44#[cfg(not(feature = "disk"))]
45impl DatabaseHandler {
46    /// A new DB handler.
47    pub fn new(_crawl_id: &Option<String>) -> Self {
48        Default::default()
49    }
50    /// Delete the db by id.
51    pub fn delete_db_by_id(&mut self) {}
52}
53
54#[cfg(feature = "disk")]
55impl DatabaseHandler {
56    /// A new DB handler.
57    pub fn new(crawl_id: &Option<String>) -> Self {
58        Self {
59            persist: false,
60            pool: tokio::sync::OnceCell::const_new(),
61            crawl_id: match crawl_id {
62                Some(id) => {
63                    let sanitized_id = AC.replace_all(id, &*AC_REPLACE);
64
65                    Some(format!("{}_{}", sanitized_id, get_id()))
66                }
67                _ => None,
68            },
69            seeded: false,
70        }
71    }
72
73    /// Determine if the pool is initialized.
74    pub fn pool_inited(&self) -> bool {
75        self.pool.initialized()
76    }
77
78    /// Determine if a seed was already done.
79    pub fn ready(&self) -> bool {
80        self.seeded
81    }
82
83    /// Set the seeded state
84    pub fn set_seeded(&mut self, seeded: bool) {
85        self.seeded = seeded;
86    }
87
88    /// Set the persist state
89    pub fn set_persisted(&mut self, persist: bool) {
90        self.persist = persist;
91    }
92
93    /// Generate a sqlite pool.
94    pub async fn generate_pool(&self) -> SqlitePool {
95        let db_path = get_db_path(&self.crawl_id);
96        let direct = db_path.starts_with("sqlite://");
97
98        // not a shared sqlite db.
99        if direct {
100            create_file_and_directory(&db_path[9..]).await;
101        } else {
102            create_file_and_directory(&db_path).await;
103        }
104
105        let db_url = if direct {
106            db_path
107        } else {
108            format!("sqlite://{}", db_path)
109        };
110
111        let pool = SqlitePool::connect_lazy(&db_url).expect("Failed to connect to the database");
112
113        let create_resources_table = sqlx::query(
114            r#"CREATE TABLE IF NOT EXISTS resources (
115                            id INTEGER PRIMARY KEY,
116                            url TEXT NOT NULL COLLATE NOCASE
117                        );
118                        CREATE INDEX IF NOT EXISTS idx_url ON resources (url COLLATE NOCASE);"#,
119        )
120        .execute(&pool);
121
122        let create_signatures_table = sqlx::query(
123            r#"CREATE TABLE IF NOT EXISTS signatures (
124                            id INTEGER PRIMARY KEY,
125                            url INTEGER NOT NULL
126                        );
127                        CREATE INDEX IF NOT EXISTS idx_url ON signatures (url);"#,
128        )
129        .execute(&pool);
130
131        // Run the queries concurrently
132        let (resources_result, signatures_result) =
133            tokio::join!(create_resources_table, create_signatures_table);
134
135        // Handle possible errors
136        if let Err(e) = resources_result {
137            log::warn!("SQLite error creating resources table: {:?}", e);
138        }
139
140        if let Err(e) = signatures_result {
141            log::warn!("SQLite error creating signatures table: {:?}", e);
142        }
143
144        pool
145    }
146
147    /// Get or initialize the database pool
148    pub async fn initlaize_pool(&self) {
149        if !self.pool_inited() {
150            let _ = self.pool.set(self.generate_pool().await);
151        }
152    }
153
154    /// Set the pool directly.
155    pub async fn set_pool(&self, pool: SqlitePool) {
156        let _ = self.pool.set(pool);
157    }
158
159    /// Get or initialize the database pool
160    pub async fn get_db_pool(&self) -> &SqlitePool {
161        self.pool.get_or_init(|| self.generate_pool()).await
162    }
163
164    /// Check if a URL exists (ignore case)
165    pub async fn url_exists(&self, pool: &SqlitePool, url_to_check: &str) -> bool {
166        match sqlx::query("SELECT 1 FROM resources WHERE url = ? LIMIT 1")
167            .bind(url_to_check)
168            .fetch_optional(pool)
169            .await
170        {
171            Ok(result) => result.is_some(),
172            Err(e) => {
173                if let Some(db_err) = e.as_database_error() {
174                    emit_log(db_err.message());
175                } else {
176                    emit_log(&format!("A non-database error occurred: {:?}", e));
177                }
178                false
179            }
180        }
181    }
182
183    /// Check if a signature exists (ignore case)
184    pub async fn signature_exists(&self, pool: &SqlitePool, signature_to_check: u64) -> bool {
185        match sqlx::query("SELECT 1 FROM signatures WHERE url = ? LIMIT 1")
186            .bind(signature_to_check.to_string())
187            .fetch_optional(pool)
188            .await
189        {
190            Ok(result) => result.is_some(),
191            Err(e) => {
192                if let Some(db_err) = e.as_database_error() {
193                    emit_log(db_err.message());
194                } else {
195                    emit_log(&format!("A non-database error occurred: {:?}", e));
196                }
197                false
198            }
199        }
200    }
201
202    /// Insert a new URL if it doesn't exist
203    pub async fn insert_url(&self, pool: &SqlitePool, new_url: &str) {
204        if !self.url_exists(pool, new_url).await {
205            if let Err(e) = sqlx::query("INSERT INTO resources (url) VALUES (?)")
206                .bind(new_url)
207                .execute(pool)
208                .await
209            {
210                if let Some(db_err) = e.as_database_error() {
211                    emit_log(db_err.message());
212                } else {
213                    emit_log(&format!("A non-database error occurred: {:?}", e));
214                }
215            }
216        }
217    }
218
219    /// Insert a new signature if it doesn't exist
220    pub async fn insert_signature(&self, pool: &SqlitePool, new_signature: u64) {
221        if !self.signature_exists(pool, new_signature).await {
222            if let Err(e) = sqlx::query("INSERT INTO signatures (url) VALUES (?)")
223                .bind(new_signature.to_string())
224                .execute(pool)
225                .await
226            {
227                if let Some(db_err) = e.as_database_error() {
228                    emit_log(db_err.message());
229                } else {
230                    emit_log(&format!("A non-database error occurred: {:?}", e));
231                }
232            }
233        }
234    }
235
236    /// Seed the database and manage URLs
237    pub async fn seed(
238        &self,
239        pool: &SqlitePool,
240        mut urls: HashSet<CaseInsensitiveString>,
241    ) -> Result<HashSet<CaseInsensitiveString>, sqlx::Error> {
242        const CHUNK_SIZE: usize = 500;
243        const KEEP_COUNT: usize = 100;
244
245        let mut tx: Transaction<'_, Sqlite> = pool.begin().await?;
246        let mut keep_urls = HashSet::with_capacity(KEEP_COUNT);
247
248        for url in urls.iter().take(KEEP_COUNT) {
249            keep_urls.insert(url.clone());
250        }
251
252        urls.retain(|url| !keep_urls.contains(url));
253
254        for chunk in keep_urls.iter().collect::<Vec<_>>().chunks(CHUNK_SIZE) {
255            let mut query = "INSERT OR IGNORE INTO resources (url) VALUES ".to_string();
256            query.push_str(&vec!["(?)"; chunk.len()].join(", "));
257            let mut statement = sqlx::query(&query);
258
259            for url in chunk {
260                statement = statement.bind(url.to_string());
261            }
262
263            statement.execute(&mut *tx).await?;
264        }
265
266        for chunk in urls.drain().collect::<Vec<_>>().chunks(CHUNK_SIZE) {
267            let mut query = "INSERT OR IGNORE INTO resources (url) VALUES ".to_string();
268            query.push_str(&vec!["(?)"; chunk.len()].join(", "));
269            let mut statement = sqlx::query(&query);
270
271            for url in chunk {
272                statement = statement.bind(url.to_string());
273            }
274
275            statement.execute(&mut *tx).await?;
276        }
277
278        tx.commit().await?;
279
280        Ok(keep_urls)
281    }
282
283    /// Count the records stored.
284    pub async fn count_records(pool: &SqlitePool) -> Result<u64, sqlx::Error> {
285        let result = sqlx::query_scalar::<_, u64>("SELECT COUNT(*) FROM resources")
286            .fetch_one(pool)
287            .await?;
288        Ok(result)
289    }
290
291    /// Get all the resources stored.
292    pub async fn get_all_resources(
293        pool: &SqlitePool,
294    ) -> Result<HashSet<CaseInsensitiveString>, sqlx::Error> {
295        use sqlx::Row;
296        let rows = sqlx::query("SELECT url FROM resources")
297            .fetch_all(pool) // Fetches all rows at once.
298            .await?;
299
300        let urls = rows
301            .into_iter()
302            .map(|row| row.get::<String, _>("url").into())
303            .collect();
304
305        Ok(urls)
306    }
307
308    /// Clear DB by id
309    pub fn delete_db_by_id(&self) {
310        let _ = std::fs::remove_file(get_db_path(&self.crawl_id));
311    }
312
313    /// Clear the resources table.
314    pub async fn clear_table(pool: &SqlitePool) -> Result<(), sqlx::Error> {
315        let _ = tokio::join!(
316            sqlx::query("DELETE FROM resources").execute(pool),
317            sqlx::query("DELETE FROM signatures").execute(pool)
318        );
319        Ok(())
320    }
321}
322
323#[cfg(feature = "disk")]
324impl Drop for DatabaseHandler {
325    fn drop(&mut self) {
326        if !self.persist {
327            self.delete_db_by_id();
328        }
329    }
330}
331
332/// simple counter to get the next ID.
333#[cfg(feature = "disk")]
334fn get_id() -> usize {
335    static COUNTER: AtomicUsize = AtomicUsize::new(1);
336
337    let mut current = COUNTER.load(Ordering::Relaxed);
338    loop {
339        let next = if current == usize::MAX {
340            1
341        } else {
342            current + 1
343        };
344        match COUNTER.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
345            Ok(_) => return current,
346            Err(updated) => current = updated,
347        }
348    }
349}
350
351/// Get the db path.
352pub fn get_db_path(crawl_id: &Option<String>) -> String {
353    // Get the base database URL or default to a temporary directory
354    let base_url = std::env::var("SQLITE_DATABASE_URL").unwrap_or_else(|_| {
355        let temp_dir = std::env::temp_dir();
356        temp_dir.to_string_lossy().into_owned()
357    });
358
359    let delim = if base_url.starts_with("sqlite://memory:") {
360        ":"
361    } else {
362        "/"
363    };
364
365    // Determine the db_path
366    let db_path = match crawl_id {
367        Some(crawl_id) => {
368            format!(
369                "{}{delim}spider_{}.db",
370                base_url.trim_end_matches('/'),
371                crawl_id.replace(".", "_")
372            )
373        }
374        None => format!("{}{delim}spider.db", base_url.trim_end_matches('/')),
375    };
376
377    db_path
378}
379
380/// Create the file and directory if locally.
381#[cfg(feature = "disk")]
382async fn create_file_and_directory(file_path: &str) {
383    let path = std::path::Path::new(file_path);
384
385    if let Some(parent) = path.parent() {
386        let _ = crate::utils::uring_fs::create_dir_all(parent.display().to_string()).await;
387    }
388
389    if !path.exists() {
390        let _ = crate::utils::uring_fs::write_file(path.display().to_string(), Vec::new()).await;
391    }
392}
393
394#[cfg(test)]
395#[cfg(feature = "disk")]
396mod tests {
397    use super::*;
398    use tokio;
399
400    #[tokio::test]
401    async fn test_connect_db() {
402        let handler = DatabaseHandler::new(&Some("example.com".into()));
403        let test_url = CaseInsensitiveString::new("http://example.com");
404        let pool = handler.get_db_pool().await;
405
406        if handler.url_exists(pool, &test_url).await {
407            println!("URL '{}' already exists in the database.", test_url);
408        } else {
409            handler.insert_url(pool, &test_url).await;
410            println!("URL '{}' was inserted into the database.", test_url);
411        }
412
413        assert!(
414            handler.url_exists(pool, &test_url).await,
415            "URL should exist after insertion."
416        );
417    }
418
419    #[tokio::test]
420    async fn test_url_insert_and_exists() {
421        let handler = DatabaseHandler::new(&Some("example.com".into()));
422        let new_url = CaseInsensitiveString::new("http://new-example.com");
423        let pool = handler.get_db_pool().await;
424
425        assert!(
426            !handler.url_exists(pool, &new_url).await,
427            "URL should not exist initially."
428        );
429
430        handler.insert_url(pool, &new_url).await;
431        assert!(
432            handler.url_exists(pool, &new_url).await,
433            "URL should exist after insertion."
434        );
435    }
436
437    #[tokio::test]
438    async fn test_url_case_insensitivity() {
439        let handler = DatabaseHandler::new(&Some("case-test.com".into()));
440        let url1 = CaseInsensitiveString::new("http://case-test.com");
441        let url2 = CaseInsensitiveString::new("http://CASE-TEST.com");
442        let pool = handler.get_db_pool().await;
443
444        handler.insert_url(pool, &url1).await;
445        assert!(
446            handler.url_exists(pool, &url2).await,
447            "URL check should be case-insensitive."
448        );
449    }
450
451    #[tokio::test]
452    async fn test_seed_urls() {
453        let handler = DatabaseHandler::new(&Some("example.com".into()));
454        let mut urls = HashSet::new();
455        urls.insert(CaseInsensitiveString::new("http://foo.com"));
456        urls.insert(CaseInsensitiveString::new("http://bar.com"));
457        let pool = handler.get_db_pool().await;
458
459        handler
460            .seed(pool, urls.clone())
461            .await
462            .expect("Seeding failed");
463
464        for url in urls {
465            assert!(
466                handler.url_exists(pool, &url).await,
467                "Seeded URL should exist after seeding."
468            );
469        }
470    }
471}