Skip to main content

sqlite_graphrag/commands/
merge_entities.rs

1//! Handler for the `merge-entities` CLI subcommand (GAP-19).
2//!
3//! Merges two or more source entities into a single target entity by:
4//!   1. Retargeting all relationships pointing at any source to the target.
5//!   2. Deduplicating relationships that become identical after the merge
6//!      (same source_id + target_id + relation).
7//!   3. Retargeting memory_entities bindings.
8//!   4. Deleting the now-empty source entity rows.
9
10use crate::errors::AppError;
11use crate::i18n::errors_msg;
12use crate::output::{self, OutputFormat};
13use crate::paths::AppPaths;
14use crate::storage::connection::open_rw;
15use crate::storage::entities;
16use rusqlite::params;
17use serde::Serialize;
18
19#[derive(clap::Args)]
20#[command(after_long_help = "EXAMPLES:\n  \
21    # Merge two source entities into a target\n  \
22    sqlite-graphrag merge-entities --names auth,authentication --into auth-service\n\n  \
23    # Merge three sources into one target across a namespace\n  \
24    sqlite-graphrag merge-entities --names svc-a,svc-b,old-svc --into canonical-service --namespace my-project\n\n\
25NOTE:\n  \
26    --names is a comma-separated list of source entity names.\n  \
27    --into is the target entity name and must already exist.\n  \
28    Source entities are deleted after the merge; the target is preserved.\n  \
29    Duplicate relationships (same endpoints + relation) are removed automatically.\n  \
30    Run `sqlite-graphrag cleanup-orphans` afterwards if sources had no other links.")]
31pub struct MergeEntitiesArgs {
32    /// Comma-separated list of source entity names to merge into the target.
33    #[arg(long, value_delimiter = ',', value_name = "NAMES")]
34    pub names: Vec<String>,
35    /// Target entity name. Must already exist. All source relationships are redirected here.
36    #[arg(long, value_name = "TARGET")]
37    pub into: String,
38    #[arg(long)]
39    pub namespace: Option<String>,
40    #[arg(long, value_enum, default_value = "json")]
41    pub format: OutputFormat,
42    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
43    pub json: bool,
44    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
45    pub db: Option<String>,
46}
47
48#[derive(Serialize)]
49struct MergeEntitiesResponse {
50    action: String,
51    sources: Vec<String>,
52    target: String,
53    namespace: String,
54    relationships_moved: usize,
55    entities_removed: usize,
56    /// Total execution time in milliseconds from handler start to serialisation.
57    elapsed_ms: u64,
58}
59
60pub fn run(args: MergeEntitiesArgs) -> Result<(), AppError> {
61    let inicio = std::time::Instant::now();
62
63    if args.names.is_empty() {
64        return Err(AppError::Validation(
65            "--names must contain at least one source entity name".to_string(),
66        ));
67    }
68
69    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
70    let paths = AppPaths::resolve(args.db.as_deref())?;
71
72    crate::storage::connection::ensure_db_ready(&paths)?;
73
74    let mut conn = open_rw(&paths.db)?;
75
76    // Resolve target entity ID.
77    let target_id = entities::find_entity_id(&conn, &namespace, &args.into)?
78        .ok_or_else(|| AppError::NotFound(errors_msg::entity_not_found(&args.into, &namespace)))?;
79
80    // Resolve source entity IDs — reject self-referential merge (G21).
81    let mut source_ids: Vec<i64> = Vec::with_capacity(args.names.len());
82    for name in &args.names {
83        if name == &args.into {
84            return Err(AppError::Validation(format!(
85                "source entity '{}' equals target '{}' — self-referential merge is not allowed",
86                name, args.into
87            )));
88        }
89        let id = entities::find_entity_id(&conn, &namespace, name)?
90            .ok_or_else(|| AppError::NotFound(errors_msg::entity_not_found(name, &namespace)))?;
91        if !source_ids.contains(&id) {
92            source_ids.push(id);
93        }
94    }
95
96    if source_ids.is_empty() {
97        return Err(AppError::Validation(
98            "no valid source entities to merge (all names equal the target or were duplicates)"
99                .to_string(),
100        ));
101    }
102
103    let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
104
105    let mut relationships_moved: usize = 0;
106
107    for &src_id in &source_ids {
108        // Step 1a: redirect source_id, ignoring UNIQUE conflicts.
109        let moved_src = tx.execute(
110            "UPDATE OR IGNORE relationships SET source_id = ?1 WHERE source_id = ?2",
111            params![target_id, src_id],
112        )?;
113        tx.execute(
114            "DELETE FROM relationships WHERE source_id = ?1",
115            params![src_id],
116        )?;
117        // Step 1b: redirect target_id, ignoring UNIQUE conflicts.
118        let moved_tgt = tx.execute(
119            "UPDATE OR IGNORE relationships SET target_id = ?1 WHERE target_id = ?2",
120            params![target_id, src_id],
121        )?;
122        tx.execute(
123            "DELETE FROM relationships WHERE target_id = ?1",
124            params![src_id],
125        )?;
126        relationships_moved += moved_src + moved_tgt;
127    }
128
129    // Step 2: remove self-loops introduced by the redirect (target → target).
130    tx.execute("DELETE FROM relationships WHERE source_id = target_id", [])?;
131
132    // Step 3: deduplicate relationships that now share (source, target, relation).
133    // Safety net — UPDATE OR IGNORE should have handled most duplicates above.
134    tx.execute(
135        "DELETE FROM relationships
136         WHERE id NOT IN (
137             SELECT MIN(id)
138             FROM relationships
139             GROUP BY source_id, target_id, relation
140         )",
141        [],
142    )?;
143
144    // Step 4: retarget memory_entities bindings.
145    // Use UPDATE OR IGNORE to skip conflicts when memory is already bound to
146    // target entity. Then DELETE remaining source rows (the conflicting ones
147    // that UPDATE OR IGNORE skipped). Same pattern as relationships (Step 1).
148    for &src_id in &source_ids {
149        tx.execute(
150            "UPDATE OR IGNORE memory_entities SET entity_id = ?1 WHERE entity_id = ?2",
151            params![target_id, src_id],
152        )?;
153        tx.execute(
154            "DELETE FROM memory_entities WHERE entity_id = ?1",
155            params![src_id],
156        )?;
157    }
158
159    // Step 5: deduplicate memory_entities bindings (same memory + entity).
160    tx.execute(
161        "DELETE FROM memory_entities
162         WHERE rowid NOT IN (
163             SELECT MIN(rowid)
164             FROM memory_entities
165             GROUP BY memory_id, entity_id
166         )",
167        [],
168    )?;
169
170    // Step 6: delete source entities (vec_entities first — no FK CASCADE on vec0).
171    let mut entities_removed: usize = 0;
172    for &src_id in &source_ids {
173        let _ = tx.execute(
174            "DELETE FROM vec_entities WHERE entity_id = ?1",
175            params![src_id],
176        );
177        let removed = tx.execute("DELETE FROM entities WHERE id = ?1", params![src_id])?;
178        entities_removed += removed;
179    }
180
181    // Step 7: recalculate degree for target and all adjacent entities.
182    let adjacent_ids: Vec<i64> = {
183        let mut stmt = tx.prepare(
184            "SELECT DISTINCT CASE WHEN source_id = ?1 THEN target_id ELSE source_id END
185             FROM relationships WHERE source_id = ?1 OR target_id = ?1",
186        )?;
187        let ids: Vec<i64> = stmt
188            .query_map(params![target_id], |r| r.get(0))?
189            .collect::<Result<Vec<_>, _>>()?;
190        ids
191    };
192    entities::recalculate_degree(&tx, target_id)?;
193    for &adj_id in &adjacent_ids {
194        entities::recalculate_degree(&tx, adj_id)?;
195    }
196
197    tx.commit()?;
198
199    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
200
201    // Build the list of sources that were actually processed (excluding target duplicates).
202    let processed_sources: Vec<String> = args
203        .names
204        .iter()
205        .filter(|n| n.as_str() != args.into.as_str())
206        .cloned()
207        .collect();
208
209    let response = MergeEntitiesResponse {
210        action: "merged".to_string(),
211        sources: processed_sources,
212        target: args.into.clone(),
213        namespace: namespace.clone(),
214        relationships_moved,
215        entities_removed,
216        elapsed_ms: inicio.elapsed().as_millis() as u64,
217    };
218
219    match args.format {
220        OutputFormat::Json => output::emit_json(&response)?,
221        OutputFormat::Text | OutputFormat::Markdown => {
222            output::emit_text(&format!(
223                "merged: {} sources into '{}' (relationships_moved={}, entities_removed={}) [{}]",
224                response.sources.len(),
225                response.target,
226                response.relationships_moved,
227                response.entities_removed,
228                response.namespace
229            ));
230        }
231    }
232
233    Ok(())
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn merge_entities_response_serializes_all_fields() {
242        let resp = MergeEntitiesResponse {
243            action: "merged".to_string(),
244            sources: vec!["auth".to_string(), "authentication".to_string()],
245            target: "auth-service".to_string(),
246            namespace: "global".to_string(),
247            relationships_moved: 7,
248            entities_removed: 2,
249            elapsed_ms: 15,
250        };
251        let json = serde_json::to_value(&resp).expect("serialization failed");
252        assert_eq!(json["action"], "merged");
253        assert_eq!(json["target"], "auth-service");
254        assert_eq!(json["namespace"], "global");
255        assert_eq!(json["relationships_moved"], 7);
256        assert_eq!(json["entities_removed"], 2);
257        let sources = json["sources"].as_array().expect("must be array");
258        assert_eq!(sources.len(), 2);
259        assert!(json["elapsed_ms"].is_number());
260    }
261
262    #[test]
263    fn merge_entities_response_action_is_merged() {
264        let resp = MergeEntitiesResponse {
265            action: "merged".to_string(),
266            sources: vec!["src".to_string()],
267            target: "tgt".to_string(),
268            namespace: "ns".to_string(),
269            relationships_moved: 0,
270            entities_removed: 1,
271            elapsed_ms: 0,
272        };
273        assert_eq!(resp.action, "merged");
274    }
275
276    #[test]
277    fn merge_entities_response_empty_sources_serializes() {
278        let resp = MergeEntitiesResponse {
279            action: "merged".to_string(),
280            sources: vec![],
281            target: "target".to_string(),
282            namespace: "global".to_string(),
283            relationships_moved: 0,
284            entities_removed: 0,
285            elapsed_ms: 1,
286        };
287        let json = serde_json::to_value(&resp).expect("serialization failed");
288        let sources = json["sources"].as_array().expect("must be array");
289        assert_eq!(sources.len(), 0);
290    }
291
292    #[test]
293    fn merge_entities_response_with_zero_relationships_moved() {
294        let resp = MergeEntitiesResponse {
295            action: "merged".to_string(),
296            sources: vec!["src-a".to_string()],
297            target: "tgt".to_string(),
298            namespace: "global".to_string(),
299            relationships_moved: 0,
300            entities_removed: 1,
301            elapsed_ms: 5,
302        };
303        let json = serde_json::to_value(&resp).expect("serialization failed");
304        assert_eq!(json["relationships_moved"], 0);
305        assert_eq!(json["entities_removed"], 1);
306    }
307
308    #[test]
309    fn merge_entities_response_multiple_sources() {
310        let resp = MergeEntitiesResponse {
311            action: "merged".to_string(),
312            sources: vec!["a".into(), "b".into(), "c".into()],
313            target: "canonical".to_string(),
314            namespace: "proj".to_string(),
315            relationships_moved: 12,
316            entities_removed: 3,
317            elapsed_ms: 42,
318        };
319        let json = serde_json::to_value(&resp).expect("serialization failed");
320        assert_eq!(json["entities_removed"], 3);
321        let sources = json["sources"].as_array().unwrap();
322        assert_eq!(sources.len(), 3);
323    }
324}