Skip to main content

sqlite_graphrag/commands/
remember_batch.rs

1//! Handler for the `remember-batch` CLI subcommand (G08).
2//!
3//! Accepts NDJSON via stdin where each line is a memory to persist.
4//! One CLI invocation, one slot, one DB connection — eliminates N-process
5//! contention from parallel `remember` calls.
6
7use crate::errors::AppError;
8use crate::output;
9use crate::paths::AppPaths;
10use crate::storage::connection::open_rw;
11use crate::storage::{entities, memories, versions};
12use serde::{Deserialize, Serialize};
13use std::io::BufRead;
14
15#[derive(clap::Args)]
16#[command(after_long_help = "EXAMPLES:\n  \
17    # Pipe NDJSON memories from stdin\n  \
18    echo '{\"name\":\"mem-a\",\"type\":\"note\",\"description\":\"a\",\"body\":\"content\"}' | \
19    sqlite-graphrag remember-batch --json\n\n  \
20    # Atomic batch with --transaction\n  \
21    cat memories.ndjson | sqlite-graphrag remember-batch --transaction --json")]
22pub struct RememberBatchArgs {
23    /// Apply all memories in a single transaction (all-or-nothing).
24    #[arg(long)]
25    pub transaction: bool,
26    /// Stop processing on the first failure.
27    #[arg(long)]
28    pub fail_fast: bool,
29    /// Apply force-merge to all memories (update existing by name).
30    #[arg(long)]
31    pub force_merge: bool,
32    /// Validate inputs and emit preview events without persisting or embedding.
33    #[arg(long)]
34    pub dry_run: bool,
35    /// Namespace override for all memories.
36    #[arg(long, env = "SQLITE_GRAPHRAG_NAMESPACE")]
37    pub namespace: Option<String>,
38    /// Emit NDJSON output.
39    #[arg(long)]
40    pub json: bool,
41    /// Database path override.
42    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
43    pub db: Option<String>,
44}
45
46#[derive(Deserialize)]
47struct BatchInputLine {
48    name: String,
49    #[serde(default = "default_type")]
50    r#type: String,
51    #[serde(default)]
52    description: String,
53    #[serde(default)]
54    body: String,
55    #[serde(default)]
56    entities: Vec<crate::storage::entities::NewEntity>,
57    #[serde(default)]
58    relationships: Vec<crate::storage::entities::NewRelationship>,
59}
60
61fn default_type() -> String {
62    "note".to_string()
63}
64
65#[derive(Serialize)]
66struct BatchItemEvent {
67    name: String,
68    status: String,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    memory_id: Option<i64>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    error: Option<String>,
73    index: usize,
74}
75
76#[derive(Serialize)]
77struct BatchSummary {
78    summary: bool,
79    total: usize,
80    succeeded: usize,
81    failed: usize,
82    elapsed_ms: u64,
83}
84
85pub fn run(
86    args: RememberBatchArgs,
87    llm_backend: crate::cli::LlmBackendChoice,
88) -> Result<(), AppError> {
89    let start = std::time::Instant::now();
90    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
91    let paths = AppPaths::resolve(args.db.as_deref())?;
92    paths.ensure_dirs()?;
93    crate::storage::connection::ensure_db_ready(&paths)?;
94    let mut conn = open_rw(&paths.db)?;
95
96    let stdin = std::io::stdin();
97    let lines: Vec<String> = stdin
98        .lock()
99        .lines()
100        .map_while(Result::ok)
101        .filter(|l| !l.trim().is_empty())
102        .collect();
103
104    let total = lines.len();
105    let mut succeeded = 0usize;
106    let mut failed = 0usize;
107
108    if args.dry_run {
109        for (idx, line) in lines.iter().enumerate() {
110            match serde_json::from_str::<BatchInputLine>(line) {
111                Ok(input) => {
112                    let normalized_name = crate::parsers::normalize_entity_name(&input.name);
113                    if normalized_name.is_empty() {
114                        failed += 1;
115                        output::emit_json(&BatchItemEvent {
116                            name: String::new(),
117                            status: "failed".to_string(),
118                            memory_id: None,
119                            error: Some(format!("line {idx}: name normalizes to empty string")),
120                            index: idx,
121                        })?;
122                        continue;
123                    }
124                    let existing = memories::find_by_name(&conn, &namespace, &normalized_name)?;
125                    let action = if existing.is_some() {
126                        if args.force_merge {
127                            "would_update"
128                        } else {
129                            "would_fail_duplicate"
130                        }
131                    } else {
132                        "would_create"
133                    };
134                    succeeded += 1;
135                    output::emit_json(&BatchItemEvent {
136                        name: normalized_name,
137                        status: action.to_string(),
138                        memory_id: existing.map(|(id, _, _)| id),
139                        error: None,
140                        index: idx,
141                    })?;
142                }
143                Err(e) => {
144                    failed += 1;
145                    output::emit_json(&BatchItemEvent {
146                        name: String::new(),
147                        status: "failed".to_string(),
148                        memory_id: None,
149                        error: Some(format!("line {idx}: invalid JSON: {e}")),
150                        index: idx,
151                    })?;
152                }
153            }
154        }
155
156        output::emit_json(&BatchSummary {
157            summary: true,
158            total,
159            succeeded,
160            failed,
161            elapsed_ms: start.elapsed().as_millis() as u64,
162        })?;
163        return Ok(());
164    }
165
166    if args.transaction {
167        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
168        for (idx, line) in lines.iter().enumerate() {
169            match process_line(
170                &tx,
171                &namespace,
172                line,
173                idx,
174                args.force_merge,
175                &paths,
176                llm_backend,
177            ) {
178                Ok(event) => {
179                    output::emit_json(&event)?;
180                    succeeded += 1;
181                }
182                Err(e) => {
183                    failed += 1;
184                    output::emit_json(&BatchItemEvent {
185                        name: String::new(),
186                        status: "failed".to_string(),
187                        memory_id: None,
188                        error: Some(format!("{e}")),
189                        index: idx,
190                    })?;
191                    if args.fail_fast {
192                        break;
193                    }
194                }
195            }
196        }
197        if failed == 0 || !args.fail_fast {
198            tx.commit()?;
199        }
200    } else {
201        for (idx, line) in lines.iter().enumerate() {
202            let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
203            match process_line(
204                &tx,
205                &namespace,
206                line,
207                idx,
208                args.force_merge,
209                &paths,
210                llm_backend,
211            ) {
212                Ok(event) => {
213                    tx.commit()?;
214                    output::emit_json(&event)?;
215                    succeeded += 1;
216                }
217                Err(e) => {
218                    drop(tx);
219                    failed += 1;
220                    output::emit_json(&BatchItemEvent {
221                        name: String::new(),
222                        status: "failed".to_string(),
223                        memory_id: None,
224                        error: Some(format!("{e}")),
225                        index: idx,
226                    })?;
227                    if args.fail_fast {
228                        break;
229                    }
230                }
231            }
232        }
233    }
234
235    output::emit_json(&BatchSummary {
236        summary: true,
237        total,
238        succeeded,
239        failed,
240        elapsed_ms: start.elapsed().as_millis() as u64,
241    })?;
242
243    Ok(())
244}
245
246fn process_line(
247    tx: &rusqlite::Transaction<'_>,
248    namespace: &str,
249    line: &str,
250    index: usize,
251    force_merge: bool,
252    paths: &AppPaths,
253    llm_backend: crate::cli::LlmBackendChoice,
254) -> Result<BatchItemEvent, AppError> {
255    let input: BatchInputLine = serde_json::from_str(line)
256        .map_err(|e| AppError::Validation(format!("line {index}: invalid JSON: {e}")))?;
257
258    let normalized_name = crate::parsers::normalize_entity_name(&input.name);
259    if normalized_name.is_empty() {
260        return Err(AppError::Validation(format!(
261            "line {index}: name normalizes to empty string"
262        )));
263    }
264
265    let body_hash = blake3::hash(input.body.as_bytes()).to_hex().to_string();
266
267    let existing = memories::find_by_name(tx, namespace, &normalized_name)?;
268
269    let (memory_id, batch_action) = if let Some((existing_id, _updated_at, _version)) = existing {
270        if !force_merge {
271            return Err(AppError::Duplicate(format!(
272                "memory '{normalized_name}' already exists; use --force-merge to update"
273            )));
274        }
275        let snippet: String = input.body.chars().take(200).collect();
276        // Capture old FTS values BEFORE the UPDATE for sync_fts_after_update
277        // (trg_fts_au trigger is absent by design due to sqlite-vec conflict).
278        let (old_fts_name, old_fts_desc, old_fts_body): (String, String, String) = tx.query_row(
279            "SELECT name, description, body FROM memories WHERE id = ?1",
280            rusqlite::params![existing_id],
281            |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
282        )?;
283        memories::update(
284            tx,
285            existing_id,
286            &memories::NewMemory {
287                namespace: namespace.to_string(),
288                name: normalized_name.clone(),
289                memory_type: input.r#type.clone(),
290                description: input.description.clone(),
291                body: input.body.clone(),
292                body_hash,
293                session_id: None,
294                source: "agent".to_string(),
295                metadata: serde_json::json!({}),
296            },
297            None,
298        )?;
299        memories::sync_fts_after_update(
300            tx,
301            existing_id,
302            &old_fts_name,
303            &old_fts_desc,
304            &old_fts_body,
305            &normalized_name,
306            &input.description,
307            &input.body,
308        )?;
309        let next_v = versions::next_version(tx, existing_id)?;
310        versions::insert_version(
311            tx,
312            existing_id,
313            next_v,
314            &normalized_name,
315            &input.r#type,
316            &input.description,
317            &input.body,
318            "{}",
319            None,
320            "edit",
321        )?;
322
323        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
324        match crate::embedder::embed_passage_with_choice(
325            &paths.models,
326            &input.body,
327            Some(llm_backend),
328        ) {
329            Ok((embedding, _backend)) => {
330                memories::upsert_vec(
331                    tx,
332                    existing_id,
333                    namespace,
334                    &input.r#type,
335                    &embedding,
336                    &normalized_name,
337                    &snippet,
338                )?;
339            }
340            Err(AppError::Validation(msg)) => return Err(AppError::Validation(msg)),
341            Err(e) if skip_embed => {
342                tracing::warn!(error = %e, "remember-batch: embedding failed; --skip-embedding-on-failure active, persisting without embedding");
343            }
344            Err(e) => return Err(e),
345        }
346        (existing_id, "updated")
347    } else {
348        let new_mem = memories::NewMemory {
349            namespace: namespace.to_string(),
350            name: normalized_name.clone(),
351            memory_type: input.r#type.clone(),
352            description: input.description.clone(),
353            body: input.body.clone(),
354            body_hash,
355            session_id: None,
356            source: "agent".to_string(),
357            metadata: serde_json::json!({}),
358        };
359        let id = memories::insert(tx, &new_mem)?;
360        versions::insert_version(
361            tx,
362            id,
363            1,
364            &normalized_name,
365            &input.r#type,
366            &input.description,
367            &input.body,
368            "{}",
369            None,
370            "create",
371        )?;
372
373        let snippet: String = input.body.chars().take(200).collect();
374        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
375        match crate::embedder::embed_passage_with_choice(
376            &paths.models,
377            &input.body,
378            Some(llm_backend),
379        ) {
380            Ok((embedding, _backend)) => {
381                memories::upsert_vec(
382                    tx,
383                    id,
384                    namespace,
385                    &input.r#type,
386                    &embedding,
387                    &normalized_name,
388                    &snippet,
389                )?;
390            }
391            Err(AppError::Validation(msg)) => return Err(AppError::Validation(msg)),
392            Err(e) if skip_embed => {
393                tracing::warn!(error = %e, "remember-batch: embedding failed; --skip-embedding-on-failure active, persisting without embedding");
394            }
395            Err(e) => return Err(e),
396        }
397        (id, "created")
398    };
399
400    // Persist graph entities and relationships if provided
401    for entity in &input.entities {
402        let entity_id = entities::upsert_entity(tx, namespace, entity)?;
403        let entity_text = match &entity.description {
404            Some(desc) => format!("{} {}", entity.name, desc),
405            None => entity.name.clone(),
406        };
407        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
408        match crate::embedder::embed_entity_texts_cached(
409            &paths.models,
410            std::slice::from_ref(&entity_text),
411            1,
412        ) {
413            Ok((entity_embedding_vec, _stats)) => {
414                if let Some(entity_embedding) = entity_embedding_vec.into_iter().next() {
415                    entities::upsert_entity_vec(
416                        tx,
417                        entity_id,
418                        namespace,
419                        entity.entity_type,
420                        &entity_embedding,
421                        &entity.name,
422                    )?;
423                }
424            }
425            Err(e) if skip_embed => {
426                tracing::warn!(error = %e, "remember-batch: entity embedding failed; --skip-embedding-on-failure active");
427            }
428            Err(e) => return Err(e),
429        }
430        entities::link_memory_entity(tx, memory_id, entity_id)?;
431    }
432
433    for rel in &input.relationships {
434        let src_name = crate::parsers::normalize_entity_name(&rel.source);
435        let tgt_name = crate::parsers::normalize_entity_name(&rel.target);
436        if let (Some(src_id), Some(tgt_id)) = (
437            entities::find_entity_id(tx, namespace, &src_name)?,
438            entities::find_entity_id(tx, namespace, &tgt_name)?,
439        ) {
440            entities::create_or_fetch_relationship(
441                tx,
442                namespace,
443                src_id,
444                tgt_id,
445                &rel.relation,
446                rel.strength,
447                rel.description.as_deref(),
448            )?;
449        }
450    }
451
452    Ok(BatchItemEvent {
453        name: normalized_name,
454        status: batch_action.to_string(),
455        memory_id: Some(memory_id),
456        error: None,
457        index,
458    })
459}