Skip to main content

seshat_storage/repository/
file_ir_repository.rs

1//! SQLite implementation of [`FileIRRepository`].
2
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6use rusqlite::{Connection, params};
7use seshat_core::{BranchId, ProjectFile};
8
9use super::symbol_index_repository::{
10    delete_definitions, delete_imports, insert_definitions, insert_imports,
11};
12use super::{FileIRRepository, extract_definitions, extract_imports};
13use crate::StorageError;
14use crate::ir_serialization::IR_SCHEMA_VERSION;
15
16/// SQLite-backed file IR repository.
17#[derive(Debug, Clone)]
18pub struct SqliteFileIRRepository {
19    conn: Arc<Mutex<Connection>>,
20}
21
22impl SqliteFileIRRepository {
23    /// Create a new repository backed by the given connection.
24    pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
25        Self { conn }
26    }
27
28    fn conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, StorageError> {
29        self.conn.lock().map_err(|e| {
30            StorageError::QueryError(format!("Failed to acquire connection lock: {e}"))
31        })
32    }
33}
34
35impl FileIRRepository for SqliteFileIRRepository {
36    fn upsert(
37        &self,
38        branch_id: &BranchId,
39        file: &ProjectFile,
40        last_commit_date: Option<i64>,
41    ) -> Result<(), StorageError> {
42        let conn = self.conn()?;
43
44        let file_path = file.path.to_string_lossy();
45        let ir_data = crate::ir_serialization::serialize_ir(file)?;
46
47        conn.execute(
48            "INSERT INTO files_ir (branch_id, file_path, language, content_hash, ir_data, ir_schema_version, last_commit_date, updated_at)
49             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, datetime('now'))
50             ON CONFLICT(branch_id, file_path) DO UPDATE SET
51               language = excluded.language,
52               content_hash = excluded.content_hash,
53               ir_data = excluded.ir_data,
54               ir_schema_version = excluded.ir_schema_version,
55               last_commit_date = excluded.last_commit_date,
56               updated_at = datetime('now')",
57            params![
58                branch_id.0,
59                file_path.as_ref(),
60                file.language.as_str(),
61                file.content_hash,
62                ir_data,
63                i64::from(IR_SCHEMA_VERSION),
64                last_commit_date,
65            ],
66        )?;
67
68        Ok(())
69    }
70
71    fn upsert_with_symbol_index(
72        &self,
73        branch_id: &BranchId,
74        file: &ProjectFile,
75        last_commit_date: Option<i64>,
76    ) -> Result<(), StorageError> {
77        let definitions = extract_definitions(file);
78        let imports = extract_imports(file);
79        let file_path = file.path.to_string_lossy();
80        let ir_data = crate::ir_serialization::serialize_ir(file)?;
81
82        let conn = self.conn()?;
83        let tx = conn.unchecked_transaction().map_err(|e| {
84            StorageError::QueryError(format!("begin files_ir+symbol-index tx: {e}"))
85        })?;
86
87        tx.execute(
88            "INSERT INTO files_ir (branch_id, file_path, language, content_hash, ir_data, ir_schema_version, last_commit_date, updated_at)
89             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, datetime('now'))
90             ON CONFLICT(branch_id, file_path) DO UPDATE SET
91               language = excluded.language,
92               content_hash = excluded.content_hash,
93               ir_data = excluded.ir_data,
94               ir_schema_version = excluded.ir_schema_version,
95               last_commit_date = excluded.last_commit_date,
96               updated_at = datetime('now')",
97            params![
98                branch_id.0,
99                file_path.as_ref(),
100                file.language.as_str(),
101                file.content_hash,
102                ir_data,
103                i64::from(IR_SCHEMA_VERSION),
104                last_commit_date,
105            ],
106        )?;
107
108        delete_definitions(&tx, &branch_id.0, file_path.as_ref())?;
109        delete_imports(&tx, &branch_id.0, file_path.as_ref())?;
110        insert_definitions(&tx, &branch_id.0, &definitions)?;
111        insert_imports(&tx, &branch_id.0, &imports)?;
112
113        tx.commit().map_err(|e| {
114            StorageError::QueryError(format!("commit files_ir+symbol-index tx: {e}"))
115        })?;
116        Ok(())
117    }
118
119    fn get_by_path(
120        &self,
121        branch_id: &BranchId,
122        file_path: &str,
123    ) -> Result<ProjectFile, StorageError> {
124        let conn = self.conn()?;
125
126        conn.query_row(
127            "SELECT ir_data FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
128            params![branch_id.0, file_path],
129            row_to_project_file,
130        )
131        .map_err(|e| match e {
132            rusqlite::Error::QueryReturnedNoRows => StorageError::NotFound {
133                entity: "FileIR",
134                id: format!("{}/{}", branch_id.0, file_path),
135            },
136            other => StorageError::from(other),
137        })
138    }
139
140    fn get_by_branch(&self, branch_id: &BranchId) -> Result<Vec<ProjectFile>, StorageError> {
141        let conn = self.conn()?;
142
143        let mut stmt = conn.prepare("SELECT ir_data FROM files_ir WHERE branch_id = ?1")?;
144
145        let rows = stmt.query_map(params![branch_id.0], row_to_project_file)?;
146
147        rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
148    }
149
150    fn get_file_hashes_by_branch(
151        &self,
152        branch_id: &BranchId,
153    ) -> Result<HashMap<String, String>, StorageError> {
154        let conn = self.conn()?;
155
156        // Only return hashes for files whose IR blob matches the current schema
157        // version. Stale blobs (older IR_SCHEMA_VERSION) are excluded so that
158        // the scanner re-parses them rather than skipping them as "unchanged".
159        let mut stmt = conn.prepare(
160            "SELECT file_path, content_hash FROM files_ir
161             WHERE branch_id = ?1 AND ir_schema_version = ?2",
162        )?;
163
164        let rows = stmt.query_map(params![branch_id.0, i64::from(IR_SCHEMA_VERSION)], |row| {
165            Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
166        })?;
167
168        rows.collect::<Result<HashMap<_, _>, _>>()
169            .map_err(Into::into)
170    }
171
172    fn delete_by_path(&self, branch_id: &BranchId, file_path: &str) -> Result<(), StorageError> {
173        let conn = self.conn()?;
174
175        let affected = conn.execute(
176            "DELETE FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
177            params![branch_id.0, file_path],
178        )?;
179
180        if affected == 0 {
181            return Err(StorageError::NotFound {
182                entity: "FileIR",
183                id: format!("{}/{}", branch_id.0, file_path),
184            });
185        }
186
187        Ok(())
188    }
189
190    fn delete_with_symbol_index(
191        &self,
192        branch_id: &BranchId,
193        file_path: &str,
194    ) -> Result<(), StorageError> {
195        let conn = self.conn()?;
196        let tx = conn.unchecked_transaction().map_err(|e| {
197            StorageError::QueryError(format!("begin files_ir+symbol-index delete tx: {e}"))
198        })?;
199
200        let affected = tx.execute(
201            "DELETE FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
202            params![branch_id.0, file_path],
203        )?;
204
205        // Always sweep the symbol-index, even when files_ir already missed —
206        // protects against orphan rows from any earlier non-atomic write.
207        delete_definitions(&tx, &branch_id.0, file_path)?;
208        delete_imports(&tx, &branch_id.0, file_path)?;
209
210        tx.commit().map_err(|e| {
211            StorageError::QueryError(format!("commit files_ir+symbol-index delete tx: {e}"))
212        })?;
213
214        if affected == 0 {
215            return Err(StorageError::NotFound {
216                entity: "FileIR",
217                id: format!("{}/{}", branch_id.0, file_path),
218            });
219        }
220
221        Ok(())
222    }
223
224    fn check_content_hash(
225        &self,
226        branch_id: &BranchId,
227        file_path: &str,
228        content_hash: &str,
229    ) -> Result<bool, StorageError> {
230        let conn = self.conn()?;
231
232        let result: Result<String, _> = conn.query_row(
233            "SELECT content_hash FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
234            params![branch_id.0, file_path],
235            |row| row.get(0),
236        );
237
238        match result {
239            Ok(stored_hash) => Ok(stored_hash == content_hash),
240            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(false),
241            Err(e) => Err(e.into()),
242        }
243    }
244
245    fn get_file_dates_by_branch(
246        &self,
247        branch_id: &BranchId,
248    ) -> Result<HashMap<String, Option<i64>>, StorageError> {
249        let conn = self.conn()?;
250
251        let mut stmt =
252            conn.prepare("SELECT file_path, last_commit_date FROM files_ir WHERE branch_id = ?1")?;
253
254        let rows = stmt.query_map(params![branch_id.0], |row| {
255            Ok((row.get::<_, String>(0)?, row.get::<_, Option<i64>>(1)?))
256        })?;
257
258        rows.collect::<Result<HashMap<_, _>, _>>()
259            .map_err(Into::into)
260    }
261
262    fn update_convention_compliance_counts(
263        &self,
264        branch_id: &BranchId,
265        counts: &HashMap<String, u32>,
266    ) -> Result<(), StorageError> {
267        let conn = self.conn()?;
268
269        // Reset all counts for this branch first (files not in `counts` get 0).
270        conn.execute(
271            "UPDATE files_ir SET convention_compliance_count = 0 WHERE branch_id = ?1",
272            params![branch_id.0],
273        )?;
274
275        // Update each file's count.
276        let mut stmt = conn.prepare(
277            "UPDATE files_ir SET convention_compliance_count = ?1
278             WHERE branch_id = ?2 AND file_path = ?3",
279        )?;
280
281        for (file_path, count) in counts {
282            stmt.execute(params![count, branch_id.0, file_path])?;
283        }
284
285        Ok(())
286    }
287}
288
289// ---------------------------------------------------------------------------
290// Helpers
291// ---------------------------------------------------------------------------
292
293/// Map a rusqlite Row (ir_data BLOB) to a `ProjectFile`.
294fn row_to_project_file(row: &rusqlite::Row<'_>) -> rusqlite::Result<ProjectFile> {
295    let ir_data: Vec<u8> = row.get(0)?;
296    crate::ir_serialization::deserialize_ir(&ir_data).map_err(|e| {
297        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Blob, Box::new(e))
298    })
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use crate::Database;
305    use seshat_core::Language;
306    use seshat_core::test_helpers::make_project_file;
307
308    /// Helper: create an in-memory DB and return a `SqliteFileIRRepository`.
309    fn test_repo() -> SqliteFileIRRepository {
310        let db = Database::open(":memory:").expect("in-memory DB");
311        SqliteFileIRRepository::new(db.connection().clone())
312    }
313
314    #[test]
315    fn upsert_insert_and_get_by_path() {
316        let repo = test_repo();
317        let branch = BranchId::from("main");
318        let mut file = make_project_file(Language::Rust);
319        file.path = "src/main.rs".into();
320        file.content_hash = "abc123".to_string();
321
322        repo.upsert(&branch, &file, None)
323            .expect("upsert should succeed");
324
325        let fetched = repo
326            .get_by_path(&branch, "src/main.rs")
327            .expect("get_by_path should succeed");
328        assert_eq!(fetched.path.to_string_lossy(), "src/main.rs");
329        assert_eq!(fetched.language, Language::Rust);
330        assert_eq!(fetched.content_hash, "abc123");
331    }
332
333    #[test]
334    fn upsert_updates_existing() {
335        let repo = test_repo();
336        let branch = BranchId::from("main");
337        let mut file = make_project_file(Language::Rust);
338        file.path = "src/lib.rs".into();
339        file.content_hash = "hash_v1".to_string();
340
341        repo.upsert(&branch, &file, None).expect("first upsert");
342
343        // Update the same file with new hash
344        file.content_hash = "hash_v2".to_string();
345        repo.upsert(&branch, &file, None).expect("second upsert");
346
347        let fetched = repo.get_by_path(&branch, "src/lib.rs").unwrap();
348        assert_eq!(fetched.content_hash, "hash_v2");
349
350        // Should only be one record for this branch
351        let all = repo.get_by_branch(&branch).unwrap();
352        assert_eq!(all.len(), 1);
353    }
354
355    #[test]
356    fn get_by_path_not_found() {
357        let repo = test_repo();
358        let branch = BranchId::from("main");
359        let result = repo.get_by_path(&branch, "nonexistent.rs");
360        assert!(matches!(result, Err(StorageError::NotFound { .. })));
361    }
362
363    #[test]
364    fn get_by_branch() {
365        let repo = test_repo();
366        let branch_a = BranchId::from("branch-a");
367        let branch_b = BranchId::from("branch-b");
368
369        let mut f1 = make_project_file(Language::Rust);
370        f1.path = "src/one.rs".into();
371        f1.content_hash = "h1".to_string();
372
373        let mut f2 = make_project_file(Language::Python);
374        f2.path = "src/two.py".into();
375        f2.content_hash = "h2".to_string();
376
377        let mut f3 = make_project_file(Language::TypeScript);
378        f3.path = "src/three.ts".into();
379        f3.content_hash = "h3".to_string();
380
381        repo.upsert(&branch_a, &f1, None).unwrap();
382        repo.upsert(&branch_a, &f2, None).unwrap();
383        repo.upsert(&branch_b, &f3, None).unwrap();
384
385        let a_files = repo.get_by_branch(&branch_a).unwrap();
386        assert_eq!(a_files.len(), 2);
387
388        let b_files = repo.get_by_branch(&branch_b).unwrap();
389        assert_eq!(b_files.len(), 1);
390        assert_eq!(b_files[0].language, Language::TypeScript);
391    }
392
393    #[test]
394    fn delete_by_path() {
395        let repo = test_repo();
396        let branch = BranchId::from("main");
397        let mut file = make_project_file(Language::Rust);
398        file.path = "src/delete_me.rs".into();
399        file.content_hash = "d1".to_string();
400
401        repo.upsert(&branch, &file, None).unwrap();
402        repo.delete_by_path(&branch, "src/delete_me.rs")
403            .expect("delete should succeed");
404
405        let result = repo.get_by_path(&branch, "src/delete_me.rs");
406        assert!(matches!(result, Err(StorageError::NotFound { .. })));
407    }
408
409    #[test]
410    fn delete_by_path_not_found() {
411        let repo = test_repo();
412        let branch = BranchId::from("main");
413        let result = repo.delete_by_path(&branch, "nonexistent.rs");
414        assert!(matches!(result, Err(StorageError::NotFound { .. })));
415    }
416
417    #[test]
418    fn check_content_hash_matches() {
419        let repo = test_repo();
420        let branch = BranchId::from("main");
421        let mut file = make_project_file(Language::Rust);
422        file.path = "src/check.rs".into();
423        file.content_hash = "correct_hash".to_string();
424
425        repo.upsert(&branch, &file, None).unwrap();
426
427        assert!(
428            repo.check_content_hash(&branch, "src/check.rs", "correct_hash")
429                .unwrap()
430        );
431    }
432
433    #[test]
434    fn check_content_hash_mismatch() {
435        let repo = test_repo();
436        let branch = BranchId::from("main");
437        let mut file = make_project_file(Language::Rust);
438        file.path = "src/check.rs".into();
439        file.content_hash = "hash_a".to_string();
440
441        repo.upsert(&branch, &file, None).unwrap();
442
443        assert!(
444            !repo
445                .check_content_hash(&branch, "src/check.rs", "hash_b")
446                .unwrap()
447        );
448    }
449
450    #[test]
451    fn check_content_hash_no_record() {
452        let repo = test_repo();
453        let branch = BranchId::from("main");
454
455        assert!(
456            !repo
457                .check_content_hash(&branch, "nonexistent.rs", "any_hash")
458                .unwrap()
459        );
460    }
461
462    #[test]
463    fn all_language_variants_roundtrip() {
464        let repo = test_repo();
465        let branch = BranchId::from("main");
466
467        let languages = [
468            Language::Rust,
469            Language::TypeScript,
470            Language::JavaScript,
471            Language::Python,
472        ];
473
474        for lang in languages {
475            let mut file = make_project_file(lang);
476            file.path = format!("test.{}", lang.extensions()[0]).into();
477            file.content_hash = format!("hash_{lang}");
478
479            repo.upsert(&branch, &file, None).unwrap();
480
481            let fetched = repo
482                .get_by_path(&branch, &file.path.to_string_lossy())
483                .unwrap();
484            assert_eq!(
485                fetched.language, lang,
486                "language roundtrip failed for {lang}"
487            );
488        }
489    }
490
491    #[test]
492    fn get_file_hashes_by_branch_returns_all_hashes() {
493        let repo = test_repo();
494        let branch = BranchId::from("main");
495
496        let mut f1 = make_project_file(Language::Rust);
497        f1.path = "src/main.rs".into();
498        f1.content_hash = "hash_main".to_string();
499
500        let mut f2 = make_project_file(Language::Rust);
501        f2.path = "src/lib.rs".into();
502        f2.content_hash = "hash_lib".to_string();
503
504        repo.upsert(&branch, &f1, None).unwrap();
505        repo.upsert(&branch, &f2, None).unwrap();
506
507        let hashes = repo.get_file_hashes_by_branch(&branch).unwrap();
508        assert_eq!(hashes.len(), 2);
509        assert_eq!(hashes.get("src/main.rs").unwrap(), "hash_main");
510        assert_eq!(hashes.get("src/lib.rs").unwrap(), "hash_lib");
511    }
512
513    #[test]
514    fn get_file_hashes_by_branch_empty() {
515        let repo = test_repo();
516        let branch = BranchId::from("empty-branch");
517
518        let hashes = repo.get_file_hashes_by_branch(&branch).unwrap();
519        assert!(hashes.is_empty());
520    }
521
522    #[test]
523    fn get_file_hashes_by_branch_isolates_branches() {
524        let repo = test_repo();
525        let branch_a = BranchId::from("branch-a");
526        let branch_b = BranchId::from("branch-b");
527
528        let mut f1 = make_project_file(Language::Rust);
529        f1.path = "src/a.rs".into();
530        f1.content_hash = "hash_a".to_string();
531
532        let mut f2 = make_project_file(Language::Python);
533        f2.path = "src/b.py".into();
534        f2.content_hash = "hash_b".to_string();
535
536        repo.upsert(&branch_a, &f1, None).unwrap();
537        repo.upsert(&branch_b, &f2, None).unwrap();
538
539        let a_hashes = repo.get_file_hashes_by_branch(&branch_a).unwrap();
540        assert_eq!(a_hashes.len(), 1);
541        assert!(a_hashes.contains_key("src/a.rs"));
542
543        let b_hashes = repo.get_file_hashes_by_branch(&branch_b).unwrap();
544        assert_eq!(b_hashes.len(), 1);
545        assert!(b_hashes.contains_key("src/b.py"));
546    }
547
548    #[test]
549    fn upsert_stores_last_commit_date() {
550        let repo = test_repo();
551        let branch = BranchId::from("main");
552
553        let mut file = make_project_file(Language::Rust);
554        file.path = "src/dated.rs".into();
555        file.content_hash = "hash_dated".to_string();
556
557        let timestamp = 1_700_000_000_i64;
558        repo.upsert(&branch, &file, Some(timestamp)).unwrap();
559
560        let dates = repo.get_file_dates_by_branch(&branch).unwrap();
561        assert_eq!(dates.len(), 1);
562        assert_eq!(dates.get("src/dated.rs").unwrap(), &Some(timestamp));
563    }
564
565    #[test]
566    fn upsert_with_none_date() {
567        let repo = test_repo();
568        let branch = BranchId::from("main");
569
570        let mut file = make_project_file(Language::Rust);
571        file.path = "src/no_date.rs".into();
572        file.content_hash = "hash_nodate".to_string();
573
574        repo.upsert(&branch, &file, None).unwrap();
575
576        let dates = repo.get_file_dates_by_branch(&branch).unwrap();
577        assert_eq!(dates.len(), 1);
578        assert_eq!(dates.get("src/no_date.rs").unwrap(), &None);
579    }
580
581    #[test]
582    fn get_file_dates_by_branch_empty() {
583        let repo = test_repo();
584        let branch = BranchId::from("empty-branch");
585
586        let dates = repo.get_file_dates_by_branch(&branch).unwrap();
587        assert!(dates.is_empty());
588    }
589
590    #[test]
591    fn update_convention_compliance_counts_sets_values() {
592        let repo = test_repo();
593        let branch = BranchId::from("main");
594
595        let mut f1 = make_project_file(Language::Rust);
596        f1.path = "src/good.rs".into();
597        f1.content_hash = "h1".to_string();
598
599        let mut f2 = make_project_file(Language::Rust);
600        f2.path = "src/ok.rs".into();
601        f2.content_hash = "h2".to_string();
602
603        repo.upsert(&branch, &f1, None).unwrap();
604        repo.upsert(&branch, &f2, None).unwrap();
605
606        let mut counts = HashMap::new();
607        counts.insert("src/good.rs".to_string(), 5);
608        counts.insert("src/ok.rs".to_string(), 2);
609
610        repo.update_convention_compliance_counts(&branch, &counts)
611            .unwrap();
612
613        // Verify by querying the DB directly.
614        let conn = repo.conn.lock().unwrap();
615        let count: u32 = conn
616            .query_row(
617                "SELECT convention_compliance_count FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
618                params![branch.0, "src/good.rs"],
619                |row| row.get(0),
620            )
621            .unwrap();
622        assert_eq!(count, 5);
623
624        let count: u32 = conn
625            .query_row(
626                "SELECT convention_compliance_count FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
627                params![branch.0, "src/ok.rs"],
628                |row| row.get(0),
629            )
630            .unwrap();
631        assert_eq!(count, 2);
632    }
633
634    #[test]
635    fn update_convention_compliance_counts_resets_missing_files() {
636        let repo = test_repo();
637        let branch = BranchId::from("main");
638
639        let mut f1 = make_project_file(Language::Rust);
640        f1.path = "src/a.rs".into();
641        f1.content_hash = "h1".to_string();
642
643        let mut f2 = make_project_file(Language::Rust);
644        f2.path = "src/b.rs".into();
645        f2.content_hash = "h2".to_string();
646
647        repo.upsert(&branch, &f1, None).unwrap();
648        repo.upsert(&branch, &f2, None).unwrap();
649
650        // First update: both files have counts.
651        let mut counts = HashMap::new();
652        counts.insert("src/a.rs".to_string(), 3);
653        counts.insert("src/b.rs".to_string(), 7);
654        repo.update_convention_compliance_counts(&branch, &counts)
655            .unwrap();
656
657        // Second update: only a.rs has count — b.rs should reset to 0.
658        let mut counts2 = HashMap::new();
659        counts2.insert("src/a.rs".to_string(), 4);
660        repo.update_convention_compliance_counts(&branch, &counts2)
661            .unwrap();
662
663        let conn = repo.conn.lock().unwrap();
664        let count_a: u32 = conn
665            .query_row(
666                "SELECT convention_compliance_count FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
667                params![branch.0, "src/a.rs"],
668                |row| row.get(0),
669            )
670            .unwrap();
671        assert_eq!(count_a, 4);
672
673        let count_b: u32 = conn
674            .query_row(
675                "SELECT convention_compliance_count FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
676                params![branch.0, "src/b.rs"],
677                |row| row.get(0),
678            )
679            .unwrap();
680        assert_eq!(count_b, 0, "file not in counts map should be reset to 0");
681    }
682}