1use crate::entity_type::EntityType;
13use crate::errors::AppError;
14use crate::output::{self, OutputFormat};
15use crate::paths::AppPaths;
16use crate::storage::connection::open_rw;
17use rusqlite::params;
18use serde::Serialize;
19
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # Rename a single edge from 'mentions' to 'related'\n \
23 sqlite-graphrag reclassify-relation --source tokio --target axum \\\n \
24 --from-relation mentions --to-relation related\n\n \
25 # Rename every 'mentions' edge in the namespace to 'related'\n \
26 sqlite-graphrag reclassify-relation \\\n \
27 --from-relation mentions --to-relation related --batch\n\n \
28 # Dry-run to preview what would change\n \
29 sqlite-graphrag reclassify-relation \\\n \
30 --from-relation mentions --to-relation related --batch --dry-run\n\n \
31 # Batch rename only edges whose source is a 'tool' entity\n \
32 sqlite-graphrag reclassify-relation \\\n \
33 --from-relation uses --to-relation depends_on --batch \\\n \
34 --filter-source-type tool\n\n\
35NOTE:\n \
36 Single mode requires --source, --target and --from-relation.\n \
37 Batch mode requires --from-relation, --to-relation and --batch.\n \
38 --filter-source-type and --filter-target-type are only effective in batch mode.")]
39pub struct ReclassifyRelationArgs {
40 #[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
42 pub source: Option<String>,
43 #[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
45 pub target: Option<String>,
46 #[arg(long, value_parser = crate::parsers::parse_relation, value_name = "RELATION")]
48 pub from_relation: String,
49 #[arg(long, value_parser = crate::parsers::parse_relation, value_name = "RELATION")]
51 pub to_relation: String,
52 #[arg(long, default_value_t = false)]
54 pub batch: bool,
55 #[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
57 pub filter_source_type: Option<EntityType>,
58 #[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
60 pub filter_target_type: Option<EntityType>,
61 #[arg(long, default_value_t = false)]
63 pub dry_run: bool,
64 #[arg(long)]
65 pub namespace: Option<String>,
66 #[arg(long, value_enum, default_value = "json")]
67 pub format: OutputFormat,
68 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
69 pub json: bool,
70 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
71 pub db: Option<String>,
72}
73
74#[derive(Serialize)]
75struct ReclassifyRelationResponse {
76 action: String,
77 from_relation: String,
78 to_relation: String,
79 count: usize,
81 merged_duplicates: usize,
84 namespace: String,
85 elapsed_ms: u64,
86}
87
88pub fn run(args: ReclassifyRelationArgs) -> Result<(), AppError> {
89 let inicio = std::time::Instant::now();
90 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
91 let paths = AppPaths::resolve(args.db.as_deref())?;
92
93 crate::storage::connection::ensure_db_ready(&paths)?;
94
95 crate::parsers::warn_if_non_canonical(&args.from_relation);
97 crate::parsers::warn_if_non_canonical(&args.to_relation);
98
99 if args.from_relation == args.to_relation {
101 return Err(AppError::Validation(
102 "--from-relation and --to-relation must be different".to_string(),
103 ));
104 }
105
106 let mut conn = open_rw(&paths.db)?;
107
108 if args.batch {
109 run_batch(args, inicio, namespace, &mut conn)
110 } else {
111 run_single(args, inicio, namespace, &mut conn)
112 }
113}
114
115fn run_single(
120 args: ReclassifyRelationArgs,
121 inicio: std::time::Instant,
122 namespace: String,
123 conn: &mut rusqlite::Connection,
124) -> Result<(), AppError> {
125 let source_name = args.source.as_deref().ok_or_else(|| {
126 AppError::Validation(
127 "--source is required in single mode (omit --batch for single-edge rename)".to_string(),
128 )
129 })?;
130 let target_name = args
131 .target
132 .as_deref()
133 .ok_or_else(|| AppError::Validation("--target is required in single mode".to_string()))?;
134
135 let source_name_norm = crate::parsers::normalize_entity_name(source_name);
138 let target_name_norm = crate::parsers::normalize_entity_name(target_name);
139 let source_id: i64 = conn
140 .query_row(
141 "SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
142 params![source_name_norm, namespace],
143 |r| r.get(0),
144 )
145 .map_err(|_| {
146 AppError::NotFound(format!(
147 "source entity '{source_name}' not found in namespace '{namespace}'"
148 ))
149 })?;
150
151 let target_id: i64 = conn
152 .query_row(
153 "SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
154 params![target_name_norm, namespace],
155 |r| r.get(0),
156 )
157 .map_err(|_| {
158 AppError::NotFound(format!(
159 "target entity '{target_name}' not found in namespace '{namespace}'"
160 ))
161 })?;
162
163 let original_count: i64 = conn.query_row(
165 "SELECT COUNT(*) FROM relationships
166 WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
167 params![source_id, target_id, args.from_relation, namespace],
168 |r| r.get(0),
169 )?;
170
171 if original_count == 0 {
172 return Err(AppError::NotFound(format!(
173 "edge '{source_name}' --[{}]--> '{target_name}' not found in namespace '{namespace}'",
174 args.from_relation
175 )));
176 }
177
178 if args.dry_run {
179 emit_response(
180 &args,
181 "dry_run",
182 original_count as usize,
183 0,
184 namespace,
185 inicio,
186 )?;
187 return Ok(());
188 }
189
190 let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
191
192 let updated = tx.execute(
193 "UPDATE OR IGNORE relationships
194 SET relation = ?1
195 WHERE source_id = ?2 AND target_id = ?3 AND relation = ?4 AND namespace = ?5",
196 params![
197 args.to_relation,
198 source_id,
199 target_id,
200 args.from_relation,
201 namespace
202 ],
203 )?;
204
205 let deleted = tx.execute(
207 "DELETE FROM relationships
208 WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
209 params![source_id, target_id, args.from_relation, namespace],
210 )?;
211
212 tx.commit()?;
213
214 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
215
216 let merged = (original_count as usize).saturating_sub(updated + deleted);
217 emit_response(&args, "reclassified", updated, merged, namespace, inicio)
218}
219
220fn run_batch(
225 args: ReclassifyRelationArgs,
226 inicio: std::time::Instant,
227 namespace: String,
228 conn: &mut rusqlite::Connection,
229) -> Result<(), AppError> {
230 let source_filter = args
233 .filter_source_type
234 .map(|t| format!(" AND src.type = '{}'", t.as_str()))
235 .unwrap_or_default();
236 let target_filter = args
237 .filter_target_type
238 .map(|t| format!(" AND tgt.type = '{}'", t.as_str()))
239 .unwrap_or_default();
240 let has_filters = !source_filter.is_empty() || !target_filter.is_empty();
241
242 let original_count: i64 = if has_filters {
244 conn.query_row(
245 &format!(
246 "SELECT COUNT(*) FROM relationships r
247 JOIN entities src ON src.id = r.source_id
248 JOIN entities tgt ON tgt.id = r.target_id
249 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
250 ),
251 params![args.from_relation, namespace],
252 |r| r.get(0),
253 )?
254 } else {
255 conn.query_row(
256 "SELECT COUNT(*) FROM relationships
257 WHERE relation = ?1 AND namespace = ?2",
258 params![args.from_relation, namespace],
259 |r| r.get(0),
260 )?
261 };
262
263 if original_count == 0 {
264 tracing::warn!(
265 from_relation = %args.from_relation,
266 namespace = %namespace,
267 "reclassify-relation batch matched zero edges — verify --from-relation value"
268 );
269 }
270
271 if args.dry_run {
272 emit_response(
273 &args,
274 "dry_run",
275 original_count as usize,
276 0,
277 namespace,
278 inicio,
279 )?;
280 return Ok(());
281 }
282
283 let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
284
285 let updated = if has_filters {
286 let ids: Vec<i64> = {
288 let mut stmt = tx.prepare(&format!(
289 "SELECT r.id FROM relationships r
290 JOIN entities src ON src.id = r.source_id
291 JOIN entities tgt ON tgt.id = r.target_id
292 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
293 ))?;
294 let collected: Vec<i64> = stmt
295 .query_map(params![args.from_relation, namespace], |r| r.get(0))?
296 .collect::<Result<Vec<_>, _>>()?;
297 collected
298 };
299
300 let mut moved: usize = 0;
301 for id in &ids {
302 let n = tx.execute(
303 "UPDATE OR IGNORE relationships
304 SET relation = ?1
305 WHERE id = ?2",
306 params![args.to_relation, id],
307 )?;
308 moved += n;
309 }
310 moved
311 } else {
312 tx.execute(
313 "UPDATE OR IGNORE relationships
314 SET relation = ?1
315 WHERE relation = ?2 AND namespace = ?3",
316 params![args.to_relation, args.from_relation, namespace],
317 )?
318 };
319
320 let deleted = if has_filters {
322 tx.execute(
323 &format!(
324 "DELETE FROM relationships WHERE id IN (
325 SELECT r.id FROM relationships r
326 JOIN entities src ON src.id = r.source_id
327 JOIN entities tgt ON tgt.id = r.target_id
328 WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}
329 )"
330 ),
331 params![args.from_relation, namespace],
332 )?
333 } else {
334 tx.execute(
335 "DELETE FROM relationships WHERE relation = ?1 AND namespace = ?2",
336 params![args.from_relation, namespace],
337 )?
338 };
339
340 tx.commit()?;
341
342 conn.execute_batch("ANALYZE relationships;")?;
343 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
344
345 let merged = (original_count as usize).saturating_sub(updated + deleted);
346 emit_response(&args, "reclassified", updated, merged, namespace, inicio)
347}
348
349fn emit_response(
354 args: &ReclassifyRelationArgs,
355 action: &str,
356 count: usize,
357 merged_duplicates: usize,
358 namespace: String,
359 inicio: std::time::Instant,
360) -> Result<(), AppError> {
361 let response = ReclassifyRelationResponse {
362 action: action.to_string(),
363 from_relation: args.from_relation.clone(),
364 to_relation: args.to_relation.clone(),
365 count,
366 merged_duplicates,
367 namespace: namespace.clone(),
368 elapsed_ms: inicio.elapsed().as_millis() as u64,
369 };
370
371 match args.format {
372 OutputFormat::Json => output::emit_json(&response)?,
373 OutputFormat::Text | OutputFormat::Markdown => {
374 output::emit_text(&format!(
375 "{action}: {count} edges '{}' → '{}' [{namespace}] (duplicates merged: {merged_duplicates})",
376 args.from_relation, args.to_relation
377 ));
378 }
379 }
380 Ok(())
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 fn make_response(action: &str, count: usize, merged: usize) -> ReclassifyRelationResponse {
388 ReclassifyRelationResponse {
389 action: action.to_string(),
390 from_relation: "mentions".to_string(),
391 to_relation: "related".to_string(),
392 count,
393 merged_duplicates: merged,
394 namespace: "global".to_string(),
395 elapsed_ms: 1,
396 }
397 }
398
399 #[test]
400 fn response_serializes_all_fields() {
401 let resp = make_response("reclassified", 5, 0);
402 let json = serde_json::to_value(&resp).expect("serialization failed");
403 assert_eq!(json["action"], "reclassified");
404 assert_eq!(json["from_relation"], "mentions");
405 assert_eq!(json["to_relation"], "related");
406 assert_eq!(json["count"], 5);
407 assert_eq!(json["merged_duplicates"], 0);
408 assert_eq!(json["namespace"], "global");
409 assert!(json["elapsed_ms"].is_number());
410 }
411
412 #[test]
413 fn response_action_dry_run() {
414 let resp = make_response("dry_run", 10, 0);
415 let json = serde_json::to_value(&resp).expect("serialization failed");
416 assert_eq!(json["action"], "dry_run");
417 assert_eq!(json["count"], 10);
418 assert_eq!(json["merged_duplicates"], 0);
419 }
420
421 #[test]
422 fn response_merged_duplicates_nonzero() {
423 let resp = make_response("reclassified", 7, 3);
425 let json = serde_json::to_value(&resp).expect("serialization failed");
426 assert_eq!(json["count"], 7);
427 assert_eq!(json["merged_duplicates"], 3);
428 }
429
430 #[test]
431 fn response_count_zero_when_nothing_matched() {
432 let resp = make_response("reclassified", 0, 0);
433 let json = serde_json::to_value(&resp).expect("serialization failed");
434 assert_eq!(json["count"], 0);
435 assert_eq!(json["merged_duplicates"], 0);
436 }
437
438 #[test]
439 fn response_action_values_exhaustive() {
440 for action in &["reclassified", "dry_run"] {
441 let resp = make_response(action, 1, 0);
442 let json = serde_json::to_value(&resp).expect("serialization");
443 assert_eq!(json["action"], *action);
444 }
445 }
446
447 #[test]
448 fn response_from_and_to_relation_present() {
449 let resp = ReclassifyRelationResponse {
450 action: "reclassified".to_string(),
451 from_relation: "uses".to_string(),
452 to_relation: "depends_on".to_string(),
453 count: 3,
454 merged_duplicates: 1,
455 namespace: "my-project".to_string(),
456 elapsed_ms: 5,
457 };
458 let json = serde_json::to_value(&resp).expect("serialization failed");
459 assert_eq!(json["from_relation"], "uses");
460 assert_eq!(json["to_relation"], "depends_on");
461 }
462
463 #[test]
464 fn same_relation_value_rejected_at_logic_level() {
465 let from = "mentions".to_string();
468 let to = "mentions".to_string();
469 assert!(
470 from == to,
471 "same-value rename must be caught before DB access"
472 );
473 }
474}