1use crate::platform;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::io::{self, Write};
12use std::path::{Path, PathBuf};
13use tempfile::NamedTempFile;
14
15pub const JOURNAL_FILENAME: &str = ".toggle-atomic.journal";
17
18pub const LOCK_FILENAME: &str = ".toggle-atomic.lock";
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23pub enum JournalStatus {
24 Staged,
26 Committing,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct JournalEntry {
34 pub target_path: PathBuf,
36 pub temp_path: PathBuf,
38 pub backup_path: Option<PathBuf>,
41 pub content_sha256: String,
43 pub rename_completed: bool,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct Journal {
50 pub version: u32,
52 pub status: JournalStatus,
54 pub created_at: String,
56 pub backup_enabled: bool,
58 pub entries: Vec<JournalEntry>,
60}
61
62impl Journal {
63 pub fn new(entries: Vec<JournalEntry>, backup_enabled: bool) -> Self {
65 let now = chrono_lite_now();
66 Self {
67 version: 1,
68 status: JournalStatus::Staged,
69 created_at: now,
70 backup_enabled,
71 entries,
72 }
73 }
74
75 pub fn transition_to_committing(&mut self) {
77 self.status = JournalStatus::Committing;
78 }
79
80 pub fn mark_entry_completed(&mut self, index: usize) {
82 if let Some(entry) = self.entries.get_mut(index) {
83 entry.rename_completed = true;
84 }
85 }
86}
87
88pub fn sha256_hex(data: &[u8]) -> String {
90 let mut hasher = Sha256::new();
91 hasher.update(data);
92 format!("{:x}", hasher.finalize())
93}
94
95pub fn sha256_file(path: &Path) -> io::Result<String> {
97 let data = std::fs::read(path)?;
98 Ok(sha256_hex(&data))
99}
100
101pub fn journal_dir(targets: &[PathBuf]) -> io::Result<PathBuf> {
104 let cwd = std::env::current_dir()?;
105 match NamedTempFile::new_in(&cwd) {
107 Ok(_) => Ok(cwd),
108 Err(_) => {
109 if let Some(first) = targets.first() {
111 if let Some(parent) = first.parent() {
112 eprintln!(
113 "Warning: CWD is not writable. Using '{}' for journal.",
114 parent.display()
115 );
116 return Ok(parent.to_path_buf());
117 }
118 }
119 Err(io::Error::new(
120 io::ErrorKind::PermissionDenied,
121 "Cannot create journal: CWD is not writable and no target files specified",
122 ))
123 }
124 }
125}
126
127pub fn persist_journal(journal: &Journal, journal_path: &Path) -> io::Result<()> {
130 let dir = journal_path.parent().unwrap_or(Path::new("."));
131 let mut tmp = NamedTempFile::new_in(dir)?;
132 let json = serde_json::to_string_pretty(journal)
133 .map_err(|e| io::Error::other(format!("Failed to serialize journal: {}", e)))?;
134 tmp.write_all(json.as_bytes())?;
135 platform::durable_sync(tmp.as_file())?;
136 tmp.persist(journal_path).map_err(|e| e.error)?;
137 Ok(())
138}
139
140pub fn persist_journal_best_effort(journal: &Journal, journal_path: &Path) {
144 let dir = journal_path.parent().unwrap_or(Path::new("."));
145 if let Ok(mut tmp) = NamedTempFile::new_in(dir) {
146 if let Ok(json) = serde_json::to_string_pretty(journal) {
147 if tmp.write_all(json.as_bytes()).is_ok() {
148 let _ = tmp.persist(journal_path);
149 }
150 }
151 }
152}
153
154pub fn read_journal(journal_path: &Path) -> io::Result<Option<Journal>> {
157 match std::fs::read_to_string(journal_path) {
158 Ok(content) => {
159 let journal: Journal = serde_json::from_str(&content).map_err(|e| {
160 io::Error::new(
161 io::ErrorKind::InvalidData,
162 format!(
163 "Journal file is corrupted ({}). Manual inspection required: {}",
164 e,
165 journal_path.display()
166 ),
167 )
168 })?;
169 Ok(Some(journal))
170 }
171 Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(None),
172 Err(e) => Err(e),
173 }
174}
175
176pub fn delete_journal(journal_path: &Path) -> io::Result<()> {
178 match std::fs::remove_file(journal_path) {
179 Ok(()) => Ok(()),
180 Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()),
181 Err(e) => Err(e),
182 }
183}
184
185pub fn recover_staged(journal: &Journal, journal_path: &Path) -> io::Result<()> {
188 eprintln!("Recovering from Staged state: cleaning up temp files...");
189 for entry in &journal.entries {
190 if entry.temp_path.exists() {
191 if let Err(e) = std::fs::remove_file(&entry.temp_path) {
192 eprintln!(
193 "Warning: failed to delete temp file '{}': {}",
194 entry.temp_path.display(),
195 e
196 );
197 }
198 }
199 }
200 delete_journal(journal_path)?;
201 eprintln!("Recovery complete. No original files were modified.");
202 Ok(())
203}
204
205pub fn recover_rollback(journal: &Journal, journal_path: &Path) -> io::Result<()> {
208 eprintln!("Recovering from Committing state: rolling back...");
209
210 let completed: Vec<_> = journal
212 .entries
213 .iter()
214 .filter(|e| e.rename_completed)
215 .collect();
216 let pending: Vec<_> = journal
217 .entries
218 .iter()
219 .filter(|e| !e.rename_completed)
220 .collect();
221
222 if !completed.is_empty() {
223 eprintln!(" {} file(s) were renamed:", completed.len());
224 for e in &completed {
225 eprintln!(" {}", e.target_path.display());
226 }
227 }
228 if !pending.is_empty() {
229 eprintln!(" {} file(s) were NOT renamed:", pending.len());
230 for e in &pending {
231 eprintln!(" {}", e.target_path.display());
232 }
233 }
234
235 if !journal.backup_enabled {
236 eprintln!("Error: --no-backup was used. Cannot roll back completed renames automatically.");
237 eprintln!("Manual intervention required for the files listed above.");
238 for entry in &pending {
240 if entry.temp_path.exists() {
241 let _ = std::fs::remove_file(&entry.temp_path);
242 }
243 }
244 delete_journal(journal_path)?;
245 return Err(io::Error::other(
246 "Rollback impossible without backups. See output above for affected files.",
247 ));
248 }
249
250 let mut errors = Vec::new();
252 for entry in completed.iter().rev() {
253 if let Some(ref backup_path) = entry.backup_path {
254 if backup_path.exists() {
255 if let Err(e) = platform::rename_with_retry(backup_path, &entry.target_path) {
256 errors.push(format!(
257 "Failed to restore '{}' from backup '{}': {}",
258 entry.target_path.display(),
259 backup_path.display(),
260 e
261 ));
262 } else {
263 eprintln!(" Restored: {}", entry.target_path.display());
264 }
265 } else {
266 errors.push(format!(
267 "Backup file missing for '{}': expected '{}'",
268 entry.target_path.display(),
269 backup_path.display()
270 ));
271 }
272 }
273 }
274
275 for entry in &pending {
277 if entry.temp_path.exists() {
278 let _ = std::fs::remove_file(&entry.temp_path);
279 }
280 }
281
282 if !errors.is_empty() {
283 eprintln!("Rollback completed with errors:");
284 for err in &errors {
285 eprintln!(" {}", err);
286 }
287 return Err(io::Error::other(format!(
289 "{} rollback error(s) occurred. Journal preserved.",
290 errors.len()
291 )));
292 }
293
294 for entry in &completed {
296 if let Some(ref backup_path) = entry.backup_path {
297 let _ = std::fs::remove_file(backup_path);
298 }
299 }
300
301 delete_journal(journal_path)?;
302 eprintln!("Rollback complete. All files restored to pre-operation state.");
303 Ok(())
304}
305
306pub fn recover_forward(journal: &Journal, journal_path: &Path) -> io::Result<()> {
309 eprintln!("Forward recovery: completing interrupted commit...");
310
311 let pending: Vec<(usize, &JournalEntry)> = journal
312 .entries
313 .iter()
314 .enumerate()
315 .filter(|(_, e)| !e.rename_completed)
316 .collect();
317
318 if pending.is_empty() {
319 eprintln!("All renames were already completed. Cleaning up.");
320 for entry in &journal.entries {
322 if let Some(ref backup_path) = entry.backup_path {
323 let _ = std::fs::remove_file(backup_path);
324 }
325 }
326 delete_journal(journal_path)?;
327 return Ok(());
328 }
329
330 eprintln!(" {} file(s) remaining to rename.", pending.len());
331
332 let mut updated_journal = journal.clone();
333 let mut errors = Vec::new();
334
335 for (idx, entry) in &pending {
336 if !entry.temp_path.exists() {
338 errors.push(format!(
339 "Temp file missing for '{}': expected '{}'",
340 entry.target_path.display(),
341 entry.temp_path.display()
342 ));
343 continue;
344 }
345
346 match sha256_file(&entry.temp_path) {
348 Ok(hash) if hash == entry.content_sha256 => {}
349 Ok(hash) => {
350 errors.push(format!(
351 "SHA-256 mismatch for '{}': expected {}, got {}",
352 entry.temp_path.display(),
353 entry.content_sha256,
354 hash
355 ));
356 continue;
357 }
358 Err(e) => {
359 errors.push(format!(
360 "Cannot read temp file '{}': {}",
361 entry.temp_path.display(),
362 e
363 ));
364 continue;
365 }
366 }
367
368 if entry.target_path.exists() {
370 if let Ok(meta) = std::fs::metadata(&entry.target_path) {
371 let _ = std::fs::set_permissions(&entry.temp_path, meta.permissions());
372 }
373 }
374
375 match platform::rename_with_retry(&entry.temp_path, &entry.target_path) {
377 Ok(()) => {
378 eprintln!(" Renamed: {}", entry.target_path.display());
379 updated_journal.mark_entry_completed(*idx);
380 persist_journal_best_effort(&updated_journal, journal_path);
381 }
382 Err(e) => {
383 errors.push(format!(
384 "Failed to rename '{}' -> '{}': {}",
385 entry.temp_path.display(),
386 entry.target_path.display(),
387 e
388 ));
389 break;
391 }
392 }
393 }
394
395 if !errors.is_empty() {
396 eprintln!("Forward recovery incomplete:");
397 for err in &errors {
398 eprintln!(" {}", err);
399 }
400 persist_journal(&updated_journal, journal_path)?;
401 return Err(io::Error::other(format!(
402 "{} error(s) during forward recovery. Journal preserved for retry.",
403 errors.len()
404 )));
405 }
406
407 for entry in &journal.entries {
409 if let Some(ref backup_path) = entry.backup_path {
410 let _ = std::fs::remove_file(backup_path);
411 }
412 }
413 delete_journal(journal_path)?;
414 eprintln!("Forward recovery complete. All files updated.");
415 Ok(())
416}
417
418pub fn perform_recovery(journal_path: &Path, forward: bool) -> io::Result<()> {
420 let journal = match read_journal(journal_path)? {
421 Some(j) => j,
422 None => {
423 eprintln!("No journal found. Nothing to recover.");
424 return Ok(());
425 }
426 };
427
428 match journal.status {
429 JournalStatus::Staged => {
430 if forward {
431 eprintln!(
432 "Warning: --recover-forward has no effect in Staged state. \
433 No renames occurred. Rolling back."
434 );
435 }
436 recover_staged(&journal, journal_path)
437 }
438 JournalStatus::Committing => {
439 if forward {
440 recover_forward(&journal, journal_path)
441 } else {
442 recover_rollback(&journal, journal_path)
443 }
444 }
445 }
446}
447
448fn chrono_lite_now() -> String {
450 use std::time::SystemTime;
451 match SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) {
452 Ok(d) => format!("{}s-since-epoch", d.as_secs()),
453 Err(_) => "unknown".to_string(),
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use tempfile::TempDir;
461
462 #[test]
463 fn test_sha256_hex() {
464 let hash = sha256_hex(b"hello world");
465 assert_eq!(
466 hash,
467 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
468 );
469 }
470
471 #[test]
472 fn test_journal_roundtrip() {
473 let dir = TempDir::new().unwrap();
474 let journal_path = dir.path().join(JOURNAL_FILENAME);
475
476 let journal = Journal::new(
477 vec![JournalEntry {
478 target_path: PathBuf::from("/tmp/test.py"),
479 temp_path: PathBuf::from("/tmp/.tmpXXXX"),
480 backup_path: Some(PathBuf::from("/tmp/test.py.bak")),
481 content_sha256: "abc123".to_string(),
482 rename_completed: false,
483 }],
484 true,
485 );
486
487 persist_journal(&journal, &journal_path).unwrap();
488 let loaded = read_journal(&journal_path).unwrap().unwrap();
489 assert_eq!(loaded.version, 1);
490 assert_eq!(loaded.status, JournalStatus::Staged);
491 assert!(loaded.backup_enabled);
492 assert_eq!(loaded.entries.len(), 1);
493 assert_eq!(loaded.entries[0].target_path, PathBuf::from("/tmp/test.py"));
494 assert_eq!(loaded.entries[0].content_sha256, "abc123");
495 assert!(!loaded.entries[0].rename_completed);
496 }
497
498 #[test]
499 fn test_journal_not_found() {
500 let result = read_journal(Path::new("/nonexistent/.toggle-atomic.journal")).unwrap();
501 assert!(result.is_none());
502 }
503
504 #[test]
505 fn test_journal_corrupt() {
506 let dir = TempDir::new().unwrap();
507 let journal_path = dir.path().join(JOURNAL_FILENAME);
508 std::fs::write(&journal_path, "not valid json {{{").unwrap();
509 let result = read_journal(&journal_path);
510 assert!(result.is_err());
511 }
512
513 #[test]
514 fn test_recover_staged_cleans_temps() {
515 let dir = TempDir::new().unwrap();
516 let temp_file = dir.path().join("temp_staged");
517 std::fs::write(&temp_file, "staged content").unwrap();
518 let journal_path = dir.path().join(JOURNAL_FILENAME);
519
520 let journal = Journal::new(
521 vec![JournalEntry {
522 target_path: dir.path().join("target.py"),
523 temp_path: temp_file.clone(),
524 backup_path: None,
525 content_sha256: "xxx".to_string(),
526 rename_completed: false,
527 }],
528 false,
529 );
530 persist_journal(&journal, &journal_path).unwrap();
531
532 recover_staged(&journal, &journal_path).unwrap();
533 assert!(!temp_file.exists());
534 assert!(!journal_path.exists());
535 }
536
537 #[test]
538 fn test_status_transitions() {
539 let mut journal = Journal::new(vec![], true);
540 assert_eq!(journal.status, JournalStatus::Staged);
541 journal.transition_to_committing();
542 assert_eq!(journal.status, JournalStatus::Committing);
543 }
544
545 #[test]
546 fn test_journal_with_unicode_paths() {
547 let dir = TempDir::new().unwrap();
548 let journal_path = dir.path().join(JOURNAL_FILENAME);
549
550 let journal = Journal::new(
551 vec![JournalEntry {
552 target_path: PathBuf::from("/tmp/café/données.py"),
553 temp_path: PathBuf::from("/tmp/café/.tmpXXXX"),
554 backup_path: Some(PathBuf::from("/tmp/café/données.py.bak")),
555 content_sha256: "abc".to_string(),
556 rename_completed: false,
557 }],
558 true,
559 );
560
561 persist_journal(&journal, &journal_path).unwrap();
562 let loaded = read_journal(&journal_path).unwrap().unwrap();
563 assert_eq!(
564 loaded.entries[0].target_path,
565 PathBuf::from("/tmp/café/données.py")
566 );
567 }
568}