1use std::collections::HashMap;
2
3use rusqlite::params;
4use scitadel_core::error::CoreError;
5use scitadel_core::models::{DownloadStatus, Paper, PaperId};
6use scitadel_core::ports::PaperRepository;
7
8use super::Database;
9use crate::error::DbError;
10
11const UPSERT_SQL: &str = "\
12 INSERT INTO papers
13 (id, title, authors, abstract, full_text, summary, doi, arxiv_id,
14 pubmed_id, inspire_id, openalex_id, year, journal, url,
15 source_urls, created_at, updated_at)
16 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17)
17 ON CONFLICT(id) DO UPDATE SET
18 title = excluded.title,
19 authors = excluded.authors,
20 abstract = CASE WHEN excluded.abstract != '' THEN excluded.abstract
21 ELSE papers.abstract END,
22 full_text = COALESCE(excluded.full_text, papers.full_text),
23 summary = COALESCE(excluded.summary, papers.summary),
24 doi = COALESCE(excluded.doi, papers.doi),
25 arxiv_id = COALESCE(excluded.arxiv_id, papers.arxiv_id),
26 pubmed_id = COALESCE(excluded.pubmed_id, papers.pubmed_id),
27 inspire_id = COALESCE(excluded.inspire_id, papers.inspire_id),
28 openalex_id= COALESCE(excluded.openalex_id, papers.openalex_id),
29 year = COALESCE(excluded.year, papers.year),
30 journal = COALESCE(excluded.journal, papers.journal),
31 url = COALESCE(excluded.url, papers.url),
32 source_urls= excluded.source_urls,
33 updated_at = excluded.updated_at";
34
35pub struct SqlitePaperRepository {
36 db: Database,
37}
38
39impl SqlitePaperRepository {
40 pub fn new(db: Database) -> Self {
41 Self { db }
42 }
43
44 fn resolve_doi_conflict(
47 conn: &rusqlite::Connection,
48 paper: &Paper,
49 ) -> Result<Paper, CoreError> {
50 if let Some(doi) = &paper.doi {
51 let existing_id: Option<String> = conn
52 .query_row(
53 "SELECT id FROM papers WHERE doi = ?1 AND id != ?2",
54 params![doi, paper.id.as_str()],
55 |row| row.get(0),
56 )
57 .optional()
58 .map_err(DbError::Sqlite)?;
59 if let Some(id) = existing_id {
60 let mut merged = paper.clone();
61 merged.id = PaperId::from(id);
62 return Ok(merged);
63 }
64 }
65 Ok(paper.clone())
66 }
67
68 fn resolve_doi_conflict_tx(
69 tx: &rusqlite::Transaction<'_>,
70 paper: &Paper,
71 ) -> Result<Paper, CoreError> {
72 if let Some(doi) = &paper.doi {
73 let existing_id: Option<String> = tx
74 .query_row(
75 "SELECT id FROM papers WHERE doi = ?1 AND id != ?2",
76 params![doi, paper.id.as_str()],
77 |row| row.get(0),
78 )
79 .optional()
80 .map_err(DbError::Sqlite)?;
81 if let Some(id) = existing_id {
82 let mut merged = paper.clone();
83 merged.id = PaperId::from(id);
84 return Ok(merged);
85 }
86 }
87 Ok(paper.clone())
88 }
89
90 fn paper_params(paper: &Paper) -> [Box<dyn rusqlite::types::ToSql>; 17] {
91 [
92 Box::new(paper.id.as_str().to_string()),
93 Box::new(paper.title.clone()),
94 Box::new(serde_json::to_string(&paper.authors).unwrap_or_default()),
95 Box::new(paper.r#abstract.clone()),
96 Box::new(paper.full_text.clone()),
97 Box::new(paper.summary.clone()),
98 Box::new(paper.doi.clone()),
99 Box::new(paper.arxiv_id.clone()),
100 Box::new(paper.pubmed_id.clone()),
101 Box::new(paper.inspire_id.clone()),
102 Box::new(paper.openalex_id.clone()),
103 Box::new(paper.year),
104 Box::new(paper.journal.clone()),
105 Box::new(paper.url.clone()),
106 Box::new(serde_json::to_string(&paper.source_urls).unwrap_or_default()),
107 Box::new(paper.created_at.to_rfc3339()),
108 Box::new(paper.updated_at.to_rfc3339()),
109 ]
110 }
111}
112
113fn row_to_paper(row: &rusqlite::Row) -> rusqlite::Result<Paper> {
114 let id: String = row.get("id")?;
115 let authors_json: String = row.get("authors")?;
116 let source_urls_json: String = row.get("source_urls")?;
117 let created_at: String = row.get("created_at")?;
118 let updated_at: String = row.get("updated_at")?;
119
120 let local_path: Option<String> = row.get("local_path").ok();
121 let download_status_raw: Option<String> = row.get("download_status").ok();
122 let last_attempt_at_raw: Option<String> = row.get("last_attempt_at").ok();
123 let bibtex_key: Option<String> = row.get("bibtex_key").ok();
124
125 Ok(Paper {
126 id: PaperId::from(id),
127 title: row.get("title")?,
128 authors: serde_json::from_str(&authors_json).unwrap_or_default(),
129 r#abstract: row.get("abstract")?,
130 full_text: row.get("full_text")?,
131 summary: row.get("summary")?,
132 doi: row.get("doi")?,
133 arxiv_id: row.get("arxiv_id")?,
134 pubmed_id: row.get("pubmed_id")?,
135 inspire_id: row.get("inspire_id")?,
136 openalex_id: row.get("openalex_id")?,
137 year: row.get("year")?,
138 journal: row.get("journal")?,
139 url: row.get("url")?,
140 source_urls: serde_json::from_str(&source_urls_json).unwrap_or_default(),
141 created_at: super::parse_rfc3339_or_now(&created_at),
142 updated_at: super::parse_rfc3339_or_now(&updated_at),
143 local_path,
144 download_status: download_status_raw
145 .as_deref()
146 .and_then(DownloadStatus::parse),
147 last_attempt_at: last_attempt_at_raw
148 .as_deref()
149 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
150 .map(|dt| dt.with_timezone(&chrono::Utc)),
151 bibtex_key,
152 })
153}
154
155impl PaperRepository for SqlitePaperRepository {
156 fn save(&self, paper: &Paper) -> Result<(), CoreError> {
157 let conn = self.db.conn()?;
158 let paper = Self::resolve_doi_conflict(&conn, paper)?;
159 let p = Self::paper_params(&paper);
160 let params: Vec<&dyn rusqlite::types::ToSql> = p.iter().map(|b| b.as_ref()).collect();
161 match conn.execute(UPSERT_SQL, params.as_slice()) {
162 Ok(_) => Ok(()),
163 Err(rusqlite::Error::SqliteFailure(err, _))
164 if err.code == rusqlite::ErrorCode::ConstraintViolation =>
165 {
166 if let Some(doi) = &paper.doi {
168 let existing_id: Option<String> = conn
169 .query_row(
170 "SELECT id FROM papers WHERE doi = ?1",
171 params![doi],
172 |row| row.get(0),
173 )
174 .optional()
175 .map_err(DbError::Sqlite)?;
176 if let Some(eid) = existing_id {
177 let mut retry = paper.clone();
178 retry.id = PaperId::from(eid);
179 let p2 = Self::paper_params(&retry);
180 let params2: Vec<&dyn rusqlite::types::ToSql> =
181 p2.iter().map(|b| b.as_ref()).collect();
182 conn.execute(UPSERT_SQL, params2.as_slice())
183 .map_err(DbError::Sqlite)?;
184 }
185 }
186 Ok(())
187 }
188 Err(e) => Err(DbError::Sqlite(e).into()),
189 }
190 }
191
192 fn save_many(&self, papers: &[Paper]) -> Result<HashMap<PaperId, PaperId>, CoreError> {
193 let mut conn = self.db.conn()?;
194 let mut id_remap = HashMap::new();
195 let tx = conn.transaction().map_err(DbError::Sqlite)?;
196 for paper in papers {
197 let resolved = Self::resolve_doi_conflict_tx(&tx, paper)?;
198 if resolved.id != paper.id {
199 id_remap.insert(paper.id.clone(), resolved.id.clone());
200 }
201 let p = Self::paper_params(&resolved);
202 let params: Vec<&dyn rusqlite::types::ToSql> = p.iter().map(|b| b.as_ref()).collect();
203 match tx.execute(UPSERT_SQL, params.as_slice()) {
204 Ok(_) => {}
205 Err(rusqlite::Error::SqliteFailure(err, _))
206 if err.code == rusqlite::ErrorCode::ConstraintViolation =>
207 {
208 if let Some(doi) = &resolved.doi {
212 let existing_id: Option<String> = tx
213 .query_row(
214 "SELECT id FROM papers WHERE doi = ?1",
215 params![doi],
216 |row| row.get(0),
217 )
218 .optional()
219 .map_err(DbError::Sqlite)?;
220 if let Some(eid) = existing_id {
221 id_remap.insert(paper.id.clone(), PaperId::from(eid.clone()));
222 let mut retry = resolved.clone();
223 retry.id = PaperId::from(eid);
224 let p2 = Self::paper_params(&retry);
225 let params2: Vec<&dyn rusqlite::types::ToSql> =
226 p2.iter().map(|b| b.as_ref()).collect();
227 tx.execute(UPSERT_SQL, params2.as_slice())
228 .map_err(DbError::Sqlite)?;
229 }
230 }
233 }
234 Err(e) => return Err(DbError::Sqlite(e).into()),
235 }
236 }
237 tx.commit().map_err(DbError::Sqlite)?;
238 Ok(id_remap)
239 }
240
241 fn get(&self, paper_id: &str) -> Result<Option<Paper>, CoreError> {
242 let conn = self.db.conn()?;
243 let mut stmt = conn
244 .prepare("SELECT * FROM papers WHERE id = ?1")
245 .map_err(DbError::Sqlite)?;
246 let result = stmt
247 .query_row(params![paper_id], row_to_paper)
248 .optional()
249 .map_err(DbError::Sqlite)?;
250 Ok(result)
251 }
252
253 fn find_by_doi(&self, doi: &str) -> Result<Option<Paper>, CoreError> {
254 let conn = self.db.conn()?;
255 let mut stmt = conn
256 .prepare("SELECT * FROM papers WHERE doi = ?1")
257 .map_err(DbError::Sqlite)?;
258 let result = stmt
259 .query_row(params![doi], row_to_paper)
260 .optional()
261 .map_err(DbError::Sqlite)?;
262 Ok(result)
263 }
264
265 fn find_by_title(&self, title: &str) -> Result<Option<Paper>, CoreError> {
266 let conn = self.db.conn()?;
267 let mut stmt = conn
268 .prepare("SELECT * FROM papers WHERE LOWER(title) = LOWER(?1)")
269 .map_err(DbError::Sqlite)?;
270 let result = stmt
271 .query_row(params![title], row_to_paper)
272 .optional()
273 .map_err(DbError::Sqlite)?;
274 Ok(result)
275 }
276
277 fn list_all(&self, limit: i64, offset: i64) -> Result<Vec<Paper>, CoreError> {
278 let conn = self.db.conn()?;
279 let mut stmt = conn
280 .prepare("SELECT * FROM papers ORDER BY created_at DESC LIMIT ?1 OFFSET ?2")
281 .map_err(DbError::Sqlite)?;
282 let papers = stmt
283 .query_map(params![limit, offset], row_to_paper)
284 .map_err(DbError::Sqlite)?
285 .filter_map(Result::ok)
286 .collect();
287 Ok(papers)
288 }
289
290 fn update_full_text(&self, paper_id: &str, text: &str) -> Result<(), CoreError> {
291 let conn = self.db.conn()?;
292 conn.execute(
293 "UPDATE papers SET full_text = ?1, updated_at = ?2 WHERE id = ?3",
294 params![text, chrono::Utc::now().to_rfc3339(), paper_id],
295 )
296 .map_err(DbError::Sqlite)?;
297 Ok(())
298 }
299
300 fn update_download_state(
301 &self,
302 paper_id: &str,
303 local_path: Option<&str>,
304 status: DownloadStatus,
305 ) -> Result<(), CoreError> {
306 let conn = self.db.conn()?;
307 let now = chrono::Utc::now().to_rfc3339();
308 conn.execute(
309 "UPDATE papers SET local_path = ?1, download_status = ?2, last_attempt_at = ?3, \
310 updated_at = ?3 WHERE id = ?4",
311 params![local_path, status.as_str(), now, paper_id],
312 )
313 .map_err(DbError::Sqlite)?;
314 Ok(())
315 }
316
317 fn update_bibtex_key(&self, paper_id: &str, key: &str) -> Result<(), CoreError> {
318 let conn = self.db.conn()?;
319 conn.execute(
320 "UPDATE papers SET bibtex_key = ?1 WHERE id = ?2",
321 params![key, paper_id],
322 )
323 .map_err(DbError::Sqlite)?;
324 Ok(())
325 }
326}
327
328use rusqlite::OptionalExtension;
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::sqlite::Database;
335
336 fn setup() -> (Database, SqlitePaperRepository) {
337 let db = Database::open_in_memory().unwrap();
338 db.migrate().unwrap();
339 let repo = SqlitePaperRepository::new(db.clone());
340 (db, repo)
341 }
342
343 #[test]
344 fn test_save_and_get() {
345 let (_, repo) = setup();
346 let paper = Paper::new("Test Paper");
347 repo.save(&paper).unwrap();
348
349 let loaded = repo.get(paper.id.as_str()).unwrap().unwrap();
350 assert_eq!(loaded.title, "Test Paper");
351 }
352
353 #[test]
354 fn test_find_by_doi() {
355 let (_, repo) = setup();
356 let mut paper = Paper::new("DOI Paper");
357 paper.doi = Some("10.1234/test".to_string());
358 repo.save(&paper).unwrap();
359
360 let found = repo.find_by_doi("10.1234/test").unwrap().unwrap();
361 assert_eq!(found.id, paper.id);
362 }
363
364 #[test]
365 fn test_upsert_merges() {
366 let (_, repo) = setup();
367 let mut paper = Paper::new("Merge Test");
368 paper.doi = Some("10.1234/merge".to_string());
369 repo.save(&paper).unwrap();
370
371 let mut updated = paper.clone();
372 updated.arxiv_id = Some("2301.00001".to_string());
373 repo.save(&updated).unwrap();
374
375 let loaded = repo.get(paper.id.as_str()).unwrap().unwrap();
376 assert_eq!(loaded.arxiv_id, Some("2301.00001".to_string()));
377 }
378
379 #[test]
380 fn test_doi_conflict_across_papers() {
381 let (_, repo) = setup();
382
383 let mut paper1 = Paper::new("Original Paper");
385 paper1.doi = Some("10.1234/conflict".to_string());
386 repo.save(&paper1).unwrap();
387
388 let mut paper2 = Paper::new("Updated Title");
390 paper2.doi = Some("10.1234/conflict".to_string());
391 paper2.arxiv_id = Some("2301.99999".to_string());
392 repo.save(&paper2).unwrap();
393
394 let loaded = repo.find_by_doi("10.1234/conflict").unwrap().unwrap();
396 assert_eq!(loaded.id, paper1.id, "should reuse original paper ID");
397 assert_eq!(loaded.title, "Updated Title", "should update title");
398 assert_eq!(
399 loaded.arxiv_id,
400 Some("2301.99999".to_string()),
401 "should merge arxiv_id"
402 );
403 }
404
405 #[test]
406 fn test_doi_conflict_in_save_many() {
407 let (_, repo) = setup();
408
409 let mut existing = Paper::new("Existing Paper");
410 existing.doi = Some("10.1234/batch".to_string());
411 repo.save(&existing).unwrap();
412
413 let mut new_paper = Paper::new("Batch Paper");
415 new_paper.doi = Some("10.1234/batch".to_string());
416 new_paper.pubmed_id = Some("12345".to_string());
417 repo.save_many(&[new_paper]).unwrap();
418
419 let loaded = repo.find_by_doi("10.1234/batch").unwrap().unwrap();
420 assert_eq!(loaded.id, existing.id);
421 assert_eq!(loaded.pubmed_id, Some("12345".to_string()));
422 }
423
424 #[test]
425 fn test_list_all() {
426 let (_, repo) = setup();
427 for i in 0..5 {
428 let paper = Paper::new(format!("Paper {i}"));
429 repo.save(&paper).unwrap();
430 }
431
432 let papers = repo.list_all(3, 0).unwrap();
433 assert_eq!(papers.len(), 3);
434 }
435
436 #[test]
439 fn test_cross_search_dedup_save_roundtrip() {
440 use scitadel_core::models::CandidatePaper;
441 use scitadel_core::services::dedup::deduplicate;
442
443 let (_, repo) = setup();
444
445 let candidates_1 = vec![
447 CandidatePaper {
448 doi: Some("10.1000/alpha".into()),
449 ..CandidatePaper::new("openalex", "oa-1", "Alpha Paper")
450 },
451 CandidatePaper {
452 doi: Some("10.1000/beta".into()),
453 ..CandidatePaper::new("openalex", "oa-2", "Beta Paper")
454 },
455 CandidatePaper {
456 doi: Some("10.1000/gamma".into()),
457 ..CandidatePaper::new("pubmed", "pm-1", "Gamma Paper")
458 },
459 ];
460 let (papers_1, _results_1) = deduplicate(&candidates_1, 0.85);
461 assert_eq!(papers_1.len(), 3);
462 let remap_1 = repo.save_many(&papers_1).unwrap();
463 assert!(remap_1.is_empty(), "no conflicts on first save");
464
465 let candidates_2 = vec![
467 CandidatePaper {
468 doi: Some("10.1000/alpha".into()),
469 arxiv_id: Some("2301.00001".into()),
470 ..CandidatePaper::new("arxiv", "ax-1", "Alpha Paper (arxiv)")
471 },
472 CandidatePaper {
473 doi: Some("10.1000/gamma".into()),
474 pubmed_id: Some("99999".into()),
475 ..CandidatePaper::new("pubmed", "pm-2", "Gamma Paper Revised")
476 },
477 CandidatePaper {
478 doi: Some("10.1000/delta".into()),
479 ..CandidatePaper::new("openalex", "oa-3", "Delta Paper")
480 },
481 ];
482 let (papers_2, results_2) = deduplicate(&candidates_2, 0.85);
483 assert_eq!(
484 papers_2.len(),
485 3,
486 "dedup sees them as distinct (different IDs)"
487 );
488
489 let remap_2 = repo.save_many(&papers_2).unwrap();
490 assert_eq!(
491 remap_2.len(),
492 2,
493 "alpha and gamma should remap to existing IDs"
494 );
495
496 let alpha_original = papers_1
498 .iter()
499 .find(|p| p.doi.as_deref() == Some("10.1000/alpha"))
500 .unwrap();
501 let alpha_new = papers_2
502 .iter()
503 .find(|p| p.doi.as_deref() == Some("10.1000/alpha"))
504 .unwrap();
505 assert_eq!(remap_2[&alpha_new.id], alpha_original.id);
506
507 let all = repo.list_all(100, 0).unwrap();
509 assert_eq!(all.len(), 4, "3 from first search + 1 new from second");
510
511 let alpha = repo.find_by_doi("10.1000/alpha").unwrap().unwrap();
513 assert_eq!(alpha.id, alpha_original.id, "kept original ID");
514 assert_eq!(
515 alpha.arxiv_id,
516 Some("2301.00001".into()),
517 "merged arxiv_id from second search"
518 );
519
520 for sr in &results_2 {
522 let resolved_id = remap_2.get(&sr.paper_id).unwrap_or(&sr.paper_id);
523 assert!(
524 repo.get(resolved_id.as_str()).unwrap().is_some(),
525 "remapped paper_id should exist in DB"
526 );
527 }
528 }
529
530 #[test]
531 fn download_state_round_trips() {
532 let (_, repo) = setup();
533 let paper = Paper::new("DL state");
534 repo.save(&paper).unwrap();
535
536 let initial = repo.get(paper.id.as_str()).unwrap().unwrap();
538 assert!(initial.local_path.is_none());
539 assert!(initial.download_status.is_none());
540 assert!(initial.last_attempt_at.is_none());
541
542 repo.update_download_state(
544 paper.id.as_str(),
545 Some("/tmp/foo.pdf"),
546 DownloadStatus::Downloaded,
547 )
548 .unwrap();
549 let after = repo.get(paper.id.as_str()).unwrap().unwrap();
550 assert_eq!(after.local_path.as_deref(), Some("/tmp/foo.pdf"));
551 assert_eq!(after.download_status, Some(DownloadStatus::Downloaded));
552 assert!(after.last_attempt_at.is_some());
553
554 repo.update_download_state(paper.id.as_str(), None, DownloadStatus::Failed)
556 .unwrap();
557 let failed = repo.get(paper.id.as_str()).unwrap().unwrap();
558 assert!(failed.local_path.is_none());
559 assert_eq!(failed.download_status, Some(DownloadStatus::Failed));
560 }
561}