Skip to main content

sqlite_graphrag/commands/
prune_ner.rs

1//! Handler for the `prune-ner` CLI subcommand.
2//!
3//! Removes NER bindings (rows in `memory_entities`) for a single entity or for
4//! all entities in the namespace. Useful for cleaning up low-quality automatic
5//! extractions without touching the entities or memories themselves.
6
7use crate::errors::AppError;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_rw;
11use serde::Serialize;
12
13#[derive(clap::Args)]
14#[command(after_long_help = "EXAMPLES:\n  \
15    # Preview bindings that would be removed for a single entity\n  \
16    sqlite-graphrag prune-ner --entity jwt-token --dry-run\n\n  \
17    # Remove all NER bindings for a single entity\n  \
18    sqlite-graphrag prune-ner --entity jwt-token --yes\n\n  \
19    # Remove ALL NER bindings in the current namespace\n  \
20    sqlite-graphrag prune-ner --all --yes\n\n  \
21NOTE:\n  \
22    This command deletes rows from memory_entities (the link table between\n  \
23    memories and extracted entities). The entities and memories themselves\n  \
24    are not deleted. Use cleanup-orphans afterwards to remove entity nodes\n  \
25    that have no remaining links.")]
26pub struct PruneNerArgs {
27    /// Entity name whose bindings should be removed.
28    /// Mutually exclusive with --all.
29    #[arg(long, conflicts_with = "all", value_name = "NAME")]
30    pub entity: Option<String>,
31
32    /// Remove all NER bindings in the namespace. Mutually exclusive with --entity.
33    #[arg(long, conflicts_with = "entity", default_value_t = false)]
34    pub all: bool,
35
36    #[arg(long)]
37    pub namespace: Option<String>,
38
39    /// Preview count without deleting.
40    #[arg(long, default_value_t = false)]
41    pub dry_run: bool,
42
43    /// Skip confirmation for destructive operation.
44    #[arg(long, default_value_t = false)]
45    pub yes: bool,
46
47    #[arg(long, value_enum, default_value = "json")]
48    pub format: OutputFormat,
49
50    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
51    pub json: bool,
52
53    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
54    pub db: Option<String>,
55}
56
57#[derive(Serialize)]
58struct PruneNerResponse {
59    action: String,
60    bindings_removed: usize,
61    namespace: String,
62    /// Entity name targeted, when `--entity` was used.
63    #[serde(skip_serializing_if = "Option::is_none")]
64    entity: Option<String>,
65    /// Total execution time in milliseconds from handler start to serialisation.
66    elapsed_ms: u64,
67}
68
69pub fn run(args: PruneNerArgs) -> Result<(), AppError> {
70    let inicio = std::time::Instant::now();
71
72    if args.entity.is_none() && !args.all {
73        return Err(AppError::Validation(
74            "either --entity <NAME> or --all must be specified".to_string(),
75        ));
76    }
77
78    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
79    let paths = AppPaths::resolve(args.db.as_deref())?;
80
81    crate::storage::connection::ensure_db_ready(&paths)?;
82
83    let mut conn = open_rw(&paths.db)?;
84
85    // Count how many rows would be affected.
86    let count: usize = if let Some(ref entity_name) = args.entity {
87        conn.query_row(
88            "SELECT COUNT(*) FROM memory_entities me
89             JOIN entities e ON e.id = me.entity_id
90             WHERE e.name = ?1 AND e.namespace = ?2",
91            rusqlite::params![entity_name, namespace],
92            |r| r.get::<_, i64>(0).map(|v| v as usize),
93        )?
94    } else {
95        conn.query_row(
96            "SELECT COUNT(*) FROM memory_entities me
97             JOIN entities e ON e.id = me.entity_id
98             WHERE e.namespace = ?1",
99            rusqlite::params![namespace],
100            |r| r.get::<_, i64>(0).map(|v| v as usize),
101        )?
102    };
103
104    if args.dry_run {
105        let response = PruneNerResponse {
106            action: "dry_run".to_string(),
107            bindings_removed: count,
108            namespace: namespace.clone(),
109            entity: args.entity.clone(),
110            elapsed_ms: inicio.elapsed().as_millis() as u64,
111        };
112
113        match args.format {
114            OutputFormat::Json => output::emit_json(&response)?,
115            OutputFormat::Text | OutputFormat::Markdown => {
116                output::emit_text(&format!(
117                    "dry_run: {count} NER bindings would be removed [{namespace}]"
118                ));
119            }
120        }
121
122        return Ok(());
123    }
124
125    if !args.yes {
126        let response = PruneNerResponse {
127            action: "aborted".to_string(),
128            bindings_removed: count,
129            namespace: namespace.clone(),
130            entity: args.entity.clone(),
131            elapsed_ms: inicio.elapsed().as_millis() as u64,
132        };
133
134        match args.format {
135            OutputFormat::Json => output::emit_json(&response)?,
136            OutputFormat::Text | OutputFormat::Markdown => {
137                output::emit_text(&format!(
138                    "aborted: {count} NER bindings would be removed; pass --yes to confirm [{namespace}]"
139                ));
140            }
141        }
142
143        return Ok(());
144    }
145
146    // Destructive path: COUNT + DELETE in same transaction for consistency.
147    let removed: usize = if let Some(ref entity_name) = args.entity {
148        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
149        let n = tx.execute(
150            "DELETE FROM memory_entities WHERE entity_id IN (
151                 SELECT id FROM entities WHERE name = ?1 AND namespace = ?2
152             )",
153            rusqlite::params![entity_name, namespace],
154        )?;
155        tx.commit()?;
156        n
157    } else {
158        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
159        let n = tx.execute(
160            "DELETE FROM memory_entities WHERE entity_id IN (
161                 SELECT id FROM entities WHERE namespace = ?1
162             )",
163            rusqlite::params![namespace],
164        )?;
165        tx.commit()?;
166        n
167    };
168
169    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
170
171    tracing::info!(
172        removed = removed,
173        namespace = %namespace,
174        entity = ?args.entity,
175        "NER bindings pruned"
176    );
177
178    let response = PruneNerResponse {
179        action: "pruned".to_string(),
180        bindings_removed: removed,
181        namespace: namespace.clone(),
182        entity: args.entity.clone(),
183        elapsed_ms: inicio.elapsed().as_millis() as u64,
184    };
185
186    match args.format {
187        OutputFormat::Json => output::emit_json(&response)?,
188        OutputFormat::Text | OutputFormat::Markdown => {
189            output::emit_text(&format!(
190                "pruned: {removed} NER bindings removed [{namespace}]"
191            ));
192        }
193    }
194
195    Ok(())
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn prune_ner_response_dry_run_serializes_correctly() {
204        let resp = PruneNerResponse {
205            action: "dry_run".to_string(),
206            bindings_removed: 42,
207            namespace: "global".to_string(),
208            entity: Some("jwt-token".to_string()),
209            elapsed_ms: 5,
210        };
211        let json = serde_json::to_value(&resp).expect("serialization failed");
212        assert_eq!(json["action"], "dry_run");
213        assert_eq!(json["bindings_removed"], 42);
214        assert_eq!(json["entity"], "jwt-token");
215        assert_eq!(json["namespace"], "global");
216    }
217
218    #[test]
219    fn prune_ner_response_pruned_all_omits_entity() {
220        let resp = PruneNerResponse {
221            action: "pruned".to_string(),
222            bindings_removed: 200,
223            namespace: "project-x".to_string(),
224            entity: None,
225            elapsed_ms: 15,
226        };
227        let json = serde_json::to_value(&resp).expect("serialization failed");
228        assert_eq!(json["action"], "pruned");
229        assert_eq!(json["bindings_removed"], 200);
230        assert!(
231            json.get("entity").is_none(),
232            "entity must be omitted when None"
233        );
234    }
235
236    #[test]
237    fn prune_ner_response_aborted_includes_count() {
238        let resp = PruneNerResponse {
239            action: "aborted".to_string(),
240            bindings_removed: 10,
241            namespace: "global".to_string(),
242            entity: None,
243            elapsed_ms: 1,
244        };
245        let json = serde_json::to_value(&resp).expect("serialization failed");
246        assert_eq!(json["action"], "aborted");
247        assert_eq!(json["bindings_removed"], 10);
248        assert!(json["elapsed_ms"].is_number());
249    }
250
251    #[test]
252    fn prune_ner_response_zero_bindings() {
253        let resp = PruneNerResponse {
254            action: "pruned".to_string(),
255            bindings_removed: 0,
256            namespace: "global".to_string(),
257            entity: Some("nonexistent".to_string()),
258            elapsed_ms: 2,
259        };
260        let json = serde_json::to_value(&resp).expect("serialization failed");
261        assert_eq!(json["bindings_removed"], 0);
262    }
263}