Skip to main content

sqlite_graphrag/commands/
remember_batch.rs

1//! Handler for the `remember-batch` CLI subcommand (G08).
2//!
3//! Accepts NDJSON via stdin where each line is a memory to persist.
4//! One CLI invocation, one slot, one DB connection — eliminates N-process
5//! contention from parallel `remember` calls.
6
7use crate::errors::AppError;
8use crate::output;
9use crate::paths::AppPaths;
10use crate::storage::connection::open_rw;
11use crate::storage::{entities, memories, versions};
12use serde::{Deserialize, Serialize};
13use std::io::BufRead;
14
15#[derive(clap::Args)]
16#[command(after_long_help = "EXAMPLES:\n  \
17    # Pipe NDJSON memories from stdin\n  \
18    echo '{\"name\":\"mem-a\",\"type\":\"note\",\"description\":\"a\",\"body\":\"content\"}' | \
19    sqlite-graphrag remember-batch --json\n\n  \
20    # Atomic batch with --transaction\n  \
21    cat memories.ndjson | sqlite-graphrag remember-batch --transaction --json")]
22pub struct RememberBatchArgs {
23    /// Apply all memories in a single transaction (all-or-nothing).
24    #[arg(long)]
25    pub transaction: bool,
26    /// Stop processing on the first failure.
27    #[arg(long)]
28    pub fail_fast: bool,
29    /// Apply force-merge to all memories (update existing by name).
30    #[arg(long)]
31    pub force_merge: bool,
32    /// Namespace override for all memories.
33    #[arg(long, env = "SQLITE_GRAPHRAG_NAMESPACE")]
34    pub namespace: Option<String>,
35    /// Emit NDJSON output.
36    #[arg(long)]
37    pub json: bool,
38    /// Database path override.
39    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
40    pub db: Option<String>,
41}
42
43#[derive(Deserialize)]
44struct BatchInputLine {
45    name: String,
46    #[serde(default = "default_type")]
47    r#type: String,
48    #[serde(default)]
49    description: String,
50    #[serde(default)]
51    body: String,
52    #[serde(default)]
53    entities: Vec<crate::storage::entities::NewEntity>,
54    #[serde(default)]
55    relationships: Vec<crate::storage::entities::NewRelationship>,
56}
57
58fn default_type() -> String {
59    "note".to_string()
60}
61
62#[derive(Serialize)]
63struct BatchItemEvent {
64    name: String,
65    status: String,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    memory_id: Option<i64>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    error: Option<String>,
70    index: usize,
71}
72
73#[derive(Serialize)]
74struct BatchSummary {
75    summary: bool,
76    total: usize,
77    succeeded: usize,
78    failed: usize,
79    elapsed_ms: u64,
80}
81
82pub fn run(args: RememberBatchArgs, llm_backend: crate::cli::LlmBackendChoice) -> Result<(), AppError> {
83    let start = std::time::Instant::now();
84    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
85    let paths = AppPaths::resolve(args.db.as_deref())?;
86    paths.ensure_dirs()?;
87    crate::storage::connection::ensure_db_ready(&paths)?;
88    let mut conn = open_rw(&paths.db)?;
89
90    let stdin = std::io::stdin();
91    let lines: Vec<String> = stdin
92        .lock()
93        .lines()
94        .map_while(Result::ok)
95        .filter(|l| !l.trim().is_empty())
96        .collect();
97
98    let total = lines.len();
99    let mut succeeded = 0usize;
100    let mut failed = 0usize;
101
102    if args.transaction {
103        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
104        for (idx, line) in lines.iter().enumerate() {
105            match process_line(&tx, &namespace, line, idx, args.force_merge, &paths, llm_backend) {
106                Ok(event) => {
107                    output::emit_json(&event)?;
108                    succeeded += 1;
109                }
110                Err(e) => {
111                    failed += 1;
112                    output::emit_json(&BatchItemEvent {
113                        name: String::new(),
114                        status: "failed".to_string(),
115                        memory_id: None,
116                        error: Some(format!("{e}")),
117                        index: idx,
118                    })?;
119                    if args.fail_fast {
120                        break;
121                    }
122                }
123            }
124        }
125        if failed == 0 || !args.fail_fast {
126            tx.commit()?;
127        }
128    } else {
129        for (idx, line) in lines.iter().enumerate() {
130            let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
131            match process_line(&tx, &namespace, line, idx, args.force_merge, &paths, llm_backend) {
132                Ok(event) => {
133                    tx.commit()?;
134                    output::emit_json(&event)?;
135                    succeeded += 1;
136                }
137                Err(e) => {
138                    drop(tx);
139                    failed += 1;
140                    output::emit_json(&BatchItemEvent {
141                        name: String::new(),
142                        status: "failed".to_string(),
143                        memory_id: None,
144                        error: Some(format!("{e}")),
145                        index: idx,
146                    })?;
147                    if args.fail_fast {
148                        break;
149                    }
150                }
151            }
152        }
153    }
154
155    output::emit_json(&BatchSummary {
156        summary: true,
157        total,
158        succeeded,
159        failed,
160        elapsed_ms: start.elapsed().as_millis() as u64,
161    })?;
162
163    Ok(())
164}
165
166fn process_line(
167    tx: &rusqlite::Transaction<'_>,
168    namespace: &str,
169    line: &str,
170    index: usize,
171    force_merge: bool,
172    paths: &AppPaths,
173    llm_backend: crate::cli::LlmBackendChoice,
174) -> Result<BatchItemEvent, AppError> {
175    let input: BatchInputLine = serde_json::from_str(line)
176        .map_err(|e| AppError::Validation(format!("line {index}: invalid JSON: {e}")))?;
177
178    let normalized_name = crate::parsers::normalize_entity_name(&input.name);
179    if normalized_name.is_empty() {
180        return Err(AppError::Validation(format!(
181            "line {index}: name normalizes to empty string"
182        )));
183    }
184
185    let body_hash = blake3::hash(input.body.as_bytes()).to_hex().to_string();
186
187    let existing = memories::find_by_name(tx, namespace, &normalized_name)?;
188
189    let (memory_id, batch_action) = if let Some((existing_id, _updated_at, _version)) = existing {
190        if !force_merge {
191            return Err(AppError::Duplicate(format!(
192                "memory '{normalized_name}' already exists; use --force-merge to update"
193            )));
194        }
195        let snippet: String = input.body.chars().take(200).collect();
196        // Capture old FTS values BEFORE the UPDATE for sync_fts_after_update
197        // (trg_fts_au trigger is absent by design due to sqlite-vec conflict).
198        let (old_fts_name, old_fts_desc, old_fts_body): (String, String, String) = tx
199            .query_row(
200                "SELECT name, description, body FROM memories WHERE id = ?1",
201                rusqlite::params![existing_id],
202                |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
203            )?;
204        memories::update(
205            tx,
206            existing_id,
207            &memories::NewMemory {
208                namespace: namespace.to_string(),
209                name: normalized_name.clone(),
210                memory_type: input.r#type.clone(),
211                description: input.description.clone(),
212                body: input.body.clone(),
213                body_hash,
214                session_id: None,
215                source: "agent".to_string(),
216                metadata: serde_json::json!({}),
217            },
218            None,
219        )?;
220        memories::sync_fts_after_update(
221            tx,
222            existing_id,
223            &old_fts_name,
224            &old_fts_desc,
225            &old_fts_body,
226            &normalized_name,
227            &input.description,
228            &input.body,
229        )?;
230        let next_v = versions::next_version(tx, existing_id)?;
231        versions::insert_version(
232            tx,
233            existing_id,
234            next_v,
235            &normalized_name,
236            &input.r#type,
237            &input.description,
238            &input.body,
239            "{}",
240            None,
241            "edit",
242        )?;
243
244        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
245        match crate::embedder::embed_passage_with_choice(&paths.models, &input.body, Some(llm_backend)) {
246            Ok((embedding, _backend)) => {
247                memories::upsert_vec(tx, existing_id, namespace, &input.r#type, &embedding, &normalized_name, &snippet)?;
248            }
249            Err(AppError::Validation(msg)) => return Err(AppError::Validation(msg)),
250            Err(e) if skip_embed => {
251                tracing::warn!(error = %e, "remember-batch: embedding failed; --skip-embedding-on-failure active, persisting without embedding");
252            }
253            Err(e) => return Err(e),
254        }
255        (existing_id, "updated")
256    } else {
257        let new_mem = memories::NewMemory {
258            namespace: namespace.to_string(),
259            name: normalized_name.clone(),
260            memory_type: input.r#type.clone(),
261            description: input.description.clone(),
262            body: input.body.clone(),
263            body_hash,
264            session_id: None,
265            source: "agent".to_string(),
266            metadata: serde_json::json!({}),
267        };
268        let id = memories::insert(tx, &new_mem)?;
269        versions::insert_version(
270            tx,
271            id,
272            1,
273            &normalized_name,
274            &input.r#type,
275            &input.description,
276            &input.body,
277            "{}",
278            None,
279            "create",
280        )?;
281
282        let snippet: String = input.body.chars().take(200).collect();
283        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
284        match crate::embedder::embed_passage_with_choice(&paths.models, &input.body, Some(llm_backend)) {
285            Ok((embedding, _backend)) => {
286                memories::upsert_vec(tx, id, namespace, &input.r#type, &embedding, &normalized_name, &snippet)?;
287            }
288            Err(AppError::Validation(msg)) => return Err(AppError::Validation(msg)),
289            Err(e) if skip_embed => {
290                tracing::warn!(error = %e, "remember-batch: embedding failed; --skip-embedding-on-failure active, persisting without embedding");
291            }
292            Err(e) => return Err(e),
293        }
294        (id, "created")
295    };
296
297    // Persist graph entities and relationships if provided
298    for entity in &input.entities {
299        let entity_id = entities::upsert_entity(tx, namespace, entity)?;
300        let entity_text = match &entity.description {
301            Some(desc) => format!("{} {}", entity.name, desc),
302            None => entity.name.clone(),
303        };
304        let skip_embed = crate::embedder::should_skip_embedding_on_failure();
305        match crate::embedder::embed_entity_texts_cached(
306            &paths.models,
307            std::slice::from_ref(&entity_text),
308            1,
309        ) {
310            Ok((entity_embedding_vec, _stats)) => {
311                if let Some(entity_embedding) = entity_embedding_vec.into_iter().next() {
312                    entities::upsert_entity_vec(
313                        tx, entity_id, namespace, entity.entity_type, &entity_embedding, &entity.name,
314                    )?;
315                }
316            }
317            Err(e) if skip_embed => {
318                tracing::warn!(error = %e, "remember-batch: entity embedding failed; --skip-embedding-on-failure active");
319            }
320            Err(e) => return Err(e),
321        }
322        entities::link_memory_entity(tx, memory_id, entity_id)?;
323    }
324
325    for rel in &input.relationships {
326        let src_name = crate::parsers::normalize_entity_name(&rel.source);
327        let tgt_name = crate::parsers::normalize_entity_name(&rel.target);
328        if let (Some(src_id), Some(tgt_id)) = (
329            entities::find_entity_id(tx, namespace, &src_name)?,
330            entities::find_entity_id(tx, namespace, &tgt_name)?,
331        ) {
332            entities::create_or_fetch_relationship(
333                tx,
334                namespace,
335                src_id,
336                tgt_id,
337                &rel.relation,
338                rel.strength,
339                rel.description.as_deref(),
340            )?;
341        }
342    }
343
344    Ok(BatchItemEvent {
345        name: normalized_name,
346        status: batch_action.to_string(),
347        memory_id: Some(memory_id),
348        error: None,
349        index,
350    })
351}