Skip to main content

sqlite_graphrag/commands/
reclassify_relation.rs

1//! Handler for the `reclassify-relation` CLI subcommand (GAP-13).
2//!
3//! Renames a relation type in the `relationships` table — either a single
4//! directed edge (`--source`, `--target`, `--from-relation`) or every edge of
5//! a given type in the namespace (`--batch`).
6//!
7//! When the rename would produce a duplicate `(source_id, target_id, relation)`
8//! triple, `UPDATE OR IGNORE` skips the conflicting row and the subsequent
9//! `DELETE` removes it; the count of such skipped rows is reported as
10//! `merged_duplicates`.
11
12use crate::entity_type::EntityType;
13use crate::errors::AppError;
14use crate::output::{self, OutputFormat};
15use crate::paths::AppPaths;
16use crate::storage::connection::open_rw;
17use rusqlite::params;
18use serde::Serialize;
19
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n  \
22    # Rename a single edge from 'mentions' to 'related'\n  \
23    sqlite-graphrag reclassify-relation --source tokio --target axum \\\n  \
24        --from-relation mentions --to-relation related\n\n  \
25    # Rename every 'mentions' edge in the namespace to 'related'\n  \
26    sqlite-graphrag reclassify-relation \\\n  \
27        --from-relation mentions --to-relation related --batch\n\n  \
28    # Dry-run to preview what would change\n  \
29    sqlite-graphrag reclassify-relation \\\n  \
30        --from-relation mentions --to-relation related --batch --dry-run\n\n  \
31    # Batch rename only edges whose source is a 'tool' entity\n  \
32    sqlite-graphrag reclassify-relation \\\n  \
33        --from-relation uses --to-relation depends_on --batch \\\n  \
34        --filter-source-type tool\n\n  \
35    # Migrate edges stored with a LITERAL hyphenated relation (P4):\n  \
36    # --from-relation normalizes 'applies-to' to 'applies_to' and never\n  \
37    # matches the raw stored value; --literal-from matches it verbatim.\n  \
38    sqlite-graphrag reclassify-relation \\\n  \
39        --literal-from applies-to --to-relation applies_to --batch\n\n\
40NOTE:\n  \
41    Single mode requires --source, --target and --from-relation (or --literal-from).\n  \
42    Batch mode requires --from-relation (or --literal-from), --to-relation and --batch.\n  \
43    --from-relation and --literal-from are mutually exclusive; exactly one is required.\n  \
44    --filter-source-type and --filter-target-type are only effective in batch mode.")]
45pub struct ReclassifyRelationArgs {
46    /// Source entity name (single mode). Mutually exclusive with --batch.
47    #[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
48    pub source: Option<String>,
49    /// Target entity name (single mode). Mutually exclusive with --batch.
50    #[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
51    pub target: Option<String>,
52    /// Current relation type to rename (normalized: hyphens become
53    /// underscores at the CLI boundary). Required in both single and batch
54    /// modes unless --literal-from is given.
55    #[arg(
56        long,
57        value_parser = crate::parsers::parse_relation,
58        value_name = "RELATION",
59        required_unless_present = "literal_from",
60        conflicts_with = "literal_from"
61    )]
62    pub from_relation: Option<String>,
63    /// v1.1.1 (P4): current relation type to rename, matched LITERALLY —
64    /// no normalization is applied, so edges stored with hyphenated values
65    /// (e.g. `applies-to`) become reachable. Mutually exclusive with
66    /// --from-relation.
67    #[arg(long, value_name = "RELATION")]
68    pub literal_from: Option<String>,
69    /// New relation type to assign. Required in both single and batch modes.
70    #[arg(long, value_parser = crate::parsers::parse_relation, value_name = "RELATION")]
71    pub to_relation: String,
72    /// Enable batch reclassification of all edges with --from-relation. Requires --from-relation and --to-relation.
73    #[arg(long, default_value_t = false)]
74    pub batch: bool,
75    /// Filter batch: only rename edges whose source entity has this type.
76    #[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
77    pub filter_source_type: Option<EntityType>,
78    /// Filter batch: only rename edges whose target entity has this type.
79    #[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
80    pub filter_target_type: Option<EntityType>,
81    /// Preview count without committing changes.
82    #[arg(long, default_value_t = false)]
83    pub dry_run: bool,
84    #[arg(long)]
85    pub namespace: Option<String>,
86    #[arg(long, value_enum, default_value = "json")]
87    pub format: OutputFormat,
88    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
89    pub json: bool,
90    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
91    pub db: Option<String>,
92}
93
94#[derive(Serialize)]
95struct ReclassifyRelationResponse {
96    action: String,
97    from_relation: String,
98    to_relation: String,
99    /// Number of edges successfully renamed.
100    count: usize,
101    /// Edges that collided with an existing (source, target, to_relation) triple
102    /// and were removed rather than renamed (UPDATE OR IGNORE + DELETE pattern).
103    merged_duplicates: usize,
104    namespace: String,
105    elapsed_ms: u64,
106}
107
108impl ReclassifyRelationArgs {
109    /// v1.1.1 (P4): the relation value used in every WHERE clause.
110    ///
111    /// `--literal-from` wins and is matched VERBATIM (no normalization);
112    /// otherwise the clap-normalized `--from-relation` applies. Clap
113    /// guarantees exactly one of the two is present
114    /// (`required_unless_present` + `conflicts_with`).
115    fn effective_from(&self) -> &str {
116        self.literal_from
117            .as_deref()
118            .or(self.from_relation.as_deref())
119            .unwrap_or_default()
120    }
121}
122
123pub fn run(args: ReclassifyRelationArgs) -> Result<(), AppError> {
124    let inicio = std::time::Instant::now();
125    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
126    let paths = AppPaths::resolve(args.db.as_deref())?;
127
128    crate::storage::connection::ensure_db_ready(&paths)?;
129
130    // Emit warnings for non-canonical relation values.
131    crate::parsers::warn_if_non_canonical(args.effective_from());
132    crate::parsers::warn_if_non_canonical(&args.to_relation);
133
134    // Reject same-value renames: nothing to do and would silently remove
135    // duplicates. The comparison uses the EFFECTIVE from value, so migrating
136    // a literal hyphenated relation onto its normalized form (e.g.
137    // `--literal-from applies-to --to-relation applies_to`) is a VALID
138    // migration, not an equality.
139    if args.effective_from() == args.to_relation {
140        return Err(AppError::Validation(
141            "--from-relation/--literal-from and --to-relation must be different".to_string(),
142        ));
143    }
144
145    let mut conn = open_rw(&paths.db)?;
146
147    if args.batch {
148        run_batch(args, inicio, namespace, &mut conn)
149    } else {
150        run_single(args, inicio, namespace, &mut conn)
151    }
152}
153
154// ---------------------------------------------------------------------------
155// Single mode
156// ---------------------------------------------------------------------------
157
158fn run_single(
159    args: ReclassifyRelationArgs,
160    inicio: std::time::Instant,
161    namespace: String,
162    conn: &mut rusqlite::Connection,
163) -> Result<(), AppError> {
164    let source_name = args.source.as_deref().ok_or_else(|| {
165        AppError::Validation(
166            "--source is required in single mode (omit --batch for single-edge rename)".to_string(),
167        )
168    })?;
169    let target_name = args
170        .target
171        .as_deref()
172        .ok_or_else(|| AppError::Validation("--target is required in single mode".to_string()))?;
173
174    // Resolve entity IDs — fail fast if either side does not exist.
175    // Normalize names to match the normalized stored entity names.
176    let source_name_norm = crate::parsers::normalize_entity_name(source_name);
177    let target_name_norm = crate::parsers::normalize_entity_name(target_name);
178    let source_id: i64 = conn
179        .query_row(
180            "SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
181            params![source_name_norm, namespace],
182            |r| r.get(0),
183        )
184        .map_err(|_| {
185            AppError::NotFound(format!(
186                "source entity '{source_name}' not found in namespace '{namespace}'"
187            ))
188        })?;
189
190    let target_id: i64 = conn
191        .query_row(
192            "SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
193            params![target_name_norm, namespace],
194            |r| r.get(0),
195        )
196        .map_err(|_| {
197            AppError::NotFound(format!(
198                "target entity '{target_name}' not found in namespace '{namespace}'"
199            ))
200        })?;
201
202    // Verify the edge to rename exists.
203    let original_count: i64 = conn.query_row(
204        "SELECT COUNT(*) FROM relationships
205         WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
206        params![source_id, target_id, args.effective_from(), namespace],
207        |r| r.get(0),
208    )?;
209
210    if original_count == 0 {
211        return Err(AppError::NotFound(format!(
212            "edge '{source_name}' --[{}]--> '{target_name}' not found in namespace '{namespace}'",
213            args.effective_from()
214        )));
215    }
216
217    if args.dry_run {
218        emit_response(
219            &args,
220            "dry_run",
221            original_count as usize,
222            0,
223            namespace,
224            inicio,
225        )?;
226        return Ok(());
227    }
228
229    let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
230
231    let updated = tx.execute(
232        "UPDATE OR IGNORE relationships
233         SET relation = ?1
234         WHERE source_id = ?2 AND target_id = ?3 AND relation = ?4 AND namespace = ?5",
235        params![
236            args.to_relation,
237            source_id,
238            target_id,
239            args.effective_from(),
240            namespace
241        ],
242    )?;
243
244    // Remove rows that UPDATE OR IGNORE silently skipped due to UNIQUE collision.
245    let deleted = tx.execute(
246        "DELETE FROM relationships
247         WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
248        params![source_id, target_id, args.effective_from(), namespace],
249    )?;
250
251    tx.commit()?;
252
253    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
254
255    let merged = (original_count as usize).saturating_sub(updated + deleted);
256    emit_response(&args, "reclassified", updated, merged, namespace, inicio)
257}
258
259// ---------------------------------------------------------------------------
260// Batch mode
261// ---------------------------------------------------------------------------
262
263fn run_batch(
264    args: ReclassifyRelationArgs,
265    inicio: std::time::Instant,
266    namespace: String,
267    conn: &mut rusqlite::Connection,
268) -> Result<(), AppError> {
269    // Build WHERE clause extensions for optional entity-type filters.
270    // The base query joins relationships with source/target entities.
271    let source_filter = args
272        .filter_source_type
273        .map(|t| format!(" AND src.type = '{}'", t.as_str()))
274        .unwrap_or_default();
275    let target_filter = args
276        .filter_target_type
277        .map(|t| format!(" AND tgt.type = '{}'", t.as_str()))
278        .unwrap_or_default();
279    let has_filters = !source_filter.is_empty() || !target_filter.is_empty();
280
281    // Count edges that would be affected (used for both dry-run and confirmation).
282    let original_count: i64 = if has_filters {
283        conn.query_row(
284            &format!(
285                "SELECT COUNT(*) FROM relationships r
286                 JOIN entities src ON src.id = r.source_id
287                 JOIN entities tgt ON tgt.id = r.target_id
288                 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
289            ),
290            params![args.effective_from(), namespace],
291            |r| r.get(0),
292        )?
293    } else {
294        conn.query_row(
295            "SELECT COUNT(*) FROM relationships
296             WHERE relation = ?1 AND namespace = ?2",
297            params![args.effective_from(), namespace],
298            |r| r.get(0),
299        )?
300    };
301
302    if original_count == 0 {
303        tracing::warn!(target: "reclassify_relation",
304            from_relation = %args.effective_from(),
305            namespace = %namespace,
306            "reclassify-relation batch matched zero edges — verify --from-relation value"
307        );
308    }
309
310    if args.dry_run {
311        emit_response(
312            &args,
313            "dry_run",
314            original_count as usize,
315            0,
316            namespace,
317            inicio,
318        )?;
319        return Ok(());
320    }
321
322    let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
323
324    let updated = if has_filters {
325        // For filtered batch we need to collect IDs first, then update.
326        let ids: Vec<i64> = {
327            let mut stmt = tx.prepare(&format!(
328                "SELECT r.id FROM relationships r
329                 JOIN entities src ON src.id = r.source_id
330                 JOIN entities tgt ON tgt.id = r.target_id
331                 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
332            ))?;
333            let collected: Vec<i64> = stmt
334                .query_map(params![args.effective_from(), namespace], |r| r.get(0))?
335                .collect::<Result<Vec<_>, _>>()?;
336            collected
337        };
338
339        let mut moved: usize = 0;
340        for id in &ids {
341            let n = tx.execute(
342                "UPDATE OR IGNORE relationships
343                 SET relation = ?1
344                 WHERE id = ?2",
345                params![args.to_relation, id],
346            )?;
347            moved += n;
348        }
349        moved
350    } else {
351        tx.execute(
352            "UPDATE OR IGNORE relationships
353             SET relation = ?1
354             WHERE relation = ?2 AND namespace = ?3",
355            params![args.to_relation, args.effective_from(), namespace],
356        )?
357    };
358
359    // Remove rows the UPDATE OR IGNORE left behind (UNIQUE collision survivors).
360    let deleted = if has_filters {
361        tx.execute(
362            &format!(
363                "DELETE FROM relationships WHERE id IN (
364                     SELECT r.id FROM relationships r
365                     JOIN entities src ON src.id = r.source_id
366                     JOIN entities tgt ON tgt.id = r.target_id
367                     WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}
368                 )"
369            ),
370            params![args.effective_from(), namespace],
371        )?
372    } else {
373        tx.execute(
374            "DELETE FROM relationships WHERE relation = ?1 AND namespace = ?2",
375            params![args.effective_from(), namespace],
376        )?
377    };
378
379    tx.commit()?;
380
381    conn.execute_batch("ANALYZE relationships;")?;
382    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
383
384    let merged = (original_count as usize).saturating_sub(updated + deleted);
385    emit_response(&args, "reclassified", updated, merged, namespace, inicio)
386}
387
388// ---------------------------------------------------------------------------
389// Shared response emitter
390// ---------------------------------------------------------------------------
391
392fn emit_response(
393    args: &ReclassifyRelationArgs,
394    action: &str,
395    count: usize,
396    merged_duplicates: usize,
397    namespace: String,
398    inicio: std::time::Instant,
399) -> Result<(), AppError> {
400    let response = ReclassifyRelationResponse {
401        action: action.to_string(),
402        from_relation: args.effective_from().to_string(),
403        to_relation: args.to_relation.clone(),
404        count,
405        merged_duplicates,
406        namespace: namespace.clone(),
407        elapsed_ms: inicio.elapsed().as_millis() as u64,
408    };
409
410    match args.format {
411        OutputFormat::Json => output::emit_json(&response)?,
412        OutputFormat::Text | OutputFormat::Markdown => {
413            output::emit_text(&format!(
414                "{action}: {count} edges '{}' → '{}' [{namespace}] (duplicates merged: {merged_duplicates})",
415                args.effective_from(), args.to_relation
416            ));
417        }
418    }
419    Ok(())
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    fn make_response(action: &str, count: usize, merged: usize) -> ReclassifyRelationResponse {
427        ReclassifyRelationResponse {
428            action: action.to_string(),
429            from_relation: "mentions".to_string(),
430            to_relation: "related".to_string(),
431            count,
432            merged_duplicates: merged,
433            namespace: "global".to_string(),
434            elapsed_ms: 1,
435        }
436    }
437
438    #[test]
439    fn response_serializes_all_fields() {
440        let resp = make_response("reclassified", 5, 0);
441        let json = serde_json::to_value(&resp).expect("serialization failed");
442        assert_eq!(json["action"], "reclassified");
443        assert_eq!(json["from_relation"], "mentions");
444        assert_eq!(json["to_relation"], "related");
445        assert_eq!(json["count"], 5);
446        assert_eq!(json["merged_duplicates"], 0);
447        assert_eq!(json["namespace"], "global");
448        assert!(json["elapsed_ms"].is_number());
449    }
450
451    #[test]
452    fn response_action_dry_run() {
453        let resp = make_response("dry_run", 10, 0);
454        let json = serde_json::to_value(&resp).expect("serialization failed");
455        assert_eq!(json["action"], "dry_run");
456        assert_eq!(json["count"], 10);
457        assert_eq!(json["merged_duplicates"], 0);
458    }
459
460    #[test]
461    fn response_merged_duplicates_nonzero() {
462        // Simulates a case where 3 out of 10 edges collided with existing rows.
463        let resp = make_response("reclassified", 7, 3);
464        let json = serde_json::to_value(&resp).expect("serialization failed");
465        assert_eq!(json["count"], 7);
466        assert_eq!(json["merged_duplicates"], 3);
467    }
468
469    #[test]
470    fn response_count_zero_when_nothing_matched() {
471        let resp = make_response("reclassified", 0, 0);
472        let json = serde_json::to_value(&resp).expect("serialization failed");
473        assert_eq!(json["count"], 0);
474        assert_eq!(json["merged_duplicates"], 0);
475    }
476
477    #[test]
478    fn response_action_values_exhaustive() {
479        for action in &["reclassified", "dry_run"] {
480            let resp = make_response(action, 1, 0);
481            let json = serde_json::to_value(&resp).expect("serialization");
482            assert_eq!(json["action"], *action);
483        }
484    }
485
486    #[test]
487    fn response_from_and_to_relation_present() {
488        let resp = ReclassifyRelationResponse {
489            action: "reclassified".to_string(),
490            from_relation: "uses".to_string(),
491            to_relation: "depends_on".to_string(),
492            count: 3,
493            merged_duplicates: 1,
494            namespace: "my-project".to_string(),
495            elapsed_ms: 5,
496        };
497        let json = serde_json::to_value(&resp).expect("serialization failed");
498        assert_eq!(json["from_relation"], "uses");
499        assert_eq!(json["to_relation"], "depends_on");
500    }
501
502    #[test]
503    fn same_relation_value_rejected_at_logic_level() {
504        // Validates that the guard in run() would catch from == to.
505        // We test the condition directly since we cannot call run() without a DB.
506        let from = "mentions".to_string();
507        let to = "mentions".to_string();
508        assert!(
509            from == to,
510            "same-value rename must be caught before DB access"
511        );
512    }
513
514    // -----------------------------------------------------------------------
515    // v1.1.1 (P4): --literal-from — filtro sem normalização
516    // -----------------------------------------------------------------------
517
518    fn base_args() -> ReclassifyRelationArgs {
519        ReclassifyRelationArgs {
520            source: None,
521            target: None,
522            from_relation: None,
523            literal_from: None,
524            to_relation: "applies_to".to_string(),
525            batch: true,
526            filter_source_type: None,
527            filter_target_type: None,
528            dry_run: false,
529            namespace: Some("global".to_string()),
530            format: OutputFormat::Json,
531            json: true,
532            db: None,
533        }
534    }
535
536    #[test]
537    fn effective_from_prefers_literal_and_falls_back_to_normalized() {
538        let mut args = base_args();
539        args.from_relation = Some("applies_to".to_string());
540        assert_eq!(args.effective_from(), "applies_to");
541
542        args.literal_from = Some("applies-to".to_string());
543        assert_eq!(
544            args.effective_from(),
545            "applies-to",
546            "literal value must win and stay verbatim"
547        );
548
549        // Migração literal→normalizado é VÁLIDA (não é igualdade).
550        assert_ne!(args.effective_from(), args.to_relation);
551    }
552
553    fn setup_migrated_db() -> (tempfile::TempDir, rusqlite::Connection) {
554        crate::storage::connection::register_vec_extension();
555        let tmp = tempfile::TempDir::new().expect("tempdir");
556        let db_path = tmp.path().join("test.db");
557        let mut conn = rusqlite::Connection::open(&db_path).expect("open");
558        crate::migrations::runner().run(&mut conn).expect("migrate");
559        (tmp, conn)
560    }
561
562    #[test]
563    fn literal_from_migrates_hyphenated_edge_unreachable_by_normalized_filter() {
564        let (_tmp, mut conn) = setup_migrated_db();
565        conn.execute(
566            "INSERT INTO entities (namespace, name, type) VALUES ('global','ent-a','concept')",
567            [],
568        )
569        .unwrap();
570        let a = conn.last_insert_rowid();
571        conn.execute(
572            "INSERT INTO entities (namespace, name, type) VALUES ('global','ent-b','concept')",
573            [],
574        )
575        .unwrap();
576        let b = conn.last_insert_rowid();
577        // Aresta gravada com o valor LITERAL com hífen — inalcançável pelo
578        // --from-relation (que normaliza para 'applies_to' na borda clap).
579        conn.execute(
580            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight) \
581             VALUES ('global', ?1, ?2, 'applies-to', 0.5)",
582            params![a, b],
583        )
584        .unwrap();
585
586        let mut args = base_args();
587        args.literal_from = Some("applies-to".to_string());
588        run_batch(
589            args,
590            std::time::Instant::now(),
591            "global".to_string(),
592            &mut conn,
593        )
594        .expect("batch literal migration");
595
596        let migrated: i64 = conn
597            .query_row(
598                "SELECT COUNT(*) FROM relationships WHERE relation = 'applies_to'",
599                [],
600                |r| r.get(0),
601            )
602            .unwrap();
603        assert_eq!(migrated, 1, "hyphenated edge must be migrated");
604        let leftover: i64 = conn
605            .query_row(
606                "SELECT COUNT(*) FROM relationships WHERE relation = 'applies-to'",
607                [],
608                |r| r.get(0),
609            )
610            .unwrap();
611        assert_eq!(leftover, 0, "no literal edge may remain");
612    }
613
614    #[test]
615    fn cli_rejects_literal_from_combined_with_from_relation() {
616        use clap::Parser;
617        let err = match crate::cli::Cli::try_parse_from([
618            "sqlite-graphrag",
619            "reclassify-relation",
620            "--from-relation",
621            "mentions",
622            "--literal-from",
623            "applies-to",
624            "--to-relation",
625            "related",
626            "--batch",
627        ]) {
628            Err(e) => e,
629            Ok(_) => panic!("mutually exclusive flags must fail to parse"),
630        };
631        assert_eq!(err.kind(), clap::error::ErrorKind::ArgumentConflict);
632    }
633
634    #[test]
635    fn cli_requires_one_of_from_relation_or_literal_from() {
636        use clap::Parser;
637        let err = match crate::cli::Cli::try_parse_from([
638            "sqlite-graphrag",
639            "reclassify-relation",
640            "--to-relation",
641            "related",
642            "--batch",
643        ]) {
644            Err(e) => e,
645            Ok(_) => panic!("one of the from flags is required"),
646        };
647        assert_eq!(err.kind(), clap::error::ErrorKind::MissingRequiredArgument);
648    }
649
650    #[test]
651    fn cli_accepts_literal_from_alone_and_keeps_it_verbatim() {
652        use clap::Parser;
653        let parsed = crate::cli::Cli::try_parse_from([
654            "sqlite-graphrag",
655            "reclassify-relation",
656            "--literal-from",
657            "applies-to",
658            "--to-relation",
659            "applies_to",
660            "--batch",
661        ])
662        .expect("literal-from alone must parse");
663        match parsed.command {
664            Some(crate::cli::Commands::ReclassifyRelation(a)) => {
665                assert_eq!(a.literal_from.as_deref(), Some("applies-to"));
666                assert!(a.from_relation.is_none());
667                assert_eq!(a.effective_from(), "applies-to");
668            }
669            _ => unreachable!("unexpected command"),
670        }
671    }
672}