Skip to main content

sqlite_graphrag/commands/
remember_batch.rs

1//! Handler for the `remember-batch` CLI subcommand (G08).
2//!
3//! Accepts NDJSON via stdin where each line is a memory to persist.
4//! One CLI invocation, one slot, one DB connection — eliminates N-process
5//! contention from parallel `remember` calls.
6
7use crate::errors::AppError;
8use crate::output;
9use crate::paths::AppPaths;
10use crate::storage::connection::open_rw;
11use crate::storage::{entities, memories, versions};
12use serde::{Deserialize, Serialize};
13use std::io::BufRead;
14
15#[derive(clap::Args)]
16#[command(after_long_help = "EXAMPLES:\n  \
17    # Pipe NDJSON memories from stdin\n  \
18    echo '{\"name\":\"mem-a\",\"type\":\"note\",\"description\":\"a\",\"body\":\"content\"}' | \
19    sqlite-graphrag remember-batch --json\n\n  \
20    # Atomic batch with --transaction\n  \
21    cat memories.ndjson | sqlite-graphrag remember-batch --transaction --json")]
22pub struct RememberBatchArgs {
23    /// Apply all memories in a single transaction (all-or-nothing).
24    #[arg(long)]
25    pub transaction: bool,
26    /// Stop processing on the first failure.
27    #[arg(long)]
28    pub fail_fast: bool,
29    /// Apply force-merge to all memories (update existing by name).
30    #[arg(long)]
31    pub force_merge: bool,
32    /// Namespace override for all memories.
33    #[arg(long, env = "SQLITE_GRAPHRAG_NAMESPACE")]
34    pub namespace: Option<String>,
35    /// Emit NDJSON output.
36    #[arg(long)]
37    pub json: bool,
38    /// Database path override.
39    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
40    pub db: Option<String>,
41    #[command(flatten)]
42    pub daemon: crate::cli::DaemonOpts,
43}
44
45#[derive(Deserialize)]
46struct BatchInputLine {
47    name: String,
48    #[serde(default = "default_type")]
49    r#type: String,
50    #[serde(default)]
51    description: String,
52    #[serde(default)]
53    body: String,
54    #[serde(default)]
55    entities: Vec<crate::storage::entities::NewEntity>,
56    #[serde(default)]
57    relationships: Vec<crate::storage::entities::NewRelationship>,
58}
59
60fn default_type() -> String {
61    "note".to_string()
62}
63
64#[derive(Serialize)]
65struct BatchItemEvent {
66    name: String,
67    status: String,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    memory_id: Option<i64>,
70    #[serde(skip_serializing_if = "Option::is_none")]
71    error: Option<String>,
72    index: usize,
73}
74
75#[derive(Serialize)]
76struct BatchSummary {
77    summary: bool,
78    total: usize,
79    succeeded: usize,
80    failed: usize,
81    elapsed_ms: u64,
82}
83
84pub fn run(args: RememberBatchArgs) -> Result<(), AppError> {
85    let start = std::time::Instant::now();
86    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
87    let paths = AppPaths::resolve(args.db.as_deref())?;
88    paths.ensure_dirs()?;
89    crate::storage::connection::ensure_db_ready(&paths)?;
90    let mut conn = open_rw(&paths.db)?;
91
92    let stdin = std::io::stdin();
93    let lines: Vec<String> = stdin
94        .lock()
95        .lines()
96        .map_while(Result::ok)
97        .filter(|l| !l.trim().is_empty())
98        .collect();
99
100    let total = lines.len();
101    let mut succeeded = 0usize;
102    let mut failed = 0usize;
103
104    if args.transaction {
105        let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
106        for (idx, line) in lines.iter().enumerate() {
107            match process_line(&tx, &namespace, line, idx, args.force_merge, &paths) {
108                Ok(event) => {
109                    output::emit_json(&event)?;
110                    succeeded += 1;
111                }
112                Err(e) => {
113                    failed += 1;
114                    output::emit_json(&BatchItemEvent {
115                        name: String::new(),
116                        status: "failed".to_string(),
117                        memory_id: None,
118                        error: Some(format!("{e}")),
119                        index: idx,
120                    })?;
121                    if args.fail_fast {
122                        break;
123                    }
124                }
125            }
126        }
127        if failed == 0 || !args.fail_fast {
128            tx.commit()?;
129        }
130    } else {
131        for (idx, line) in lines.iter().enumerate() {
132            let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
133            match process_line(&tx, &namespace, line, idx, args.force_merge, &paths) {
134                Ok(event) => {
135                    tx.commit()?;
136                    output::emit_json(&event)?;
137                    succeeded += 1;
138                }
139                Err(e) => {
140                    drop(tx);
141                    failed += 1;
142                    output::emit_json(&BatchItemEvent {
143                        name: String::new(),
144                        status: "failed".to_string(),
145                        memory_id: None,
146                        error: Some(format!("{e}")),
147                        index: idx,
148                    })?;
149                    if args.fail_fast {
150                        break;
151                    }
152                }
153            }
154        }
155    }
156
157    output::emit_json(&BatchSummary {
158        summary: true,
159        total,
160        succeeded,
161        failed,
162        elapsed_ms: start.elapsed().as_millis() as u64,
163    })?;
164
165    Ok(())
166}
167
168fn process_line(
169    tx: &rusqlite::Transaction<'_>,
170    namespace: &str,
171    line: &str,
172    index: usize,
173    force_merge: bool,
174    paths: &AppPaths,
175) -> Result<BatchItemEvent, AppError> {
176    let input: BatchInputLine = serde_json::from_str(line)
177        .map_err(|e| AppError::Validation(format!("line {index}: invalid JSON: {e}")))?;
178
179    let normalized_name = crate::parsers::normalize_entity_name(&input.name);
180    if normalized_name.is_empty() {
181        return Err(AppError::Validation(format!(
182            "line {index}: name normalizes to empty string"
183        )));
184    }
185
186    let body_hash = blake3::hash(input.body.as_bytes()).to_hex().to_string();
187
188    let existing = memories::find_by_name(tx, namespace, &normalized_name)?;
189
190    let memory_id = if let Some((existing_id, _updated_at, _version)) = existing {
191        if !force_merge {
192            return Err(AppError::Duplicate(format!(
193                "memory '{normalized_name}' already exists; use --force-merge to update"
194            )));
195        }
196        let snippet: String = input.body.chars().take(200).collect();
197        memories::update(
198            tx,
199            existing_id,
200            &memories::NewMemory {
201                namespace: namespace.to_string(),
202                name: normalized_name.clone(),
203                memory_type: input.r#type.clone(),
204                description: input.description.clone(),
205                body: input.body.clone(),
206                body_hash,
207                session_id: None,
208                source: "agent".to_string(),
209                metadata: serde_json::json!({}),
210            },
211            None,
212        )?;
213        let next_v = versions::next_version(tx, existing_id)?;
214        versions::insert_version(
215            tx,
216            existing_id,
217            next_v,
218            &normalized_name,
219            &input.r#type,
220            &input.description,
221            &input.body,
222            "{}",
223            None,
224            "edit",
225        )?;
226
227        let embedding = crate::daemon::embed_passage_or_local(&paths.models, &input.body)?;
228        memories::upsert_vec(
229            tx,
230            existing_id,
231            namespace,
232            &input.r#type,
233            &embedding,
234            &normalized_name,
235            &snippet,
236        )?;
237        existing_id
238    } else {
239        let new_mem = memories::NewMemory {
240            namespace: namespace.to_string(),
241            name: normalized_name.clone(),
242            memory_type: input.r#type.clone(),
243            description: input.description.clone(),
244            body: input.body.clone(),
245            body_hash,
246            session_id: None,
247            source: "agent".to_string(),
248            metadata: serde_json::json!({}),
249        };
250        let id = memories::insert(tx, &new_mem)?;
251        versions::insert_version(
252            tx,
253            id,
254            1,
255            &normalized_name,
256            &input.r#type,
257            &input.description,
258            &input.body,
259            "{}",
260            None,
261            "create",
262        )?;
263
264        let snippet: String = input.body.chars().take(200).collect();
265        let embedding = crate::daemon::embed_passage_or_local(&paths.models, &input.body)?;
266        memories::upsert_vec(
267            tx,
268            id,
269            namespace,
270            &input.r#type,
271            &embedding,
272            &normalized_name,
273            &snippet,
274        )?;
275        id
276    };
277
278    // Persist graph entities and relationships if provided
279    for entity in &input.entities {
280        let entity_id = entities::upsert_entity(tx, namespace, entity)?;
281        let entity_text = match &entity.description {
282            Some(desc) => format!("{} {}", entity.name, desc),
283            None => entity.name.clone(),
284        };
285        let entity_embedding = crate::daemon::embed_passage_or_local(&paths.models, &entity_text)?;
286        entities::upsert_entity_vec(
287            tx,
288            entity_id,
289            namespace,
290            entity.entity_type,
291            &entity_embedding,
292            &entity.name,
293        )?;
294        entities::link_memory_entity(tx, memory_id, entity_id)?;
295    }
296
297    for rel in &input.relationships {
298        let src_name = crate::parsers::normalize_entity_name(&rel.source);
299        let tgt_name = crate::parsers::normalize_entity_name(&rel.target);
300        if let (Some(src_id), Some(tgt_id)) = (
301            entities::find_entity_id(tx, namespace, &src_name)?,
302            entities::find_entity_id(tx, namespace, &tgt_name)?,
303        ) {
304            entities::create_or_fetch_relationship(
305                tx,
306                namespace,
307                src_id,
308                tgt_id,
309                &rel.relation,
310                rel.strength,
311                rel.description.as_deref(),
312            )?;
313        }
314    }
315
316    Ok(BatchItemEvent {
317        name: normalized_name,
318        status: "indexed".to_string(),
319        memory_id: Some(memory_id),
320        error: None,
321        index,
322    })
323}