1use crate::errors::AppError;
11use crate::output::{self, OutputFormat};
12use crate::parsers::normalize_entity_name;
13use crate::paths::AppPaths;
14use crate::storage::connection::open_rw;
15use rusqlite::params;
16use serde::Serialize;
17
18#[derive(clap::Args)]
19#[command(after_long_help = "EXAMPLES:\n \
20 # Preview which entities would be renamed or merged\n \
21 sqlite-graphrag normalize-entities --dry-run\n\n \
22 # Apply normalization to all entity names\n \
23 sqlite-graphrag normalize-entities --yes\n\n \
24 # Scope to a specific namespace\n \
25 sqlite-graphrag normalize-entities --yes --namespace my-project\n\n\
26NOTE:\n \
27 When a normalized name already exists, the source entity is merged into\n \
28 the existing target via relationship retargeting (UPDATE OR IGNORE + DELETE).\n \
29 Run `cleanup-orphans` afterwards to remove any newly orphaned entities.")]
30pub struct NormalizeEntitiesArgs {
31 #[arg(long, conflicts_with = "yes")]
33 pub dry_run: bool,
34 #[arg(long, conflicts_with = "dry_run")]
36 pub yes: bool,
37 #[arg(long)]
38 pub namespace: Option<String>,
39 #[arg(long, value_enum, default_value = "json")]
40 pub format: OutputFormat,
41 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
42 pub json: bool,
43 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
44 pub db: Option<String>,
45}
46
47#[derive(Serialize)]
48struct NormalizeEntitiesResponse {
49 action: String,
51 normalized_count: usize,
53 merged_count: usize,
56 namespace: String,
57 elapsed_ms: u64,
59}
60
61pub fn run(args: NormalizeEntitiesArgs) -> Result<(), AppError> {
62 let inicio = std::time::Instant::now();
63
64 if !args.dry_run && !args.yes {
65 return Err(AppError::Validation(
66 "pass --dry-run to preview or --yes to apply changes".to_string(),
67 ));
68 }
69
70 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
71 let paths = AppPaths::resolve(args.db.as_deref())?;
72
73 crate::storage::connection::ensure_db_ready(&paths)?;
74
75 let mut conn = open_rw(&paths.db)?;
76
77 let entities: Vec<(i64, String)> = {
79 let mut stmt =
80 conn.prepare_cached("SELECT id, name FROM entities WHERE namespace = ?1 ORDER BY id")?;
81 let rows = stmt.query_map(params![namespace], |r| {
82 Ok((r.get::<_, i64>(0)?, r.get::<_, String>(1)?))
83 })?;
84 rows.collect::<Result<Vec<_>, _>>()?
85 };
86
87 let to_change: Vec<(i64, String, String)> = entities
89 .iter()
90 .filter_map(|(id, name)| {
91 let normalized = normalize_entity_name(name);
92 if normalized != *name {
93 Some((*id, name.clone(), normalized))
94 } else {
95 None
96 }
97 })
98 .collect();
99
100 let already_normalized: std::collections::HashSet<String> = entities
104 .iter()
105 .filter(|(_, name)| normalize_entity_name(name) == *name)
106 .map(|(_, name)| name.clone())
107 .collect();
108
109 let mut target_groups: std::collections::HashMap<String, usize> =
110 std::collections::HashMap::with_capacity(to_change.len());
111 for (_, _, normalized) in &to_change {
112 *target_groups.entry(normalized.clone()).or_insert(0) += 1;
113 }
114
115 let mut merge_count_preview: usize = 0;
116 let mut rename_count_preview: usize = 0;
117 for (target, count) in &target_groups {
118 if *count > 1 || already_normalized.contains(target) {
119 let extra = if already_normalized.contains(target) {
121 *count } else {
123 count - 1 };
125 merge_count_preview += extra;
126 rename_count_preview += count - extra;
127 } else {
128 rename_count_preview += 1;
129 }
130 }
131
132 if args.dry_run {
133 let response = NormalizeEntitiesResponse {
134 action: "dry_run".to_string(),
135 normalized_count: rename_count_preview,
136 merged_count: merge_count_preview,
137 namespace,
138 elapsed_ms: inicio.elapsed().as_millis() as u64,
139 };
140 match args.format {
141 OutputFormat::Json => output::emit_json(&response)?,
142 OutputFormat::Text | OutputFormat::Markdown => {
143 output::emit_text(&format!(
144 "dry_run: {} entity names would be normalized",
145 response.normalized_count
146 ));
147 }
148 }
149 return Ok(());
150 }
151
152 let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
154
155 let mut normalized_count: usize = 0;
156 let mut merged_count: usize = 0;
157
158 for (src_id, _original_name, normalized) in &to_change {
159 let existing_id: Option<i64> = {
161 let mut stmt =
162 tx.prepare_cached("SELECT id FROM entities WHERE namespace = ?1 AND name = ?2")?;
163 match stmt.query_row(params![namespace, normalized], |r| r.get::<_, i64>(0)) {
164 Ok(id) => Some(id),
165 Err(rusqlite::Error::QueryReturnedNoRows) => None,
166 Err(e) => return Err(AppError::Database(e)),
167 }
168 };
169
170 match existing_id {
171 Some(target_id) if target_id != *src_id => {
172 tx.execute(
175 "UPDATE OR IGNORE relationships SET source_id = ?1 WHERE source_id = ?2",
176 params![target_id, src_id],
177 )?;
178 tx.execute(
179 "DELETE FROM relationships WHERE source_id = ?1",
180 params![src_id],
181 )?;
182 tx.execute(
184 "UPDATE OR IGNORE relationships SET target_id = ?1 WHERE target_id = ?2",
185 params![target_id, src_id],
186 )?;
187 tx.execute(
188 "DELETE FROM relationships WHERE target_id = ?1",
189 params![src_id],
190 )?;
191 tx.execute("DELETE FROM relationships WHERE source_id = target_id", [])?;
193 tx.execute(
195 "UPDATE OR IGNORE memory_entities SET entity_id = ?1 WHERE entity_id = ?2",
196 params![target_id, src_id],
197 )?;
198 tx.execute(
199 "DELETE FROM memory_entities WHERE entity_id = ?1",
200 params![src_id],
201 )?;
202 tx.execute("DELETE FROM entities WHERE id = ?1", params![src_id])?;
204 tx.execute(
206 "UPDATE entities
207 SET degree = (SELECT COUNT(*) FROM relationships
208 WHERE source_id = entities.id OR target_id = entities.id)
209 WHERE id = ?1",
210 params![target_id],
211 )?;
212 tracing::info!(target: "normalize_entities",
213 src_id = src_id,
214 target_id = target_id,
215 normalized = normalized,
216 "entity merged into existing normalized target"
217 );
218 merged_count += 1;
219 }
220 _ => {
221 tx.execute(
223 "UPDATE entities SET name = ?1, updated_at = unixepoch() WHERE id = ?2",
224 params![normalized, src_id],
225 )?;
226 tracing::info!(target: "normalize_entities",
227 entity_id = src_id,
228 normalized = normalized,
229 "entity name normalized"
230 );
231 normalized_count += 1;
232 }
233 }
234 }
235
236 tx.commit()?;
237 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
238
239 let response = NormalizeEntitiesResponse {
240 action: "normalized".to_string(),
241 normalized_count,
242 merged_count,
243 namespace,
244 elapsed_ms: inicio.elapsed().as_millis() as u64,
245 };
246
247 match args.format {
248 OutputFormat::Json => output::emit_json(&response)?,
249 OutputFormat::Text | OutputFormat::Markdown => {
250 output::emit_text(&format!(
251 "normalized: {} renamed, {} merged",
252 response.normalized_count, response.merged_count
253 ));
254 }
255 }
256
257 Ok(())
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::storage::connection::register_vec_extension;
264 use rusqlite::Connection;
265 use tempfile::TempDir;
266
267 type TestResult = Result<(), Box<dyn std::error::Error>>;
268
269 fn setup_db() -> Result<(TempDir, Connection), Box<dyn std::error::Error>> {
271 register_vec_extension();
272 let tmp = TempDir::new()?;
273 let db_path = tmp.path().join("test.db");
274 let mut conn = Connection::open(&db_path)?;
275 crate::migrations::runner().run(&mut conn)?;
276 Ok((tmp, conn))
277 }
278
279 fn insert_entity(conn: &Connection, name: &str) -> Result<i64, Box<dyn std::error::Error>> {
282 conn.execute(
284 "INSERT INTO entities (namespace, name, type, description) VALUES ('global', ?1, 'concept', NULL)",
285 params![name],
286 )?;
287 let id: i64 = conn.query_row(
288 "SELECT id FROM entities WHERE namespace = 'global' AND name = ?1",
289 params![name],
290 |r| r.get(0),
291 )?;
292 Ok(id)
293 }
294
295 #[test]
296 fn dry_run_returns_count_without_changes() -> TestResult {
297 let (_tmp, conn) = setup_db()?;
298 insert_entity(&conn, "Hello World")?;
299 insert_entity(&conn, "already-normalized")?;
300
301 let count: i64 = conn.query_row(
303 "SELECT COUNT(*) FROM entities WHERE name = 'Hello World' AND namespace = 'global'",
304 [],
305 |r| r.get(0),
306 )?;
307 assert_eq!(count, 1, "entity must exist before dry run");
308
309 let count_after: i64 = conn.query_row(
311 "SELECT COUNT(*) FROM entities WHERE name = 'Hello World' AND namespace = 'global'",
312 [],
313 |r| r.get(0),
314 )?;
315 assert_eq!(count_after, 1, "dry run must not rename entities");
316 Ok(())
317 }
318
319 #[test]
320 fn renames_unnormalized_entity_in_place() -> TestResult {
321 let (_tmp, conn) = setup_db()?;
322 let src_id = insert_entity(&conn, "Hello World")?;
323
324 {
326 let normalized = normalize_entity_name("Hello World");
327 let existing: Option<i64> = {
328 match conn.query_row(
329 "SELECT id FROM entities WHERE namespace = 'global' AND name = ?1",
330 params![normalized],
331 |r| r.get::<_, i64>(0),
332 ) {
333 Ok(id) => Some(id),
334 Err(rusqlite::Error::QueryReturnedNoRows) => None,
335 Err(e) => return Err(e.into()),
336 }
337 };
338 assert!(existing.is_none(), "no collision expected");
339 conn.execute(
340 "UPDATE entities SET name = ?1 WHERE id = ?2",
341 params![normalized, src_id],
342 )?;
343 }
344
345 let name: String = conn.query_row(
346 "SELECT name FROM entities WHERE id = ?1",
347 params![src_id],
348 |r| r.get(0),
349 )?;
350 assert_eq!(name, "hello-world");
351 Ok(())
352 }
353
354 #[test]
355 fn merges_into_existing_on_collision() -> TestResult {
356 let (_tmp, conn) = setup_db()?;
357 let target_id = insert_entity(&conn, "hello-world")?;
359 let src_id = insert_entity(&conn, "Hello World")?;
361
362 conn.execute(
364 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
365 VALUES ('global', ?1, ?1, 'related', 0.5)",
366 params![src_id],
367 )?;
368
369 conn.execute(
371 "UPDATE OR IGNORE relationships SET source_id = ?1 WHERE source_id = ?2",
372 params![target_id, src_id],
373 )?;
374 conn.execute(
375 "DELETE FROM relationships WHERE source_id = ?1",
376 params![src_id],
377 )?;
378 conn.execute("DELETE FROM entities WHERE id = ?1", params![src_id])?;
379
380 let src_exists: i64 = conn.query_row(
382 "SELECT COUNT(*) FROM entities WHERE id = ?1",
383 params![src_id],
384 |r| r.get(0),
385 )?;
386 assert_eq!(src_exists, 0, "source entity must be deleted after merge");
387
388 let target_name: String = conn.query_row(
390 "SELECT name FROM entities WHERE id = ?1",
391 params![target_id],
392 |r| r.get(0),
393 )?;
394 assert_eq!(target_name, "hello-world");
395 Ok(())
396 }
397
398 #[test]
399 fn normalize_entities_response_serializes_correctly() {
400 let resp = NormalizeEntitiesResponse {
401 action: "normalized".to_string(),
402 normalized_count: 3,
403 merged_count: 1,
404 namespace: "global".to_string(),
405 elapsed_ms: 42,
406 };
407 let json = serde_json::to_value(&resp).expect("serialization");
408 assert_eq!(json["action"], "normalized");
409 assert_eq!(json["normalized_count"], 3);
410 assert_eq!(json["merged_count"], 1);
411 assert_eq!(json["namespace"], "global");
412 assert!(json["elapsed_ms"].as_u64().is_some());
413 }
414
415 #[test]
416 fn dry_run_response_has_correct_action() {
417 let resp = NormalizeEntitiesResponse {
418 action: "dry_run".to_string(),
419 normalized_count: 5,
420 merged_count: 0,
421 namespace: "test".to_string(),
422 elapsed_ms: 1,
423 };
424 let json = serde_json::to_value(&resp).expect("serialization");
425 assert_eq!(json["action"], "dry_run");
426 }
427}