Skip to main content

scitadel_db/sqlite/
papers.rs

1use std::collections::HashMap;
2
3use rusqlite::params;
4use scitadel_core::error::CoreError;
5use scitadel_core::models::{DownloadStatus, Paper, PaperId};
6use scitadel_core::ports::PaperRepository;
7
8use super::Database;
9use crate::error::DbError;
10
11const UPSERT_SQL: &str = "\
12    INSERT INTO papers
13        (id, title, authors, abstract, full_text, summary, doi, arxiv_id,
14         pubmed_id, inspire_id, openalex_id, year, journal, url,
15         source_urls, created_at, updated_at)
16    VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17)
17    ON CONFLICT(id) DO UPDATE SET
18        title      = excluded.title,
19        authors    = excluded.authors,
20        abstract   = CASE WHEN excluded.abstract != '' THEN excluded.abstract
21                          ELSE papers.abstract END,
22        full_text  = COALESCE(excluded.full_text, papers.full_text),
23        summary    = COALESCE(excluded.summary, papers.summary),
24        doi        = COALESCE(excluded.doi, papers.doi),
25        arxiv_id   = COALESCE(excluded.arxiv_id, papers.arxiv_id),
26        pubmed_id  = COALESCE(excluded.pubmed_id, papers.pubmed_id),
27        inspire_id = COALESCE(excluded.inspire_id, papers.inspire_id),
28        openalex_id= COALESCE(excluded.openalex_id, papers.openalex_id),
29        year       = COALESCE(excluded.year, papers.year),
30        journal    = COALESCE(excluded.journal, papers.journal),
31        url        = COALESCE(excluded.url, papers.url),
32        source_urls= excluded.source_urls,
33        updated_at = excluded.updated_at";
34
35pub struct SqlitePaperRepository {
36    db: Database,
37}
38
39impl SqlitePaperRepository {
40    pub fn new(db: Database) -> Self {
41        Self { db }
42    }
43
44    /// If a paper with the same DOI already exists, return a clone with the existing ID
45    /// so the upsert merges into the existing row instead of violating the DOI unique index.
46    fn resolve_doi_conflict(
47        conn: &rusqlite::Connection,
48        paper: &Paper,
49    ) -> Result<Paper, CoreError> {
50        if let Some(doi) = &paper.doi {
51            let existing_id: Option<String> = conn
52                .query_row(
53                    "SELECT id FROM papers WHERE doi = ?1 AND id != ?2",
54                    params![doi, paper.id.as_str()],
55                    |row| row.get(0),
56                )
57                .optional()
58                .map_err(DbError::Sqlite)?;
59            if let Some(id) = existing_id {
60                let mut merged = paper.clone();
61                merged.id = PaperId::from(id);
62                return Ok(merged);
63            }
64        }
65        Ok(paper.clone())
66    }
67
68    fn resolve_doi_conflict_tx(
69        tx: &rusqlite::Transaction<'_>,
70        paper: &Paper,
71    ) -> Result<Paper, CoreError> {
72        if let Some(doi) = &paper.doi {
73            let existing_id: Option<String> = tx
74                .query_row(
75                    "SELECT id FROM papers WHERE doi = ?1 AND id != ?2",
76                    params![doi, paper.id.as_str()],
77                    |row| row.get(0),
78                )
79                .optional()
80                .map_err(DbError::Sqlite)?;
81            if let Some(id) = existing_id {
82                let mut merged = paper.clone();
83                merged.id = PaperId::from(id);
84                return Ok(merged);
85            }
86        }
87        Ok(paper.clone())
88    }
89
90    fn paper_params(paper: &Paper) -> [Box<dyn rusqlite::types::ToSql>; 17] {
91        [
92            Box::new(paper.id.as_str().to_string()),
93            Box::new(paper.title.clone()),
94            Box::new(serde_json::to_string(&paper.authors).unwrap_or_default()),
95            Box::new(paper.r#abstract.clone()),
96            Box::new(paper.full_text.clone()),
97            Box::new(paper.summary.clone()),
98            Box::new(paper.doi.clone()),
99            Box::new(paper.arxiv_id.clone()),
100            Box::new(paper.pubmed_id.clone()),
101            Box::new(paper.inspire_id.clone()),
102            Box::new(paper.openalex_id.clone()),
103            Box::new(paper.year),
104            Box::new(paper.journal.clone()),
105            Box::new(paper.url.clone()),
106            Box::new(serde_json::to_string(&paper.source_urls).unwrap_or_default()),
107            Box::new(paper.created_at.to_rfc3339()),
108            Box::new(paper.updated_at.to_rfc3339()),
109        ]
110    }
111}
112
113fn row_to_paper(row: &rusqlite::Row) -> rusqlite::Result<Paper> {
114    let id: String = row.get("id")?;
115    let authors_json: String = row.get("authors")?;
116    let source_urls_json: String = row.get("source_urls")?;
117    let created_at: String = row.get("created_at")?;
118    let updated_at: String = row.get("updated_at")?;
119
120    let local_path: Option<String> = row.get("local_path").ok();
121    let download_status_raw: Option<String> = row.get("download_status").ok();
122    let last_attempt_at_raw: Option<String> = row.get("last_attempt_at").ok();
123    let bibtex_key: Option<String> = row.get("bibtex_key").ok();
124
125    Ok(Paper {
126        id: PaperId::from(id),
127        title: row.get("title")?,
128        authors: serde_json::from_str(&authors_json).unwrap_or_default(),
129        r#abstract: row.get("abstract")?,
130        full_text: row.get("full_text")?,
131        summary: row.get("summary")?,
132        doi: row.get("doi")?,
133        arxiv_id: row.get("arxiv_id")?,
134        pubmed_id: row.get("pubmed_id")?,
135        inspire_id: row.get("inspire_id")?,
136        openalex_id: row.get("openalex_id")?,
137        year: row.get("year")?,
138        journal: row.get("journal")?,
139        url: row.get("url")?,
140        source_urls: serde_json::from_str(&source_urls_json).unwrap_or_default(),
141        created_at: super::parse_rfc3339_or_now(&created_at),
142        updated_at: super::parse_rfc3339_or_now(&updated_at),
143        local_path,
144        download_status: download_status_raw
145            .as_deref()
146            .and_then(DownloadStatus::parse),
147        last_attempt_at: last_attempt_at_raw
148            .as_deref()
149            .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
150            .map(|dt| dt.with_timezone(&chrono::Utc)),
151        bibtex_key,
152    })
153}
154
155impl PaperRepository for SqlitePaperRepository {
156    fn save(&self, paper: &Paper) -> Result<(), CoreError> {
157        let conn = self.db.conn()?;
158        let paper = Self::resolve_doi_conflict(&conn, paper)?;
159        let p = Self::paper_params(&paper);
160        let params: Vec<&dyn rusqlite::types::ToSql> = p.iter().map(|b| b.as_ref()).collect();
161        match conn.execute(UPSERT_SQL, params.as_slice()) {
162            Ok(_) => Ok(()),
163            Err(rusqlite::Error::SqliteFailure(err, _))
164                if err.code == rusqlite::ErrorCode::ConstraintViolation =>
165            {
166                // DOI collision — retry with existing paper's ID
167                if let Some(doi) = &paper.doi {
168                    let existing_id: Option<String> = conn
169                        .query_row(
170                            "SELECT id FROM papers WHERE doi = ?1",
171                            params![doi],
172                            |row| row.get(0),
173                        )
174                        .optional()
175                        .map_err(DbError::Sqlite)?;
176                    if let Some(eid) = existing_id {
177                        let mut retry = paper.clone();
178                        retry.id = PaperId::from(eid);
179                        let p2 = Self::paper_params(&retry);
180                        let params2: Vec<&dyn rusqlite::types::ToSql> =
181                            p2.iter().map(|b| b.as_ref()).collect();
182                        conn.execute(UPSERT_SQL, params2.as_slice())
183                            .map_err(DbError::Sqlite)?;
184                    }
185                }
186                Ok(())
187            }
188            Err(e) => Err(DbError::Sqlite(e).into()),
189        }
190    }
191
192    fn save_many(&self, papers: &[Paper]) -> Result<HashMap<PaperId, PaperId>, CoreError> {
193        let mut conn = self.db.conn()?;
194        let mut id_remap = HashMap::new();
195        let tx = conn.transaction().map_err(DbError::Sqlite)?;
196        for paper in papers {
197            let resolved = Self::resolve_doi_conflict_tx(&tx, paper)?;
198            if resolved.id != paper.id {
199                id_remap.insert(paper.id.clone(), resolved.id.clone());
200            }
201            let p = Self::paper_params(&resolved);
202            let params: Vec<&dyn rusqlite::types::ToSql> = p.iter().map(|b| b.as_ref()).collect();
203            match tx.execute(UPSERT_SQL, params.as_slice()) {
204                Ok(_) => {}
205                Err(rusqlite::Error::SqliteFailure(err, _))
206                    if err.code == rusqlite::ErrorCode::ConstraintViolation =>
207                {
208                    // DOI unique-index collision that resolve_doi_conflict missed
209                    // (e.g. case variation, concurrent insert, or within-batch dup).
210                    // Look up the existing paper by DOI and retry as an update.
211                    if let Some(doi) = &resolved.doi {
212                        let existing_id: Option<String> = tx
213                            .query_row(
214                                "SELECT id FROM papers WHERE doi = ?1",
215                                params![doi],
216                                |row| row.get(0),
217                            )
218                            .optional()
219                            .map_err(DbError::Sqlite)?;
220                        if let Some(eid) = existing_id {
221                            id_remap.insert(paper.id.clone(), PaperId::from(eid.clone()));
222                            let mut retry = resolved.clone();
223                            retry.id = PaperId::from(eid);
224                            let p2 = Self::paper_params(&retry);
225                            let params2: Vec<&dyn rusqlite::types::ToSql> =
226                                p2.iter().map(|b| b.as_ref()).collect();
227                            tx.execute(UPSERT_SQL, params2.as_slice())
228                                .map_err(DbError::Sqlite)?;
229                        }
230                        // If no DOI match found either, skip silently — paper
231                        // may have been blocked by another unique constraint.
232                    }
233                }
234                Err(e) => return Err(DbError::Sqlite(e).into()),
235            }
236        }
237        tx.commit().map_err(DbError::Sqlite)?;
238        Ok(id_remap)
239    }
240
241    fn get(&self, paper_id: &str) -> Result<Option<Paper>, CoreError> {
242        let conn = self.db.conn()?;
243        let mut stmt = conn
244            .prepare("SELECT * FROM papers WHERE id = ?1")
245            .map_err(DbError::Sqlite)?;
246        let result = stmt
247            .query_row(params![paper_id], row_to_paper)
248            .optional()
249            .map_err(DbError::Sqlite)?;
250        Ok(result)
251    }
252
253    fn find_by_doi(&self, doi: &str) -> Result<Option<Paper>, CoreError> {
254        let conn = self.db.conn()?;
255        let mut stmt = conn
256            .prepare("SELECT * FROM papers WHERE doi = ?1")
257            .map_err(DbError::Sqlite)?;
258        let result = stmt
259            .query_row(params![doi], row_to_paper)
260            .optional()
261            .map_err(DbError::Sqlite)?;
262        Ok(result)
263    }
264
265    fn find_by_title(&self, title: &str) -> Result<Option<Paper>, CoreError> {
266        let conn = self.db.conn()?;
267        let mut stmt = conn
268            .prepare("SELECT * FROM papers WHERE LOWER(title) = LOWER(?1)")
269            .map_err(DbError::Sqlite)?;
270        let result = stmt
271            .query_row(params![title], row_to_paper)
272            .optional()
273            .map_err(DbError::Sqlite)?;
274        Ok(result)
275    }
276
277    fn list_all(&self, limit: i64, offset: i64) -> Result<Vec<Paper>, CoreError> {
278        let conn = self.db.conn()?;
279        let mut stmt = conn
280            .prepare("SELECT * FROM papers ORDER BY created_at DESC LIMIT ?1 OFFSET ?2")
281            .map_err(DbError::Sqlite)?;
282        let papers = stmt
283            .query_map(params![limit, offset], row_to_paper)
284            .map_err(DbError::Sqlite)?
285            .filter_map(Result::ok)
286            .collect();
287        Ok(papers)
288    }
289
290    fn update_full_text(&self, paper_id: &str, text: &str) -> Result<(), CoreError> {
291        let conn = self.db.conn()?;
292        conn.execute(
293            "UPDATE papers SET full_text = ?1, updated_at = ?2 WHERE id = ?3",
294            params![text, chrono::Utc::now().to_rfc3339(), paper_id],
295        )
296        .map_err(DbError::Sqlite)?;
297        Ok(())
298    }
299
300    fn update_download_state(
301        &self,
302        paper_id: &str,
303        local_path: Option<&str>,
304        status: DownloadStatus,
305    ) -> Result<(), CoreError> {
306        let conn = self.db.conn()?;
307        let now = chrono::Utc::now().to_rfc3339();
308        conn.execute(
309            "UPDATE papers SET local_path = ?1, download_status = ?2, last_attempt_at = ?3, \
310             updated_at = ?3 WHERE id = ?4",
311            params![local_path, status.as_str(), now, paper_id],
312        )
313        .map_err(DbError::Sqlite)?;
314        Ok(())
315    }
316
317    fn update_bibtex_key(&self, paper_id: &str, key: &str) -> Result<(), CoreError> {
318        let conn = self.db.conn()?;
319        conn.execute(
320            "UPDATE papers SET bibtex_key = ?1 WHERE id = ?2",
321            params![key, paper_id],
322        )
323        .map_err(DbError::Sqlite)?;
324        Ok(())
325    }
326}
327
328// Need this import for .optional()
329use rusqlite::OptionalExtension;
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate::sqlite::Database;
335
336    fn setup() -> (Database, SqlitePaperRepository) {
337        let db = Database::open_in_memory().unwrap();
338        db.migrate().unwrap();
339        let repo = SqlitePaperRepository::new(db.clone());
340        (db, repo)
341    }
342
343    #[test]
344    fn test_save_and_get() {
345        let (_, repo) = setup();
346        let paper = Paper::new("Test Paper");
347        repo.save(&paper).unwrap();
348
349        let loaded = repo.get(paper.id.as_str()).unwrap().unwrap();
350        assert_eq!(loaded.title, "Test Paper");
351    }
352
353    #[test]
354    fn test_find_by_doi() {
355        let (_, repo) = setup();
356        let mut paper = Paper::new("DOI Paper");
357        paper.doi = Some("10.1234/test".to_string());
358        repo.save(&paper).unwrap();
359
360        let found = repo.find_by_doi("10.1234/test").unwrap().unwrap();
361        assert_eq!(found.id, paper.id);
362    }
363
364    #[test]
365    fn test_upsert_merges() {
366        let (_, repo) = setup();
367        let mut paper = Paper::new("Merge Test");
368        paper.doi = Some("10.1234/merge".to_string());
369        repo.save(&paper).unwrap();
370
371        let mut updated = paper.clone();
372        updated.arxiv_id = Some("2301.00001".to_string());
373        repo.save(&updated).unwrap();
374
375        let loaded = repo.get(paper.id.as_str()).unwrap().unwrap();
376        assert_eq!(loaded.arxiv_id, Some("2301.00001".to_string()));
377    }
378
379    #[test]
380    fn test_doi_conflict_across_papers() {
381        let (_, repo) = setup();
382
383        // First paper with a DOI
384        let mut paper1 = Paper::new("Original Paper");
385        paper1.doi = Some("10.1234/conflict".to_string());
386        repo.save(&paper1).unwrap();
387
388        // Second paper with same DOI but different ID (simulates a second search)
389        let mut paper2 = Paper::new("Updated Title");
390        paper2.doi = Some("10.1234/conflict".to_string());
391        paper2.arxiv_id = Some("2301.99999".to_string());
392        repo.save(&paper2).unwrap();
393
394        // Should have merged into the original, not created a second row
395        let loaded = repo.find_by_doi("10.1234/conflict").unwrap().unwrap();
396        assert_eq!(loaded.id, paper1.id, "should reuse original paper ID");
397        assert_eq!(loaded.title, "Updated Title", "should update title");
398        assert_eq!(
399            loaded.arxiv_id,
400            Some("2301.99999".to_string()),
401            "should merge arxiv_id"
402        );
403    }
404
405    #[test]
406    fn test_doi_conflict_in_save_many() {
407        let (_, repo) = setup();
408
409        let mut existing = Paper::new("Existing Paper");
410        existing.doi = Some("10.1234/batch".to_string());
411        repo.save(&existing).unwrap();
412
413        // Batch save with a colliding DOI
414        let mut new_paper = Paper::new("Batch Paper");
415        new_paper.doi = Some("10.1234/batch".to_string());
416        new_paper.pubmed_id = Some("12345".to_string());
417        repo.save_many(&[new_paper]).unwrap();
418
419        let loaded = repo.find_by_doi("10.1234/batch").unwrap().unwrap();
420        assert_eq!(loaded.id, existing.id);
421        assert_eq!(loaded.pubmed_id, Some("12345".to_string()));
422    }
423
424    #[test]
425    fn test_list_all() {
426        let (_, repo) = setup();
427        for i in 0..5 {
428            let paper = Paper::new(format!("Paper {i}"));
429            repo.save(&paper).unwrap();
430        }
431
432        let papers = repo.list_all(3, 0).unwrap();
433        assert_eq!(papers.len(), 3);
434    }
435
436    /// Integration test: simulate two searches with overlapping DOIs going
437    /// through dedup → save_many, the same flow as the MCP search tool.
438    #[test]
439    fn test_cross_search_dedup_save_roundtrip() {
440        use scitadel_core::models::CandidatePaper;
441        use scitadel_core::services::dedup::deduplicate;
442
443        let (_, repo) = setup();
444
445        // --- First search: returns 3 papers ---
446        let candidates_1 = vec![
447            CandidatePaper {
448                doi: Some("10.1000/alpha".into()),
449                ..CandidatePaper::new("openalex", "oa-1", "Alpha Paper")
450            },
451            CandidatePaper {
452                doi: Some("10.1000/beta".into()),
453                ..CandidatePaper::new("openalex", "oa-2", "Beta Paper")
454            },
455            CandidatePaper {
456                doi: Some("10.1000/gamma".into()),
457                ..CandidatePaper::new("pubmed", "pm-1", "Gamma Paper")
458            },
459        ];
460        let (papers_1, _results_1) = deduplicate(&candidates_1, 0.85);
461        assert_eq!(papers_1.len(), 3);
462        let remap_1 = repo.save_many(&papers_1).unwrap();
463        assert!(remap_1.is_empty(), "no conflicts on first save");
464
465        // --- Second search: 2 overlapping DOIs + 1 new ---
466        let candidates_2 = vec![
467            CandidatePaper {
468                doi: Some("10.1000/alpha".into()),
469                arxiv_id: Some("2301.00001".into()),
470                ..CandidatePaper::new("arxiv", "ax-1", "Alpha Paper (arxiv)")
471            },
472            CandidatePaper {
473                doi: Some("10.1000/gamma".into()),
474                pubmed_id: Some("99999".into()),
475                ..CandidatePaper::new("pubmed", "pm-2", "Gamma Paper Revised")
476            },
477            CandidatePaper {
478                doi: Some("10.1000/delta".into()),
479                ..CandidatePaper::new("openalex", "oa-3", "Delta Paper")
480            },
481        ];
482        let (papers_2, results_2) = deduplicate(&candidates_2, 0.85);
483        assert_eq!(
484            papers_2.len(),
485            3,
486            "dedup sees them as distinct (different IDs)"
487        );
488
489        let remap_2 = repo.save_many(&papers_2).unwrap();
490        assert_eq!(
491            remap_2.len(),
492            2,
493            "alpha and gamma should remap to existing IDs"
494        );
495
496        // Verify the remap points to the original paper IDs
497        let alpha_original = papers_1
498            .iter()
499            .find(|p| p.doi.as_deref() == Some("10.1000/alpha"))
500            .unwrap();
501        let alpha_new = papers_2
502            .iter()
503            .find(|p| p.doi.as_deref() == Some("10.1000/alpha"))
504            .unwrap();
505        assert_eq!(remap_2[&alpha_new.id], alpha_original.id);
506
507        // Verify DB state: should have 4 papers total, not 6
508        let all = repo.list_all(100, 0).unwrap();
509        assert_eq!(all.len(), 4, "3 from first search + 1 new from second");
510
511        // Verify metadata was merged
512        let alpha = repo.find_by_doi("10.1000/alpha").unwrap().unwrap();
513        assert_eq!(alpha.id, alpha_original.id, "kept original ID");
514        assert_eq!(
515            alpha.arxiv_id,
516            Some("2301.00001".into()),
517            "merged arxiv_id from second search"
518        );
519
520        // Verify search_results can be remapped correctly
521        for sr in &results_2 {
522            let resolved_id = remap_2.get(&sr.paper_id).unwrap_or(&sr.paper_id);
523            assert!(
524                repo.get(resolved_id.as_str()).unwrap().is_some(),
525                "remapped paper_id should exist in DB"
526            );
527        }
528    }
529
530    #[test]
531    fn download_state_round_trips() {
532        let (_, repo) = setup();
533        let paper = Paper::new("DL state");
534        repo.save(&paper).unwrap();
535
536        // Pristine row: no download attempted yet.
537        let initial = repo.get(paper.id.as_str()).unwrap().unwrap();
538        assert!(initial.local_path.is_none());
539        assert!(initial.download_status.is_none());
540        assert!(initial.last_attempt_at.is_none());
541
542        // Successful download.
543        repo.update_download_state(
544            paper.id.as_str(),
545            Some("/tmp/foo.pdf"),
546            DownloadStatus::Downloaded,
547        )
548        .unwrap();
549        let after = repo.get(paper.id.as_str()).unwrap().unwrap();
550        assert_eq!(after.local_path.as_deref(), Some("/tmp/foo.pdf"));
551        assert_eq!(after.download_status, Some(DownloadStatus::Downloaded));
552        assert!(after.last_attempt_at.is_some());
553
554        // Subsequent failure overwrites cleanly (path None, status Failed).
555        repo.update_download_state(paper.id.as_str(), None, DownloadStatus::Failed)
556            .unwrap();
557        let failed = repo.get(paper.id.as_str()).unwrap().unwrap();
558        assert!(failed.local_path.is_none());
559        assert_eq!(failed.download_status, Some(DownloadStatus::Failed));
560    }
561}