Skip to main content

spreadsheet_mcp/
fork.rs

1use crate::security::canonicalize_and_enforce_within_workspace;
2use crate::utils::make_short_random_id;
3use anyhow::{Result, anyhow};
4use chrono::{DateTime, Utc};
5use parking_lot::Mutex;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use sha2::{Digest, Sha256};
10use std::collections::{BTreeMap, HashMap};
11use std::fs;
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16const FORK_DIR: &str = "/tmp/mcp-forks";
17const CHECKPOINT_DIR: &str = "/tmp/mcp-checkpoints";
18#[allow(dead_code)]
19const STAGED_SNAPSHOT_DIR: &str = "/tmp/mcp-staged";
20const DEFAULT_TTL_SECS: u64 = 0;
21const DEFAULT_MAX_FORKS: usize = 10;
22const CLEANUP_TASK_CHECK_SECS: u64 = 60;
23const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024; // 100MB
24const DEFAULT_MAX_CHECKPOINTS_PER_FORK: usize = 10;
25const DEFAULT_MAX_STAGED_CHANGES_PER_FORK: usize = 20;
26const DEFAULT_MAX_CHECKPOINT_TOTAL_BYTES: u64 = 500 * 1024 * 1024;
27
28#[derive(Debug, Clone)]
29pub struct EditOp {
30    pub timestamp: DateTime<Utc>,
31    pub sheet: String,
32    pub address: String,
33    pub value: String,
34    pub is_formula: bool,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct StagedOp {
39    pub kind: String,
40    pub payload: JsonValue,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
44pub struct ChangeSummary {
45    pub op_kinds: Vec<String>,
46    pub affected_sheets: Vec<String>,
47    pub affected_bounds: Vec<String>,
48    pub counts: BTreeMap<String, u64>,
49    #[serde(default)]
50    #[serde(skip_serializing_if = "BTreeMap::is_empty")]
51    pub flags: BTreeMap<String, bool>,
52    pub warnings: Vec<String>,
53}
54
55#[derive(Debug, Clone)]
56pub struct StagedChange {
57    pub change_id: String,
58    pub created_at: DateTime<Utc>,
59    pub label: Option<String>,
60    pub ops: Vec<StagedOp>,
61    pub summary: ChangeSummary,
62    pub fork_path_snapshot: Option<PathBuf>,
63}
64
65#[derive(Debug, Clone)]
66pub struct Checkpoint {
67    pub checkpoint_id: String,
68    pub created_at: DateTime<Utc>,
69    pub label: Option<String>,
70    pub snapshot_path: PathBuf,
71    pub recalc_needed: bool,
72}
73
74#[derive(Debug)]
75pub struct ForkContext {
76    pub fork_id: String,
77    pub base_path: PathBuf,
78    pub work_path: PathBuf,
79    pub created_at: Instant,
80    pub last_accessed: Instant,
81    pub edits: Vec<EditOp>,
82    pub staged_changes: Vec<StagedChange>,
83    pub checkpoints: Vec<Checkpoint>,
84    pub recalc_needed: bool,
85    base_hash: String,
86    base_modified: std::time::SystemTime,
87}
88
89impl ForkContext {
90    fn new(fork_id: String, base_path: PathBuf, work_path: PathBuf) -> Result<Self> {
91        let metadata = fs::metadata(&base_path)?;
92        let base_modified = metadata.modified()?;
93        let base_hash = hash_file(&base_path)?;
94
95        Ok(Self {
96            fork_id,
97            base_path,
98            work_path,
99            created_at: Instant::now(),
100            last_accessed: Instant::now(),
101            edits: Vec::new(),
102            staged_changes: Vec::new(),
103            checkpoints: Vec::new(),
104            recalc_needed: false,
105            base_hash,
106            base_modified,
107        })
108    }
109
110    pub fn is_expired(&self, ttl: Duration) -> bool {
111        if ttl.is_zero() {
112            return false;
113        }
114        self.last_accessed.elapsed() > ttl
115    }
116
117    pub fn touch(&mut self) {
118        self.last_accessed = Instant::now();
119    }
120
121    pub fn validate_base_unchanged(&self) -> Result<()> {
122        let metadata = fs::metadata(&self.base_path)?;
123        let current_modified = metadata.modified()?;
124
125        if current_modified != self.base_modified {
126            return Err(anyhow!("base file modified since fork creation"));
127        }
128
129        let current_hash = hash_file(&self.base_path)?;
130        if current_hash != self.base_hash {
131            return Err(anyhow!("base file content changed since fork creation"));
132        }
133
134        Ok(())
135    }
136
137    fn checkpoint_dir(&self) -> PathBuf {
138        PathBuf::from(CHECKPOINT_DIR).join(&self.fork_id)
139    }
140
141    fn cleanup_files(&self) {
142        let _ = fs::remove_file(&self.work_path);
143        for staged in &self.staged_changes {
144            remove_staged_snapshot(staged);
145        }
146        let checkpoint_dir = self.checkpoint_dir();
147        if checkpoint_dir.starts_with(CHECKPOINT_DIR) {
148            let _ = fs::remove_dir_all(&checkpoint_dir);
149        }
150    }
151}
152
153fn hash_file(path: &Path) -> Result<String> {
154    let contents = fs::read(path)?;
155    let mut hasher = Sha256::new();
156    hasher.update(&contents);
157    Ok(format!("{:x}", hasher.finalize()))
158}
159
160#[derive(Debug, Clone)]
161pub struct ForkConfig {
162    pub ttl: Duration,
163    pub max_forks: usize,
164    pub fork_dir: PathBuf,
165}
166
167impl Default for ForkConfig {
168    fn default() -> Self {
169        Self {
170            ttl: Duration::from_secs(DEFAULT_TTL_SECS),
171            max_forks: DEFAULT_MAX_FORKS,
172            fork_dir: PathBuf::from(FORK_DIR),
173        }
174    }
175}
176
177pub struct ForkRegistry {
178    forks: Mutex<HashMap<String, ForkContext>>,
179    config: ForkConfig,
180}
181
182impl ForkRegistry {
183    pub fn new(config: ForkConfig) -> Result<Self> {
184        fs::create_dir_all(&config.fork_dir)?;
185        Ok(Self {
186            forks: Mutex::new(HashMap::new()),
187            config,
188        })
189    }
190
191    pub fn start_cleanup_task(self: Arc<Self>) {
192        if self.config.ttl.is_zero() {
193            return;
194        }
195        tokio::spawn(async move {
196            let mut interval = tokio::time::interval(Duration::from_secs(CLEANUP_TASK_CHECK_SECS));
197            loop {
198                interval.tick().await;
199                self.evict_expired();
200            }
201        });
202    }
203
204    pub fn create_fork(&self, base_path: &Path, workspace_root: &Path) -> Result<String> {
205        self.evict_expired();
206
207        {
208            let forks = self.forks.lock();
209            if forks.len() >= self.config.max_forks {
210                return Err(anyhow!(
211                    "max forks ({}) reached, discard existing forks first",
212                    self.config.max_forks
213                ));
214            }
215        }
216
217        let ext = base_path
218            .extension()
219            .and_then(|e| e.to_str())
220            .map(|e| e.to_ascii_lowercase());
221
222        if ext.as_deref() != Some("xlsx") {
223            return Err(anyhow!(
224                "only .xlsx files supported for fork/recalc (got {:?})",
225                ext
226            ));
227        }
228
229        if !base_path.exists() {
230            return Err(anyhow!("base file does not exist: {:?}", base_path));
231        }
232
233        // Enforce workspace boundary using canonicalized, symlink-aware paths.
234        let base_path_canon = canonicalize_and_enforce_within_workspace(
235            workspace_root,
236            base_path,
237            "create_fork",
238            "base_path",
239        )?;
240
241        let metadata = fs::metadata(&base_path_canon)?;
242        if metadata.len() > MAX_FILE_SIZE {
243            return Err(anyhow!(
244                "base file too large: {} bytes (max {} MB)",
245                metadata.len(),
246                MAX_FILE_SIZE / 1024 / 1024
247            ));
248        }
249
250        let fork_id = {
251            let mut attempts: u32 = 0;
252            loop {
253                let candidate = make_short_random_id("fork", 12);
254                let work_path = self.config.fork_dir.join(format!("{}.xlsx", candidate));
255                let exists_in_registry = self.forks.lock().contains_key(&candidate);
256                if !exists_in_registry && !work_path.exists() {
257                    break candidate;
258                }
259                attempts += 1;
260                if attempts > 20 {
261                    return Err(anyhow!("failed to allocate unique fork id"));
262                }
263            }
264        };
265        let work_path = self.config.fork_dir.join(format!("{}.xlsx", fork_id));
266
267        fs::copy(&base_path_canon, &work_path)?;
268
269        let context = ForkContext::new(fork_id.clone(), base_path_canon, work_path)?;
270
271        self.forks.lock().insert(fork_id.clone(), context);
272
273        Ok(fork_id)
274    }
275
276    pub fn get_fork(&self, fork_id: &str) -> Result<Arc<ForkContext>> {
277        self.evict_expired();
278
279        let mut forks = self.forks.lock();
280        let ctx = forks
281            .get_mut(fork_id)
282            .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
283        ctx.touch();
284        Ok(Arc::new(ctx.clone()))
285    }
286
287    pub fn get_fork_path(&self, fork_id: &str) -> Option<PathBuf> {
288        let mut forks = self.forks.lock();
289        if let Some(ctx) = forks.get_mut(fork_id) {
290            ctx.touch();
291            return Some(ctx.work_path.clone());
292        }
293        None
294    }
295
296    pub fn with_fork_mut<F, R>(&self, fork_id: &str, f: F) -> Result<R>
297    where
298        F: FnOnce(&mut ForkContext) -> Result<R>,
299    {
300        let mut forks = self.forks.lock();
301        let ctx = forks
302            .get_mut(fork_id)
303            .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
304        ctx.touch();
305        f(ctx)
306    }
307
308    pub fn discard_fork(&self, fork_id: &str) -> Result<()> {
309        let mut forks = self.forks.lock();
310        if let Some(ctx) = forks.remove(fork_id) {
311            ctx.cleanup_files();
312        }
313        Ok(())
314    }
315
316    pub fn save_fork(
317        &self,
318        fork_id: &str,
319        target_path: &Path,
320        workspace_root: &Path,
321        drop_fork: bool,
322    ) -> Result<()> {
323        // Enforce workspace boundary using canonicalized, symlink-aware paths.
324        let _target_canon = canonicalize_and_enforce_within_workspace(
325            workspace_root,
326            target_path,
327            "save_fork",
328            "target_path",
329        )?;
330
331        let ext = target_path
332            .extension()
333            .and_then(|e| e.to_str())
334            .map(|e| e.to_ascii_lowercase());
335
336        if ext.as_deref() != Some("xlsx") {
337            return Err(anyhow!("target must be .xlsx"));
338        }
339
340        let mut forks = self.forks.lock();
341        let ctx = forks
342            .get(fork_id)
343            .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
344
345        ctx.validate_base_unchanged()?;
346
347        fs::copy(&ctx.work_path, target_path)?;
348
349        if drop_fork && let Some(ctx) = forks.remove(fork_id) {
350            let _ = fs::remove_file(&ctx.work_path);
351        }
352
353        Ok(())
354    }
355
356    pub fn ttl(&self) -> Duration {
357        self.config.ttl
358    }
359
360    pub fn list_forks(&self) -> Vec<ForkInfo> {
361        self.evict_expired();
362
363        let forks = self.forks.lock();
364        forks
365            .values()
366            .map(|ctx| ForkInfo {
367                fork_id: ctx.fork_id.clone(),
368                base_path: ctx.base_path.display().to_string(),
369                created_at: ctx.created_at,
370                edit_count: ctx.edits.len(),
371                recalc_needed: ctx.recalc_needed,
372            })
373            .collect()
374    }
375
376    pub fn create_checkpoint(&self, fork_id: &str, label: Option<String>) -> Result<Checkpoint> {
377        self.evict_expired();
378
379        let work_path = {
380            let forks = self.forks.lock();
381            let ctx = forks
382                .get(fork_id)
383                .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
384            ctx.work_path.clone()
385        };
386
387        let checkpoint_id = make_short_random_id("cp", 12);
388        let dir = PathBuf::from(CHECKPOINT_DIR).join(fork_id);
389        fs::create_dir_all(&dir)?;
390        let snapshot_path = dir.join(format!("{}.xlsx", checkpoint_id));
391        fs::copy(&work_path, &snapshot_path)?;
392
393        let recalc_needed = self
394            .get_fork(fork_id)
395            .map(|ctx| ctx.recalc_needed)
396            .unwrap_or(false);
397
398        let checkpoint = Checkpoint {
399            checkpoint_id: checkpoint_id.clone(),
400            created_at: Utc::now(),
401            label,
402            snapshot_path,
403            recalc_needed,
404        };
405
406        self.with_fork_mut(fork_id, |ctx| {
407            ctx.checkpoints.push(checkpoint.clone());
408            enforce_checkpoint_limits(ctx)?;
409            Ok(())
410        })?;
411
412        Ok(checkpoint)
413    }
414
415    pub fn list_checkpoints(&self, fork_id: &str) -> Result<Vec<Checkpoint>> {
416        let ctx = self.get_fork(fork_id)?;
417        Ok(ctx.checkpoints.clone())
418    }
419
420    pub fn delete_checkpoint(&self, fork_id: &str, checkpoint_id: &str) -> Result<()> {
421        self.with_fork_mut(fork_id, |ctx| {
422            let index = ctx
423                .checkpoints
424                .iter()
425                .position(|c| c.checkpoint_id == checkpoint_id)
426                .ok_or_else(|| anyhow!("checkpoint not found: {}", checkpoint_id))?;
427            let removed = ctx.checkpoints.remove(index);
428            let _ = fs::remove_file(&removed.snapshot_path);
429            Ok(())
430        })
431    }
432
433    pub fn restore_checkpoint(&self, fork_id: &str, checkpoint_id: &str) -> Result<Checkpoint> {
434        self.evict_expired();
435
436        let (work_path, checkpoint) = {
437            let forks = self.forks.lock();
438            let ctx = forks
439                .get(fork_id)
440                .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
441            let checkpoint = ctx
442                .checkpoints
443                .iter()
444                .find(|c| c.checkpoint_id == checkpoint_id)
445                .cloned()
446                .ok_or_else(|| anyhow!("checkpoint not found: {}", checkpoint_id))?;
447            (ctx.work_path.clone(), checkpoint)
448        };
449
450        fs::copy(&checkpoint.snapshot_path, &work_path)?;
451
452        self.with_fork_mut(fork_id, |ctx| {
453            let cutoff = checkpoint.created_at;
454            ctx.edits.retain(|e| e.timestamp <= cutoff);
455            let mut i = 0;
456            while i < ctx.staged_changes.len() {
457                if ctx.staged_changes[i].created_at > cutoff {
458                    let removed = ctx.staged_changes.remove(i);
459                    remove_staged_snapshot(&removed);
460                } else {
461                    i += 1;
462                }
463            }
464            ctx.recalc_needed = checkpoint.recalc_needed;
465            Ok(())
466        })?;
467
468        Ok(checkpoint)
469    }
470
471    pub fn add_staged_change(&self, fork_id: &str, staged: StagedChange) -> Result<()> {
472        self.with_fork_mut(fork_id, |ctx| {
473            ctx.staged_changes.push(staged);
474            enforce_staged_limits(ctx);
475            Ok(())
476        })
477    }
478
479    pub fn list_staged_changes(&self, fork_id: &str) -> Result<Vec<StagedChange>> {
480        let ctx = self.get_fork(fork_id)?;
481        Ok(ctx.staged_changes.clone())
482    }
483
484    pub fn take_staged_change(&self, fork_id: &str, change_id: &str) -> Result<StagedChange> {
485        self.with_fork_mut(fork_id, |ctx| {
486            let index = ctx
487                .staged_changes
488                .iter()
489                .position(|c| c.change_id == change_id)
490                .ok_or_else(|| anyhow!("staged change not found: {}", change_id))?;
491            Ok(ctx.staged_changes.remove(index))
492        })
493    }
494
495    pub fn discard_staged_change(&self, fork_id: &str, change_id: &str) -> Result<()> {
496        let removed = self.take_staged_change(fork_id, change_id)?;
497        remove_staged_snapshot(&removed);
498        Ok(())
499    }
500
501    fn evict_expired(&self) {
502        if self.config.ttl.is_zero() {
503            return;
504        }
505        let mut forks = self.forks.lock();
506        let expired: Vec<String> = forks
507            .iter()
508            .filter(|(_, ctx)| ctx.is_expired(self.config.ttl))
509            .map(|(id, _)| id.clone())
510            .collect();
511
512        for id in expired {
513            if let Some(ctx) = forks.remove(&id) {
514                ctx.cleanup_files();
515                tracing::debug!(fork_id = %id, "evicted expired fork");
516            }
517        }
518    }
519}
520
521fn remove_staged_snapshot(staged: &StagedChange) {
522    if let Some(path) = staged.fork_path_snapshot.as_ref() {
523        let _ = fs::remove_file(path);
524    }
525}
526
527fn enforce_staged_limits(ctx: &mut ForkContext) {
528    while ctx.staged_changes.len() > DEFAULT_MAX_STAGED_CHANGES_PER_FORK {
529        let removed = ctx.staged_changes.remove(0);
530        remove_staged_snapshot(&removed);
531    }
532}
533
534fn enforce_checkpoint_limits(ctx: &mut ForkContext) -> Result<()> {
535    while ctx.checkpoints.len() > DEFAULT_MAX_CHECKPOINTS_PER_FORK {
536        let removed = ctx.checkpoints.remove(0);
537        let _ = fs::remove_file(&removed.snapshot_path);
538    }
539
540    loop {
541        let mut total_bytes = 0u64;
542        for cp in &ctx.checkpoints {
543            if let Ok(meta) = fs::metadata(&cp.snapshot_path) {
544                total_bytes += meta.len();
545            }
546        }
547        if total_bytes <= DEFAULT_MAX_CHECKPOINT_TOTAL_BYTES || ctx.checkpoints.len() <= 1 {
548            break;
549        }
550        let removed = ctx.checkpoints.remove(0);
551        let _ = fs::remove_file(&removed.snapshot_path);
552    }
553
554    Ok(())
555}
556
557impl Clone for ForkContext {
558    fn clone(&self) -> Self {
559        Self {
560            fork_id: self.fork_id.clone(),
561            base_path: self.base_path.clone(),
562            work_path: self.work_path.clone(),
563            created_at: self.created_at,
564            last_accessed: self.last_accessed,
565            edits: self.edits.clone(),
566            staged_changes: self.staged_changes.clone(),
567            checkpoints: self.checkpoints.clone(),
568            recalc_needed: self.recalc_needed,
569            base_hash: self.base_hash.clone(),
570            base_modified: self.base_modified,
571        }
572    }
573}
574
575#[derive(Debug, Clone)]
576pub struct ForkInfo {
577    pub fork_id: String,
578    pub base_path: String,
579    pub created_at: Instant,
580    pub edit_count: usize,
581    pub recalc_needed: bool,
582}