Skip to main content

sqlite_graphrag/commands/
remember_batch.rs

1//! Handler for the `remember-batch` CLI subcommand (G08).
2//!
3//! Accepts NDJSON via stdin where each line is a memory to persist.
4//! One CLI invocation, one slot, one DB connection — eliminates N-process
5//! contention from parallel `remember` calls.
6
7use crate::errors::AppError;
8use crate::output;
9use crate::paths::AppPaths;
10use crate::storage::connection::open_rw;
11use crate::storage::{entities, memories, versions};
12use serde::{Deserialize, Serialize};
13use std::io::BufRead;
14
15#[derive(clap::Args)]
16#[command(after_long_help = "EXAMPLES:\n  \
17    # Pipe NDJSON memories from stdin\n  \
18    echo '{\"name\":\"mem-a\",\"type\":\"note\",\"description\":\"a\",\"body\":\"content\"}' | \
19    sqlite-graphrag remember-batch --json\n\n  \
20    # Atomic batch with --transaction\n  \
21    cat memories.ndjson | sqlite-graphrag remember-batch --transaction --json")]
22pub struct RememberBatchArgs {
23    /// Apply all memories in a single transaction (all-or-nothing).
24    #[arg(long)]
25    pub transaction: bool,
26    /// Stop processing on the first failure.
27    #[arg(long)]
28    pub fail_fast: bool,
29    /// Apply force-merge to all memories (update existing by name).
30    #[arg(long)]
31    pub force_merge: bool,
32    /// Namespace override for all memories.
33    #[arg(long, env = "SQLITE_GRAPHRAG_NAMESPACE")]
34    pub namespace: Option<String>,
35    /// Emit NDJSON output.
36    #[arg(long)]
37    pub json: bool,
38    /// Database path override.
39    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
40    pub db: Option<String>,
41}
42
43#[derive(Deserialize)]
44struct BatchInputLine {
45    name: String,
46    #[serde(default = "default_type")]
47    r#type: String,
48    #[serde(default)]
49    description: String,
50    #[serde(default)]
51    body: String,
52    #[serde(default)]
53    entities: Vec<crate::storage::entities::NewEntity>,
54    #[serde(default)]
55    relationships: Vec<crate::storage::entities::NewRelationship>,
56}
57
58fn default_type() -> String {
59    "note".to_string()
60}
61
62#[derive(Serialize)]
63struct BatchItemEvent {
64    name: String,
65    status: String,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    memory_id: Option<i64>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    error: Option<String>,
70    index: usize,
71}
72
73#[derive(Serialize)]
74struct BatchSummary {
75    summary: bool,
76    total: usize,
77    succeeded: usize,
78    failed: usize,
79    elapsed_ms: u64,
80}
81
82pub fn run(args: RememberBatchArgs) -> Result<(), AppError> {
83    let start = std::time::Instant::now();
84    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
85    let paths = AppPaths::resolve(args.db.as_deref())?;
86    paths.ensure_dirs()?;
87    crate::storage::connection::ensure_db_ready(&paths)?;
88    let mut conn = open_rw(&paths.db)?;
89
90    let stdin = std::io::stdin();
91    let lines: Vec<String> = stdin
92        .lock()
93        .lines()
94        .map_while(Result::ok)
95        .filter(|l| !l.trim().is_empty())
96        .collect();
97
98    let total = lines.len();
99    let mut succeeded = 0usize;
100    let mut failed = 0usize;
101
102    if args.transaction {
103        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
104        for (idx, line) in lines.iter().enumerate() {
105            match process_line(&tx, &namespace, line, idx, args.force_merge, &paths) {
106                Ok(event) => {
107                    output::emit_json(&event)?;
108                    succeeded += 1;
109                }
110                Err(e) => {
111                    failed += 1;
112                    output::emit_json(&BatchItemEvent {
113                        name: String::new(),
114                        status: "failed".to_string(),
115                        memory_id: None,
116                        error: Some(format!("{e}")),
117                        index: idx,
118                    })?;
119                    if args.fail_fast {
120                        break;
121                    }
122                }
123            }
124        }
125        if failed == 0 || !args.fail_fast {
126            tx.commit()?;
127        }
128    } else {
129        for (idx, line) in lines.iter().enumerate() {
130            let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
131            match process_line(&tx, &namespace, line, idx, args.force_merge, &paths) {
132                Ok(event) => {
133                    tx.commit()?;
134                    output::emit_json(&event)?;
135                    succeeded += 1;
136                }
137                Err(e) => {
138                    drop(tx);
139                    failed += 1;
140                    output::emit_json(&BatchItemEvent {
141                        name: String::new(),
142                        status: "failed".to_string(),
143                        memory_id: None,
144                        error: Some(format!("{e}")),
145                        index: idx,
146                    })?;
147                    if args.fail_fast {
148                        break;
149                    }
150                }
151            }
152        }
153    }
154
155    output::emit_json(&BatchSummary {
156        summary: true,
157        total,
158        succeeded,
159        failed,
160        elapsed_ms: start.elapsed().as_millis() as u64,
161    })?;
162
163    Ok(())
164}
165
166fn process_line(
167    tx: &rusqlite::Transaction<'_>,
168    namespace: &str,
169    line: &str,
170    index: usize,
171    force_merge: bool,
172    paths: &AppPaths,
173) -> Result<BatchItemEvent, AppError> {
174    let input: BatchInputLine = serde_json::from_str(line)
175        .map_err(|e| AppError::Validation(format!("line {index}: invalid JSON: {e}")))?;
176
177    let normalized_name = crate::parsers::normalize_entity_name(&input.name);
178    if normalized_name.is_empty() {
179        return Err(AppError::Validation(format!(
180            "line {index}: name normalizes to empty string"
181        )));
182    }
183
184    let body_hash = blake3::hash(input.body.as_bytes()).to_hex().to_string();
185
186    let existing = memories::find_by_name(tx, namespace, &normalized_name)?;
187
188    let memory_id = if let Some((existing_id, _updated_at, _version)) = existing {
189        if !force_merge {
190            return Err(AppError::Duplicate(format!(
191                "memory '{normalized_name}' already exists; use --force-merge to update"
192            )));
193        }
194        let snippet: String = input.body.chars().take(200).collect();
195        memories::update(
196            tx,
197            existing_id,
198            &memories::NewMemory {
199                namespace: namespace.to_string(),
200                name: normalized_name.clone(),
201                memory_type: input.r#type.clone(),
202                description: input.description.clone(),
203                body: input.body.clone(),
204                body_hash,
205                session_id: None,
206                source: "agent".to_string(),
207                metadata: serde_json::json!({}),
208            },
209            None,
210        )?;
211        let next_v = versions::next_version(tx, existing_id)?;
212        versions::insert_version(
213            tx,
214            existing_id,
215            next_v,
216            &normalized_name,
217            &input.r#type,
218            &input.description,
219            &input.body,
220            "{}",
221            None,
222            "edit",
223        )?;
224
225        let embedding = crate::embedder::embed_passage_local(&paths.models, &input.body)?;
226        memories::upsert_vec(
227            tx,
228            existing_id,
229            namespace,
230            &input.r#type,
231            &embedding,
232            &normalized_name,
233            &snippet,
234        )?;
235        existing_id
236    } else {
237        let new_mem = memories::NewMemory {
238            namespace: namespace.to_string(),
239            name: normalized_name.clone(),
240            memory_type: input.r#type.clone(),
241            description: input.description.clone(),
242            body: input.body.clone(),
243            body_hash,
244            session_id: None,
245            source: "agent".to_string(),
246            metadata: serde_json::json!({}),
247        };
248        let id = memories::insert(tx, &new_mem)?;
249        versions::insert_version(
250            tx,
251            id,
252            1,
253            &normalized_name,
254            &input.r#type,
255            &input.description,
256            &input.body,
257            "{}",
258            None,
259            "create",
260        )?;
261
262        let snippet: String = input.body.chars().take(200).collect();
263        let embedding = crate::embedder::embed_passage_local(&paths.models, &input.body)?;
264        memories::upsert_vec(
265            tx,
266            id,
267            namespace,
268            &input.r#type,
269            &embedding,
270            &normalized_name,
271            &snippet,
272        )?;
273        id
274    };
275
276    // Persist graph entities and relationships if provided
277    for entity in &input.entities {
278        let entity_id = entities::upsert_entity(tx, namespace, entity)?;
279        let entity_text = match &entity.description {
280            Some(desc) => format!("{} {}", entity.name, desc),
281            None => entity.name.clone(),
282        };
283        let entity_embedding = crate::embedder::embed_passage_local(&paths.models, &entity_text)?;
284        entities::upsert_entity_vec(
285            tx,
286            entity_id,
287            namespace,
288            entity.entity_type,
289            &entity_embedding,
290            &entity.name,
291        )?;
292        entities::link_memory_entity(tx, memory_id, entity_id)?;
293    }
294
295    for rel in &input.relationships {
296        let src_name = crate::parsers::normalize_entity_name(&rel.source);
297        let tgt_name = crate::parsers::normalize_entity_name(&rel.target);
298        if let (Some(src_id), Some(tgt_id)) = (
299            entities::find_entity_id(tx, namespace, &src_name)?,
300            entities::find_entity_id(tx, namespace, &tgt_name)?,
301        ) {
302            entities::create_or_fetch_relationship(
303                tx,
304                namespace,
305                src_id,
306                tgt_id,
307                &rel.relation,
308                rel.strength,
309                rel.description.as_deref(),
310            )?;
311        }
312    }
313
314    Ok(BatchItemEvent {
315        name: normalized_name,
316        status: "indexed".to_string(),
317        memory_id: Some(memory_id),
318        error: None,
319        index,
320    })
321}