Skip to main content

sqlite_graphrag/commands/
cache.rs

1//! Handler for the `cache` CLI subcommand and its nested operations.
2//!
3//! Manages cached resources such as the multilingual-e5-small ONNX model and
4//! the BERT NER classifier downloaded into the XDG cache directory on first
5//! `init`. Used to reclaim disk space or recover from corrupted cache state.
6
7use crate::errors::AppError;
8use crate::output;
9use crate::paths::AppPaths;
10use serde::Serialize;
11
12#[derive(clap::Args)]
13#[command(after_long_help = "EXAMPLES:\n  \
14    # Remove cached embedding/NER model files (forces re-download on next init)\n  \
15    sqlite-graphrag cache clear-models\n\n  \
16    # Skip the confirmation prompt\n  \
17    sqlite-graphrag cache clear-models --yes\n\n  \
18    # List cached model files\n  \
19    sqlite-graphrag cache list\n\n  \
20    # List cached model files as JSON\n  \
21    sqlite-graphrag cache list --json")]
22pub struct CacheArgs {
23    #[command(subcommand)]
24    pub command: CacheCommands,
25}
26
27#[derive(clap::Subcommand)]
28pub enum CacheCommands {
29    /// Remove cached embedding/NER model files (forces re-download on next `init`).
30    ClearModels(ClearModelsArgs),
31    /// List cached embedding/NER model files with sizes and total disk usage.
32    List(CacheListArgs),
33}
34
35#[derive(clap::Args)]
36pub struct CacheListArgs {
37    /// Output as JSON.
38    #[arg(long)]
39    pub json: bool,
40}
41
42#[derive(clap::Args)]
43pub struct ClearModelsArgs {
44    /// Skip confirmation prompt and proceed with deletion immediately.
45    #[arg(long, default_value_t = false, help = "Skip confirmation prompt")]
46    pub yes: bool,
47    /// Output format: json (default), text, or markdown.
48    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
49    pub json: bool,
50}
51
52#[derive(Serialize)]
53struct ClearModelsResponse {
54    cache_path: String,
55    existed: bool,
56    bytes_freed: u64,
57    files_removed: usize,
58    /// Total execution time in milliseconds from handler start to serialisation.
59    elapsed_ms: u64,
60}
61
62pub fn run(args: CacheArgs) -> Result<(), AppError> {
63    match args.command {
64        CacheCommands::ClearModels(a) => clear_models(a),
65        CacheCommands::List(a) => run_list(a),
66    }
67}
68
69fn clear_models(args: ClearModelsArgs) -> Result<(), AppError> {
70    let inicio = std::time::Instant::now();
71    // Resolve the canonical models directory through AppPaths to honour
72    // SQLITE_GRAPHRAG_CACHE_DIR overrides used by tests and CI.
73    let paths = AppPaths::resolve(None)?;
74    let models_dir = paths.models.clone();
75
76    if !args.yes {
77        // For machine consumption stay deterministic: refuse without --yes.
78        return Err(AppError::Validation(
79            "destructive operation: pass --yes to confirm cache deletion".to_string(),
80        ));
81    }
82
83    let existed = models_dir.exists();
84    let mut bytes_freed: u64 = 0;
85    let mut files_removed: usize = 0;
86
87    if existed {
88        bytes_freed = dir_size(&models_dir).unwrap_or(0);
89        files_removed = count_files(&models_dir).unwrap_or(0);
90        std::fs::remove_dir_all(&models_dir)?;
91    }
92
93    output::emit_json(&ClearModelsResponse {
94        cache_path: models_dir.display().to_string(),
95        existed,
96        bytes_freed,
97        files_removed,
98        elapsed_ms: inicio.elapsed().as_millis() as u64,
99    })?;
100
101    Ok(())
102}
103
104#[derive(Serialize)]
105struct CacheFileEntry {
106    name: String,
107    path: String,
108    size_bytes: u64,
109    modified_at: String,
110}
111
112#[derive(Serialize)]
113struct CacheListResponse {
114    schema_version: u32,
115    cache_path: String,
116    files: Vec<CacheFileEntry>,
117    total_bytes: u64,
118    total_human: String,
119}
120
121fn format_bytes_human(bytes: u64) -> String {
122    const KB: u64 = 1024;
123    const MB: u64 = KB * 1024;
124    const GB: u64 = MB * 1024;
125    if bytes >= GB {
126        format!("{:.1} GB", bytes as f64 / GB as f64)
127    } else if bytes >= MB {
128        format!("{:.1} MB", bytes as f64 / MB as f64)
129    } else if bytes >= KB {
130        format!("{:.1} KB", bytes as f64 / KB as f64)
131    } else {
132        format!("{bytes} B")
133    }
134}
135
136fn collect_cache_files(
137    dir: &std::path::Path,
138    base: &std::path::Path,
139    entries: &mut Vec<CacheFileEntry>,
140) -> std::io::Result<()> {
141    for entry in std::fs::read_dir(dir)? {
142        let entry = entry?;
143        let meta = entry.metadata()?;
144        let path = entry.path();
145        if meta.is_dir() {
146            collect_cache_files(&path, base, entries)?;
147        } else {
148            let size_bytes = meta.len();
149            let relative = path.strip_prefix(base).unwrap_or(&path);
150            let name = relative.to_string_lossy().into_owned();
151            let modified_at = meta
152                .modified()
153                .ok()
154                .map(|t| {
155                    let secs = t
156                        .duration_since(std::time::UNIX_EPOCH)
157                        .unwrap_or_default()
158                        .as_secs();
159                    // Format as RFC 3339 (UTC) without chrono dependency.
160                    let secs_i64 = secs as i64;
161                    let (y, mo, d, h, mi, s) = epoch_to_ymd_hms(secs_i64);
162                    format!("{y:04}-{mo:02}-{d:02}T{h:02}:{mi:02}:{s:02}Z")
163                })
164                .unwrap_or_else(|| "unknown".to_string());
165            entries.push(CacheFileEntry {
166                name,
167                path: path.display().to_string(),
168                size_bytes,
169                modified_at,
170            });
171        }
172    }
173    Ok(())
174}
175
176/// Converts Unix epoch seconds to (year, month, day, hour, minute, second) UTC.
177fn epoch_to_ymd_hms(secs: i64) -> (i32, u8, u8, u8, u8, u8) {
178    let s = (secs % 60) as u8;
179    let total_min = secs / 60;
180    let mi = (total_min % 60) as u8;
181    let total_h = total_min / 60;
182    let h = (total_h % 24) as u8;
183    let mut days = total_h / 24;
184    // Compute year/month/day from days since epoch (1970-01-01).
185    let mut y = 1970i32;
186    loop {
187        let days_in_y = if is_leap(y) { 366 } else { 365 };
188        if days < days_in_y {
189            break;
190        }
191        days -= days_in_y;
192        y += 1;
193    }
194    let leap = is_leap(y);
195    let months = [
196        31u8,
197        if leap { 29 } else { 28 },
198        31,
199        30,
200        31,
201        30,
202        31,
203        31,
204        30,
205        31,
206        30,
207        31,
208    ];
209    let mut mo = 1u8;
210    for &days_in_m in &months {
211        if days < days_in_m as i64 {
212            break;
213        }
214        days -= days_in_m as i64;
215        mo += 1;
216    }
217    let d = (days + 1) as u8;
218    (y, mo, d, h, mi, s)
219}
220
221fn is_leap(y: i32) -> bool {
222    (y % 4 == 0 && y % 100 != 0) || y % 400 == 0
223}
224
225fn run_list(args: CacheListArgs) -> Result<(), AppError> {
226    let paths = AppPaths::resolve(None)?;
227    let models_dir = &paths.models;
228
229    let mut entries: Vec<CacheFileEntry> = Vec::new();
230    if models_dir.exists() {
231        collect_cache_files(models_dir, models_dir, &mut entries).map_err(AppError::Io)?;
232    }
233
234    entries.sort_by(|a, b| a.name.cmp(&b.name));
235    let total_bytes: u64 = entries.iter().map(|e| e.size_bytes).sum();
236    let total_human = format_bytes_human(total_bytes);
237    let n_files = entries.len();
238
239    if args.json {
240        output::emit_json(&CacheListResponse {
241            schema_version: 1,
242            cache_path: models_dir.display().to_string(),
243            files: entries,
244            total_bytes,
245            total_human,
246        })?;
247    } else if entries.is_empty() {
248        output::emit_text("(empty)");
249    } else {
250        for e in &entries {
251            output::emit_text(&format!(
252                "{:<40} {:>10}  {}",
253                e.name,
254                format_bytes_human(e.size_bytes),
255                e.modified_at
256            ));
257        }
258        output::emit_text(&format!("\nTOTAL: {n_files} files, {total_human}"));
259    }
260
261    Ok(())
262}
263
264fn dir_size(path: &std::path::Path) -> std::io::Result<u64> {
265    let mut total = 0u64;
266    for entry in std::fs::read_dir(path)? {
267        let entry = entry?;
268        let meta = entry.metadata()?;
269        if meta.is_dir() {
270            total = total.saturating_add(dir_size(&entry.path()).unwrap_or(0));
271        } else {
272            total = total.saturating_add(meta.len());
273        }
274    }
275    Ok(total)
276}
277
278fn count_files(path: &std::path::Path) -> std::io::Result<usize> {
279    let mut count = 0usize;
280    for entry in std::fs::read_dir(path)? {
281        let entry = entry?;
282        let meta = entry.metadata()?;
283        if meta.is_dir() {
284            count = count.saturating_add(count_files(&entry.path()).unwrap_or(0));
285        } else {
286            count += 1;
287        }
288    }
289    Ok(count)
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn clear_models_response_serializes_all_fields() {
298        let resp = ClearModelsResponse {
299            cache_path: "/tmp/sqlite-graphrag/models".to_string(),
300            existed: true,
301            bytes_freed: 465_000_000,
302            files_removed: 14,
303            elapsed_ms: 12,
304        };
305        let json = serde_json::to_value(&resp).expect("serialization");
306        assert_eq!(json["existed"], true);
307        assert_eq!(json["bytes_freed"], 465_000_000u64);
308        assert_eq!(json["files_removed"], 14);
309        assert_eq!(json["elapsed_ms"], 12);
310    }
311
312    #[test]
313    fn clear_models_without_yes_returns_validation_error() {
314        let args = ClearModelsArgs {
315            yes: false,
316            json: false,
317        };
318        let result = clear_models(args);
319        assert!(matches!(result, Err(AppError::Validation(_))));
320    }
321}