Skip to main content

sqlite_graphrag/commands/
remember_batch.rs

1//! Handler for the `remember-batch` CLI subcommand (G08).
2//!
3//! Accepts NDJSON via stdin where each line is a memory to persist.
4//! One CLI invocation, one slot, one DB connection — eliminates N-process
5//! contention from parallel `remember` calls.
6
7use crate::errors::AppError;
8use crate::output;
9use crate::paths::AppPaths;
10use crate::storage::connection::open_rw;
11use crate::storage::{entities, memories, versions};
12use serde::{Deserialize, Serialize};
13use std::io::BufRead;
14
15#[derive(clap::Args)]
16#[command(after_long_help = "EXAMPLES:\n  \
17    # Pipe NDJSON memories from stdin\n  \
18    echo '{\"name\":\"mem-a\",\"type\":\"note\",\"description\":\"a\",\"body\":\"content\"}' | \
19    sqlite-graphrag remember-batch --json\n\n  \
20    # Atomic batch with --transaction\n  \
21    cat memories.ndjson | sqlite-graphrag remember-batch --transaction --json")]
22pub struct RememberBatchArgs {
23    /// Apply all memories in a single transaction (all-or-nothing).
24    #[arg(long)]
25    pub transaction: bool,
26    /// Stop processing on the first failure.
27    #[arg(long)]
28    pub fail_fast: bool,
29    /// Apply force-merge to all memories (update existing by name).
30    #[arg(long)]
31    pub force_merge: bool,
32    /// Validate inputs and emit preview events without persisting or embedding.
33    #[arg(long)]
34    pub dry_run: bool,
35    /// Namespace override for all memories.
36    #[arg(long, env = "SQLITE_GRAPHRAG_NAMESPACE")]
37    pub namespace: Option<String>,
38    /// Emit NDJSON output.
39    #[arg(long)]
40    pub json: bool,
41    /// Database path override.
42    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
43    pub db: Option<String>,
44}
45
46#[derive(Deserialize)]
47struct BatchInputLine {
48    name: String,
49    #[serde(default = "default_type")]
50    r#type: String,
51    #[serde(default)]
52    description: String,
53    #[serde(default)]
54    body: String,
55    #[serde(default)]
56    entities: Vec<crate::storage::entities::NewEntity>,
57    #[serde(default)]
58    relationships: Vec<crate::storage::entities::NewRelationship>,
59}
60
61fn default_type() -> String {
62    "note".to_string()
63}
64
65#[derive(Serialize)]
66struct BatchItemEvent {
67    name: String,
68    status: String,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    memory_id: Option<i64>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    error: Option<String>,
73    index: usize,
74}
75
76#[derive(Serialize)]
77struct BatchSummary {
78    summary: bool,
79    total: usize,
80    succeeded: usize,
81    failed: usize,
82    elapsed_ms: u64,
83}
84
85pub fn run(
86    args: RememberBatchArgs,
87    llm_backend: crate::cli::LlmBackendChoice,
88    embedding_backend: crate::cli::EmbeddingBackendChoice,
89) -> Result<(), AppError> {
90    let start = std::time::Instant::now();
91    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
92    let paths = AppPaths::resolve(args.db.as_deref())?;
93    paths.ensure_dirs()?;
94    crate::storage::connection::ensure_db_ready(&paths)?;
95    let mut conn = open_rw(&paths.db)?;
96
97    let stdin = std::io::stdin();
98    let lines: Vec<String> = stdin
99        .lock()
100        .lines()
101        .map_while(Result::ok)
102        .filter(|l| !l.trim().is_empty())
103        .collect();
104
105    let total = lines.len();
106    let mut succeeded = 0usize;
107    let mut failed = 0usize;
108
109    if args.dry_run {
110        for (idx, line) in lines.iter().enumerate() {
111            match serde_json::from_str::<BatchInputLine>(line) {
112                Ok(input) => {
113                    let normalized_name = crate::parsers::normalize_entity_name(&input.name);
114                    if normalized_name.is_empty() {
115                        failed += 1;
116                        output::emit_json(&BatchItemEvent {
117                            name: String::new(),
118                            status: "failed".to_string(),
119                            memory_id: None,
120                            error: Some(format!("line {idx}: name normalizes to empty string")),
121                            index: idx,
122                        })?;
123                        continue;
124                    }
125                    let existing = memories::find_by_name(&conn, &namespace, &normalized_name)?;
126                    let action = if existing.is_some() {
127                        if args.force_merge {
128                            "would_update"
129                        } else {
130                            "would_fail_duplicate"
131                        }
132                    } else {
133                        "would_create"
134                    };
135                    succeeded += 1;
136                    output::emit_json(&BatchItemEvent {
137                        name: normalized_name,
138                        status: action.to_string(),
139                        memory_id: existing.map(|(id, _, _)| id),
140                        error: None,
141                        index: idx,
142                    })?;
143                }
144                Err(e) => {
145                    failed += 1;
146                    output::emit_json(&BatchItemEvent {
147                        name: String::new(),
148                        status: "failed".to_string(),
149                        memory_id: None,
150                        error: Some(format!("line {idx}: invalid JSON: {e}")),
151                        index: idx,
152                    })?;
153                }
154            }
155        }
156
157        output::emit_json(&BatchSummary {
158            summary: true,
159            total,
160            succeeded,
161            failed,
162            elapsed_ms: start.elapsed().as_millis() as u64,
163        })?;
164        return Ok(());
165    }
166
167    if args.transaction {
168        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
169        for (idx, line) in lines.iter().enumerate() {
170            match process_line(
171                &tx,
172                &namespace,
173                line,
174                idx,
175                args.force_merge,
176                &paths,
177                llm_backend,
178                embedding_backend,
179            ) {
180                Ok(event) => {
181                    output::emit_json(&event)?;
182                    succeeded += 1;
183                }
184                Err(e) => {
185                    failed += 1;
186                    output::emit_json(&BatchItemEvent {
187                        name: String::new(),
188                        status: "failed".to_string(),
189                        memory_id: None,
190                        error: Some(format!("{e}")),
191                        index: idx,
192                    })?;
193                    if args.fail_fast {
194                        break;
195                    }
196                }
197            }
198        }
199        if failed == 0 || !args.fail_fast {
200            tx.commit()?;
201        }
202    } else {
203        for (idx, line) in lines.iter().enumerate() {
204            let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
205            match process_line(
206                &tx,
207                &namespace,
208                line,
209                idx,
210                args.force_merge,
211                &paths,
212                llm_backend,
213                embedding_backend,
214            ) {
215                Ok(event) => {
216                    tx.commit()?;
217                    output::emit_json(&event)?;
218                    succeeded += 1;
219                }
220                Err(e) => {
221                    drop(tx);
222                    failed += 1;
223                    output::emit_json(&BatchItemEvent {
224                        name: String::new(),
225                        status: "failed".to_string(),
226                        memory_id: None,
227                        error: Some(format!("{e}")),
228                        index: idx,
229                    })?;
230                    if args.fail_fast {
231                        break;
232                    }
233                }
234            }
235        }
236    }
237
238    output::emit_json(&BatchSummary {
239        summary: true,
240        total,
241        succeeded,
242        failed,
243        elapsed_ms: start.elapsed().as_millis() as u64,
244    })?;
245
246    Ok(())
247}
248
249#[allow(clippy::too_many_arguments)]
250fn process_line(
251    tx: &rusqlite::Transaction<'_>,
252    namespace: &str,
253    line: &str,
254    index: usize,
255    force_merge: bool,
256    paths: &AppPaths,
257    llm_backend: crate::cli::LlmBackendChoice,
258    embedding_backend: crate::cli::EmbeddingBackendChoice,
259) -> Result<BatchItemEvent, AppError> {
260    let input: BatchInputLine = serde_json::from_str(line)
261        .map_err(|e| AppError::Validation(format!("line {index}: invalid JSON: {e}")))?;
262
263    let normalized_name = crate::parsers::normalize_entity_name(&input.name);
264    if normalized_name.is_empty() {
265        return Err(AppError::Validation(format!(
266            "line {index}: name normalizes to empty string"
267        )));
268    }
269
270    let body_hash = blake3::hash(input.body.as_bytes()).to_hex().to_string();
271
272    let existing = memories::find_by_name(tx, namespace, &normalized_name)?;
273
274    let (memory_id, batch_action) = if let Some((existing_id, _updated_at, _version)) = existing {
275        if !force_merge {
276            return Err(AppError::Duplicate(format!(
277                "memory '{normalized_name}' already exists; use --force-merge to update"
278            )));
279        }
280        let snippet: String = input.body.chars().take(200).collect();
281        // Capture old FTS values BEFORE the UPDATE for sync_fts_after_update
282        // (trg_fts_au trigger is absent by design due to sqlite-vec conflict).
283        let (old_fts_name, old_fts_desc, old_fts_body): (String, String, String) = tx.query_row(
284            "SELECT name, description, body FROM memories WHERE id = ?1",
285            rusqlite::params![existing_id],
286            |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
287        )?;
288        memories::update(
289            tx,
290            existing_id,
291            &memories::NewMemory {
292                namespace: namespace.to_string(),
293                name: normalized_name.clone(),
294                memory_type: input.r#type.clone(),
295                description: input.description.clone(),
296                body: input.body.clone(),
297                body_hash,
298                session_id: None,
299                source: "agent".to_string(),
300                metadata: serde_json::json!({}),
301            },
302            None,
303        )?;
304        memories::sync_fts_after_update(
305            tx,
306            existing_id,
307            &old_fts_name,
308            &old_fts_desc,
309            &old_fts_body,
310            &normalized_name,
311            &input.description,
312            &input.body,
313        )?;
314        let next_v = versions::next_version(tx, existing_id)?;
315        versions::insert_version(
316            tx,
317            existing_id,
318            next_v,
319            &normalized_name,
320            &input.r#type,
321            &input.description,
322            &input.body,
323            "{}",
324            None,
325            "edit",
326        )?;
327
328        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
329        match crate::embedder::embed_passage_with_embedding_choice(
330            &paths.models,
331            &input.body,
332            embedding_backend,
333            llm_backend,
334        ) {
335            Ok((embedding, _backend)) => {
336                memories::upsert_vec(
337                    tx,
338                    existing_id,
339                    namespace,
340                    &input.r#type,
341                    &embedding,
342                    &normalized_name,
343                    &snippet,
344                )?;
345            }
346            Err(AppError::Validation(msg)) => return Err(AppError::Validation(msg)),
347            Err(e) if skip_embed => {
348                tracing::warn!(error = %e, "remember-batch: embedding failed; --skip-embedding-on-failure active, persisting without embedding");
349            }
350            Err(e) => return Err(e),
351        }
352        (existing_id, "updated")
353    } else {
354        let new_mem = memories::NewMemory {
355            namespace: namespace.to_string(),
356            name: normalized_name.clone(),
357            memory_type: input.r#type.clone(),
358            description: input.description.clone(),
359            body: input.body.clone(),
360            body_hash,
361            session_id: None,
362            source: "agent".to_string(),
363            metadata: serde_json::json!({}),
364        };
365        let id = memories::insert(tx, &new_mem)?;
366        versions::insert_version(
367            tx,
368            id,
369            1,
370            &normalized_name,
371            &input.r#type,
372            &input.description,
373            &input.body,
374            "{}",
375            None,
376            "create",
377        )?;
378
379        let snippet: String = input.body.chars().take(200).collect();
380        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
381        match crate::embedder::embed_passage_with_embedding_choice(
382            &paths.models,
383            &input.body,
384            embedding_backend,
385            llm_backend,
386        ) {
387            Ok((embedding, _backend)) => {
388                memories::upsert_vec(
389                    tx,
390                    id,
391                    namespace,
392                    &input.r#type,
393                    &embedding,
394                    &normalized_name,
395                    &snippet,
396                )?;
397            }
398            Err(AppError::Validation(msg)) => return Err(AppError::Validation(msg)),
399            Err(e) if skip_embed => {
400                tracing::warn!(error = %e, "remember-batch: embedding failed; --skip-embedding-on-failure active, persisting without embedding");
401            }
402            Err(e) => return Err(e),
403        }
404        (id, "created")
405    };
406
407    // Persist graph entities and relationships if provided
408    for entity in &input.entities {
409        let entity_id = entities::upsert_entity(tx, namespace, entity)?;
410        let entity_text = match &entity.description {
411            Some(desc) => format!("{} {}", entity.name, desc),
412            None => entity.name.clone(),
413        };
414        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
415        match crate::embedder::embed_entity_texts_cached(
416            &paths.models,
417            std::slice::from_ref(&entity_text),
418            1,
419            embedding_backend,
420            llm_backend,
421        ) {
422            Ok((entity_embedding_vec, _stats)) => {
423                if let Some(entity_embedding) = entity_embedding_vec.into_iter().next() {
424                    entities::upsert_entity_vec(
425                        tx,
426                        entity_id,
427                        namespace,
428                        entity.entity_type,
429                        &entity_embedding,
430                        &entity.name,
431                    )?;
432                }
433            }
434            Err(e) if skip_embed => {
435                tracing::warn!(error = %e, "remember-batch: entity embedding failed; --skip-embedding-on-failure active");
436            }
437            Err(e) => return Err(e),
438        }
439        entities::link_memory_entity(tx, memory_id, entity_id)?;
440    }
441
442    for rel in &input.relationships {
443        let src_name = crate::parsers::normalize_entity_name(&rel.source);
444        let tgt_name = crate::parsers::normalize_entity_name(&rel.target);
445        if let (Some(src_id), Some(tgt_id)) = (
446            entities::find_entity_id(tx, namespace, &src_name)?,
447            entities::find_entity_id(tx, namespace, &tgt_name)?,
448        ) {
449            entities::create_or_fetch_relationship(
450                tx,
451                namespace,
452                src_id,
453                tgt_id,
454                &rel.relation,
455                rel.strength,
456                rel.description.as_deref(),
457            )?;
458        }
459    }
460
461    Ok(BatchItemEvent {
462        name: normalized_name,
463        status: batch_action.to_string(),
464        memory_id: Some(memory_id),
465        error: None,
466        index,
467    })
468}