1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fs;
4use std::io::Write;
5use std::path::{Path, PathBuf};
6
7#[derive(Serialize, Deserialize, Debug)]
8pub enum WalEntry {
9 AddBegin {
10 path: String,
11 },
12 AddEnd,
13 CommitBegin {
14 head_backup: Option<String>,
15 index_backup: Vec<u8>,
16 },
17 CommitEnd,
18}
19
20pub struct Wal {
21 path: PathBuf,
22}
23
24impl Wal {
25 pub fn new(shard_dir: &Path) -> Self {
26 Self {
27 path: shard_dir.join("wal.log"),
28 }
29 }
30
31 pub fn exists(&self) -> bool {
32 self.path.exists()
33 }
34
35 pub fn append(&self, entry: &WalEntry) -> Result<()> {
36 let mut file = fs::OpenOptions::new()
37 .create(true)
38 .append(true)
39 .open(&self.path)?;
40 let line = serde_json::to_string(entry)?;
41 writeln!(file, "{}", line)?;
42 file.flush()?;
43 Ok(())
44 }
45
46 pub fn read(&self) -> Result<Vec<WalEntry>> {
47 if !self.path.exists() {
48 return Ok(Vec::new());
49 }
50 let content = fs::read_to_string(&self.path)?;
51 content
52 .lines()
53 .filter(|l| !l.trim().is_empty())
54 .map(|l| serde_json::from_str(l).map_err(Into::into))
55 .collect()
56 }
57
58 pub fn truncate(&self) -> Result<()> {
59 if self.path.exists() {
60 fs::remove_file(&self.path)?;
61 }
62 Ok(())
63 }
64}
65
66pub fn recover(shard_dir: &Path) -> Result<()> {
71 let wal = Wal::new(shard_dir);
72 if !wal.exists() {
73 return Ok(());
74 }
75
76 let entries = wal.read()?;
77 if entries.is_empty() {
78 wal.truncate()?;
79 return Ok(());
80 }
81
82 let has_commit_begin = entries
84 .iter()
85 .any(|e| matches!(e, WalEntry::CommitBegin { .. }));
86 let has_commit_end = entries.iter().any(|e| matches!(e, WalEntry::CommitEnd));
87
88 if has_commit_begin && !has_commit_end {
89 for entry in &entries {
91 if let WalEntry::CommitBegin {
92 head_backup,
93 index_backup,
94 } = entry
95 {
96 let head_path = shard_dir.join("HEAD");
97 match head_backup {
98 Some(head) => fs::write(&head_path, head)?,
99 None => {
100 let _ = fs::remove_file(&head_path);
101 }
102 }
103 fs::write(shard_dir.join("index"), index_backup)?;
104 }
105 }
106 eprintln!("Recovered from incomplete commit (rolled back)");
107 }
108
109 wal.truncate()?;
111 Ok(())
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use tempfile::tempdir;
118
119 #[test]
120 fn test_wal_append_read_roundtrip() {
121 let dir = tempdir().unwrap();
122 let wal = Wal::new(dir.path());
123
124 wal.append(&WalEntry::AddBegin {
125 path: "f.txt".into(),
126 })
127 .unwrap();
128 wal.append(&WalEntry::AddEnd).unwrap();
129
130 let entries = wal.read().unwrap();
131 assert_eq!(entries.len(), 2);
132 assert!(matches!(entries[0], WalEntry::AddBegin { .. }));
133 assert!(matches!(entries[1], WalEntry::AddEnd));
134 }
135
136 #[test]
137 fn test_wal_read_empty() {
138 let dir = tempdir().unwrap();
139 let wal = Wal::new(dir.path());
140 assert!(!wal.exists());
141 let entries = wal.read().unwrap();
142 assert!(entries.is_empty());
143 }
144
145 #[test]
146 fn test_wal_truncate() {
147 let dir = tempdir().unwrap();
148 let wal = Wal::new(dir.path());
149 wal.append(&WalEntry::AddEnd).unwrap();
150 assert!(wal.exists());
151 wal.truncate().unwrap();
152 assert!(!wal.exists());
153 }
154
155 #[test]
156 fn test_wal_commit_begin_end_roundtrip() {
157 let dir = tempdir().unwrap();
158 let wal = Wal::new(dir.path());
159 wal.append(&WalEntry::CommitBegin {
160 head_backup: Some("abc".into()),
161 index_backup: b"index_data".to_vec(),
162 })
163 .unwrap();
164 wal.append(&WalEntry::CommitEnd).unwrap();
165 let entries = wal.read().unwrap();
166 assert_eq!(entries.len(), 2);
167 if let WalEntry::CommitBegin {
168 head_backup,
169 index_backup,
170 } = &entries[0]
171 {
172 assert_eq!(head_backup.as_deref(), Some("abc"));
173 assert_eq!(index_backup, b"index_data");
174 } else {
175 panic!("Expected CommitBegin");
176 }
177 }
178
179 #[test]
180 fn test_recover_no_wal() {
181 let dir = tempdir().unwrap();
182 recover(dir.path()).unwrap();
183 }
184
185 #[test]
186 fn test_recover_empty_wal() {
187 let dir = tempdir().unwrap();
188 let wal = Wal::new(dir.path());
189 wal.append(&WalEntry::AddEnd).unwrap();
190 recover(dir.path()).unwrap();
191 assert!(!wal.exists());
192 }
193
194 #[test]
195 fn test_recover_incomplete_commit() {
196 let dir = tempdir().unwrap();
197 let shard = dir.path();
198
199 fs::write(shard.join("HEAD"), "ref: refs/heads/main").unwrap();
201 fs::write(shard.join("index"), b"original_index").unwrap();
202
203 let wal = Wal::new(shard);
204 wal.append(&WalEntry::CommitBegin {
205 head_backup: Some("ref: refs/heads/main".into()),
206 index_backup: b"original_index".to_vec(),
207 })
208 .unwrap();
209 fs::write(shard.join("HEAD"), "new_commit_id").unwrap();
213 fs::write(shard.join("index"), b"new_index").unwrap();
214
215 recover(shard).unwrap();
217 assert_eq!(
218 fs::read_to_string(shard.join("HEAD")).unwrap(),
219 "ref: refs/heads/main"
220 );
221 assert_eq!(fs::read(shard.join("index")).unwrap(), b"original_index");
222 }
223
224 #[test]
225 fn test_recover_incomplete_add() {
226 let dir = tempdir().unwrap();
227 let wal = Wal::new(dir.path());
228 wal.append(&WalEntry::AddBegin {
229 path: "f.txt".into(),
230 })
231 .unwrap();
232 recover(dir.path()).unwrap();
234 assert!(!wal.exists());
235 }
236}