Skip to main content

rust_memex/
recovery.rs

1use anyhow::{Result, anyhow};
2use memex_contracts::progress::{CompactProgress, MergeProgress, RepairResult};
3use std::collections::HashSet;
4use std::path::{Path, PathBuf};
5use std::time::Instant;
6
7use crate::{
8    BM25Config, BM25Index, CrossStoreRecoveryReport, GcConfig, GcStats, StorageManager, TableStats,
9    inspect_cross_store_recovery, path_utils, repair_cross_store_recovery,
10};
11
12#[derive(Debug, Clone)]
13pub struct MergeExecution {
14    pub progress: MergeProgress,
15    pub target_path: PathBuf,
16}
17
18#[derive(Debug, Clone)]
19pub struct RepairExecution {
20    pub result: RepairResult,
21    pub report: CrossStoreRecoveryReport,
22}
23
24#[derive(Debug, Clone)]
25pub struct MaintenanceExecution {
26    pub progress: CompactProgress,
27    pub pre_stats: Option<TableStats>,
28    pub post_stats: Option<TableStats>,
29    pub gc_stats: Option<GcStats>,
30}
31
32pub async fn merge_databases(
33    source_paths: Vec<PathBuf>,
34    target_path: PathBuf,
35    dedup: bool,
36    namespace_prefix: Option<String>,
37    dry_run: bool,
38) -> Result<MergeExecution> {
39    let mut validated_sources = Vec::new();
40    let mut progress = MergeProgress::default();
41    let mut namespaces = HashSet::new();
42
43    for source in &source_paths {
44        let source_str = source.to_str().unwrap_or("");
45        match path_utils::sanitize_existing_path(source_str) {
46            Ok(validated) => validated_sources.push(validated),
47            Err(_) => progress.errors += 1,
48        }
49    }
50
51    if validated_sources.is_empty() {
52        return Err(anyhow!("No valid source databases found"));
53    }
54
55    let validated_target = path_utils::sanitize_new_path(target_path.to_str().unwrap_or(""))?;
56    let target_storage = if dry_run {
57        None
58    } else {
59        if let Some(parent) = validated_target.parent() {
60            std::fs::create_dir_all(parent)?;
61        }
62        Some(StorageManager::new_lance_only(validated_target.to_str().unwrap_or("")).await?)
63    };
64
65    let mut seen_hashes = HashSet::new();
66    if dedup
67        && let Some(ref target) = target_storage
68        && let Ok(existing_docs) = target.all_documents(None, 100_000).await
69    {
70        for doc in existing_docs {
71            if let Some(hash) = doc.content_hash {
72                seen_hashes.insert(hash);
73            }
74        }
75    }
76
77    for source_path in &validated_sources {
78        let source_storage =
79            match StorageManager::new_lance_only(source_path.to_str().unwrap_or("")).await {
80                Ok(storage) => storage,
81                Err(_) => {
82                    progress.errors += 1;
83                    continue;
84                }
85            };
86
87        let source_docs = match source_storage.all_documents(None, 100_000).await {
88            Ok(docs) => docs,
89            Err(_) => {
90                progress.errors += 1;
91                continue;
92            }
93        };
94
95        progress.total_docs += source_docs.len();
96
97        let mut docs_by_namespace = std::collections::HashMap::new();
98        for doc in source_docs {
99            docs_by_namespace
100                .entry(doc.namespace.clone())
101                .or_insert_with(Vec::new)
102                .push(doc);
103        }
104
105        for (namespace, docs) in docs_by_namespace {
106            let target_namespace = if let Some(ref prefix) = namespace_prefix {
107                format!("{prefix}{namespace}")
108            } else {
109                namespace
110            };
111            namespaces.insert(target_namespace.clone());
112
113            let mut batch = Vec::new();
114            for doc in docs {
115                if dedup && let Some(ref hash) = doc.content_hash {
116                    if seen_hashes.contains(hash) {
117                        progress.docs_skipped += 1;
118                        continue;
119                    }
120                    seen_hashes.insert(hash.clone());
121                }
122
123                batch.push(crate::ChromaDocument {
124                    id: doc.id,
125                    namespace: target_namespace.clone(),
126                    embedding: doc.embedding,
127                    metadata: doc.metadata,
128                    document: doc.document,
129                    layer: doc.layer,
130                    parent_id: doc.parent_id,
131                    children_ids: doc.children_ids,
132                    keywords: doc.keywords,
133                    content_hash: doc.content_hash,
134                    source_hash: doc.source_hash,
135                });
136                progress.docs_copied += 1;
137            }
138
139            if !dry_run
140                && !batch.is_empty()
141                && let Some(ref target) = target_storage
142                && target.add_to_store(batch).await.is_err()
143            {
144                progress.errors += 1;
145            }
146        }
147
148        progress.sources_processed += 1;
149    }
150
151    let mut namespaces = namespaces.into_iter().collect::<Vec<_>>();
152    namespaces.sort();
153    progress.namespaces = namespaces;
154
155    Ok(MergeExecution {
156        progress,
157        target_path: validated_target,
158    })
159}
160
161pub async fn repair_writes(
162    db_path: &str,
163    namespace: Option<&str>,
164    execute: bool,
165) -> Result<RepairExecution> {
166    let storage = StorageManager::new_lance_only(db_path).await?;
167    let mut bm25_config = BM25Config::default().with_read_only(!execute);
168    if let Some(path) = sibling_bm25_path(db_path) {
169        bm25_config = bm25_config.with_path(path.to_string_lossy().into_owned());
170    }
171    let bm25 = BM25Index::new(&bm25_config)?;
172
173    let report = if execute {
174        repair_cross_store_recovery(&storage, &bm25, namespace).await?
175    } else {
176        inspect_cross_store_recovery(&storage, &bm25, namespace).await?
177    };
178
179    Ok(RepairExecution {
180        result: RepairResult {
181            recovery_dir: report.recovery_dir.clone(),
182            pending_batches: report.pending_batches,
183            repaired_documents: report.repaired_documents,
184            skipped_documents: report.skipped_documents,
185            batches_repaired: report.batches_repaired,
186        },
187        report,
188    })
189}
190
191pub async fn compact_database(storage: &StorageManager) -> Result<MaintenanceExecution> {
192    let started_at = Instant::now();
193    let pre_stats = storage.stats().await.ok();
194    let stats = storage.compact().await?;
195    let post_stats = storage.stats().await.ok();
196
197    Ok(MaintenanceExecution {
198        progress: CompactProgress {
199            phase: "compact".to_string(),
200            status: "complete".to_string(),
201            description: Some("Merging small files into larger ones".to_string()),
202            files_removed: stats
203                .compaction
204                .as_ref()
205                .map(|value| value.files_removed as u64),
206            files_added: stats
207                .compaction
208                .as_ref()
209                .map(|value| value.files_added as u64),
210            fragments_removed: stats
211                .compaction
212                .as_ref()
213                .map(|value| value.fragments_removed as u64),
214            fragments_added: stats
215                .compaction
216                .as_ref()
217                .map(|value| value.fragments_added as u64),
218            old_versions: None,
219            bytes_removed: None,
220            elapsed_ms: Some(started_at.elapsed().as_millis() as u64),
221        },
222        pre_stats,
223        post_stats,
224        gc_stats: None,
225    })
226}
227
228pub async fn cleanup_versions(
229    storage: &StorageManager,
230    older_than_days: Option<u64>,
231) -> Result<MaintenanceExecution> {
232    let started_at = Instant::now();
233    let pre_stats = storage.stats().await.ok();
234    let stats = storage.cleanup(older_than_days).await?;
235    let post_stats = storage.stats().await.ok();
236    let older_than_days = older_than_days.unwrap_or(7);
237
238    Ok(MaintenanceExecution {
239        progress: CompactProgress {
240            phase: "cleanup".to_string(),
241            status: "complete".to_string(),
242            description: Some(format!(
243                "Removing old versions older than {older_than_days} days"
244            )),
245            files_removed: None,
246            files_added: None,
247            fragments_removed: None,
248            fragments_added: None,
249            old_versions: stats.prune.as_ref().map(|value| value.old_versions),
250            bytes_removed: stats.prune.as_ref().map(|value| value.bytes_removed),
251            elapsed_ms: Some(started_at.elapsed().as_millis() as u64),
252        },
253        pre_stats,
254        post_stats,
255        gc_stats: None,
256    })
257}
258
259pub async fn collect_garbage(
260    storage: &StorageManager,
261    config: &GcConfig,
262) -> Result<MaintenanceExecution> {
263    let started_at = Instant::now();
264    let pre_stats = storage.stats().await.ok();
265    let gc_stats = storage.garbage_collect(config).await?;
266    let post_stats = storage.stats().await.ok();
267
268    Ok(MaintenanceExecution {
269        progress: CompactProgress {
270            phase: "gc".to_string(),
271            status: "complete".to_string(),
272            description: Some(gc_description(config)),
273            files_removed: None,
274            files_added: None,
275            fragments_removed: None,
276            fragments_added: None,
277            old_versions: None,
278            bytes_removed: gc_stats.bytes_freed,
279            elapsed_ms: Some(started_at.elapsed().as_millis() as u64),
280        },
281        pre_stats,
282        post_stats,
283        gc_stats: Some(gc_stats),
284    })
285}
286
287fn gc_description(config: &GcConfig) -> String {
288    let mut actions = Vec::new();
289    if config.remove_orphans {
290        actions.push("orphan embeddings".to_string());
291    }
292    if config.remove_empty {
293        actions.push("empty namespaces".to_string());
294    }
295    if let Some(duration) = config.older_than.as_ref() {
296        actions.push(format!("documents older than {} days", duration.num_days()));
297    }
298    if actions.is_empty() {
299        "Running garbage collection".to_string()
300    } else {
301        format!("Removing {}", actions.join(", "))
302    }
303}
304
305pub fn sibling_bm25_path(db_path: &str) -> Option<PathBuf> {
306    let db_path = shellexpand::tilde(db_path).to_string();
307    Path::new(&db_path)
308        .parent()
309        .map(|parent| parent.join(".bm25"))
310}