1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6use rusqlite::{Connection, params};
7use seshat_core::{BranchId, ProjectFile};
8
9use super::symbol_index_repository::{
10 delete_definitions, delete_imports, insert_definitions, insert_imports,
11};
12use super::{FileIRRepository, extract_definitions, extract_imports};
13use crate::StorageError;
14use crate::ir_serialization::IR_SCHEMA_VERSION;
15
16#[derive(Debug, Clone)]
18pub struct SqliteFileIRRepository {
19 conn: Arc<Mutex<Connection>>,
20}
21
22impl SqliteFileIRRepository {
23 pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
25 Self { conn }
26 }
27
28 fn conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, StorageError> {
29 self.conn.lock().map_err(|e| {
30 StorageError::QueryError(format!("Failed to acquire connection lock: {e}"))
31 })
32 }
33}
34
35impl FileIRRepository for SqliteFileIRRepository {
36 fn upsert(
37 &self,
38 branch_id: &BranchId,
39 file: &ProjectFile,
40 last_commit_date: Option<i64>,
41 ) -> Result<(), StorageError> {
42 let conn = self.conn()?;
43
44 let file_path = file.path.to_string_lossy();
45 let ir_data = crate::ir_serialization::serialize_ir(file)?;
46
47 conn.execute(
48 "INSERT INTO files_ir (branch_id, file_path, language, content_hash, ir_data, ir_schema_version, last_commit_date, updated_at)
49 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, datetime('now'))
50 ON CONFLICT(branch_id, file_path) DO UPDATE SET
51 language = excluded.language,
52 content_hash = excluded.content_hash,
53 ir_data = excluded.ir_data,
54 ir_schema_version = excluded.ir_schema_version,
55 last_commit_date = excluded.last_commit_date,
56 updated_at = datetime('now')",
57 params![
58 branch_id.0,
59 file_path.as_ref(),
60 file.language.as_str(),
61 file.content_hash,
62 ir_data,
63 i64::from(IR_SCHEMA_VERSION),
64 last_commit_date,
65 ],
66 )?;
67
68 Ok(())
69 }
70
71 fn upsert_with_symbol_index(
72 &self,
73 branch_id: &BranchId,
74 file: &ProjectFile,
75 last_commit_date: Option<i64>,
76 ) -> Result<(), StorageError> {
77 let definitions = extract_definitions(file);
78 let imports = extract_imports(file);
79 let file_path = file.path.to_string_lossy();
80 let ir_data = crate::ir_serialization::serialize_ir(file)?;
81
82 let conn = self.conn()?;
83 let tx = conn.unchecked_transaction().map_err(|e| {
84 StorageError::QueryError(format!("begin files_ir+symbol-index tx: {e}"))
85 })?;
86
87 tx.execute(
88 "INSERT INTO files_ir (branch_id, file_path, language, content_hash, ir_data, ir_schema_version, last_commit_date, updated_at)
89 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, datetime('now'))
90 ON CONFLICT(branch_id, file_path) DO UPDATE SET
91 language = excluded.language,
92 content_hash = excluded.content_hash,
93 ir_data = excluded.ir_data,
94 ir_schema_version = excluded.ir_schema_version,
95 last_commit_date = excluded.last_commit_date,
96 updated_at = datetime('now')",
97 params![
98 branch_id.0,
99 file_path.as_ref(),
100 file.language.as_str(),
101 file.content_hash,
102 ir_data,
103 i64::from(IR_SCHEMA_VERSION),
104 last_commit_date,
105 ],
106 )?;
107
108 delete_definitions(&tx, &branch_id.0, file_path.as_ref())?;
109 delete_imports(&tx, &branch_id.0, file_path.as_ref())?;
110 insert_definitions(&tx, &branch_id.0, &definitions)?;
111 insert_imports(&tx, &branch_id.0, &imports)?;
112
113 tx.commit().map_err(|e| {
114 StorageError::QueryError(format!("commit files_ir+symbol-index tx: {e}"))
115 })?;
116 Ok(())
117 }
118
119 fn get_by_path(
120 &self,
121 branch_id: &BranchId,
122 file_path: &str,
123 ) -> Result<ProjectFile, StorageError> {
124 let conn = self.conn()?;
125
126 conn.query_row(
127 "SELECT ir_data FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
128 params![branch_id.0, file_path],
129 row_to_project_file,
130 )
131 .map_err(|e| match e {
132 rusqlite::Error::QueryReturnedNoRows => StorageError::NotFound {
133 entity: "FileIR",
134 id: format!("{}/{}", branch_id.0, file_path),
135 },
136 other => StorageError::from(other),
137 })
138 }
139
140 fn get_by_branch(&self, branch_id: &BranchId) -> Result<Vec<ProjectFile>, StorageError> {
141 let conn = self.conn()?;
142
143 let mut stmt = conn.prepare("SELECT ir_data FROM files_ir WHERE branch_id = ?1")?;
144
145 let rows = stmt.query_map(params![branch_id.0], row_to_project_file)?;
146
147 rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
148 }
149
150 fn get_file_hashes_by_branch(
151 &self,
152 branch_id: &BranchId,
153 ) -> Result<HashMap<String, String>, StorageError> {
154 let conn = self.conn()?;
155
156 let mut stmt = conn.prepare(
160 "SELECT file_path, content_hash FROM files_ir
161 WHERE branch_id = ?1 AND ir_schema_version = ?2",
162 )?;
163
164 let rows = stmt.query_map(params![branch_id.0, i64::from(IR_SCHEMA_VERSION)], |row| {
165 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
166 })?;
167
168 rows.collect::<Result<HashMap<_, _>, _>>()
169 .map_err(Into::into)
170 }
171
172 fn delete_by_path(&self, branch_id: &BranchId, file_path: &str) -> Result<(), StorageError> {
173 let conn = self.conn()?;
174
175 let affected = conn.execute(
176 "DELETE FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
177 params![branch_id.0, file_path],
178 )?;
179
180 if affected == 0 {
181 return Err(StorageError::NotFound {
182 entity: "FileIR",
183 id: format!("{}/{}", branch_id.0, file_path),
184 });
185 }
186
187 Ok(())
188 }
189
190 fn delete_with_symbol_index(
191 &self,
192 branch_id: &BranchId,
193 file_path: &str,
194 ) -> Result<(), StorageError> {
195 let conn = self.conn()?;
196 let tx = conn.unchecked_transaction().map_err(|e| {
197 StorageError::QueryError(format!("begin files_ir+symbol-index delete tx: {e}"))
198 })?;
199
200 let affected = tx.execute(
201 "DELETE FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
202 params![branch_id.0, file_path],
203 )?;
204
205 delete_definitions(&tx, &branch_id.0, file_path)?;
208 delete_imports(&tx, &branch_id.0, file_path)?;
209
210 tx.commit().map_err(|e| {
211 StorageError::QueryError(format!("commit files_ir+symbol-index delete tx: {e}"))
212 })?;
213
214 if affected == 0 {
215 return Err(StorageError::NotFound {
216 entity: "FileIR",
217 id: format!("{}/{}", branch_id.0, file_path),
218 });
219 }
220
221 Ok(())
222 }
223
224 fn check_content_hash(
225 &self,
226 branch_id: &BranchId,
227 file_path: &str,
228 content_hash: &str,
229 ) -> Result<bool, StorageError> {
230 let conn = self.conn()?;
231
232 let result: Result<String, _> = conn.query_row(
233 "SELECT content_hash FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
234 params![branch_id.0, file_path],
235 |row| row.get(0),
236 );
237
238 match result {
239 Ok(stored_hash) => Ok(stored_hash == content_hash),
240 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(false),
241 Err(e) => Err(e.into()),
242 }
243 }
244
245 fn get_file_dates_by_branch(
246 &self,
247 branch_id: &BranchId,
248 ) -> Result<HashMap<String, Option<i64>>, StorageError> {
249 let conn = self.conn()?;
250
251 let mut stmt =
252 conn.prepare("SELECT file_path, last_commit_date FROM files_ir WHERE branch_id = ?1")?;
253
254 let rows = stmt.query_map(params![branch_id.0], |row| {
255 Ok((row.get::<_, String>(0)?, row.get::<_, Option<i64>>(1)?))
256 })?;
257
258 rows.collect::<Result<HashMap<_, _>, _>>()
259 .map_err(Into::into)
260 }
261
262 fn update_convention_compliance_counts(
263 &self,
264 branch_id: &BranchId,
265 counts: &HashMap<String, u32>,
266 ) -> Result<(), StorageError> {
267 let conn = self.conn()?;
268
269 conn.execute(
271 "UPDATE files_ir SET convention_compliance_count = 0 WHERE branch_id = ?1",
272 params![branch_id.0],
273 )?;
274
275 let mut stmt = conn.prepare(
277 "UPDATE files_ir SET convention_compliance_count = ?1
278 WHERE branch_id = ?2 AND file_path = ?3",
279 )?;
280
281 for (file_path, count) in counts {
282 stmt.execute(params![count, branch_id.0, file_path])?;
283 }
284
285 Ok(())
286 }
287}
288
289fn row_to_project_file(row: &rusqlite::Row<'_>) -> rusqlite::Result<ProjectFile> {
295 let ir_data: Vec<u8> = row.get(0)?;
296 crate::ir_serialization::deserialize_ir(&ir_data).map_err(|e| {
297 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Blob, Box::new(e))
298 })
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::Database;
305 use seshat_core::Language;
306 use seshat_core::test_helpers::make_project_file;
307
308 fn test_repo() -> SqliteFileIRRepository {
310 let db = Database::open(":memory:").expect("in-memory DB");
311 SqliteFileIRRepository::new(db.connection().clone())
312 }
313
314 #[test]
315 fn upsert_insert_and_get_by_path() {
316 let repo = test_repo();
317 let branch = BranchId::from("main");
318 let mut file = make_project_file(Language::Rust);
319 file.path = "src/main.rs".into();
320 file.content_hash = "abc123".to_string();
321
322 repo.upsert(&branch, &file, None)
323 .expect("upsert should succeed");
324
325 let fetched = repo
326 .get_by_path(&branch, "src/main.rs")
327 .expect("get_by_path should succeed");
328 assert_eq!(fetched.path.to_string_lossy(), "src/main.rs");
329 assert_eq!(fetched.language, Language::Rust);
330 assert_eq!(fetched.content_hash, "abc123");
331 }
332
333 #[test]
334 fn upsert_updates_existing() {
335 let repo = test_repo();
336 let branch = BranchId::from("main");
337 let mut file = make_project_file(Language::Rust);
338 file.path = "src/lib.rs".into();
339 file.content_hash = "hash_v1".to_string();
340
341 repo.upsert(&branch, &file, None).expect("first upsert");
342
343 file.content_hash = "hash_v2".to_string();
345 repo.upsert(&branch, &file, None).expect("second upsert");
346
347 let fetched = repo.get_by_path(&branch, "src/lib.rs").unwrap();
348 assert_eq!(fetched.content_hash, "hash_v2");
349
350 let all = repo.get_by_branch(&branch).unwrap();
352 assert_eq!(all.len(), 1);
353 }
354
355 #[test]
356 fn get_by_path_not_found() {
357 let repo = test_repo();
358 let branch = BranchId::from("main");
359 let result = repo.get_by_path(&branch, "nonexistent.rs");
360 assert!(matches!(result, Err(StorageError::NotFound { .. })));
361 }
362
363 #[test]
364 fn get_by_branch() {
365 let repo = test_repo();
366 let branch_a = BranchId::from("branch-a");
367 let branch_b = BranchId::from("branch-b");
368
369 let mut f1 = make_project_file(Language::Rust);
370 f1.path = "src/one.rs".into();
371 f1.content_hash = "h1".to_string();
372
373 let mut f2 = make_project_file(Language::Python);
374 f2.path = "src/two.py".into();
375 f2.content_hash = "h2".to_string();
376
377 let mut f3 = make_project_file(Language::TypeScript);
378 f3.path = "src/three.ts".into();
379 f3.content_hash = "h3".to_string();
380
381 repo.upsert(&branch_a, &f1, None).unwrap();
382 repo.upsert(&branch_a, &f2, None).unwrap();
383 repo.upsert(&branch_b, &f3, None).unwrap();
384
385 let a_files = repo.get_by_branch(&branch_a).unwrap();
386 assert_eq!(a_files.len(), 2);
387
388 let b_files = repo.get_by_branch(&branch_b).unwrap();
389 assert_eq!(b_files.len(), 1);
390 assert_eq!(b_files[0].language, Language::TypeScript);
391 }
392
393 #[test]
394 fn delete_by_path() {
395 let repo = test_repo();
396 let branch = BranchId::from("main");
397 let mut file = make_project_file(Language::Rust);
398 file.path = "src/delete_me.rs".into();
399 file.content_hash = "d1".to_string();
400
401 repo.upsert(&branch, &file, None).unwrap();
402 repo.delete_by_path(&branch, "src/delete_me.rs")
403 .expect("delete should succeed");
404
405 let result = repo.get_by_path(&branch, "src/delete_me.rs");
406 assert!(matches!(result, Err(StorageError::NotFound { .. })));
407 }
408
409 #[test]
410 fn delete_by_path_not_found() {
411 let repo = test_repo();
412 let branch = BranchId::from("main");
413 let result = repo.delete_by_path(&branch, "nonexistent.rs");
414 assert!(matches!(result, Err(StorageError::NotFound { .. })));
415 }
416
417 #[test]
418 fn check_content_hash_matches() {
419 let repo = test_repo();
420 let branch = BranchId::from("main");
421 let mut file = make_project_file(Language::Rust);
422 file.path = "src/check.rs".into();
423 file.content_hash = "correct_hash".to_string();
424
425 repo.upsert(&branch, &file, None).unwrap();
426
427 assert!(
428 repo.check_content_hash(&branch, "src/check.rs", "correct_hash")
429 .unwrap()
430 );
431 }
432
433 #[test]
434 fn check_content_hash_mismatch() {
435 let repo = test_repo();
436 let branch = BranchId::from("main");
437 let mut file = make_project_file(Language::Rust);
438 file.path = "src/check.rs".into();
439 file.content_hash = "hash_a".to_string();
440
441 repo.upsert(&branch, &file, None).unwrap();
442
443 assert!(
444 !repo
445 .check_content_hash(&branch, "src/check.rs", "hash_b")
446 .unwrap()
447 );
448 }
449
450 #[test]
451 fn check_content_hash_no_record() {
452 let repo = test_repo();
453 let branch = BranchId::from("main");
454
455 assert!(
456 !repo
457 .check_content_hash(&branch, "nonexistent.rs", "any_hash")
458 .unwrap()
459 );
460 }
461
462 #[test]
463 fn all_language_variants_roundtrip() {
464 let repo = test_repo();
465 let branch = BranchId::from("main");
466
467 let languages = [
468 Language::Rust,
469 Language::TypeScript,
470 Language::JavaScript,
471 Language::Python,
472 ];
473
474 for lang in languages {
475 let mut file = make_project_file(lang);
476 file.path = format!("test.{}", lang.extensions()[0]).into();
477 file.content_hash = format!("hash_{lang}");
478
479 repo.upsert(&branch, &file, None).unwrap();
480
481 let fetched = repo
482 .get_by_path(&branch, &file.path.to_string_lossy())
483 .unwrap();
484 assert_eq!(
485 fetched.language, lang,
486 "language roundtrip failed for {lang}"
487 );
488 }
489 }
490
491 #[test]
492 fn get_file_hashes_by_branch_returns_all_hashes() {
493 let repo = test_repo();
494 let branch = BranchId::from("main");
495
496 let mut f1 = make_project_file(Language::Rust);
497 f1.path = "src/main.rs".into();
498 f1.content_hash = "hash_main".to_string();
499
500 let mut f2 = make_project_file(Language::Rust);
501 f2.path = "src/lib.rs".into();
502 f2.content_hash = "hash_lib".to_string();
503
504 repo.upsert(&branch, &f1, None).unwrap();
505 repo.upsert(&branch, &f2, None).unwrap();
506
507 let hashes = repo.get_file_hashes_by_branch(&branch).unwrap();
508 assert_eq!(hashes.len(), 2);
509 assert_eq!(hashes.get("src/main.rs").unwrap(), "hash_main");
510 assert_eq!(hashes.get("src/lib.rs").unwrap(), "hash_lib");
511 }
512
513 #[test]
514 fn get_file_hashes_by_branch_empty() {
515 let repo = test_repo();
516 let branch = BranchId::from("empty-branch");
517
518 let hashes = repo.get_file_hashes_by_branch(&branch).unwrap();
519 assert!(hashes.is_empty());
520 }
521
522 #[test]
523 fn get_file_hashes_by_branch_isolates_branches() {
524 let repo = test_repo();
525 let branch_a = BranchId::from("branch-a");
526 let branch_b = BranchId::from("branch-b");
527
528 let mut f1 = make_project_file(Language::Rust);
529 f1.path = "src/a.rs".into();
530 f1.content_hash = "hash_a".to_string();
531
532 let mut f2 = make_project_file(Language::Python);
533 f2.path = "src/b.py".into();
534 f2.content_hash = "hash_b".to_string();
535
536 repo.upsert(&branch_a, &f1, None).unwrap();
537 repo.upsert(&branch_b, &f2, None).unwrap();
538
539 let a_hashes = repo.get_file_hashes_by_branch(&branch_a).unwrap();
540 assert_eq!(a_hashes.len(), 1);
541 assert!(a_hashes.contains_key("src/a.rs"));
542
543 let b_hashes = repo.get_file_hashes_by_branch(&branch_b).unwrap();
544 assert_eq!(b_hashes.len(), 1);
545 assert!(b_hashes.contains_key("src/b.py"));
546 }
547
548 #[test]
549 fn upsert_stores_last_commit_date() {
550 let repo = test_repo();
551 let branch = BranchId::from("main");
552
553 let mut file = make_project_file(Language::Rust);
554 file.path = "src/dated.rs".into();
555 file.content_hash = "hash_dated".to_string();
556
557 let timestamp = 1_700_000_000_i64;
558 repo.upsert(&branch, &file, Some(timestamp)).unwrap();
559
560 let dates = repo.get_file_dates_by_branch(&branch).unwrap();
561 assert_eq!(dates.len(), 1);
562 assert_eq!(dates.get("src/dated.rs").unwrap(), &Some(timestamp));
563 }
564
565 #[test]
566 fn upsert_with_none_date() {
567 let repo = test_repo();
568 let branch = BranchId::from("main");
569
570 let mut file = make_project_file(Language::Rust);
571 file.path = "src/no_date.rs".into();
572 file.content_hash = "hash_nodate".to_string();
573
574 repo.upsert(&branch, &file, None).unwrap();
575
576 let dates = repo.get_file_dates_by_branch(&branch).unwrap();
577 assert_eq!(dates.len(), 1);
578 assert_eq!(dates.get("src/no_date.rs").unwrap(), &None);
579 }
580
581 #[test]
582 fn get_file_dates_by_branch_empty() {
583 let repo = test_repo();
584 let branch = BranchId::from("empty-branch");
585
586 let dates = repo.get_file_dates_by_branch(&branch).unwrap();
587 assert!(dates.is_empty());
588 }
589
590 #[test]
591 fn update_convention_compliance_counts_sets_values() {
592 let repo = test_repo();
593 let branch = BranchId::from("main");
594
595 let mut f1 = make_project_file(Language::Rust);
596 f1.path = "src/good.rs".into();
597 f1.content_hash = "h1".to_string();
598
599 let mut f2 = make_project_file(Language::Rust);
600 f2.path = "src/ok.rs".into();
601 f2.content_hash = "h2".to_string();
602
603 repo.upsert(&branch, &f1, None).unwrap();
604 repo.upsert(&branch, &f2, None).unwrap();
605
606 let mut counts = HashMap::new();
607 counts.insert("src/good.rs".to_string(), 5);
608 counts.insert("src/ok.rs".to_string(), 2);
609
610 repo.update_convention_compliance_counts(&branch, &counts)
611 .unwrap();
612
613 let conn = repo.conn.lock().unwrap();
615 let count: u32 = conn
616 .query_row(
617 "SELECT convention_compliance_count FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
618 params![branch.0, "src/good.rs"],
619 |row| row.get(0),
620 )
621 .unwrap();
622 assert_eq!(count, 5);
623
624 let count: u32 = conn
625 .query_row(
626 "SELECT convention_compliance_count FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
627 params![branch.0, "src/ok.rs"],
628 |row| row.get(0),
629 )
630 .unwrap();
631 assert_eq!(count, 2);
632 }
633
634 #[test]
635 fn update_convention_compliance_counts_resets_missing_files() {
636 let repo = test_repo();
637 let branch = BranchId::from("main");
638
639 let mut f1 = make_project_file(Language::Rust);
640 f1.path = "src/a.rs".into();
641 f1.content_hash = "h1".to_string();
642
643 let mut f2 = make_project_file(Language::Rust);
644 f2.path = "src/b.rs".into();
645 f2.content_hash = "h2".to_string();
646
647 repo.upsert(&branch, &f1, None).unwrap();
648 repo.upsert(&branch, &f2, None).unwrap();
649
650 let mut counts = HashMap::new();
652 counts.insert("src/a.rs".to_string(), 3);
653 counts.insert("src/b.rs".to_string(), 7);
654 repo.update_convention_compliance_counts(&branch, &counts)
655 .unwrap();
656
657 let mut counts2 = HashMap::new();
659 counts2.insert("src/a.rs".to_string(), 4);
660 repo.update_convention_compliance_counts(&branch, &counts2)
661 .unwrap();
662
663 let conn = repo.conn.lock().unwrap();
664 let count_a: u32 = conn
665 .query_row(
666 "SELECT convention_compliance_count FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
667 params![branch.0, "src/a.rs"],
668 |row| row.get(0),
669 )
670 .unwrap();
671 assert_eq!(count_a, 4);
672
673 let count_b: u32 = conn
674 .query_row(
675 "SELECT convention_compliance_count FROM files_ir WHERE branch_id = ?1 AND file_path = ?2",
676 params![branch.0, "src/b.rs"],
677 |row| row.get(0),
678 )
679 .unwrap();
680 assert_eq!(count_b, 0, "file not in counts map should be reset to 0");
681 }
682}