Skip to main content

sqlite_graphrag/commands/
remember_batch.rs

1//! Handler for the `remember-batch` CLI subcommand (G08).
2//!
3//! Accepts NDJSON via stdin where each line is a memory to persist.
4//! One CLI invocation, one slot, one DB connection — eliminates N-process
5//! contention from parallel `remember` calls.
6
7use crate::errors::AppError;
8use crate::output;
9use crate::paths::AppPaths;
10use crate::storage::connection::open_rw;
11use crate::storage::{entities, memories, versions};
12use serde::{Deserialize, Serialize};
13use std::io::BufRead;
14
15#[derive(clap::Args)]
16#[command(after_long_help = "EXAMPLES:\n  \
17    # Pipe NDJSON memories from stdin\n  \
18    echo '{\"name\":\"mem-a\",\"type\":\"note\",\"description\":\"a\",\"body\":\"content\"}' | \
19    sqlite-graphrag remember-batch --json\n\n  \
20    # Atomic batch with --transaction\n  \
21    cat memories.ndjson | sqlite-graphrag remember-batch --transaction --json")]
22pub struct RememberBatchArgs {
23    /// Apply all memories in a single transaction (all-or-nothing).
24    #[arg(long)]
25    pub transaction: bool,
26    /// Stop processing on the first failure.
27    #[arg(long)]
28    pub fail_fast: bool,
29    /// Apply force-merge to all memories (update existing by name).
30    #[arg(long)]
31    pub force_merge: bool,
32    /// Validate inputs and emit preview events without persisting or embedding.
33    #[arg(long)]
34    pub dry_run: bool,
35    /// Namespace override for all memories.
36    #[arg(long, env = "SQLITE_GRAPHRAG_NAMESPACE")]
37    pub namespace: Option<String>,
38    /// Emit NDJSON output.
39    #[arg(long)]
40    pub json: bool,
41    /// Database path override.
42    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
43    pub db: Option<String>,
44    /// GAP-SG-35: maximum simultaneous LLM embedding subprocesses, accepted for
45    /// parity with `remember`/`edit`/`ingest`/`enrich` so agents that append
46    /// `--llm-parallelism` to every invocation never hit a clap error. The
47    /// batch loop embeds one passage per item serially; this value bounds the
48    /// embedding fan-out width where the backend supports it (clamp [1, 32]).
49    #[arg(long, default_value_t = 4, value_name = "N",
50          value_parser = clap::value_parser!(u64).range(1..=32))]
51    pub llm_parallelism: u64,
52}
53
54#[derive(Deserialize)]
55struct BatchInputLine {
56    name: String,
57    #[serde(default = "default_type")]
58    r#type: String,
59    #[serde(default)]
60    description: String,
61    #[serde(default)]
62    body: String,
63    #[serde(default)]
64    entities: Vec<crate::storage::entities::NewEntity>,
65    #[serde(default)]
66    relationships: Vec<crate::storage::entities::NewRelationship>,
67}
68
69fn default_type() -> String {
70    "note".to_string()
71}
72
73#[derive(Serialize)]
74struct BatchItemEvent {
75    name: String,
76    status: String,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    memory_id: Option<i64>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    error: Option<String>,
81    index: usize,
82}
83
84#[derive(Serialize)]
85struct BatchSummary {
86    summary: bool,
87    total: usize,
88    succeeded: usize,
89    failed: usize,
90    elapsed_ms: u64,
91}
92
93pub fn run(
94    args: RememberBatchArgs,
95    llm_backend: crate::cli::LlmBackendChoice,
96    embedding_backend: crate::cli::EmbeddingBackendChoice,
97) -> Result<(), AppError> {
98    let start = std::time::Instant::now();
99    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
100    let paths = AppPaths::resolve(args.db.as_deref())?;
101    paths.ensure_dirs()?;
102    crate::storage::connection::ensure_db_ready(&paths)?;
103    let mut conn = open_rw(&paths.db)?;
104
105    let stdin = std::io::stdin();
106    let lines: Vec<String> = stdin
107        .lock()
108        .lines()
109        .map_while(Result::ok)
110        .filter(|l| !l.trim().is_empty())
111        .collect();
112
113    let total = lines.len();
114    let mut succeeded = 0usize;
115    let mut failed = 0usize;
116
117    if args.dry_run {
118        for (idx, line) in lines.iter().enumerate() {
119            match serde_json::from_str::<BatchInputLine>(line) {
120                Ok(input) => {
121                    let normalized_name = crate::parsers::normalize_entity_name(&input.name);
122                    if normalized_name.is_empty() {
123                        failed += 1;
124                        output::emit_json(&BatchItemEvent {
125                            name: String::new(),
126                            status: "failed".to_string(),
127                            memory_id: None,
128                            error: Some(format!("line {idx}: name normalizes to empty string")),
129                            index: idx,
130                        })?;
131                        continue;
132                    }
133                    let existing = memories::find_by_name(&conn, &namespace, &normalized_name)?;
134                    let action = if existing.is_some() {
135                        if args.force_merge {
136                            "would_update"
137                        } else {
138                            "would_fail_duplicate"
139                        }
140                    } else {
141                        "would_create"
142                    };
143                    succeeded += 1;
144                    output::emit_json(&BatchItemEvent {
145                        name: normalized_name,
146                        status: action.to_string(),
147                        memory_id: existing.map(|(id, _, _)| id),
148                        error: None,
149                        index: idx,
150                    })?;
151                }
152                Err(e) => {
153                    failed += 1;
154                    output::emit_json(&BatchItemEvent {
155                        name: String::new(),
156                        status: "failed".to_string(),
157                        memory_id: None,
158                        error: Some(format!("line {idx}: invalid JSON: {e}")),
159                        index: idx,
160                    })?;
161                }
162            }
163        }
164
165        output::emit_json(&BatchSummary {
166            summary: true,
167            total,
168            succeeded,
169            failed,
170            elapsed_ms: start.elapsed().as_millis() as u64,
171        })?;
172        return Ok(());
173    }
174
175    if args.transaction {
176        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
177        for (idx, line) in lines.iter().enumerate() {
178            match process_line(
179                &tx,
180                &namespace,
181                line,
182                idx,
183                args.force_merge,
184                &paths,
185                llm_backend,
186                embedding_backend,
187            ) {
188                Ok(event) => {
189                    output::emit_json(&event)?;
190                    succeeded += 1;
191                }
192                Err(e) => {
193                    failed += 1;
194                    output::emit_json(&BatchItemEvent {
195                        name: String::new(),
196                        status: "failed".to_string(),
197                        memory_id: None,
198                        error: Some(format!("{e}")),
199                        index: idx,
200                    })?;
201                    if args.fail_fast {
202                        break;
203                    }
204                }
205            }
206        }
207        if failed == 0 || !args.fail_fast {
208            tx.commit()?;
209        }
210    } else {
211        for (idx, line) in lines.iter().enumerate() {
212            let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
213            match process_line(
214                &tx,
215                &namespace,
216                line,
217                idx,
218                args.force_merge,
219                &paths,
220                llm_backend,
221                embedding_backend,
222            ) {
223                Ok(event) => {
224                    tx.commit()?;
225                    output::emit_json(&event)?;
226                    succeeded += 1;
227                }
228                Err(e) => {
229                    drop(tx);
230                    failed += 1;
231                    output::emit_json(&BatchItemEvent {
232                        name: String::new(),
233                        status: "failed".to_string(),
234                        memory_id: None,
235                        error: Some(format!("{e}")),
236                        index: idx,
237                    })?;
238                    if args.fail_fast {
239                        break;
240                    }
241                }
242            }
243        }
244    }
245
246    output::emit_json(&BatchSummary {
247        summary: true,
248        total,
249        succeeded,
250        failed,
251        elapsed_ms: start.elapsed().as_millis() as u64,
252    })?;
253
254    Ok(())
255}
256
257#[allow(clippy::too_many_arguments)]
258fn process_line(
259    tx: &rusqlite::Transaction<'_>,
260    namespace: &str,
261    line: &str,
262    index: usize,
263    force_merge: bool,
264    paths: &AppPaths,
265    llm_backend: crate::cli::LlmBackendChoice,
266    embedding_backend: crate::cli::EmbeddingBackendChoice,
267) -> Result<BatchItemEvent, AppError> {
268    let input: BatchInputLine = serde_json::from_str(line)
269        .map_err(|e| AppError::Validation(format!("line {index}: invalid JSON: {e}")))?;
270
271    let normalized_name = crate::parsers::normalize_entity_name(&input.name);
272    if normalized_name.is_empty() {
273        return Err(AppError::Validation(format!(
274            "line {index}: name normalizes to empty string"
275        )));
276    }
277
278    let body_hash = blake3::hash(input.body.as_bytes()).to_hex().to_string();
279
280    let existing = memories::find_by_name(tx, namespace, &normalized_name)?;
281
282    let (memory_id, batch_action) = if let Some((existing_id, _updated_at, _version)) = existing {
283        if !force_merge {
284            return Err(AppError::Duplicate(format!(
285                "memory '{normalized_name}' already exists; use --force-merge to update"
286            )));
287        }
288        let snippet: String = input.body.chars().take(200).collect();
289        // Capture old FTS values BEFORE the UPDATE for sync_fts_after_update
290        // (trg_fts_au trigger is absent by design due to sqlite-vec conflict).
291        let (old_fts_name, old_fts_desc, old_fts_body): (String, String, String) = tx.query_row(
292            "SELECT name, description, body FROM memories WHERE id = ?1",
293            rusqlite::params![existing_id],
294            |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
295        )?;
296        memories::update(
297            tx,
298            existing_id,
299            &memories::NewMemory {
300                namespace: namespace.to_string(),
301                name: normalized_name.clone(),
302                memory_type: input.r#type.clone(),
303                description: input.description.clone(),
304                body: input.body.clone(),
305                body_hash,
306                session_id: None,
307                source: "agent".to_string(),
308                metadata: serde_json::json!({}),
309            },
310            None,
311        )?;
312        memories::sync_fts_after_update(
313            tx,
314            existing_id,
315            &old_fts_name,
316            &old_fts_desc,
317            &old_fts_body,
318            &normalized_name,
319            &input.description,
320            &input.body,
321        )?;
322        let next_v = versions::next_version(tx, existing_id)?;
323        versions::insert_version(
324            tx,
325            existing_id,
326            next_v,
327            &normalized_name,
328            &input.r#type,
329            &input.description,
330            &input.body,
331            "{}",
332            None,
333            "edit",
334        )?;
335
336        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
337        match crate::embedder::embed_passage_with_embedding_choice(
338            &paths.models,
339            &input.body,
340            embedding_backend,
341            llm_backend,
342        ) {
343            Ok((embedding, _backend)) => {
344                memories::upsert_vec(
345                    tx,
346                    existing_id,
347                    namespace,
348                    &input.r#type,
349                    &embedding,
350                    &normalized_name,
351                    &snippet,
352                )?;
353            }
354            Err(AppError::Validation(msg)) => return Err(AppError::Validation(msg)),
355            Err(e) if skip_embed => {
356                tracing::warn!(error = %e, "remember-batch: embedding failed; --skip-embedding-on-failure active, persisting without embedding");
357            }
358            Err(e) => return Err(e),
359        }
360        (existing_id, "updated")
361    } else {
362        let new_mem = memories::NewMemory {
363            namespace: namespace.to_string(),
364            name: normalized_name.clone(),
365            memory_type: input.r#type.clone(),
366            description: input.description.clone(),
367            body: input.body.clone(),
368            body_hash,
369            session_id: None,
370            source: "agent".to_string(),
371            metadata: serde_json::json!({}),
372        };
373        let id = memories::insert(tx, &new_mem)?;
374        versions::insert_version(
375            tx,
376            id,
377            1,
378            &normalized_name,
379            &input.r#type,
380            &input.description,
381            &input.body,
382            "{}",
383            None,
384            "create",
385        )?;
386
387        let snippet: String = input.body.chars().take(200).collect();
388        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
389        match crate::embedder::embed_passage_with_embedding_choice(
390            &paths.models,
391            &input.body,
392            embedding_backend,
393            llm_backend,
394        ) {
395            Ok((embedding, _backend)) => {
396                memories::upsert_vec(
397                    tx,
398                    id,
399                    namespace,
400                    &input.r#type,
401                    &embedding,
402                    &normalized_name,
403                    &snippet,
404                )?;
405            }
406            Err(AppError::Validation(msg)) => return Err(AppError::Validation(msg)),
407            Err(e) if skip_embed => {
408                tracing::warn!(error = %e, "remember-batch: embedding failed; --skip-embedding-on-failure active, persisting without embedding");
409            }
410            Err(e) => return Err(e),
411        }
412        (id, "created")
413    };
414
415    // Persist graph entities and relationships if provided
416    for entity in &input.entities {
417        let entity_id = entities::upsert_entity(tx, namespace, entity)?;
418        let entity_text = match &entity.description {
419            Some(desc) => format!("{} {}", entity.name, desc),
420            None => entity.name.clone(),
421        };
422        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
423        match crate::embedder::embed_entity_texts_cached(
424            &paths.models,
425            std::slice::from_ref(&entity_text),
426            1,
427            embedding_backend,
428            llm_backend,
429        ) {
430            Ok((entity_embedding_vec, _stats)) => {
431                if let Some(entity_embedding) = entity_embedding_vec.into_iter().next() {
432                    entities::upsert_entity_vec(
433                        tx,
434                        entity_id,
435                        namespace,
436                        entity.entity_type,
437                        &entity_embedding,
438                        &entity.name,
439                    )?;
440                }
441            }
442            Err(e) if skip_embed => {
443                tracing::warn!(error = %e, "remember-batch: entity embedding failed; --skip-embedding-on-failure active");
444            }
445            Err(e) => return Err(e),
446        }
447        entities::link_memory_entity(tx, memory_id, entity_id)?;
448    }
449
450    for rel in &input.relationships {
451        let src_name = crate::parsers::normalize_entity_name(&rel.source);
452        let tgt_name = crate::parsers::normalize_entity_name(&rel.target);
453        if let (Some(src_id), Some(tgt_id)) = (
454            entities::find_entity_id(tx, namespace, &src_name)?,
455            entities::find_entity_id(tx, namespace, &tgt_name)?,
456        ) {
457            entities::create_or_fetch_relationship(
458                tx,
459                namespace,
460                src_id,
461                tgt_id,
462                &rel.relation,
463                rel.strength,
464                rel.description.as_deref(),
465            )?;
466        }
467    }
468
469    Ok(BatchItemEvent {
470        name: normalized_name,
471        status: batch_action.to_string(),
472        memory_id: Some(memory_id),
473        error: None,
474        index,
475    })
476}