Skip to main content

scitadel_db/sqlite/
searches.rs

1use rusqlite::{OptionalExtension, params};
2use scitadel_core::error::CoreError;
3use scitadel_core::models::{PaperId, Search, SearchId, SearchResult, SourceOutcome};
4use scitadel_core::ports::SearchRepository;
5
6use super::Database;
7use crate::error::DbError;
8
9pub struct SqliteSearchRepository {
10    db: Database,
11}
12
13impl SqliteSearchRepository {
14    pub fn new(db: Database) -> Self {
15        Self { db }
16    }
17
18    /// Full-text search over past search queries using the `searches_fts`
19    /// FTS5 index (migration 006). Returns `(Search, rank)` tuples where
20    /// lower rank = more relevant (bm25 convention). Input is sanitized of
21    /// FTS5 operators so arbitrary user strings don't raise syntax errors.
22    pub fn find_similar(&self, query: &str, limit: i64) -> Result<Vec<(Search, f64)>, DbError> {
23        let sanitized = sanitize_fts5_query(query);
24        if sanitized.is_empty() {
25            return Ok(Vec::new());
26        }
27        let conn = self.db.conn()?;
28        let mut stmt = conn
29            .prepare(
30                "SELECT s.*, bm25(searches_fts) AS rank
31                 FROM searches_fts f
32                 JOIN searches s ON s.id = f.search_id
33                 WHERE searches_fts MATCH ?1
34                 ORDER BY rank ASC
35                 LIMIT ?2",
36            )
37            .map_err(DbError::Sqlite)?;
38        let rows = stmt
39            .query_map(params![sanitized, limit], |row| {
40                let search = row_to_search(row)?;
41                let rank: f64 = row.get("rank")?;
42                Ok((search, rank))
43            })
44            .map_err(DbError::Sqlite)?;
45        let out: Vec<(Search, f64)> = rows.filter_map(Result::ok).collect();
46        Ok(out)
47    }
48}
49
50/// Strip FTS5 query-syntax characters so arbitrary user input doesn't
51/// trigger `fts5: syntax error near …`. We lose operator support but
52/// gain robustness; callers who want advanced syntax can submit
53/// pre-sanitized queries with the operators they know to be valid.
54fn sanitize_fts5_query(q: &str) -> String {
55    q.chars()
56        .map(|c| match c {
57            // These are FTS5 operators / quote chars.
58            '"' | '\'' | '(' | ')' | ':' | '*' | '-' => ' ',
59            other => other,
60        })
61        .collect::<String>()
62        .split_whitespace()
63        .collect::<Vec<_>>()
64        .join(" ")
65}
66
67fn row_to_search(row: &rusqlite::Row) -> rusqlite::Result<Search> {
68    let id: String = row.get("id")?;
69    let sources_json: String = row.get("sources")?;
70    let parameters_json: String = row.get("parameters")?;
71    let outcomes_json: String = row.get("source_outcomes")?;
72    let created_at: String = row.get("created_at")?;
73
74    let outcomes: Vec<SourceOutcome> = serde_json::from_str(&outcomes_json).unwrap_or_default();
75
76    Ok(Search {
77        id: SearchId::from(id),
78        query: row.get("query")?,
79        sources: serde_json::from_str(&sources_json).unwrap_or_default(),
80        parameters: serde_json::from_str(&parameters_json).unwrap_or_default(),
81        source_outcomes: outcomes,
82        total_candidates: row.get("total_candidates")?,
83        total_papers: row.get("total_papers")?,
84        created_at: super::parse_rfc3339_or_now(&created_at),
85    })
86}
87
88impl SearchRepository for SqliteSearchRepository {
89    fn save(&self, search: &Search) -> Result<(), CoreError> {
90        let conn = self.db.conn()?;
91        let outcomes_json = serde_json::to_string(&search.source_outcomes).unwrap_or_default();
92        conn.execute(
93            "INSERT INTO searches
94                (id, query, sources, parameters, source_outcomes,
95                 total_candidates, total_papers, created_at)
96             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
97             ON CONFLICT(id) DO UPDATE SET
98                query = excluded.query,
99                sources = excluded.sources,
100                parameters = excluded.parameters,
101                source_outcomes = excluded.source_outcomes,
102                total_candidates = excluded.total_candidates,
103                total_papers = excluded.total_papers",
104            params![
105                search.id.as_str(),
106                search.query,
107                serde_json::to_string(&search.sources).unwrap_or_default(),
108                serde_json::to_string(&search.parameters).unwrap_or_default(),
109                outcomes_json,
110                search.total_candidates,
111                search.total_papers,
112                search.created_at.to_rfc3339(),
113            ],
114        )
115        .map_err(DbError::Sqlite)?;
116        Ok(())
117    }
118
119    fn get(&self, search_id: &str) -> Result<Option<Search>, CoreError> {
120        let conn = self.db.conn()?;
121        let mut stmt = conn
122            .prepare("SELECT * FROM searches WHERE id = ?1")
123            .map_err(DbError::Sqlite)?;
124        let result = stmt
125            .query_row(params![search_id], row_to_search)
126            .optional()
127            .map_err(DbError::Sqlite)?;
128        Ok(result)
129    }
130
131    fn save_results(&self, results: &[SearchResult]) -> Result<(), CoreError> {
132        let mut conn = self.db.conn()?;
133        let tx = conn.transaction().map_err(DbError::Sqlite)?;
134        for r in results {
135            tx.execute(
136                "INSERT INTO search_results
137                    (search_id, paper_id, source, rank, score, raw_metadata)
138                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
139                 ON CONFLICT(search_id, paper_id, source) DO UPDATE SET
140                    rank = excluded.rank,
141                    score = excluded.score,
142                    raw_metadata = excluded.raw_metadata",
143                params![
144                    r.search_id.as_str(),
145                    r.paper_id.as_str(),
146                    r.source,
147                    r.rank,
148                    r.score,
149                    serde_json::to_string(&r.raw_metadata).unwrap_or_default(),
150                ],
151            )
152            .map_err(DbError::Sqlite)?;
153        }
154        tx.commit().map_err(DbError::Sqlite)?;
155        Ok(())
156    }
157
158    fn get_results(&self, search_id: &str) -> Result<Vec<SearchResult>, CoreError> {
159        let conn = self.db.conn()?;
160        let mut stmt = conn
161            .prepare("SELECT * FROM search_results WHERE search_id = ?1")
162            .map_err(DbError::Sqlite)?;
163        let results = stmt
164            .query_map(params![search_id], |row| {
165                let search_id: String = row.get("search_id")?;
166                let paper_id: String = row.get("paper_id")?;
167                let raw_json: String = row.get("raw_metadata")?;
168                Ok(SearchResult {
169                    search_id: SearchId::from(search_id),
170                    paper_id: PaperId::from(paper_id),
171                    source: row.get("source")?,
172                    rank: row.get("rank")?,
173                    score: row.get("score")?,
174                    raw_metadata: serde_json::from_str(&raw_json).unwrap_or_default(),
175                })
176            })
177            .map_err(DbError::Sqlite)?
178            .filter_map(Result::ok)
179            .collect();
180        Ok(results)
181    }
182
183    fn list_searches(&self, limit: i64) -> Result<Vec<Search>, CoreError> {
184        let conn = self.db.conn()?;
185        let mut stmt = conn
186            .prepare("SELECT * FROM searches ORDER BY created_at DESC LIMIT ?1")
187            .map_err(DbError::Sqlite)?;
188        let searches = stmt
189            .query_map(params![limit], row_to_search)
190            .map_err(DbError::Sqlite)?
191            .filter_map(Result::ok)
192            .collect();
193        Ok(searches)
194    }
195
196    fn diff_searches(
197        &self,
198        search_id_a: &str,
199        search_id_b: &str,
200    ) -> Result<(Vec<String>, Vec<String>), CoreError> {
201        let conn = self.db.conn()?;
202
203        let get_paper_ids =
204            |search_id: &str| -> Result<std::collections::HashSet<String>, DbError> {
205                let mut stmt = conn
206                    .prepare("SELECT DISTINCT paper_id FROM search_results WHERE search_id = ?1")
207                    .map_err(DbError::Sqlite)?;
208                let ids: std::collections::HashSet<String> = stmt
209                    .query_map(params![search_id], |row| row.get(0))
210                    .map_err(DbError::Sqlite)?
211                    .filter_map(Result::ok)
212                    .collect();
213                Ok(ids)
214            };
215
216        let papers_a = get_paper_ids(search_id_a).map_err(Into::<CoreError>::into)?;
217        let papers_b = get_paper_ids(search_id_b).map_err(Into::<CoreError>::into)?;
218
219        let mut added: Vec<String> = papers_b.difference(&papers_a).cloned().collect();
220        let mut removed: Vec<String> = papers_a.difference(&papers_b).cloned().collect();
221        added.sort();
222        removed.sort();
223
224        Ok((added, removed))
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::sqlite::{Database, SqlitePaperRepository};
232    use scitadel_core::models::Paper;
233    use scitadel_core::ports::PaperRepository;
234
235    fn setup() -> (Database, SqliteSearchRepository, SqlitePaperRepository) {
236        let db = Database::open_in_memory().unwrap();
237        db.migrate().unwrap();
238        let search_repo = SqliteSearchRepository::new(db.clone());
239        let paper_repo = SqlitePaperRepository::new(db.clone());
240        (db, search_repo, paper_repo)
241    }
242
243    #[test]
244    fn test_save_and_get_search() {
245        let (_, repo, _) = setup();
246        let search = Search::new("test query");
247        repo.save(&search).unwrap();
248
249        let loaded = repo.get(search.id.as_str()).unwrap().unwrap();
250        assert_eq!(loaded.query, "test query");
251    }
252
253    #[test]
254    fn fts_sanitizer_strips_operators() {
255        assert_eq!(sanitize_fts5_query(r#"GAN "stability""#), "GAN stability");
256        assert_eq!(sanitize_fts5_query("foo (bar) - baz"), "foo bar baz");
257        assert_eq!(sanitize_fts5_query("   "), "");
258        assert_eq!(sanitize_fts5_query("scope:field"), "scope field");
259    }
260
261    #[test]
262    fn fts_find_similar_roundtrip() {
263        let (_, repo, _) = setup();
264        let a = {
265            let mut s = Search::new("generative adversarial networks stability");
266            s.id = SearchId::from("s-a");
267            s
268        };
269        let b = {
270            let mut s = Search::new("attention is all you need transformers");
271            s.id = SearchId::from("s-b");
272            s
273        };
274        let c = {
275            let mut s = Search::new("retrieval augmented generation");
276            s.id = SearchId::from("s-c");
277            s
278        };
279        repo.save(&a).unwrap();
280        repo.save(&b).unwrap();
281        repo.save(&c).unwrap();
282
283        // Porter stemming should match "generative" against "generating".
284        let hits = repo.find_similar("generative networks", 10).unwrap();
285        assert!(
286            hits.iter().any(|(s, _)| s.id.as_str() == "s-a"),
287            "expected GAN search to be found; got {:?}",
288            hits.iter().map(|(s, _)| s.id.as_str()).collect::<Vec<_>>()
289        );
290    }
291
292    #[test]
293    fn fts_find_similar_empty_query() {
294        let (_, repo, _) = setup();
295        repo.save(&Search::new("something")).unwrap();
296        assert!(repo.find_similar("()(", 10).unwrap().is_empty());
297    }
298
299    #[test]
300    fn test_save_and_get_results() {
301        let (_, search_repo, paper_repo) = setup();
302
303        let paper = Paper::new("Test Paper");
304        paper_repo.save(&paper).unwrap();
305
306        let search = Search::new("test");
307        search_repo.save(&search).unwrap();
308
309        let result = SearchResult {
310            search_id: search.id.clone(),
311            paper_id: paper.id.clone(),
312            source: "pubmed".to_string(),
313            rank: Some(1),
314            score: Some(0.95),
315            raw_metadata: serde_json::Value::Null,
316        };
317        search_repo.save_results(&[result]).unwrap();
318
319        let results = search_repo.get_results(search.id.as_str()).unwrap();
320        assert_eq!(results.len(), 1);
321        assert_eq!(results[0].source, "pubmed");
322    }
323}