Skip to main content

sqlite_graphrag/commands/
prune_ner.rs

1//! Handler for the `prune-ner` CLI subcommand.
2//!
3//! Removes NER bindings (rows in `memory_entities`) for a single entity or for
4//! all entities in the namespace. Useful for cleaning up low-quality automatic
5//! extractions without touching the entities or memories themselves.
6
7use crate::errors::AppError;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_rw;
11use serde::Serialize;
12
13#[derive(clap::Args)]
14#[command(after_long_help = "EXAMPLES:\n  \
15    # Preview bindings that would be removed for a single entity\n  \
16    sqlite-graphrag prune-ner --entity jwt-token --dry-run\n\n  \
17    # Remove all NER bindings for a single entity\n  \
18    sqlite-graphrag prune-ner --entity jwt-token --yes\n\n  \
19    # Remove ALL NER bindings in the current namespace\n  \
20    sqlite-graphrag prune-ner --all --yes\n\n  \
21NOTE:\n  \
22    This command deletes rows from memory_entities (the link table between\n  \
23    memories and extracted entities). The entities and memories themselves\n  \
24    are not deleted. Use cleanup-orphans afterwards to remove entity nodes\n  \
25    that have no remaining links.")]
26pub struct PruneNerArgs {
27    /// Entity name whose bindings should be removed.
28    /// Mutually exclusive with --all.
29    #[arg(long, conflicts_with = "all", value_name = "NAME")]
30    pub entity: Option<String>,
31
32    /// Remove all NER bindings in the namespace. Mutually exclusive with --entity.
33    #[arg(long, conflicts_with = "entity", default_value_t = false)]
34    pub all: bool,
35
36    #[arg(long)]
37    pub namespace: Option<String>,
38
39    /// Preview count without deleting.
40    #[arg(long, default_value_t = false)]
41    pub dry_run: bool,
42
43    /// Skip confirmation for destructive operation.
44    #[arg(long, default_value_t = false)]
45    pub yes: bool,
46
47    #[arg(long, value_enum, default_value = "json")]
48    pub format: OutputFormat,
49
50    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
51    pub json: bool,
52
53    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
54    pub db: Option<String>,
55}
56
57#[derive(Serialize)]
58struct PruneNerResponse {
59    action: String,
60    bindings_removed: usize,
61    namespace: String,
62    /// Entity name targeted, when `--entity` was used.
63    #[serde(skip_serializing_if = "Option::is_none")]
64    entity: Option<String>,
65    /// Total execution time in milliseconds from handler start to serialisation.
66    elapsed_ms: u64,
67}
68
69pub fn run(args: PruneNerArgs) -> Result<(), AppError> {
70    let inicio = std::time::Instant::now();
71
72    if args.entity.is_none() && !args.all {
73        return Err(AppError::Validation(
74            "either --entity <NAME> or --all must be specified".to_string(),
75        ));
76    }
77
78    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
79    let paths = AppPaths::resolve(args.db.as_deref())?;
80
81    crate::storage::connection::ensure_db_ready(&paths)?;
82
83    let mut conn = open_rw(&paths.db)?;
84
85    // Count how many rows would be affected.
86    let count: usize = if let Some(ref entity_name) = args.entity {
87        conn.query_row(
88            "SELECT COUNT(*) FROM memory_entities me
89             JOIN entities e ON e.id = me.entity_id
90             WHERE e.name = ?1 AND e.namespace = ?2",
91            rusqlite::params![entity_name, namespace],
92            |r| r.get::<_, i64>(0).map(|v| v as usize),
93        )?
94    } else {
95        conn.query_row(
96            "SELECT COUNT(*) FROM memory_entities me
97             JOIN entities e ON e.id = me.entity_id
98             WHERE e.namespace = ?1",
99            rusqlite::params![namespace],
100            |r| r.get::<_, i64>(0).map(|v| v as usize),
101        )?
102    };
103
104    if args.dry_run {
105        let response = PruneNerResponse {
106            action: "dry_run".to_string(),
107            bindings_removed: count,
108            namespace: namespace.clone(),
109            entity: args.entity.clone(),
110            elapsed_ms: inicio.elapsed().as_millis() as u64,
111        };
112
113        match args.format {
114            OutputFormat::Json => output::emit_json(&response)?,
115            OutputFormat::Text | OutputFormat::Markdown => {
116                output::emit_text(&format!(
117                    "dry_run: {count} NER bindings would be removed [{namespace}]"
118                ));
119            }
120        }
121
122        return Ok(());
123    }
124
125    if !args.yes {
126        let response = PruneNerResponse {
127            action: "aborted".to_string(),
128            bindings_removed: count,
129            namespace: namespace.clone(),
130            entity: args.entity.clone(),
131            elapsed_ms: inicio.elapsed().as_millis() as u64,
132        };
133
134        match args.format {
135            OutputFormat::Json => output::emit_json(&response)?,
136            OutputFormat::Text | OutputFormat::Markdown => {
137                output::emit_text(&format!(
138                    "aborted: {count} NER bindings would be removed; pass --yes to confirm [{namespace}]"
139                ));
140            }
141        }
142
143        return Ok(());
144    }
145
146    // Destructive path: delete memory_entities rows.
147    let removed: usize = if let Some(ref entity_name) = args.entity {
148        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
149        let n = tx.execute(
150            "DELETE FROM memory_entities WHERE entity_id IN (
151                 SELECT id FROM entities WHERE name = ?1 AND namespace = ?2
152             )",
153            rusqlite::params![entity_name, namespace],
154        )?;
155        tx.commit()?;
156        n
157    } else {
158        // --all: remove every binding for entities in the namespace.
159        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
160        let n = tx.execute(
161            "DELETE FROM memory_entities WHERE entity_id IN (
162                 SELECT id FROM entities WHERE namespace = ?1
163             )",
164            rusqlite::params![namespace],
165        )?;
166        tx.commit()?;
167        n
168    };
169
170    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
171
172    tracing::info!(
173        removed = removed,
174        namespace = %namespace,
175        entity = ?args.entity,
176        "NER bindings pruned"
177    );
178
179    let response = PruneNerResponse {
180        action: "pruned".to_string(),
181        bindings_removed: removed,
182        namespace: namespace.clone(),
183        entity: args.entity.clone(),
184        elapsed_ms: inicio.elapsed().as_millis() as u64,
185    };
186
187    match args.format {
188        OutputFormat::Json => output::emit_json(&response)?,
189        OutputFormat::Text | OutputFormat::Markdown => {
190            output::emit_text(&format!(
191                "pruned: {removed} NER bindings removed [{namespace}]"
192            ));
193        }
194    }
195
196    Ok(())
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn prune_ner_response_dry_run_serializes_correctly() {
205        let resp = PruneNerResponse {
206            action: "dry_run".to_string(),
207            bindings_removed: 42,
208            namespace: "global".to_string(),
209            entity: Some("jwt-token".to_string()),
210            elapsed_ms: 5,
211        };
212        let json = serde_json::to_value(&resp).expect("serialization failed");
213        assert_eq!(json["action"], "dry_run");
214        assert_eq!(json["bindings_removed"], 42);
215        assert_eq!(json["entity"], "jwt-token");
216        assert_eq!(json["namespace"], "global");
217    }
218
219    #[test]
220    fn prune_ner_response_pruned_all_omits_entity() {
221        let resp = PruneNerResponse {
222            action: "pruned".to_string(),
223            bindings_removed: 200,
224            namespace: "project-x".to_string(),
225            entity: None,
226            elapsed_ms: 15,
227        };
228        let json = serde_json::to_value(&resp).expect("serialization failed");
229        assert_eq!(json["action"], "pruned");
230        assert_eq!(json["bindings_removed"], 200);
231        assert!(
232            json.get("entity").is_none(),
233            "entity must be omitted when None"
234        );
235    }
236
237    #[test]
238    fn prune_ner_response_aborted_includes_count() {
239        let resp = PruneNerResponse {
240            action: "aborted".to_string(),
241            bindings_removed: 10,
242            namespace: "global".to_string(),
243            entity: None,
244            elapsed_ms: 1,
245        };
246        let json = serde_json::to_value(&resp).expect("serialization failed");
247        assert_eq!(json["action"], "aborted");
248        assert_eq!(json["bindings_removed"], 10);
249        assert!(json["elapsed_ms"].is_number());
250    }
251
252    #[test]
253    fn prune_ner_response_zero_bindings() {
254        let resp = PruneNerResponse {
255            action: "pruned".to_string(),
256            bindings_removed: 0,
257            namespace: "global".to_string(),
258            entity: Some("nonexistent".to_string()),
259            elapsed_ms: 2,
260        };
261        let json = serde_json::to_value(&resp).expect("serialization failed");
262        assert_eq!(json["bindings_removed"], 0);
263    }
264}