1use crate::security::canonicalize_and_enforce_within_workspace;
2use crate::utils::make_short_random_id;
3use anyhow::{Result, anyhow};
4use chrono::{DateTime, Utc};
5use parking_lot::Mutex;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use sha2::{Digest, Sha256};
10use std::collections::{BTreeMap, HashMap};
11use std::fs;
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16const FORK_DIR: &str = "/tmp/mcp-forks";
17const CHECKPOINT_DIR: &str = "/tmp/mcp-checkpoints";
18#[allow(dead_code)]
19const STAGED_SNAPSHOT_DIR: &str = "/tmp/mcp-staged";
20const DEFAULT_TTL_SECS: u64 = 0;
21const DEFAULT_MAX_FORKS: usize = 10;
22const CLEANUP_TASK_CHECK_SECS: u64 = 60;
23const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024; const DEFAULT_MAX_CHECKPOINTS_PER_FORK: usize = 10;
25const DEFAULT_MAX_STAGED_CHANGES_PER_FORK: usize = 20;
26const DEFAULT_MAX_CHECKPOINT_TOTAL_BYTES: u64 = 500 * 1024 * 1024;
27
28#[derive(Debug, Clone)]
29pub struct EditOp {
30 pub timestamp: DateTime<Utc>,
31 pub sheet: String,
32 pub address: String,
33 pub value: String,
34 pub is_formula: bool,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct StagedOp {
39 pub kind: String,
40 pub payload: JsonValue,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
44pub struct ChangeSummary {
45 pub op_kinds: Vec<String>,
46 pub affected_sheets: Vec<String>,
47 pub affected_bounds: Vec<String>,
48 pub counts: BTreeMap<String, u64>,
49 #[serde(default)]
50 #[serde(skip_serializing_if = "BTreeMap::is_empty")]
51 pub flags: BTreeMap<String, bool>,
52 pub warnings: Vec<String>,
53}
54
55#[derive(Debug, Clone)]
56pub struct StagedChange {
57 pub change_id: String,
58 pub created_at: DateTime<Utc>,
59 pub label: Option<String>,
60 pub ops: Vec<StagedOp>,
61 pub summary: ChangeSummary,
62 pub fork_path_snapshot: Option<PathBuf>,
63}
64
65#[derive(Debug, Clone)]
66pub struct Checkpoint {
67 pub checkpoint_id: String,
68 pub created_at: DateTime<Utc>,
69 pub label: Option<String>,
70 pub snapshot_path: PathBuf,
71 pub recalc_needed: bool,
72}
73
74#[derive(Debug)]
75pub struct ForkContext {
76 pub fork_id: String,
77 pub base_path: PathBuf,
78 pub work_path: PathBuf,
79 pub created_at: Instant,
80 pub last_accessed: Instant,
81 pub edits: Vec<EditOp>,
82 pub staged_changes: Vec<StagedChange>,
83 pub checkpoints: Vec<Checkpoint>,
84 pub recalc_needed: bool,
85 base_hash: String,
86 base_modified: std::time::SystemTime,
87}
88
89impl ForkContext {
90 fn new(fork_id: String, base_path: PathBuf, work_path: PathBuf) -> Result<Self> {
91 let metadata = fs::metadata(&base_path)?;
92 let base_modified = metadata.modified()?;
93 let base_hash = hash_file(&base_path)?;
94
95 Ok(Self {
96 fork_id,
97 base_path,
98 work_path,
99 created_at: Instant::now(),
100 last_accessed: Instant::now(),
101 edits: Vec::new(),
102 staged_changes: Vec::new(),
103 checkpoints: Vec::new(),
104 recalc_needed: false,
105 base_hash,
106 base_modified,
107 })
108 }
109
110 pub fn is_expired(&self, ttl: Duration) -> bool {
111 if ttl.is_zero() {
112 return false;
113 }
114 self.last_accessed.elapsed() > ttl
115 }
116
117 pub fn touch(&mut self) {
118 self.last_accessed = Instant::now();
119 }
120
121 pub fn validate_base_unchanged(&self) -> Result<()> {
122 let metadata = fs::metadata(&self.base_path)?;
123 let current_modified = metadata.modified()?;
124
125 if current_modified != self.base_modified {
126 return Err(anyhow!("base file modified since fork creation"));
127 }
128
129 let current_hash = hash_file(&self.base_path)?;
130 if current_hash != self.base_hash {
131 return Err(anyhow!("base file content changed since fork creation"));
132 }
133
134 Ok(())
135 }
136
137 fn checkpoint_dir(&self) -> PathBuf {
138 PathBuf::from(CHECKPOINT_DIR).join(&self.fork_id)
139 }
140
141 fn cleanup_files(&self) {
142 let _ = fs::remove_file(&self.work_path);
143 for staged in &self.staged_changes {
144 remove_staged_snapshot(staged);
145 }
146 let checkpoint_dir = self.checkpoint_dir();
147 if checkpoint_dir.starts_with(CHECKPOINT_DIR) {
148 let _ = fs::remove_dir_all(&checkpoint_dir);
149 }
150 }
151}
152
153fn hash_file(path: &Path) -> Result<String> {
154 let contents = fs::read(path)?;
155 let mut hasher = Sha256::new();
156 hasher.update(&contents);
157 Ok(format!("{:x}", hasher.finalize()))
158}
159
160#[derive(Debug, Clone)]
161pub struct ForkConfig {
162 pub ttl: Duration,
163 pub max_forks: usize,
164 pub fork_dir: PathBuf,
165}
166
167impl Default for ForkConfig {
168 fn default() -> Self {
169 Self {
170 ttl: Duration::from_secs(DEFAULT_TTL_SECS),
171 max_forks: DEFAULT_MAX_FORKS,
172 fork_dir: PathBuf::from(FORK_DIR),
173 }
174 }
175}
176
177pub struct ForkRegistry {
178 forks: Mutex<HashMap<String, ForkContext>>,
179 config: ForkConfig,
180}
181
182impl ForkRegistry {
183 pub fn new(config: ForkConfig) -> Result<Self> {
184 fs::create_dir_all(&config.fork_dir)?;
185 Ok(Self {
186 forks: Mutex::new(HashMap::new()),
187 config,
188 })
189 }
190
191 pub fn start_cleanup_task(self: Arc<Self>) {
192 if self.config.ttl.is_zero() {
193 return;
194 }
195 tokio::spawn(async move {
196 let mut interval = tokio::time::interval(Duration::from_secs(CLEANUP_TASK_CHECK_SECS));
197 loop {
198 interval.tick().await;
199 self.evict_expired();
200 }
201 });
202 }
203
204 pub fn create_fork(&self, base_path: &Path, workspace_root: &Path) -> Result<String> {
205 self.evict_expired();
206
207 {
208 let forks = self.forks.lock();
209 if forks.len() >= self.config.max_forks {
210 return Err(anyhow!(
211 "max forks ({}) reached, discard existing forks first",
212 self.config.max_forks
213 ));
214 }
215 }
216
217 let ext = base_path
218 .extension()
219 .and_then(|e| e.to_str())
220 .map(|e| e.to_ascii_lowercase());
221
222 if ext.as_deref() != Some("xlsx") {
223 return Err(anyhow!(
224 "only .xlsx files supported for fork/recalc (got {:?})",
225 ext
226 ));
227 }
228
229 if !base_path.exists() {
230 return Err(anyhow!("base file does not exist: {:?}", base_path));
231 }
232
233 let base_path_canon = canonicalize_and_enforce_within_workspace(
235 workspace_root,
236 base_path,
237 "create_fork",
238 "base_path",
239 )?;
240
241 let metadata = fs::metadata(&base_path_canon)?;
242 if metadata.len() > MAX_FILE_SIZE {
243 return Err(anyhow!(
244 "base file too large: {} bytes (max {} MB)",
245 metadata.len(),
246 MAX_FILE_SIZE / 1024 / 1024
247 ));
248 }
249
250 let fork_id = {
251 let mut attempts: u32 = 0;
252 loop {
253 let candidate = make_short_random_id("fork", 12);
254 let work_path = self.config.fork_dir.join(format!("{}.xlsx", candidate));
255 let exists_in_registry = self.forks.lock().contains_key(&candidate);
256 if !exists_in_registry && !work_path.exists() {
257 break candidate;
258 }
259 attempts += 1;
260 if attempts > 20 {
261 return Err(anyhow!("failed to allocate unique fork id"));
262 }
263 }
264 };
265 let work_path = self.config.fork_dir.join(format!("{}.xlsx", fork_id));
266
267 fs::copy(&base_path_canon, &work_path)?;
268
269 let context = ForkContext::new(fork_id.clone(), base_path_canon, work_path)?;
270
271 self.forks.lock().insert(fork_id.clone(), context);
272
273 Ok(fork_id)
274 }
275
276 pub fn get_fork(&self, fork_id: &str) -> Result<Arc<ForkContext>> {
277 self.evict_expired();
278
279 let mut forks = self.forks.lock();
280 let ctx = forks
281 .get_mut(fork_id)
282 .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
283 ctx.touch();
284 Ok(Arc::new(ctx.clone()))
285 }
286
287 pub fn get_fork_path(&self, fork_id: &str) -> Option<PathBuf> {
288 let mut forks = self.forks.lock();
289 if let Some(ctx) = forks.get_mut(fork_id) {
290 ctx.touch();
291 return Some(ctx.work_path.clone());
292 }
293 None
294 }
295
296 pub fn with_fork_mut<F, R>(&self, fork_id: &str, f: F) -> Result<R>
297 where
298 F: FnOnce(&mut ForkContext) -> Result<R>,
299 {
300 let mut forks = self.forks.lock();
301 let ctx = forks
302 .get_mut(fork_id)
303 .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
304 ctx.touch();
305 f(ctx)
306 }
307
308 pub fn discard_fork(&self, fork_id: &str) -> Result<()> {
309 let mut forks = self.forks.lock();
310 if let Some(ctx) = forks.remove(fork_id) {
311 ctx.cleanup_files();
312 }
313 Ok(())
314 }
315
316 pub fn save_fork(
317 &self,
318 fork_id: &str,
319 target_path: &Path,
320 workspace_root: &Path,
321 drop_fork: bool,
322 ) -> Result<()> {
323 let _target_canon = canonicalize_and_enforce_within_workspace(
325 workspace_root,
326 target_path,
327 "save_fork",
328 "target_path",
329 )?;
330
331 let ext = target_path
332 .extension()
333 .and_then(|e| e.to_str())
334 .map(|e| e.to_ascii_lowercase());
335
336 if ext.as_deref() != Some("xlsx") {
337 return Err(anyhow!("target must be .xlsx"));
338 }
339
340 let mut forks = self.forks.lock();
341 let ctx = forks
342 .get(fork_id)
343 .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
344
345 ctx.validate_base_unchanged()?;
346
347 fs::copy(&ctx.work_path, target_path)?;
348
349 if drop_fork && let Some(ctx) = forks.remove(fork_id) {
350 let _ = fs::remove_file(&ctx.work_path);
351 }
352
353 Ok(())
354 }
355
356 pub fn ttl(&self) -> Duration {
357 self.config.ttl
358 }
359
360 pub fn list_forks(&self) -> Vec<ForkInfo> {
361 self.evict_expired();
362
363 let forks = self.forks.lock();
364 forks
365 .values()
366 .map(|ctx| ForkInfo {
367 fork_id: ctx.fork_id.clone(),
368 base_path: ctx.base_path.display().to_string(),
369 created_at: ctx.created_at,
370 edit_count: ctx.edits.len(),
371 recalc_needed: ctx.recalc_needed,
372 })
373 .collect()
374 }
375
376 pub fn create_checkpoint(&self, fork_id: &str, label: Option<String>) -> Result<Checkpoint> {
377 self.evict_expired();
378
379 let work_path = {
380 let forks = self.forks.lock();
381 let ctx = forks
382 .get(fork_id)
383 .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
384 ctx.work_path.clone()
385 };
386
387 let checkpoint_id = make_short_random_id("cp", 12);
388 let dir = PathBuf::from(CHECKPOINT_DIR).join(fork_id);
389 fs::create_dir_all(&dir)?;
390 let snapshot_path = dir.join(format!("{}.xlsx", checkpoint_id));
391 fs::copy(&work_path, &snapshot_path)?;
392
393 let recalc_needed = self
394 .get_fork(fork_id)
395 .map(|ctx| ctx.recalc_needed)
396 .unwrap_or(false);
397
398 let checkpoint = Checkpoint {
399 checkpoint_id: checkpoint_id.clone(),
400 created_at: Utc::now(),
401 label,
402 snapshot_path,
403 recalc_needed,
404 };
405
406 self.with_fork_mut(fork_id, |ctx| {
407 ctx.checkpoints.push(checkpoint.clone());
408 enforce_checkpoint_limits(ctx)?;
409 Ok(())
410 })?;
411
412 Ok(checkpoint)
413 }
414
415 pub fn list_checkpoints(&self, fork_id: &str) -> Result<Vec<Checkpoint>> {
416 let ctx = self.get_fork(fork_id)?;
417 Ok(ctx.checkpoints.clone())
418 }
419
420 pub fn delete_checkpoint(&self, fork_id: &str, checkpoint_id: &str) -> Result<()> {
421 self.with_fork_mut(fork_id, |ctx| {
422 let index = ctx
423 .checkpoints
424 .iter()
425 .position(|c| c.checkpoint_id == checkpoint_id)
426 .ok_or_else(|| anyhow!("checkpoint not found: {}", checkpoint_id))?;
427 let removed = ctx.checkpoints.remove(index);
428 let _ = fs::remove_file(&removed.snapshot_path);
429 Ok(())
430 })
431 }
432
433 pub fn restore_checkpoint(&self, fork_id: &str, checkpoint_id: &str) -> Result<Checkpoint> {
434 self.evict_expired();
435
436 let (work_path, checkpoint) = {
437 let forks = self.forks.lock();
438 let ctx = forks
439 .get(fork_id)
440 .ok_or_else(|| anyhow!("fork not found: {}", fork_id))?;
441 let checkpoint = ctx
442 .checkpoints
443 .iter()
444 .find(|c| c.checkpoint_id == checkpoint_id)
445 .cloned()
446 .ok_or_else(|| anyhow!("checkpoint not found: {}", checkpoint_id))?;
447 (ctx.work_path.clone(), checkpoint)
448 };
449
450 fs::copy(&checkpoint.snapshot_path, &work_path)?;
451
452 self.with_fork_mut(fork_id, |ctx| {
453 let cutoff = checkpoint.created_at;
454 ctx.edits.retain(|e| e.timestamp <= cutoff);
455 let mut i = 0;
456 while i < ctx.staged_changes.len() {
457 if ctx.staged_changes[i].created_at > cutoff {
458 let removed = ctx.staged_changes.remove(i);
459 remove_staged_snapshot(&removed);
460 } else {
461 i += 1;
462 }
463 }
464 ctx.recalc_needed = checkpoint.recalc_needed;
465 Ok(())
466 })?;
467
468 Ok(checkpoint)
469 }
470
471 pub fn add_staged_change(&self, fork_id: &str, staged: StagedChange) -> Result<()> {
472 self.with_fork_mut(fork_id, |ctx| {
473 ctx.staged_changes.push(staged);
474 enforce_staged_limits(ctx);
475 Ok(())
476 })
477 }
478
479 pub fn list_staged_changes(&self, fork_id: &str) -> Result<Vec<StagedChange>> {
480 let ctx = self.get_fork(fork_id)?;
481 Ok(ctx.staged_changes.clone())
482 }
483
484 pub fn take_staged_change(&self, fork_id: &str, change_id: &str) -> Result<StagedChange> {
485 self.with_fork_mut(fork_id, |ctx| {
486 let index = ctx
487 .staged_changes
488 .iter()
489 .position(|c| c.change_id == change_id)
490 .ok_or_else(|| anyhow!("staged change not found: {}", change_id))?;
491 Ok(ctx.staged_changes.remove(index))
492 })
493 }
494
495 pub fn discard_staged_change(&self, fork_id: &str, change_id: &str) -> Result<()> {
496 let removed = self.take_staged_change(fork_id, change_id)?;
497 remove_staged_snapshot(&removed);
498 Ok(())
499 }
500
501 fn evict_expired(&self) {
502 if self.config.ttl.is_zero() {
503 return;
504 }
505 let mut forks = self.forks.lock();
506 let expired: Vec<String> = forks
507 .iter()
508 .filter(|(_, ctx)| ctx.is_expired(self.config.ttl))
509 .map(|(id, _)| id.clone())
510 .collect();
511
512 for id in expired {
513 if let Some(ctx) = forks.remove(&id) {
514 ctx.cleanup_files();
515 tracing::debug!(fork_id = %id, "evicted expired fork");
516 }
517 }
518 }
519}
520
521fn remove_staged_snapshot(staged: &StagedChange) {
522 if let Some(path) = staged.fork_path_snapshot.as_ref() {
523 let _ = fs::remove_file(path);
524 }
525}
526
527fn enforce_staged_limits(ctx: &mut ForkContext) {
528 while ctx.staged_changes.len() > DEFAULT_MAX_STAGED_CHANGES_PER_FORK {
529 let removed = ctx.staged_changes.remove(0);
530 remove_staged_snapshot(&removed);
531 }
532}
533
534fn enforce_checkpoint_limits(ctx: &mut ForkContext) -> Result<()> {
535 while ctx.checkpoints.len() > DEFAULT_MAX_CHECKPOINTS_PER_FORK {
536 let removed = ctx.checkpoints.remove(0);
537 let _ = fs::remove_file(&removed.snapshot_path);
538 }
539
540 loop {
541 let mut total_bytes = 0u64;
542 for cp in &ctx.checkpoints {
543 if let Ok(meta) = fs::metadata(&cp.snapshot_path) {
544 total_bytes += meta.len();
545 }
546 }
547 if total_bytes <= DEFAULT_MAX_CHECKPOINT_TOTAL_BYTES || ctx.checkpoints.len() <= 1 {
548 break;
549 }
550 let removed = ctx.checkpoints.remove(0);
551 let _ = fs::remove_file(&removed.snapshot_path);
552 }
553
554 Ok(())
555}
556
557impl Clone for ForkContext {
558 fn clone(&self) -> Self {
559 Self {
560 fork_id: self.fork_id.clone(),
561 base_path: self.base_path.clone(),
562 work_path: self.work_path.clone(),
563 created_at: self.created_at,
564 last_accessed: self.last_accessed,
565 edits: self.edits.clone(),
566 staged_changes: self.staged_changes.clone(),
567 checkpoints: self.checkpoints.clone(),
568 recalc_needed: self.recalc_needed,
569 base_hash: self.base_hash.clone(),
570 base_modified: self.base_modified,
571 }
572 }
573}
574
575#[derive(Debug, Clone)]
576pub struct ForkInfo {
577 pub fork_id: String,
578 pub base_path: String,
579 pub created_at: Instant,
580 pub edit_count: usize,
581 pub recalc_needed: bool,
582}