Skip to main content

sqlite_graphrag/commands/
reclassify_relation.rs

1//! Handler for the `reclassify-relation` CLI subcommand (GAP-13).
2//!
3//! Renames a relation type in the `relationships` table — either a single
4//! directed edge (`--source`, `--target`, `--from-relation`) or every edge of
5//! a given type in the namespace (`--batch`).
6//!
7//! When the rename would produce a duplicate `(source_id, target_id, relation)`
8//! triple, `UPDATE OR IGNORE` skips the conflicting row and the subsequent
9//! `DELETE` removes it; the count of such skipped rows is reported as
10//! `merged_duplicates`.
11
12use crate::entity_type::EntityType;
13use crate::errors::AppError;
14use crate::output::{self, OutputFormat};
15use crate::paths::AppPaths;
16use crate::storage::connection::open_rw;
17use rusqlite::params;
18use serde::Serialize;
19
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n  \
22    # Rename a single edge from 'mentions' to 'related'\n  \
23    sqlite-graphrag reclassify-relation --source tokio --target axum \\\n  \
24        --from-relation mentions --to-relation related\n\n  \
25    # Rename every 'mentions' edge in the namespace to 'related'\n  \
26    sqlite-graphrag reclassify-relation \\\n  \
27        --from-relation mentions --to-relation related --batch\n\n  \
28    # Dry-run to preview what would change\n  \
29    sqlite-graphrag reclassify-relation \\\n  \
30        --from-relation mentions --to-relation related --batch --dry-run\n\n  \
31    # Batch rename only edges whose source is a 'tool' entity\n  \
32    sqlite-graphrag reclassify-relation \\\n  \
33        --from-relation uses --to-relation depends_on --batch \\\n  \
34        --filter-source-type tool\n\n\
35NOTE:\n  \
36    Single mode requires --source, --target and --from-relation.\n  \
37    Batch mode requires --from-relation, --to-relation and --batch.\n  \
38    --filter-source-type and --filter-target-type are only effective in batch mode.")]
39pub struct ReclassifyRelationArgs {
40    /// Source entity name (single mode). Mutually exclusive with --batch.
41    #[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
42    pub source: Option<String>,
43    /// Target entity name (single mode). Mutually exclusive with --batch.
44    #[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
45    pub target: Option<String>,
46    /// Current relation type to rename. Required in both single and batch modes.
47    #[arg(long, value_parser = crate::parsers::parse_relation, value_name = "RELATION")]
48    pub from_relation: String,
49    /// New relation type to assign. Required in both single and batch modes.
50    #[arg(long, value_parser = crate::parsers::parse_relation, value_name = "RELATION")]
51    pub to_relation: String,
52    /// Enable batch reclassification of all edges with --from-relation. Requires --from-relation and --to-relation.
53    #[arg(long, default_value_t = false)]
54    pub batch: bool,
55    /// Filter batch: only rename edges whose source entity has this type.
56    #[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
57    pub filter_source_type: Option<EntityType>,
58    /// Filter batch: only rename edges whose target entity has this type.
59    #[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
60    pub filter_target_type: Option<EntityType>,
61    /// Preview count without committing changes.
62    #[arg(long, default_value_t = false)]
63    pub dry_run: bool,
64    #[arg(long)]
65    pub namespace: Option<String>,
66    #[arg(long, value_enum, default_value = "json")]
67    pub format: OutputFormat,
68    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
69    pub json: bool,
70    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
71    pub db: Option<String>,
72}
73
74#[derive(Serialize)]
75struct ReclassifyRelationResponse {
76    action: String,
77    from_relation: String,
78    to_relation: String,
79    /// Number of edges successfully renamed.
80    count: usize,
81    /// Edges that collided with an existing (source, target, to_relation) triple
82    /// and were removed rather than renamed (UPDATE OR IGNORE + DELETE pattern).
83    merged_duplicates: usize,
84    namespace: String,
85    elapsed_ms: u64,
86}
87
88pub fn run(args: ReclassifyRelationArgs) -> Result<(), AppError> {
89    let inicio = std::time::Instant::now();
90    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
91    let paths = AppPaths::resolve(args.db.as_deref())?;
92
93    crate::storage::connection::ensure_db_ready(&paths)?;
94
95    // Emit warnings for non-canonical relation values.
96    crate::parsers::warn_if_non_canonical(&args.from_relation);
97    crate::parsers::warn_if_non_canonical(&args.to_relation);
98
99    // Reject same-value renames: nothing to do and would silently remove duplicates.
100    if args.from_relation == args.to_relation {
101        return Err(AppError::Validation(
102            "--from-relation and --to-relation must be different".to_string(),
103        ));
104    }
105
106    let mut conn = open_rw(&paths.db)?;
107
108    if args.batch {
109        run_batch(args, inicio, namespace, &mut conn)
110    } else {
111        run_single(args, inicio, namespace, &mut conn)
112    }
113}
114
115// ---------------------------------------------------------------------------
116// Single mode
117// ---------------------------------------------------------------------------
118
119fn run_single(
120    args: ReclassifyRelationArgs,
121    inicio: std::time::Instant,
122    namespace: String,
123    conn: &mut rusqlite::Connection,
124) -> Result<(), AppError> {
125    let source_name = args.source.as_deref().ok_or_else(|| {
126        AppError::Validation(
127            "--source is required in single mode (omit --batch for single-edge rename)".to_string(),
128        )
129    })?;
130    let target_name = args
131        .target
132        .as_deref()
133        .ok_or_else(|| AppError::Validation("--target is required in single mode".to_string()))?;
134
135    // Resolve entity IDs — fail fast if either side does not exist.
136    // Normalize names to match the normalized stored entity names.
137    let source_name_norm = crate::parsers::normalize_entity_name(source_name);
138    let target_name_norm = crate::parsers::normalize_entity_name(target_name);
139    let source_id: i64 = conn
140        .query_row(
141            "SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
142            params![source_name_norm, namespace],
143            |r| r.get(0),
144        )
145        .map_err(|_| {
146            AppError::NotFound(format!(
147                "source entity '{source_name}' not found in namespace '{namespace}'"
148            ))
149        })?;
150
151    let target_id: i64 = conn
152        .query_row(
153            "SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
154            params![target_name_norm, namespace],
155            |r| r.get(0),
156        )
157        .map_err(|_| {
158            AppError::NotFound(format!(
159                "target entity '{target_name}' not found in namespace '{namespace}'"
160            ))
161        })?;
162
163    // Verify the edge to rename exists.
164    let original_count: i64 = conn.query_row(
165        "SELECT COUNT(*) FROM relationships
166         WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
167        params![source_id, target_id, args.from_relation, namespace],
168        |r| r.get(0),
169    )?;
170
171    if original_count == 0 {
172        return Err(AppError::NotFound(format!(
173            "edge '{source_name}' --[{}]--> '{target_name}' not found in namespace '{namespace}'",
174            args.from_relation
175        )));
176    }
177
178    if args.dry_run {
179        emit_response(
180            &args,
181            "dry_run",
182            original_count as usize,
183            0,
184            namespace,
185            inicio,
186        )?;
187        return Ok(());
188    }
189
190    let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
191
192    let updated = tx.execute(
193        "UPDATE OR IGNORE relationships
194         SET relation = ?1
195         WHERE source_id = ?2 AND target_id = ?3 AND relation = ?4 AND namespace = ?5",
196        params![
197            args.to_relation,
198            source_id,
199            target_id,
200            args.from_relation,
201            namespace
202        ],
203    )?;
204
205    // Remove rows that UPDATE OR IGNORE silently skipped due to UNIQUE collision.
206    let deleted = tx.execute(
207        "DELETE FROM relationships
208         WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
209        params![source_id, target_id, args.from_relation, namespace],
210    )?;
211
212    tx.commit()?;
213
214    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
215
216    let merged = (original_count as usize).saturating_sub(updated + deleted);
217    emit_response(&args, "reclassified", updated, merged, namespace, inicio)
218}
219
220// ---------------------------------------------------------------------------
221// Batch mode
222// ---------------------------------------------------------------------------
223
224fn run_batch(
225    args: ReclassifyRelationArgs,
226    inicio: std::time::Instant,
227    namespace: String,
228    conn: &mut rusqlite::Connection,
229) -> Result<(), AppError> {
230    // Build WHERE clause extensions for optional entity-type filters.
231    // The base query joins relationships with source/target entities.
232    let source_filter = args
233        .filter_source_type
234        .map(|t| format!(" AND src.type = '{}'", t.as_str()))
235        .unwrap_or_default();
236    let target_filter = args
237        .filter_target_type
238        .map(|t| format!(" AND tgt.type = '{}'", t.as_str()))
239        .unwrap_or_default();
240    let has_filters = !source_filter.is_empty() || !target_filter.is_empty();
241
242    // Count edges that would be affected (used for both dry-run and confirmation).
243    let original_count: i64 = if has_filters {
244        conn.query_row(
245            &format!(
246                "SELECT COUNT(*) FROM relationships r
247                 JOIN entities src ON src.id = r.source_id
248                 JOIN entities tgt ON tgt.id = r.target_id
249                 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
250            ),
251            params![args.from_relation, namespace],
252            |r| r.get(0),
253        )?
254    } else {
255        conn.query_row(
256            "SELECT COUNT(*) FROM relationships
257             WHERE relation = ?1 AND namespace = ?2",
258            params![args.from_relation, namespace],
259            |r| r.get(0),
260        )?
261    };
262
263    if original_count == 0 {
264        tracing::warn!(
265            from_relation = %args.from_relation,
266            namespace = %namespace,
267            "reclassify-relation batch matched zero edges — verify --from-relation value"
268        );
269    }
270
271    if args.dry_run {
272        emit_response(
273            &args,
274            "dry_run",
275            original_count as usize,
276            0,
277            namespace,
278            inicio,
279        )?;
280        return Ok(());
281    }
282
283    let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
284
285    let updated = if has_filters {
286        // For filtered batch we need to collect IDs first, then update.
287        let ids: Vec<i64> = {
288            let mut stmt = tx.prepare(&format!(
289                "SELECT r.id FROM relationships r
290                 JOIN entities src ON src.id = r.source_id
291                 JOIN entities tgt ON tgt.id = r.target_id
292                 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
293            ))?;
294            let collected: Vec<i64> = stmt
295                .query_map(params![args.from_relation, namespace], |r| r.get(0))?
296                .collect::<Result<Vec<_>, _>>()?;
297            collected
298        };
299
300        let mut moved: usize = 0;
301        for id in &ids {
302            let n = tx.execute(
303                "UPDATE OR IGNORE relationships
304                 SET relation = ?1
305                 WHERE id = ?2",
306                params![args.to_relation, id],
307            )?;
308            moved += n;
309        }
310        moved
311    } else {
312        tx.execute(
313            "UPDATE OR IGNORE relationships
314             SET relation = ?1
315             WHERE relation = ?2 AND namespace = ?3",
316            params![args.to_relation, args.from_relation, namespace],
317        )?
318    };
319
320    // Remove rows the UPDATE OR IGNORE left behind (UNIQUE collision survivors).
321    let deleted = if has_filters {
322        tx.execute(
323            &format!(
324                "DELETE FROM relationships WHERE id IN (
325                     SELECT r.id FROM relationships r
326                     JOIN entities src ON src.id = r.source_id
327                     JOIN entities tgt ON tgt.id = r.target_id
328                     WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}
329                 )"
330            ),
331            params![args.from_relation, namespace],
332        )?
333    } else {
334        tx.execute(
335            "DELETE FROM relationships WHERE relation = ?1 AND namespace = ?2",
336            params![args.from_relation, namespace],
337        )?
338    };
339
340    tx.commit()?;
341
342    conn.execute_batch("ANALYZE relationships;")?;
343    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
344
345    let merged = (original_count as usize).saturating_sub(updated + deleted);
346    emit_response(&args, "reclassified", updated, merged, namespace, inicio)
347}
348
349// ---------------------------------------------------------------------------
350// Shared response emitter
351// ---------------------------------------------------------------------------
352
353fn emit_response(
354    args: &ReclassifyRelationArgs,
355    action: &str,
356    count: usize,
357    merged_duplicates: usize,
358    namespace: String,
359    inicio: std::time::Instant,
360) -> Result<(), AppError> {
361    let response = ReclassifyRelationResponse {
362        action: action.to_string(),
363        from_relation: args.from_relation.clone(),
364        to_relation: args.to_relation.clone(),
365        count,
366        merged_duplicates,
367        namespace: namespace.clone(),
368        elapsed_ms: inicio.elapsed().as_millis() as u64,
369    };
370
371    match args.format {
372        OutputFormat::Json => output::emit_json(&response)?,
373        OutputFormat::Text | OutputFormat::Markdown => {
374            output::emit_text(&format!(
375                "{action}: {count} edges '{}' → '{}' [{namespace}] (duplicates merged: {merged_duplicates})",
376                args.from_relation, args.to_relation
377            ));
378        }
379    }
380    Ok(())
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    fn make_response(action: &str, count: usize, merged: usize) -> ReclassifyRelationResponse {
388        ReclassifyRelationResponse {
389            action: action.to_string(),
390            from_relation: "mentions".to_string(),
391            to_relation: "related".to_string(),
392            count,
393            merged_duplicates: merged,
394            namespace: "global".to_string(),
395            elapsed_ms: 1,
396        }
397    }
398
399    #[test]
400    fn response_serializes_all_fields() {
401        let resp = make_response("reclassified", 5, 0);
402        let json = serde_json::to_value(&resp).expect("serialization failed");
403        assert_eq!(json["action"], "reclassified");
404        assert_eq!(json["from_relation"], "mentions");
405        assert_eq!(json["to_relation"], "related");
406        assert_eq!(json["count"], 5);
407        assert_eq!(json["merged_duplicates"], 0);
408        assert_eq!(json["namespace"], "global");
409        assert!(json["elapsed_ms"].is_number());
410    }
411
412    #[test]
413    fn response_action_dry_run() {
414        let resp = make_response("dry_run", 10, 0);
415        let json = serde_json::to_value(&resp).expect("serialization failed");
416        assert_eq!(json["action"], "dry_run");
417        assert_eq!(json["count"], 10);
418        assert_eq!(json["merged_duplicates"], 0);
419    }
420
421    #[test]
422    fn response_merged_duplicates_nonzero() {
423        // Simulates a case where 3 out of 10 edges collided with existing rows.
424        let resp = make_response("reclassified", 7, 3);
425        let json = serde_json::to_value(&resp).expect("serialization failed");
426        assert_eq!(json["count"], 7);
427        assert_eq!(json["merged_duplicates"], 3);
428    }
429
430    #[test]
431    fn response_count_zero_when_nothing_matched() {
432        let resp = make_response("reclassified", 0, 0);
433        let json = serde_json::to_value(&resp).expect("serialization failed");
434        assert_eq!(json["count"], 0);
435        assert_eq!(json["merged_duplicates"], 0);
436    }
437
438    #[test]
439    fn response_action_values_exhaustive() {
440        for action in &["reclassified", "dry_run"] {
441            let resp = make_response(action, 1, 0);
442            let json = serde_json::to_value(&resp).expect("serialization");
443            assert_eq!(json["action"], *action);
444        }
445    }
446
447    #[test]
448    fn response_from_and_to_relation_present() {
449        let resp = ReclassifyRelationResponse {
450            action: "reclassified".to_string(),
451            from_relation: "uses".to_string(),
452            to_relation: "depends_on".to_string(),
453            count: 3,
454            merged_duplicates: 1,
455            namespace: "my-project".to_string(),
456            elapsed_ms: 5,
457        };
458        let json = serde_json::to_value(&resp).expect("serialization failed");
459        assert_eq!(json["from_relation"], "uses");
460        assert_eq!(json["to_relation"], "depends_on");
461    }
462
463    #[test]
464    fn same_relation_value_rejected_at_logic_level() {
465        // Validates that the guard in run() would catch from == to.
466        // We test the condition directly since we cannot call run() without a DB.
467        let from = "mentions".to_string();
468        let to = "mentions".to_string();
469        assert!(
470            from == to,
471            "same-value rename must be caught before DB access"
472        );
473    }
474}