1use crate::entity_type::EntityType;
13use crate::errors::AppError;
14use crate::output::{self, OutputFormat};
15use crate::paths::AppPaths;
16use crate::storage::connection::open_rw;
17use rusqlite::params;
18use serde::Serialize;
19
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # Rename a single edge from 'mentions' to 'related'\n \
23 sqlite-graphrag reclassify-relation --source tokio --target axum \\\n \
24 --from-relation mentions --to-relation related\n\n \
25 # Rename every 'mentions' edge in the namespace to 'related'\n \
26 sqlite-graphrag reclassify-relation \\\n \
27 --from-relation mentions --to-relation related --batch\n\n \
28 # Dry-run to preview what would change\n \
29 sqlite-graphrag reclassify-relation \\\n \
30 --from-relation mentions --to-relation related --batch --dry-run\n\n \
31 # Batch rename only edges whose source is a 'tool' entity\n \
32 sqlite-graphrag reclassify-relation \\\n \
33 --from-relation uses --to-relation depends_on --batch \\\n \
34 --filter-source-type tool\n\n \
35 # Migrate edges stored with a LITERAL hyphenated relation (P4):\n \
36 # --from-relation normalizes 'applies-to' to 'applies_to' and never\n \
37 # matches the raw stored value; --literal-from matches it verbatim.\n \
38 sqlite-graphrag reclassify-relation \\\n \
39 --literal-from applies-to --to-relation applies_to --batch\n\n\
40NOTE:\n \
41 Single mode requires --source, --target and --from-relation (or --literal-from).\n \
42 Batch mode requires --from-relation (or --literal-from), --to-relation and --batch.\n \
43 --from-relation and --literal-from are mutually exclusive; exactly one is required.\n \
44 --filter-source-type and --filter-target-type are only effective in batch mode.")]
45pub struct ReclassifyRelationArgs {
46 #[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
48 pub source: Option<String>,
49 #[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
51 pub target: Option<String>,
52 #[arg(
56 long,
57 value_parser = crate::parsers::parse_relation,
58 value_name = "RELATION",
59 required_unless_present = "literal_from",
60 conflicts_with = "literal_from"
61 )]
62 pub from_relation: Option<String>,
63 #[arg(long, value_name = "RELATION")]
68 pub literal_from: Option<String>,
69 #[arg(long, value_parser = crate::parsers::parse_relation, value_name = "RELATION")]
71 pub to_relation: String,
72 #[arg(long, default_value_t = false)]
74 pub batch: bool,
75 #[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
77 pub filter_source_type: Option<EntityType>,
78 #[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
80 pub filter_target_type: Option<EntityType>,
81 #[arg(long, default_value_t = false)]
83 pub dry_run: bool,
84 #[arg(long)]
85 pub namespace: Option<String>,
86 #[arg(long, value_enum, default_value = "json")]
87 pub format: OutputFormat,
88 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
89 pub json: bool,
90 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
91 pub db: Option<String>,
92}
93
94#[derive(Serialize)]
95struct ReclassifyRelationResponse {
96 action: String,
97 from_relation: String,
98 to_relation: String,
99 count: usize,
101 merged_duplicates: usize,
104 namespace: String,
105 elapsed_ms: u64,
106}
107
108impl ReclassifyRelationArgs {
109 fn effective_from(&self) -> &str {
116 self.literal_from
117 .as_deref()
118 .or(self.from_relation.as_deref())
119 .unwrap_or_default()
120 }
121}
122
123pub fn run(args: ReclassifyRelationArgs) -> Result<(), AppError> {
124 let inicio = std::time::Instant::now();
125 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
126 let paths = AppPaths::resolve(args.db.as_deref())?;
127
128 crate::storage::connection::ensure_db_ready(&paths)?;
129
130 crate::parsers::warn_if_non_canonical(args.effective_from());
132 crate::parsers::warn_if_non_canonical(&args.to_relation);
133
134 if args.effective_from() == args.to_relation {
140 return Err(AppError::Validation(
141 "--from-relation/--literal-from and --to-relation must be different".to_string(),
142 ));
143 }
144
145 let mut conn = open_rw(&paths.db)?;
146
147 if args.batch {
148 run_batch(args, inicio, namespace, &mut conn)
149 } else {
150 run_single(args, inicio, namespace, &mut conn)
151 }
152}
153
154fn run_single(
159 args: ReclassifyRelationArgs,
160 inicio: std::time::Instant,
161 namespace: String,
162 conn: &mut rusqlite::Connection,
163) -> Result<(), AppError> {
164 let source_name = args.source.as_deref().ok_or_else(|| {
165 AppError::Validation(
166 "--source is required in single mode (omit --batch for single-edge rename)".to_string(),
167 )
168 })?;
169 let target_name = args
170 .target
171 .as_deref()
172 .ok_or_else(|| AppError::Validation("--target is required in single mode".to_string()))?;
173
174 let source_name_norm = crate::parsers::normalize_entity_name(source_name);
177 let target_name_norm = crate::parsers::normalize_entity_name(target_name);
178 let source_id: i64 = conn
179 .query_row(
180 "SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
181 params![source_name_norm, namespace],
182 |r| r.get(0),
183 )
184 .map_err(|_| {
185 AppError::NotFound(format!(
186 "source entity '{source_name}' not found in namespace '{namespace}'"
187 ))
188 })?;
189
190 let target_id: i64 = conn
191 .query_row(
192 "SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
193 params![target_name_norm, namespace],
194 |r| r.get(0),
195 )
196 .map_err(|_| {
197 AppError::NotFound(format!(
198 "target entity '{target_name}' not found in namespace '{namespace}'"
199 ))
200 })?;
201
202 let original_count: i64 = conn.query_row(
204 "SELECT COUNT(*) FROM relationships
205 WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
206 params![source_id, target_id, args.effective_from(), namespace],
207 |r| r.get(0),
208 )?;
209
210 if original_count == 0 {
211 return Err(AppError::NotFound(format!(
212 "edge '{source_name}' --[{}]--> '{target_name}' not found in namespace '{namespace}'",
213 args.effective_from()
214 )));
215 }
216
217 if args.dry_run {
218 emit_response(
219 &args,
220 "dry_run",
221 original_count as usize,
222 0,
223 namespace,
224 inicio,
225 )?;
226 return Ok(());
227 }
228
229 let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
230
231 let updated = tx.execute(
232 "UPDATE OR IGNORE relationships
233 SET relation = ?1
234 WHERE source_id = ?2 AND target_id = ?3 AND relation = ?4 AND namespace = ?5",
235 params![
236 args.to_relation,
237 source_id,
238 target_id,
239 args.effective_from(),
240 namespace
241 ],
242 )?;
243
244 let deleted = tx.execute(
246 "DELETE FROM relationships
247 WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
248 params![source_id, target_id, args.effective_from(), namespace],
249 )?;
250
251 tx.commit()?;
252
253 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
254
255 let merged = (original_count as usize).saturating_sub(updated + deleted);
256 emit_response(&args, "reclassified", updated, merged, namespace, inicio)
257}
258
259fn run_batch(
264 args: ReclassifyRelationArgs,
265 inicio: std::time::Instant,
266 namespace: String,
267 conn: &mut rusqlite::Connection,
268) -> Result<(), AppError> {
269 let source_filter = args
272 .filter_source_type
273 .map(|t| format!(" AND src.type = '{}'", t.as_str()))
274 .unwrap_or_default();
275 let target_filter = args
276 .filter_target_type
277 .map(|t| format!(" AND tgt.type = '{}'", t.as_str()))
278 .unwrap_or_default();
279 let has_filters = !source_filter.is_empty() || !target_filter.is_empty();
280
281 let original_count: i64 = if has_filters {
283 conn.query_row(
284 &format!(
285 "SELECT COUNT(*) FROM relationships r
286 JOIN entities src ON src.id = r.source_id
287 JOIN entities tgt ON tgt.id = r.target_id
288 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
289 ),
290 params![args.effective_from(), namespace],
291 |r| r.get(0),
292 )?
293 } else {
294 conn.query_row(
295 "SELECT COUNT(*) FROM relationships
296 WHERE relation = ?1 AND namespace = ?2",
297 params![args.effective_from(), namespace],
298 |r| r.get(0),
299 )?
300 };
301
302 if original_count == 0 {
303 tracing::warn!(target: "reclassify_relation",
304 from_relation = %args.effective_from(),
305 namespace = %namespace,
306 "reclassify-relation batch matched zero edges — verify --from-relation value"
307 );
308 }
309
310 if args.dry_run {
311 emit_response(
312 &args,
313 "dry_run",
314 original_count as usize,
315 0,
316 namespace,
317 inicio,
318 )?;
319 return Ok(());
320 }
321
322 let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
323
324 let updated = if has_filters {
325 let ids: Vec<i64> = {
327 let mut stmt = tx.prepare(&format!(
328 "SELECT r.id FROM relationships r
329 JOIN entities src ON src.id = r.source_id
330 JOIN entities tgt ON tgt.id = r.target_id
331 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
332 ))?;
333 let collected: Vec<i64> = stmt
334 .query_map(params![args.effective_from(), namespace], |r| r.get(0))?
335 .collect::<Result<Vec<_>, _>>()?;
336 collected
337 };
338
339 let mut moved: usize = 0;
340 for id in &ids {
341 let n = tx.execute(
342 "UPDATE OR IGNORE relationships
343 SET relation = ?1
344 WHERE id = ?2",
345 params![args.to_relation, id],
346 )?;
347 moved += n;
348 }
349 moved
350 } else {
351 tx.execute(
352 "UPDATE OR IGNORE relationships
353 SET relation = ?1
354 WHERE relation = ?2 AND namespace = ?3",
355 params![args.to_relation, args.effective_from(), namespace],
356 )?
357 };
358
359 let deleted = if has_filters {
361 tx.execute(
362 &format!(
363 "DELETE FROM relationships WHERE id IN (
364 SELECT r.id FROM relationships r
365 JOIN entities src ON src.id = r.source_id
366 JOIN entities tgt ON tgt.id = r.target_id
367 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}
368 )"
369 ),
370 params![args.effective_from(), namespace],
371 )?
372 } else {
373 tx.execute(
374 "DELETE FROM relationships WHERE relation = ?1 AND namespace = ?2",
375 params![args.effective_from(), namespace],
376 )?
377 };
378
379 tx.commit()?;
380
381 conn.execute_batch("ANALYZE relationships;")?;
382 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
383
384 let merged = (original_count as usize).saturating_sub(updated + deleted);
385 emit_response(&args, "reclassified", updated, merged, namespace, inicio)
386}
387
388fn emit_response(
393 args: &ReclassifyRelationArgs,
394 action: &str,
395 count: usize,
396 merged_duplicates: usize,
397 namespace: String,
398 inicio: std::time::Instant,
399) -> Result<(), AppError> {
400 let response = ReclassifyRelationResponse {
401 action: action.to_string(),
402 from_relation: args.effective_from().to_string(),
403 to_relation: args.to_relation.clone(),
404 count,
405 merged_duplicates,
406 namespace: namespace.clone(),
407 elapsed_ms: inicio.elapsed().as_millis() as u64,
408 };
409
410 match args.format {
411 OutputFormat::Json => output::emit_json(&response)?,
412 OutputFormat::Text | OutputFormat::Markdown => {
413 output::emit_text(&format!(
414 "{action}: {count} edges '{}' → '{}' [{namespace}] (duplicates merged: {merged_duplicates})",
415 args.effective_from(), args.to_relation
416 ));
417 }
418 }
419 Ok(())
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 fn make_response(action: &str, count: usize, merged: usize) -> ReclassifyRelationResponse {
427 ReclassifyRelationResponse {
428 action: action.to_string(),
429 from_relation: "mentions".to_string(),
430 to_relation: "related".to_string(),
431 count,
432 merged_duplicates: merged,
433 namespace: "global".to_string(),
434 elapsed_ms: 1,
435 }
436 }
437
438 #[test]
439 fn response_serializes_all_fields() {
440 let resp = make_response("reclassified", 5, 0);
441 let json = serde_json::to_value(&resp).expect("serialization failed");
442 assert_eq!(json["action"], "reclassified");
443 assert_eq!(json["from_relation"], "mentions");
444 assert_eq!(json["to_relation"], "related");
445 assert_eq!(json["count"], 5);
446 assert_eq!(json["merged_duplicates"], 0);
447 assert_eq!(json["namespace"], "global");
448 assert!(json["elapsed_ms"].is_number());
449 }
450
451 #[test]
452 fn response_action_dry_run() {
453 let resp = make_response("dry_run", 10, 0);
454 let json = serde_json::to_value(&resp).expect("serialization failed");
455 assert_eq!(json["action"], "dry_run");
456 assert_eq!(json["count"], 10);
457 assert_eq!(json["merged_duplicates"], 0);
458 }
459
460 #[test]
461 fn response_merged_duplicates_nonzero() {
462 let resp = make_response("reclassified", 7, 3);
464 let json = serde_json::to_value(&resp).expect("serialization failed");
465 assert_eq!(json["count"], 7);
466 assert_eq!(json["merged_duplicates"], 3);
467 }
468
469 #[test]
470 fn response_count_zero_when_nothing_matched() {
471 let resp = make_response("reclassified", 0, 0);
472 let json = serde_json::to_value(&resp).expect("serialization failed");
473 assert_eq!(json["count"], 0);
474 assert_eq!(json["merged_duplicates"], 0);
475 }
476
477 #[test]
478 fn response_action_values_exhaustive() {
479 for action in &["reclassified", "dry_run"] {
480 let resp = make_response(action, 1, 0);
481 let json = serde_json::to_value(&resp).expect("serialization");
482 assert_eq!(json["action"], *action);
483 }
484 }
485
486 #[test]
487 fn response_from_and_to_relation_present() {
488 let resp = ReclassifyRelationResponse {
489 action: "reclassified".to_string(),
490 from_relation: "uses".to_string(),
491 to_relation: "depends_on".to_string(),
492 count: 3,
493 merged_duplicates: 1,
494 namespace: "my-project".to_string(),
495 elapsed_ms: 5,
496 };
497 let json = serde_json::to_value(&resp).expect("serialization failed");
498 assert_eq!(json["from_relation"], "uses");
499 assert_eq!(json["to_relation"], "depends_on");
500 }
501
502 #[test]
503 fn same_relation_value_rejected_at_logic_level() {
504 let from = "mentions".to_string();
507 let to = "mentions".to_string();
508 assert!(
509 from == to,
510 "same-value rename must be caught before DB access"
511 );
512 }
513
514 fn base_args() -> ReclassifyRelationArgs {
519 ReclassifyRelationArgs {
520 source: None,
521 target: None,
522 from_relation: None,
523 literal_from: None,
524 to_relation: "applies_to".to_string(),
525 batch: true,
526 filter_source_type: None,
527 filter_target_type: None,
528 dry_run: false,
529 namespace: Some("global".to_string()),
530 format: OutputFormat::Json,
531 json: true,
532 db: None,
533 }
534 }
535
536 #[test]
537 fn effective_from_prefers_literal_and_falls_back_to_normalized() {
538 let mut args = base_args();
539 args.from_relation = Some("applies_to".to_string());
540 assert_eq!(args.effective_from(), "applies_to");
541
542 args.literal_from = Some("applies-to".to_string());
543 assert_eq!(
544 args.effective_from(),
545 "applies-to",
546 "literal value must win and stay verbatim"
547 );
548
549 assert_ne!(args.effective_from(), args.to_relation);
551 }
552
553 fn setup_migrated_db() -> (tempfile::TempDir, rusqlite::Connection) {
554 crate::storage::connection::register_vec_extension();
555 let tmp = tempfile::TempDir::new().expect("tempdir");
556 let db_path = tmp.path().join("test.db");
557 let mut conn = rusqlite::Connection::open(&db_path).expect("open");
558 crate::migrations::runner().run(&mut conn).expect("migrate");
559 (tmp, conn)
560 }
561
562 #[test]
563 fn literal_from_migrates_hyphenated_edge_unreachable_by_normalized_filter() {
564 let (_tmp, mut conn) = setup_migrated_db();
565 conn.execute(
566 "INSERT INTO entities (namespace, name, type) VALUES ('global','ent-a','concept')",
567 [],
568 )
569 .unwrap();
570 let a = conn.last_insert_rowid();
571 conn.execute(
572 "INSERT INTO entities (namespace, name, type) VALUES ('global','ent-b','concept')",
573 [],
574 )
575 .unwrap();
576 let b = conn.last_insert_rowid();
577 conn.execute(
580 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight) \
581 VALUES ('global', ?1, ?2, 'applies-to', 0.5)",
582 params![a, b],
583 )
584 .unwrap();
585
586 let mut args = base_args();
587 args.literal_from = Some("applies-to".to_string());
588 run_batch(
589 args,
590 std::time::Instant::now(),
591 "global".to_string(),
592 &mut conn,
593 )
594 .expect("batch literal migration");
595
596 let migrated: i64 = conn
597 .query_row(
598 "SELECT COUNT(*) FROM relationships WHERE relation = 'applies_to'",
599 [],
600 |r| r.get(0),
601 )
602 .unwrap();
603 assert_eq!(migrated, 1, "hyphenated edge must be migrated");
604 let leftover: i64 = conn
605 .query_row(
606 "SELECT COUNT(*) FROM relationships WHERE relation = 'applies-to'",
607 [],
608 |r| r.get(0),
609 )
610 .unwrap();
611 assert_eq!(leftover, 0, "no literal edge may remain");
612 }
613
614 #[test]
615 fn cli_rejects_literal_from_combined_with_from_relation() {
616 use clap::Parser;
617 let err = match crate::cli::Cli::try_parse_from([
618 "sqlite-graphrag",
619 "reclassify-relation",
620 "--from-relation",
621 "mentions",
622 "--literal-from",
623 "applies-to",
624 "--to-relation",
625 "related",
626 "--batch",
627 ]) {
628 Err(e) => e,
629 Ok(_) => panic!("mutually exclusive flags must fail to parse"),
630 };
631 assert_eq!(err.kind(), clap::error::ErrorKind::ArgumentConflict);
632 }
633
634 #[test]
635 fn cli_requires_one_of_from_relation_or_literal_from() {
636 use clap::Parser;
637 let err = match crate::cli::Cli::try_parse_from([
638 "sqlite-graphrag",
639 "reclassify-relation",
640 "--to-relation",
641 "related",
642 "--batch",
643 ]) {
644 Err(e) => e,
645 Ok(_) => panic!("one of the from flags is required"),
646 };
647 assert_eq!(err.kind(), clap::error::ErrorKind::MissingRequiredArgument);
648 }
649
650 #[test]
651 fn cli_accepts_literal_from_alone_and_keeps_it_verbatim() {
652 use clap::Parser;
653 let parsed = crate::cli::Cli::try_parse_from([
654 "sqlite-graphrag",
655 "reclassify-relation",
656 "--literal-from",
657 "applies-to",
658 "--to-relation",
659 "applies_to",
660 "--batch",
661 ])
662 .expect("literal-from alone must parse");
663 match parsed.command {
664 Some(crate::cli::Commands::ReclassifyRelation(a)) => {
665 assert_eq!(a.literal_from.as_deref(), Some("applies-to"));
666 assert!(a.from_relation.is_none());
667 assert_eq!(a.effective_from(), "applies-to");
668 }
669 _ => unreachable!("unexpected command"),
670 }
671 }
672}