1use crate::errors::AppError;
11use crate::output::{self, OutputFormat};
12use crate::parsers::normalize_entity_name;
13use crate::paths::AppPaths;
14use crate::storage::connection::open_rw;
15use rusqlite::params;
16use serde::Serialize;
17
18#[derive(clap::Args)]
19#[command(after_long_help = "EXAMPLES:\n \
20 # Preview which entities would be renamed or merged\n \
21 sqlite-graphrag normalize-entities --dry-run\n\n \
22 # Apply normalization to all entity names\n \
23 sqlite-graphrag normalize-entities --yes\n\n \
24 # Scope to a specific namespace\n \
25 sqlite-graphrag normalize-entities --yes --namespace my-project\n\n\
26NOTE:\n \
27 When a normalized name already exists, the source entity is merged into\n \
28 the existing target via relationship retargeting (UPDATE OR IGNORE + DELETE).\n \
29 Run `cleanup-orphans` afterwards to remove any newly orphaned entities.")]
30pub struct NormalizeEntitiesArgs {
31 #[arg(long, conflicts_with = "yes")]
33 pub dry_run: bool,
34 #[arg(long, conflicts_with = "dry_run")]
36 pub yes: bool,
37 #[arg(long)]
38 pub namespace: Option<String>,
39 #[arg(long, value_enum, default_value = "json")]
40 pub format: OutputFormat,
41 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
42 pub json: bool,
43 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
44 pub db: Option<String>,
45}
46
47#[derive(Serialize)]
48struct NormalizeEntitiesResponse {
49 action: String,
51 normalized_count: usize,
53 merged_count: usize,
56 namespace: String,
57 elapsed_ms: u64,
59}
60
61pub fn run(args: NormalizeEntitiesArgs) -> Result<(), AppError> {
62 let inicio = std::time::Instant::now();
63
64 if !args.dry_run && !args.yes {
65 return Err(AppError::Validation(
66 "pass --dry-run to preview or --yes to apply changes".to_string(),
67 ));
68 }
69
70 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
71 let paths = AppPaths::resolve(args.db.as_deref())?;
72
73 crate::storage::connection::ensure_db_ready(&paths)?;
74
75 let mut conn = open_rw(&paths.db)?;
76
77 let entities: Vec<(i64, String)> = {
79 let mut stmt =
80 conn.prepare("SELECT id, name FROM entities WHERE namespace = ?1 ORDER BY id")?;
81 let rows = stmt.query_map(params![namespace], |r| {
82 Ok((r.get::<_, i64>(0)?, r.get::<_, String>(1)?))
83 })?;
84 rows.collect::<Result<Vec<_>, _>>()?
85 };
86
87 let to_change: Vec<(i64, String, String)> = entities
89 .iter()
90 .filter_map(|(id, name)| {
91 let normalized = normalize_entity_name(name);
92 if normalized != *name {
93 Some((*id, name.clone(), normalized))
94 } else {
95 None
96 }
97 })
98 .collect();
99
100 let normalized_count_preview = to_change.len();
101
102 if args.dry_run {
103 let response = NormalizeEntitiesResponse {
104 action: "dry_run".to_string(),
105 normalized_count: normalized_count_preview,
106 merged_count: 0,
107 namespace,
108 elapsed_ms: inicio.elapsed().as_millis() as u64,
109 };
110 match args.format {
111 OutputFormat::Json => output::emit_json(&response)?,
112 OutputFormat::Text | OutputFormat::Markdown => {
113 output::emit_text(&format!(
114 "dry_run: {} entity names would be normalized",
115 response.normalized_count
116 ));
117 }
118 }
119 return Ok(());
120 }
121
122 let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
124
125 let mut normalized_count: usize = 0;
126 let mut merged_count: usize = 0;
127
128 for (src_id, _original_name, normalized) in &to_change {
129 let existing_id: Option<i64> = {
131 let mut stmt =
132 tx.prepare_cached("SELECT id FROM entities WHERE namespace = ?1 AND name = ?2")?;
133 match stmt.query_row(params![namespace, normalized], |r| r.get::<_, i64>(0)) {
134 Ok(id) => Some(id),
135 Err(rusqlite::Error::QueryReturnedNoRows) => None,
136 Err(e) => return Err(AppError::Database(e)),
137 }
138 };
139
140 match existing_id {
141 Some(target_id) if target_id != *src_id => {
142 tx.execute(
145 "UPDATE OR IGNORE relationships SET source_id = ?1 WHERE source_id = ?2",
146 params![target_id, src_id],
147 )?;
148 tx.execute(
149 "DELETE FROM relationships WHERE source_id = ?1",
150 params![src_id],
151 )?;
152 tx.execute(
154 "UPDATE OR IGNORE relationships SET target_id = ?1 WHERE target_id = ?2",
155 params![target_id, src_id],
156 )?;
157 tx.execute(
158 "DELETE FROM relationships WHERE target_id = ?1",
159 params![src_id],
160 )?;
161 tx.execute("DELETE FROM relationships WHERE source_id = target_id", [])?;
163 tx.execute(
165 "UPDATE OR IGNORE memory_entities SET entity_id = ?1 WHERE entity_id = ?2",
166 params![target_id, src_id],
167 )?;
168 tx.execute(
169 "DELETE FROM memory_entities WHERE entity_id = ?1",
170 params![src_id],
171 )?;
172 tx.execute("DELETE FROM entities WHERE id = ?1", params![src_id])?;
174 tx.execute(
176 "UPDATE entities
177 SET degree = (SELECT COUNT(*) FROM relationships
178 WHERE source_id = entities.id OR target_id = entities.id)
179 WHERE id = ?1",
180 params![target_id],
181 )?;
182 tracing::info!(
183 src_id = src_id,
184 target_id = target_id,
185 normalized = normalized,
186 "entity merged into existing normalized target"
187 );
188 merged_count += 1;
189 }
190 _ => {
191 tx.execute(
193 "UPDATE entities SET name = ?1, updated_at = unixepoch() WHERE id = ?2",
194 params![normalized, src_id],
195 )?;
196 tracing::info!(
197 entity_id = src_id,
198 normalized = normalized,
199 "entity name normalized"
200 );
201 normalized_count += 1;
202 }
203 }
204 }
205
206 tx.commit()?;
207 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
208
209 let response = NormalizeEntitiesResponse {
210 action: "normalized".to_string(),
211 normalized_count,
212 merged_count,
213 namespace,
214 elapsed_ms: inicio.elapsed().as_millis() as u64,
215 };
216
217 match args.format {
218 OutputFormat::Json => output::emit_json(&response)?,
219 OutputFormat::Text | OutputFormat::Markdown => {
220 output::emit_text(&format!(
221 "normalized: {} renamed, {} merged",
222 response.normalized_count, response.merged_count
223 ));
224 }
225 }
226
227 Ok(())
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::storage::connection::register_vec_extension;
234 use rusqlite::Connection;
235 use tempfile::TempDir;
236
237 type TestResult = Result<(), Box<dyn std::error::Error>>;
238
239 fn setup_db() -> Result<(TempDir, Connection), Box<dyn std::error::Error>> {
241 register_vec_extension();
242 let tmp = TempDir::new()?;
243 let db_path = tmp.path().join("test.db");
244 let mut conn = Connection::open(&db_path)?;
245 crate::migrations::runner().run(&mut conn)?;
246 Ok((tmp, conn))
247 }
248
249 fn insert_entity(conn: &Connection, name: &str) -> Result<i64, Box<dyn std::error::Error>> {
252 conn.execute(
254 "INSERT INTO entities (namespace, name, type, description) VALUES ('global', ?1, 'concept', NULL)",
255 params![name],
256 )?;
257 let id: i64 = conn.query_row(
258 "SELECT id FROM entities WHERE namespace = 'global' AND name = ?1",
259 params![name],
260 |r| r.get(0),
261 )?;
262 Ok(id)
263 }
264
265 #[test]
266 fn dry_run_returns_count_without_changes() -> TestResult {
267 let (_tmp, conn) = setup_db()?;
268 insert_entity(&conn, "Hello World")?;
269 insert_entity(&conn, "already-normalized")?;
270
271 let count: i64 = conn.query_row(
273 "SELECT COUNT(*) FROM entities WHERE name = 'Hello World' AND namespace = 'global'",
274 [],
275 |r| r.get(0),
276 )?;
277 assert_eq!(count, 1, "entity must exist before dry run");
278
279 let count_after: i64 = conn.query_row(
281 "SELECT COUNT(*) FROM entities WHERE name = 'Hello World' AND namespace = 'global'",
282 [],
283 |r| r.get(0),
284 )?;
285 assert_eq!(count_after, 1, "dry run must not rename entities");
286 Ok(())
287 }
288
289 #[test]
290 fn renames_unnormalized_entity_in_place() -> TestResult {
291 let (_tmp, conn) = setup_db()?;
292 let src_id = insert_entity(&conn, "Hello World")?;
293
294 {
296 let normalized = normalize_entity_name("Hello World");
297 let existing: Option<i64> = {
298 match conn.query_row(
299 "SELECT id FROM entities WHERE namespace = 'global' AND name = ?1",
300 params![normalized],
301 |r| r.get::<_, i64>(0),
302 ) {
303 Ok(id) => Some(id),
304 Err(rusqlite::Error::QueryReturnedNoRows) => None,
305 Err(e) => return Err(e.into()),
306 }
307 };
308 assert!(existing.is_none(), "no collision expected");
309 conn.execute(
310 "UPDATE entities SET name = ?1 WHERE id = ?2",
311 params![normalized, src_id],
312 )?;
313 }
314
315 let name: String = conn.query_row(
316 "SELECT name FROM entities WHERE id = ?1",
317 params![src_id],
318 |r| r.get(0),
319 )?;
320 assert_eq!(name, "hello-world");
321 Ok(())
322 }
323
324 #[test]
325 fn merges_into_existing_on_collision() -> TestResult {
326 let (_tmp, conn) = setup_db()?;
327 let target_id = insert_entity(&conn, "hello-world")?;
329 let src_id = insert_entity(&conn, "Hello World")?;
331
332 conn.execute(
334 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
335 VALUES ('global', ?1, ?1, 'related', 0.5)",
336 params![src_id],
337 )?;
338
339 conn.execute(
341 "UPDATE OR IGNORE relationships SET source_id = ?1 WHERE source_id = ?2",
342 params![target_id, src_id],
343 )?;
344 conn.execute(
345 "DELETE FROM relationships WHERE source_id = ?1",
346 params![src_id],
347 )?;
348 conn.execute("DELETE FROM entities WHERE id = ?1", params![src_id])?;
349
350 let src_exists: i64 = conn.query_row(
352 "SELECT COUNT(*) FROM entities WHERE id = ?1",
353 params![src_id],
354 |r| r.get(0),
355 )?;
356 assert_eq!(src_exists, 0, "source entity must be deleted after merge");
357
358 let target_name: String = conn.query_row(
360 "SELECT name FROM entities WHERE id = ?1",
361 params![target_id],
362 |r| r.get(0),
363 )?;
364 assert_eq!(target_name, "hello-world");
365 Ok(())
366 }
367
368 #[test]
369 fn normalize_entities_response_serializes_correctly() {
370 let resp = NormalizeEntitiesResponse {
371 action: "normalized".to_string(),
372 normalized_count: 3,
373 merged_count: 1,
374 namespace: "global".to_string(),
375 elapsed_ms: 42,
376 };
377 let json = serde_json::to_value(&resp).expect("serialization");
378 assert_eq!(json["action"], "normalized");
379 assert_eq!(json["normalized_count"], 3);
380 assert_eq!(json["merged_count"], 1);
381 assert_eq!(json["namespace"], "global");
382 assert!(json["elapsed_ms"].as_u64().is_some());
383 }
384
385 #[test]
386 fn dry_run_response_has_correct_action() {
387 let resp = NormalizeEntitiesResponse {
388 action: "dry_run".to_string(),
389 normalized_count: 5,
390 merged_count: 0,
391 namespace: "test".to_string(),
392 elapsed_ms: 1,
393 };
394 let json = serde_json::to_value(&resp).expect("serialization");
395 assert_eq!(json["action"], "dry_run");
396 }
397}