1#[cfg(feature = "disk")]
2use case_insensitive_string::CaseInsensitiveString;
3#[cfg(feature = "disk")]
4use hashbrown::HashSet;
5#[cfg(feature = "disk")]
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8#[cfg(feature = "disk")]
9use crate::utils::emit_log;
10#[cfg(feature = "disk")]
11use sqlx::{sqlite::SqlitePool, Sqlite, Transaction};
12
13#[cfg(feature = "disk")]
14lazy_static! {
15 static ref AC: aho_corasick::AhoCorasick = {
16 let patterns = vec![".", "/", ":", "\\", "?", "*", "\"", "<", ">", "|"];
17 aho_corasick::AhoCorasick::new(&patterns).expect("valid replacer")
18 };
19 static ref AC_REPLACE: [&'static str; 10] = ["_", "_", "_", "_", "_", "_", "_", "_", "_", "_"];
20}
21
22#[derive(Default, Debug, Clone)]
23#[cfg(feature = "disk")]
24pub struct DatabaseHandler {
26 pub persist: bool,
28 pub crawl_id: Option<String>,
30 pool: tokio::sync::OnceCell<SqlitePool>,
32 pub seeded: bool,
34}
35
36#[derive(Default, Debug, Clone)]
37#[cfg(not(feature = "disk"))]
38pub struct DatabaseHandler {
40 pub persist: bool,
42}
43
44#[cfg(not(feature = "disk"))]
45impl DatabaseHandler {
46 pub fn new(_crawl_id: &Option<String>) -> Self {
48 Default::default()
49 }
50 pub fn delete_db_by_id(&mut self) {}
52}
53
54#[cfg(feature = "disk")]
55impl DatabaseHandler {
56 pub fn new(crawl_id: &Option<String>) -> Self {
58 Self {
59 persist: false,
60 pool: tokio::sync::OnceCell::const_new(),
61 crawl_id: match crawl_id {
62 Some(id) => {
63 let sanitized_id = AC.replace_all(id, &*AC_REPLACE);
64
65 Some(format!("{}_{}", sanitized_id, get_id()))
66 }
67 _ => None,
68 },
69 seeded: false,
70 }
71 }
72
73 pub fn pool_inited(&self) -> bool {
75 self.pool.initialized()
76 }
77
78 pub fn ready(&self) -> bool {
80 self.seeded
81 }
82
83 pub fn set_seeded(&mut self, seeded: bool) {
85 self.seeded = seeded;
86 }
87
88 pub fn set_persisted(&mut self, persist: bool) {
90 self.persist = persist;
91 }
92
93 pub async fn generate_pool(&self) -> SqlitePool {
95 let db_path = get_db_path(&self.crawl_id);
96 let direct = db_path.starts_with("sqlite://");
97
98 if direct {
100 create_file_and_directory(&db_path[9..]).await;
101 } else {
102 create_file_and_directory(&db_path).await;
103 }
104
105 let db_url = if direct {
106 db_path
107 } else {
108 format!("sqlite://{}", db_path)
109 };
110
111 let pool = SqlitePool::connect_lazy(&db_url).expect("Failed to connect to the database");
112
113 let create_resources_table = sqlx::query(
114 r#"CREATE TABLE IF NOT EXISTS resources (
115 id INTEGER PRIMARY KEY,
116 url TEXT NOT NULL COLLATE NOCASE
117 );
118 CREATE INDEX IF NOT EXISTS idx_url ON resources (url COLLATE NOCASE);"#,
119 )
120 .execute(&pool);
121
122 let create_signatures_table = sqlx::query(
123 r#"CREATE TABLE IF NOT EXISTS signatures (
124 id INTEGER PRIMARY KEY,
125 url INTEGER NOT NULL
126 );
127 CREATE INDEX IF NOT EXISTS idx_url ON signatures (url);"#,
128 )
129 .execute(&pool);
130
131 let (resources_result, signatures_result) =
133 tokio::join!(create_resources_table, create_signatures_table);
134
135 if let Err(e) = resources_result {
137 log::warn!("SQLite error creating resources table: {:?}", e);
138 }
139
140 if let Err(e) = signatures_result {
141 log::warn!("SQLite error creating signatures table: {:?}", e);
142 }
143
144 pool
145 }
146
147 pub async fn initlaize_pool(&self) {
149 if !self.pool_inited() {
150 let _ = self.pool.set(self.generate_pool().await);
151 }
152 }
153
154 pub async fn set_pool(&self, pool: SqlitePool) {
156 let _ = self.pool.set(pool);
157 }
158
159 pub async fn get_db_pool(&self) -> &SqlitePool {
161 self.pool.get_or_init(|| self.generate_pool()).await
162 }
163
164 pub async fn url_exists(&self, pool: &SqlitePool, url_to_check: &str) -> bool {
166 match sqlx::query("SELECT 1 FROM resources WHERE url = ? LIMIT 1")
167 .bind(url_to_check)
168 .fetch_optional(pool)
169 .await
170 {
171 Ok(result) => result.is_some(),
172 Err(e) => {
173 if let Some(db_err) = e.as_database_error() {
174 emit_log(db_err.message());
175 } else {
176 emit_log(&format!("A non-database error occurred: {:?}", e));
177 }
178 false
179 }
180 }
181 }
182
183 pub async fn signature_exists(&self, pool: &SqlitePool, signature_to_check: u64) -> bool {
185 match sqlx::query("SELECT 1 FROM signatures WHERE url = ? LIMIT 1")
186 .bind(signature_to_check.to_string())
187 .fetch_optional(pool)
188 .await
189 {
190 Ok(result) => result.is_some(),
191 Err(e) => {
192 if let Some(db_err) = e.as_database_error() {
193 emit_log(db_err.message());
194 } else {
195 emit_log(&format!("A non-database error occurred: {:?}", e));
196 }
197 false
198 }
199 }
200 }
201
202 pub async fn insert_url(&self, pool: &SqlitePool, new_url: &str) {
204 if !self.url_exists(pool, new_url).await {
205 if let Err(e) = sqlx::query("INSERT INTO resources (url) VALUES (?)")
206 .bind(new_url)
207 .execute(pool)
208 .await
209 {
210 if let Some(db_err) = e.as_database_error() {
211 emit_log(db_err.message());
212 } else {
213 emit_log(&format!("A non-database error occurred: {:?}", e));
214 }
215 }
216 }
217 }
218
219 pub async fn insert_signature(&self, pool: &SqlitePool, new_signature: u64) {
221 if !self.signature_exists(pool, new_signature).await {
222 if let Err(e) = sqlx::query("INSERT INTO signatures (url) VALUES (?)")
223 .bind(new_signature.to_string())
224 .execute(pool)
225 .await
226 {
227 if let Some(db_err) = e.as_database_error() {
228 emit_log(db_err.message());
229 } else {
230 emit_log(&format!("A non-database error occurred: {:?}", e));
231 }
232 }
233 }
234 }
235
236 pub async fn seed(
238 &self,
239 pool: &SqlitePool,
240 mut urls: HashSet<CaseInsensitiveString>,
241 ) -> Result<HashSet<CaseInsensitiveString>, sqlx::Error> {
242 const CHUNK_SIZE: usize = 500;
243 const KEEP_COUNT: usize = 100;
244
245 let mut tx: Transaction<'_, Sqlite> = pool.begin().await?;
246 let mut keep_urls = HashSet::with_capacity(KEEP_COUNT);
247
248 for url in urls.iter().take(KEEP_COUNT) {
249 keep_urls.insert(url.clone());
250 }
251
252 urls.retain(|url| !keep_urls.contains(url));
253
254 for chunk in keep_urls.iter().collect::<Vec<_>>().chunks(CHUNK_SIZE) {
255 let mut query = "INSERT OR IGNORE INTO resources (url) VALUES ".to_string();
256 query.push_str(&vec!["(?)"; chunk.len()].join(", "));
257 let mut statement = sqlx::query(&query);
258
259 for url in chunk {
260 statement = statement.bind(url.to_string());
261 }
262
263 statement.execute(&mut *tx).await?;
264 }
265
266 for chunk in urls.drain().collect::<Vec<_>>().chunks(CHUNK_SIZE) {
267 let mut query = "INSERT OR IGNORE INTO resources (url) VALUES ".to_string();
268 query.push_str(&vec!["(?)"; chunk.len()].join(", "));
269 let mut statement = sqlx::query(&query);
270
271 for url in chunk {
272 statement = statement.bind(url.to_string());
273 }
274
275 statement.execute(&mut *tx).await?;
276 }
277
278 tx.commit().await?;
279
280 Ok(keep_urls)
281 }
282
283 pub async fn count_records(pool: &SqlitePool) -> Result<u64, sqlx::Error> {
285 let result = sqlx::query_scalar::<_, u64>("SELECT COUNT(*) FROM resources")
286 .fetch_one(pool)
287 .await?;
288 Ok(result)
289 }
290
291 pub async fn get_all_resources(
293 pool: &SqlitePool,
294 ) -> Result<HashSet<CaseInsensitiveString>, sqlx::Error> {
295 use sqlx::Row;
296 let rows = sqlx::query("SELECT url FROM resources")
297 .fetch_all(pool) .await?;
299
300 let urls = rows
301 .into_iter()
302 .map(|row| row.get::<String, _>("url").into())
303 .collect();
304
305 Ok(urls)
306 }
307
308 pub fn delete_db_by_id(&self) {
310 let _ = std::fs::remove_file(get_db_path(&self.crawl_id));
311 }
312
313 pub async fn clear_table(pool: &SqlitePool) -> Result<(), sqlx::Error> {
315 let _ = tokio::join!(
316 sqlx::query("DELETE FROM resources").execute(pool),
317 sqlx::query("DELETE FROM signatures").execute(pool)
318 );
319 Ok(())
320 }
321}
322
323#[cfg(feature = "disk")]
324impl Drop for DatabaseHandler {
325 fn drop(&mut self) {
326 if !self.persist {
327 self.delete_db_by_id();
328 }
329 }
330}
331
332#[cfg(feature = "disk")]
334fn get_id() -> usize {
335 static COUNTER: AtomicUsize = AtomicUsize::new(1);
336
337 let mut current = COUNTER.load(Ordering::Relaxed);
338 loop {
339 let next = if current == usize::MAX {
340 1
341 } else {
342 current + 1
343 };
344 match COUNTER.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
345 Ok(_) => return current,
346 Err(updated) => current = updated,
347 }
348 }
349}
350
351pub fn get_db_path(crawl_id: &Option<String>) -> String {
353 let base_url = std::env::var("SQLITE_DATABASE_URL").unwrap_or_else(|_| {
355 let temp_dir = std::env::temp_dir();
356 temp_dir.to_string_lossy().into_owned()
357 });
358
359 let delim = if base_url.starts_with("sqlite://memory:") {
360 ":"
361 } else {
362 "/"
363 };
364
365 let db_path = match crawl_id {
367 Some(crawl_id) => {
368 format!(
369 "{}{delim}spider_{}.db",
370 base_url.trim_end_matches('/'),
371 crawl_id.replace(".", "_")
372 )
373 }
374 None => format!("{}{delim}spider.db", base_url.trim_end_matches('/')),
375 };
376
377 db_path
378}
379
380#[cfg(feature = "disk")]
382async fn create_file_and_directory(file_path: &str) {
383 let path = std::path::Path::new(file_path);
384
385 if let Some(parent) = path.parent() {
386 let _ = crate::utils::uring_fs::create_dir_all(parent.display().to_string()).await;
387 }
388
389 if !path.exists() {
390 let _ = crate::utils::uring_fs::write_file(path.display().to_string(), Vec::new()).await;
391 }
392}
393
394#[cfg(test)]
395#[cfg(feature = "disk")]
396mod tests {
397 use super::*;
398 use tokio;
399
400 #[tokio::test]
401 async fn test_connect_db() {
402 let handler = DatabaseHandler::new(&Some("example.com".into()));
403 let test_url = CaseInsensitiveString::new("http://example.com");
404 let pool = handler.get_db_pool().await;
405
406 if handler.url_exists(pool, &test_url).await {
407 println!("URL '{}' already exists in the database.", test_url);
408 } else {
409 handler.insert_url(pool, &test_url).await;
410 println!("URL '{}' was inserted into the database.", test_url);
411 }
412
413 assert!(
414 handler.url_exists(pool, &test_url).await,
415 "URL should exist after insertion."
416 );
417 }
418
419 #[tokio::test]
420 async fn test_url_insert_and_exists() {
421 let handler = DatabaseHandler::new(&Some("example.com".into()));
422 let new_url = CaseInsensitiveString::new("http://new-example.com");
423 let pool = handler.get_db_pool().await;
424
425 assert!(
426 !handler.url_exists(pool, &new_url).await,
427 "URL should not exist initially."
428 );
429
430 handler.insert_url(pool, &new_url).await;
431 assert!(
432 handler.url_exists(pool, &new_url).await,
433 "URL should exist after insertion."
434 );
435 }
436
437 #[tokio::test]
438 async fn test_url_case_insensitivity() {
439 let handler = DatabaseHandler::new(&Some("case-test.com".into()));
440 let url1 = CaseInsensitiveString::new("http://case-test.com");
441 let url2 = CaseInsensitiveString::new("http://CASE-TEST.com");
442 let pool = handler.get_db_pool().await;
443
444 handler.insert_url(pool, &url1).await;
445 assert!(
446 handler.url_exists(pool, &url2).await,
447 "URL check should be case-insensitive."
448 );
449 }
450
451 #[tokio::test]
452 async fn test_seed_urls() {
453 let handler = DatabaseHandler::new(&Some("example.com".into()));
454 let mut urls = HashSet::new();
455 urls.insert(CaseInsensitiveString::new("http://foo.com"));
456 urls.insert(CaseInsensitiveString::new("http://bar.com"));
457 let pool = handler.get_db_pool().await;
458
459 handler
460 .seed(pool, urls.clone())
461 .await
462 .expect("Seeding failed");
463
464 for url in urls {
465 assert!(
466 handler.url_exists(pool, &url).await,
467 "Seeded URL should exist after seeding."
468 );
469 }
470 }
471}