Skip to main content

sqlite_graphrag/commands/
prune_ner.rs

1//! Handler for the `prune-ner` CLI subcommand.
2//!
3//! Removes NER bindings (rows in `memory_entities`) for a single entity or for
4//! all entities in the namespace. Useful for cleaning up low-quality automatic
5//! extractions without touching the entities or memories themselves.
6
7use crate::errors::AppError;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_rw;
11use serde::Serialize;
12
13#[derive(clap::Args)]
14#[command(after_long_help = "EXAMPLES:\n  \
15    # Preview bindings that would be removed for a single entity\n  \
16    sqlite-graphrag prune-ner --entity jwt-token --dry-run\n\n  \
17    # Remove all NER bindings for a single entity\n  \
18    sqlite-graphrag prune-ner --entity jwt-token --yes\n\n  \
19    # Remove ALL NER bindings in the current namespace\n  \
20    sqlite-graphrag prune-ner --all --yes\n\n  \
21NOTE:\n  \
22    This command deletes rows from memory_entities (the link table between\n  \
23    memories and extracted entities). The entities and memories themselves\n  \
24    are not deleted. Use cleanup-orphans afterwards to remove entity nodes\n  \
25    that have no remaining links.")]
26pub struct PruneNerArgs {
27    /// Entity name whose bindings should be removed.
28    /// Mutually exclusive with --all.
29    #[arg(long, conflicts_with = "all", value_name = "NAME")]
30    pub entity: Option<String>,
31
32    /// Remove all NER bindings in the namespace. Mutually exclusive with --entity.
33    #[arg(long, conflicts_with = "entity", default_value_t = false)]
34    pub all: bool,
35
36    #[arg(long)]
37    pub namespace: Option<String>,
38
39    /// Preview count without deleting.
40    #[arg(long, default_value_t = false)]
41    pub dry_run: bool,
42
43    /// Skip confirmation for destructive operation.
44    #[arg(long, default_value_t = false)]
45    pub yes: bool,
46
47    #[arg(long, value_enum, default_value = "json")]
48    pub format: OutputFormat,
49
50    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
51    pub json: bool,
52
53    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
54    pub db: Option<String>,
55}
56
57#[derive(Serialize)]
58struct PruneNerResponse {
59    action: String,
60    bindings_removed: usize,
61    namespace: String,
62    /// Entity name targeted, when `--entity` was used.
63    #[serde(skip_serializing_if = "Option::is_none")]
64    entity: Option<String>,
65    /// Total execution time in milliseconds from handler start to serialisation.
66    elapsed_ms: u64,
67}
68
69pub fn run(args: PruneNerArgs) -> Result<(), AppError> {
70    let inicio = std::time::Instant::now();
71
72    if args.entity.is_none() && !args.all {
73        return Err(AppError::Validation(
74            "either --entity <NAME> or --all must be specified".to_string(),
75        ));
76    }
77
78    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
79    let paths = AppPaths::resolve(args.db.as_deref())?;
80
81    crate::storage::connection::ensure_db_ready(&paths)?;
82
83    let mut conn = open_rw(&paths.db)?;
84
85    // Count how many rows would be affected.
86    let count: usize = if let Some(ref entity_name) = args.entity {
87        conn.query_row(
88            "SELECT COUNT(*) FROM memory_entities me
89             JOIN entities e ON e.id = me.entity_id
90             WHERE e.name = ?1 AND e.namespace = ?2",
91            rusqlite::params![entity_name, namespace],
92            |r| r.get::<_, i64>(0).map(|v| v as usize),
93        )?
94    } else {
95        conn.query_row(
96            "SELECT COUNT(*) FROM memory_entities me
97             JOIN entities e ON e.id = me.entity_id
98             WHERE e.namespace = ?1",
99            rusqlite::params![namespace],
100            |r| r.get::<_, i64>(0).map(|v| v as usize),
101        )?
102    };
103
104    if args.dry_run {
105        let response = PruneNerResponse {
106            action: "dry_run".to_string(),
107            bindings_removed: count,
108            namespace: namespace.clone(),
109            entity: args.entity.clone(),
110            elapsed_ms: inicio.elapsed().as_millis() as u64,
111        };
112
113        match args.format {
114            OutputFormat::Json => output::emit_json(&response)?,
115            OutputFormat::Text | OutputFormat::Markdown => {
116                output::emit_text(&format!(
117                    "dry_run: {count} NER bindings would be removed [{namespace}]"
118                ));
119            }
120        }
121
122        return Ok(());
123    }
124
125    if !args.yes {
126        let response = PruneNerResponse {
127            action: "aborted".to_string(),
128            bindings_removed: count,
129            namespace: namespace.clone(),
130            entity: args.entity.clone(),
131            elapsed_ms: inicio.elapsed().as_millis() as u64,
132        };
133
134        match args.format {
135            OutputFormat::Json => output::emit_json(&response)?,
136            OutputFormat::Text | OutputFormat::Markdown => {
137                output::emit_text(&format!(
138                    "aborted: {count} NER bindings would be removed; pass --yes to confirm [{namespace}]"
139                ));
140            }
141        }
142
143        return Ok(());
144    }
145
146    // Destructive path: COUNT + DELETE in same transaction for consistency.
147    let removed: usize = if let Some(ref entity_name) = args.entity {
148        // Normalize to match the normalized stored entity names.
149        let entity_name = crate::parsers::normalize_entity_name(entity_name);
150        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
151        let n = tx.execute(
152            "DELETE FROM memory_entities WHERE entity_id IN (
153                 SELECT id FROM entities WHERE name = ?1 AND namespace = ?2
154             )",
155            rusqlite::params![entity_name, namespace],
156        )?;
157        tx.commit()?;
158        n
159    } else {
160        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
161        let n = tx.execute(
162            "DELETE FROM memory_entities WHERE entity_id IN (
163                 SELECT id FROM entities WHERE namespace = ?1
164             )",
165            rusqlite::params![namespace],
166        )?;
167        tx.commit()?;
168        n
169    };
170
171    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
172
173    tracing::info!(
174        removed = removed,
175        namespace = %namespace,
176        entity = ?args.entity,
177        "NER bindings pruned"
178    );
179
180    let response = PruneNerResponse {
181        action: "pruned".to_string(),
182        bindings_removed: removed,
183        namespace: namespace.clone(),
184        entity: args.entity.clone(),
185        elapsed_ms: inicio.elapsed().as_millis() as u64,
186    };
187
188    match args.format {
189        OutputFormat::Json => output::emit_json(&response)?,
190        OutputFormat::Text | OutputFormat::Markdown => {
191            output::emit_text(&format!(
192                "pruned: {removed} NER bindings removed [{namespace}]"
193            ));
194        }
195    }
196
197    Ok(())
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn prune_ner_response_dry_run_serializes_correctly() {
206        let resp = PruneNerResponse {
207            action: "dry_run".to_string(),
208            bindings_removed: 42,
209            namespace: "global".to_string(),
210            entity: Some("jwt-token".to_string()),
211            elapsed_ms: 5,
212        };
213        let json = serde_json::to_value(&resp).expect("serialization failed");
214        assert_eq!(json["action"], "dry_run");
215        assert_eq!(json["bindings_removed"], 42);
216        assert_eq!(json["entity"], "jwt-token");
217        assert_eq!(json["namespace"], "global");
218    }
219
220    #[test]
221    fn prune_ner_response_pruned_all_omits_entity() {
222        let resp = PruneNerResponse {
223            action: "pruned".to_string(),
224            bindings_removed: 200,
225            namespace: "project-x".to_string(),
226            entity: None,
227            elapsed_ms: 15,
228        };
229        let json = serde_json::to_value(&resp).expect("serialization failed");
230        assert_eq!(json["action"], "pruned");
231        assert_eq!(json["bindings_removed"], 200);
232        assert!(
233            json.get("entity").is_none(),
234            "entity must be omitted when None"
235        );
236    }
237
238    #[test]
239    fn prune_ner_response_aborted_includes_count() {
240        let resp = PruneNerResponse {
241            action: "aborted".to_string(),
242            bindings_removed: 10,
243            namespace: "global".to_string(),
244            entity: None,
245            elapsed_ms: 1,
246        };
247        let json = serde_json::to_value(&resp).expect("serialization failed");
248        assert_eq!(json["action"], "aborted");
249        assert_eq!(json["bindings_removed"], 10);
250        assert!(json["elapsed_ms"].is_number());
251    }
252
253    #[test]
254    fn prune_ner_response_zero_bindings() {
255        let resp = PruneNerResponse {
256            action: "pruned".to_string(),
257            bindings_removed: 0,
258            namespace: "global".to_string(),
259            entity: Some("nonexistent".to_string()),
260            elapsed_ms: 2,
261        };
262        let json = serde_json::to_value(&resp).expect("serialization failed");
263        assert_eq!(json["bindings_removed"], 0);
264    }
265}