Skip to main content

sqlite_graphrag/commands/
normalize_entities.rs

1//! Handler for the `normalize-entities` CLI subcommand (GAP-15).
2//!
3//! Scans all existing entity names in the namespace and normalizes them to
4//! kebab-case ASCII using [`crate::parsers::normalize_entity_name`].
5//! When a normalized name already exists (collision), the source entity is
6//! merged into the target using the same logic as `merge-entities`:
7//! relationships are retargeted via `UPDATE OR IGNORE` + `DELETE`, then
8//! the source row is removed. Otherwise the entity name is updated in place.
9
10use crate::errors::AppError;
11use crate::output::{self, OutputFormat};
12use crate::parsers::normalize_entity_name;
13use crate::paths::AppPaths;
14use crate::storage::connection::open_rw;
15use rusqlite::params;
16use serde::Serialize;
17
18#[derive(clap::Args)]
19#[command(after_long_help = "EXAMPLES:\n  \
20    # Preview which entities would be renamed or merged\n  \
21    sqlite-graphrag normalize-entities --dry-run\n\n  \
22    # Apply normalization to all entity names\n  \
23    sqlite-graphrag normalize-entities --yes\n\n  \
24    # Scope to a specific namespace\n  \
25    sqlite-graphrag normalize-entities --yes --namespace my-project\n\n\
26NOTE:\n  \
27    When a normalized name already exists, the source entity is merged into\n  \
28    the existing target via relationship retargeting (UPDATE OR IGNORE + DELETE).\n  \
29    Run `cleanup-orphans` afterwards to remove any newly orphaned entities.")]
30pub struct NormalizeEntitiesArgs {
31    /// Preview changes without persisting them.
32    #[arg(long, conflicts_with = "yes")]
33    pub dry_run: bool,
34    /// Apply normalization without interactive confirmation.
35    #[arg(long, conflicts_with = "dry_run")]
36    pub yes: bool,
37    #[arg(long)]
38    pub namespace: Option<String>,
39    #[arg(long, value_enum, default_value = "json")]
40    pub format: OutputFormat,
41    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
42    pub json: bool,
43    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
44    pub db: Option<String>,
45}
46
47#[derive(Serialize)]
48struct NormalizeEntitiesResponse {
49    /// "normalized" when changes were applied, "dry_run" when only previewed.
50    action: String,
51    /// Number of entities whose names were updated in place.
52    normalized_count: usize,
53    /// Number of entities that collided with an existing normalized name and
54    /// were merged into the target.
55    merged_count: usize,
56    namespace: String,
57    /// Total execution time in milliseconds from handler start to serialisation.
58    elapsed_ms: u64,
59}
60
61pub fn run(args: NormalizeEntitiesArgs) -> Result<(), AppError> {
62    let inicio = std::time::Instant::now();
63
64    if !args.dry_run && !args.yes {
65        return Err(AppError::Validation(
66            "pass --dry-run to preview or --yes to apply changes".to_string(),
67        ));
68    }
69
70    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
71    let paths = AppPaths::resolve(args.db.as_deref())?;
72
73    crate::storage::connection::ensure_db_ready(&paths)?;
74
75    let mut conn = open_rw(&paths.db)?;
76
77    // Collect all entity (id, name) pairs for the namespace.
78    let entities: Vec<(i64, String)> = {
79        let mut stmt =
80            conn.prepare_cached("SELECT id, name FROM entities WHERE namespace = ?1 ORDER BY id")?;
81        let rows = stmt.query_map(params![namespace], |r| {
82            Ok((r.get::<_, i64>(0)?, r.get::<_, String>(1)?))
83        })?;
84        rows.collect::<Result<Vec<_>, _>>()?
85    };
86
87    // Compute which names need changing.
88    let to_change: Vec<(i64, String, String)> = entities
89        .iter()
90        .filter_map(|(id, name)| {
91            let normalized = normalize_entity_name(name);
92            if normalized != *name {
93                Some((*id, name.clone(), normalized))
94            } else {
95                None
96            }
97        })
98        .collect();
99
100    // G10: classify changes into renames (no collision) and merges (collision).
101    // A collision occurs when two distinct names normalize to the same target,
102    // or when the normalized target already exists in the DB as an already-normalized entity.
103    let already_normalized: std::collections::HashSet<String> = entities
104        .iter()
105        .filter(|(_, name)| normalize_entity_name(name) == *name)
106        .map(|(_, name)| name.clone())
107        .collect();
108
109    let mut target_groups: std::collections::HashMap<String, usize> =
110        std::collections::HashMap::with_capacity(to_change.len());
111    for (_, _, normalized) in &to_change {
112        *target_groups.entry(normalized.clone()).or_insert(0) += 1;
113    }
114
115    let mut merge_count_preview: usize = 0;
116    let mut rename_count_preview: usize = 0;
117    for (target, count) in &target_groups {
118        if *count > 1 || already_normalized.contains(target) {
119            // All sources in this group will merge into the existing or first entity
120            let extra = if already_normalized.contains(target) {
121                *count // all merge into existing
122            } else {
123                count - 1 // first one renames, rest merge
124            };
125            merge_count_preview += extra;
126            rename_count_preview += count - extra;
127        } else {
128            rename_count_preview += 1;
129        }
130    }
131
132    if args.dry_run {
133        let response = NormalizeEntitiesResponse {
134            action: "dry_run".to_string(),
135            normalized_count: rename_count_preview,
136            merged_count: merge_count_preview,
137            namespace,
138            elapsed_ms: inicio.elapsed().as_millis() as u64,
139        };
140        match args.format {
141            OutputFormat::Json => output::emit_json(&response)?,
142            OutputFormat::Text | OutputFormat::Markdown => {
143                output::emit_text(&format!(
144                    "dry_run: {} entity names would be normalized",
145                    response.normalized_count
146                ));
147            }
148        }
149        return Ok(());
150    }
151
152    // Apply changes inside a transaction.
153    let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
154
155    let mut normalized_count: usize = 0;
156    let mut merged_count: usize = 0;
157
158    for (src_id, _original_name, normalized) in &to_change {
159        // Check whether a row with the normalized name already exists.
160        let existing_id: Option<i64> = {
161            let mut stmt =
162                tx.prepare_cached("SELECT id FROM entities WHERE namespace = ?1 AND name = ?2")?;
163            match stmt.query_row(params![namespace, normalized], |r| r.get::<_, i64>(0)) {
164                Ok(id) => Some(id),
165                Err(rusqlite::Error::QueryReturnedNoRows) => None,
166                Err(e) => return Err(AppError::Database(e)),
167            }
168        };
169
170        match existing_id {
171            Some(target_id) if target_id != *src_id => {
172                // Collision: merge source into target using UPDATE OR IGNORE + DELETE.
173                // Step 1a: redirect source_id.
174                tx.execute(
175                    "UPDATE OR IGNORE relationships SET source_id = ?1 WHERE source_id = ?2",
176                    params![target_id, src_id],
177                )?;
178                tx.execute(
179                    "DELETE FROM relationships WHERE source_id = ?1",
180                    params![src_id],
181                )?;
182                // Step 1b: redirect target_id.
183                tx.execute(
184                    "UPDATE OR IGNORE relationships SET target_id = ?1 WHERE target_id = ?2",
185                    params![target_id, src_id],
186                )?;
187                tx.execute(
188                    "DELETE FROM relationships WHERE target_id = ?1",
189                    params![src_id],
190                )?;
191                // Remove self-loops.
192                tx.execute("DELETE FROM relationships WHERE source_id = target_id", [])?;
193                // Retarget memory_entities bindings.
194                tx.execute(
195                    "UPDATE OR IGNORE memory_entities SET entity_id = ?1 WHERE entity_id = ?2",
196                    params![target_id, src_id],
197                )?;
198                tx.execute(
199                    "DELETE FROM memory_entities WHERE entity_id = ?1",
200                    params![src_id],
201                )?;
202                // Remove the source entity row.
203                tx.execute("DELETE FROM entities WHERE id = ?1", params![src_id])?;
204                // Recalculate degree for the surviving target.
205                tx.execute(
206                    "UPDATE entities
207                     SET degree = (SELECT COUNT(*) FROM relationships
208                                   WHERE source_id = entities.id OR target_id = entities.id)
209                     WHERE id = ?1",
210                    params![target_id],
211                )?;
212                tracing::info!(target: "normalize_entities",
213                    src_id = src_id,
214                    target_id = target_id,
215                    normalized = normalized,
216                    "entity merged into existing normalized target"
217                );
218                merged_count += 1;
219            }
220            _ => {
221                // No collision: simple rename.
222                tx.execute(
223                    "UPDATE entities SET name = ?1, updated_at = unixepoch() WHERE id = ?2",
224                    params![normalized, src_id],
225                )?;
226                tracing::info!(target: "normalize_entities",
227                    entity_id = src_id,
228                    normalized = normalized,
229                    "entity name normalized"
230                );
231                normalized_count += 1;
232            }
233        }
234    }
235
236    tx.commit()?;
237    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
238
239    let response = NormalizeEntitiesResponse {
240        action: "normalized".to_string(),
241        normalized_count,
242        merged_count,
243        namespace,
244        elapsed_ms: inicio.elapsed().as_millis() as u64,
245    };
246
247    match args.format {
248        OutputFormat::Json => output::emit_json(&response)?,
249        OutputFormat::Text | OutputFormat::Markdown => {
250            output::emit_text(&format!(
251                "normalized: {} renamed, {} merged",
252                response.normalized_count, response.merged_count
253            ));
254        }
255    }
256
257    Ok(())
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::storage::connection::register_vec_extension;
264    use rusqlite::Connection;
265    use tempfile::TempDir;
266
267    type TestResult = Result<(), Box<dyn std::error::Error>>;
268
269    /// Opens a temp DB with the full schema applied via migrations.
270    fn setup_db() -> Result<(TempDir, Connection), Box<dyn std::error::Error>> {
271        register_vec_extension();
272        let tmp = TempDir::new()?;
273        let db_path = tmp.path().join("test.db");
274        let mut conn = Connection::open(&db_path)?;
275        crate::migrations::runner().run(&mut conn)?;
276        Ok((tmp, conn))
277    }
278
279    /// Inserts an entity bypassing `upsert_entity` normalization, so tests can
280    /// seed deliberately un-normalized names.
281    fn insert_entity(conn: &Connection, name: &str) -> Result<i64, Box<dyn std::error::Error>> {
282        // Bypass upsert_entity normalization to seed raw (un-normalized) names.
283        conn.execute(
284            "INSERT INTO entities (namespace, name, type, description) VALUES ('global', ?1, 'concept', NULL)",
285            params![name],
286        )?;
287        let id: i64 = conn.query_row(
288            "SELECT id FROM entities WHERE namespace = 'global' AND name = ?1",
289            params![name],
290            |r| r.get(0),
291        )?;
292        Ok(id)
293    }
294
295    #[test]
296    fn dry_run_returns_count_without_changes() -> TestResult {
297        let (_tmp, conn) = setup_db()?;
298        insert_entity(&conn, "Hello World")?;
299        insert_entity(&conn, "already-normalized")?;
300
301        // Verify "Hello World" exists.
302        let count: i64 = conn.query_row(
303            "SELECT COUNT(*) FROM entities WHERE name = 'Hello World' AND namespace = 'global'",
304            [],
305            |r| r.get(0),
306        )?;
307        assert_eq!(count, 1, "entity must exist before dry run");
308
309        // dry_run must not modify anything.
310        let count_after: i64 = conn.query_row(
311            "SELECT COUNT(*) FROM entities WHERE name = 'Hello World' AND namespace = 'global'",
312            [],
313            |r| r.get(0),
314        )?;
315        assert_eq!(count_after, 1, "dry run must not rename entities");
316        Ok(())
317    }
318
319    #[test]
320    fn renames_unnormalized_entity_in_place() -> TestResult {
321        let (_tmp, conn) = setup_db()?;
322        let src_id = insert_entity(&conn, "Hello World")?;
323
324        // Apply normalization directly via the internal logic.
325        {
326            let normalized = normalize_entity_name("Hello World");
327            let existing: Option<i64> = {
328                match conn.query_row(
329                    "SELECT id FROM entities WHERE namespace = 'global' AND name = ?1",
330                    params![normalized],
331                    |r| r.get::<_, i64>(0),
332                ) {
333                    Ok(id) => Some(id),
334                    Err(rusqlite::Error::QueryReturnedNoRows) => None,
335                    Err(e) => return Err(e.into()),
336                }
337            };
338            assert!(existing.is_none(), "no collision expected");
339            conn.execute(
340                "UPDATE entities SET name = ?1 WHERE id = ?2",
341                params![normalized, src_id],
342            )?;
343        }
344
345        let name: String = conn.query_row(
346            "SELECT name FROM entities WHERE id = ?1",
347            params![src_id],
348            |r| r.get(0),
349        )?;
350        assert_eq!(name, "hello-world");
351        Ok(())
352    }
353
354    #[test]
355    fn merges_into_existing_on_collision() -> TestResult {
356        let (_tmp, conn) = setup_db()?;
357        // Target already exists with the normalized name.
358        let target_id = insert_entity(&conn, "hello-world")?;
359        // Source has the un-normalized form that normalizes to the same value.
360        let src_id = insert_entity(&conn, "Hello World")?;
361
362        // Insert a relationship attached to src_id.
363        conn.execute(
364            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
365             VALUES ('global', ?1, ?1, 'related', 0.5)",
366            params![src_id],
367        )?;
368
369        // Merge: retarget relationships from src → target.
370        conn.execute(
371            "UPDATE OR IGNORE relationships SET source_id = ?1 WHERE source_id = ?2",
372            params![target_id, src_id],
373        )?;
374        conn.execute(
375            "DELETE FROM relationships WHERE source_id = ?1",
376            params![src_id],
377        )?;
378        conn.execute("DELETE FROM entities WHERE id = ?1", params![src_id])?;
379
380        // Source must be gone.
381        let src_exists: i64 = conn.query_row(
382            "SELECT COUNT(*) FROM entities WHERE id = ?1",
383            params![src_id],
384            |r| r.get(0),
385        )?;
386        assert_eq!(src_exists, 0, "source entity must be deleted after merge");
387
388        // Target must still exist.
389        let target_name: String = conn.query_row(
390            "SELECT name FROM entities WHERE id = ?1",
391            params![target_id],
392            |r| r.get(0),
393        )?;
394        assert_eq!(target_name, "hello-world");
395        Ok(())
396    }
397
398    #[test]
399    fn normalize_entities_response_serializes_correctly() {
400        let resp = NormalizeEntitiesResponse {
401            action: "normalized".to_string(),
402            normalized_count: 3,
403            merged_count: 1,
404            namespace: "global".to_string(),
405            elapsed_ms: 42,
406        };
407        let json = serde_json::to_value(&resp).expect("serialization");
408        assert_eq!(json["action"], "normalized");
409        assert_eq!(json["normalized_count"], 3);
410        assert_eq!(json["merged_count"], 1);
411        assert_eq!(json["namespace"], "global");
412        assert!(json["elapsed_ms"].as_u64().is_some());
413    }
414
415    #[test]
416    fn dry_run_response_has_correct_action() {
417        let resp = NormalizeEntitiesResponse {
418            action: "dry_run".to_string(),
419            normalized_count: 5,
420            merged_count: 0,
421            namespace: "test".to_string(),
422            elapsed_ms: 1,
423        };
424        let json = serde_json::to_value(&resp).expect("serialization");
425        assert_eq!(json["action"], "dry_run");
426    }
427}