Skip to main content

sqlite_graphrag/commands/
restore.rs

1//! Handler for the `restore` CLI subcommand.
2
3use crate::errors::AppError;
4use crate::i18n::errors_msg;
5use crate::output;
6use crate::output::JsonOutputFormat;
7use crate::paths::AppPaths;
8use crate::storage::connection::open_rw;
9use crate::storage::memories;
10use crate::storage::versions;
11use rusqlite::params;
12use rusqlite::OptionalExtension;
13use serde::Serialize;
14
15#[derive(clap::Args)]
16#[command(after_long_help = "EXAMPLES:\n  \
17    # Restore the latest non-`restore` version of a memory\n  \
18    sqlite-graphrag restore --name onboarding\n\n  \
19    # Restore a specific version\n  \
20    sqlite-graphrag restore --name onboarding --version 3\n\n  \
21    # Restore within a specific namespace\n  \
22    sqlite-graphrag restore --name onboarding --namespace my-project")]
23pub struct RestoreArgs {
24    /// Memory name as a positional argument. Alternative to `--name`.
25    #[arg(
26        value_name = "NAME",
27        conflicts_with = "name",
28        help = "Memory name to restore; alternative to --name"
29    )]
30    pub name_positional: Option<String>,
31    /// Memory name to restore (must exist, including soft-deleted/forgotten).
32    #[arg(long)]
33    pub name: Option<String>,
34    /// Version to restore. When omitted, defaults to the latest non-`restore` version
35    /// from `memory_versions`. This makes the forget+restore workflow work without
36    /// requiring the user to discover the version first.
37    #[arg(long)]
38    pub version: Option<i64>,
39    #[arg(
40        long,
41        help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
42    )]
43    pub namespace: Option<String>,
44    /// Optimistic locking: reject if the current updated_at does not match (exit 3).
45    #[arg(
46        long,
47        value_name = "EPOCH_OR_RFC3339",
48        value_parser = crate::parsers::parse_expected_updated_at,
49        long_help = "Optimistic lock: reject if updated_at does not match. \
50Accepts Unix epoch (e.g. 1700000000) or RFC 3339 (e.g. 2026-04-19T12:00:00Z)."
51    )]
52    pub expected_updated_at: Option<i64>,
53    /// Output format.
54    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
55    pub format: JsonOutputFormat,
56    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
57    pub json: bool,
58    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
59    pub db: Option<String>,
60}
61
62#[derive(Serialize)]
63struct RestoreResponse {
64    /// Always `"restored"` — signals the completed action to shell callers and LLM agents.
65    action: String,
66    memory_id: i64,
67    name: String,
68    version: i64,
69    restored_from: i64,
70    /// Total execution time in milliseconds from handler start to serialisation.
71    elapsed_ms: u64,
72}
73
74pub fn run(
75    args: RestoreArgs,
76    llm_backend: crate::cli::LlmBackendChoice,
77    embedding_backend: crate::cli::EmbeddingBackendChoice,
78) -> Result<(), AppError> {
79    let start = std::time::Instant::now();
80    let _ = args.format;
81    tracing::debug!(target: "restore", name = ?args.name_positional.as_deref().or(args.name.as_deref()), version = ?args.version, "restoring version");
82    let name = args
83        .name_positional
84        .as_deref()
85        .or(args.name.as_deref())
86        .ok_or_else(|| {
87            AppError::Validation(
88                "name required: pass as positional argument or via --name".to_string(),
89            )
90        })?
91        .to_string();
92    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
93    let paths = AppPaths::resolve(args.db.as_deref())?;
94    let mut conn = open_rw(&paths.db)?;
95
96    // PRD line 1118: query WITHOUT a deleted_at filter — restore must work on soft-deleted memories
97    let result: Option<(i64, i64)> = conn
98        .query_row(
99            "SELECT id, updated_at FROM memories WHERE namespace = ?1 AND name = ?2",
100            params![namespace, name],
101            |r| Ok((r.get(0)?, r.get(1)?)),
102        )
103        .optional()?;
104    let (memory_id, current_updated_at) = result
105        .ok_or_else(|| AppError::NotFound(errors_msg::memory_not_found(&name, &namespace)))?;
106
107    if let Some(expected) = args.expected_updated_at {
108        if expected != current_updated_at {
109            return Err(AppError::Conflict(errors_msg::optimistic_lock_conflict(
110                expected,
111                current_updated_at,
112            )));
113        }
114    }
115
116    // v1.0.22 P0: resolve optional `--version`. When absent, uses the highest version
117    // whose `change_reason` is not 'restore' (recovers the real state, not meta-restore).
118    // Lets the forget+restore workflow function without manually reading memory_versions.
119    let target_version: i64 = match args.version {
120        Some(v) => v,
121        None => {
122            let last: Option<i64> = conn
123                .query_row(
124                    "SELECT MAX(version) FROM memory_versions
125                     WHERE memory_id = ?1 AND change_reason != 'restore'",
126                    params![memory_id],
127                    |r| r.get(0),
128                )
129                .optional()?
130                .flatten();
131            let v = last.ok_or_else(|| {
132                AppError::NotFound(errors_msg::memory_not_found(&name, &namespace))
133            })?;
134            tracing::info!(target: "restore",
135                "restore --version omitted; using latest non-restore version: {}",
136                v
137            );
138            v
139        }
140    };
141
142    let version_row: (String, String, String, String, String) = {
143        let mut stmt = conn.prepare_cached(
144            "SELECT name, type, description, body, metadata
145             FROM memory_versions
146             WHERE memory_id = ?1 AND version = ?2",
147        )?;
148
149        stmt.query_row(params![memory_id, target_version], |r| {
150            Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?, r.get(4)?))
151        })
152        .map_err(|_| AppError::NotFound(errors_msg::version_not_found(target_version, &name)))?
153    };
154
155    let (_old_name, old_type, old_description, old_body, old_metadata) = version_row;
156
157    // Read current FTS-indexed values before the UPDATE so sync_fts_after_update
158    // can issue the correct DELETE command for the external-content FTS5 table.
159    let (cur_name, cur_desc, cur_body): (String, String, String) = conn.query_row(
160        "SELECT name, description, body FROM memories WHERE id = ?1",
161        params![memory_id],
162        |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
163    )?;
164
165    // v1.0.21 P1-D: re-embed restored body to keep `vec_memories` synchronized
166    // with `memories`. Without this, semantic queries used the post-forget version
167    // vector, causing inconsistent recall (vec_memories=2 vs memories=3 after forget+restore).
168    output::emit_progress_i18n(
169        "Re-computing embedding for restored memory...",
170        crate::i18n::validation::runtime_pt::restore_recomputing_embedding(),
171    );
172    let skip_embed = crate::embedder::should_skip_embedding_on_failure();
173    let embedding: Option<Vec<f32>> = match crate::embedder::embed_passage_with_embedding_choice(
174        &paths.models,
175        &old_body,
176        embedding_backend,
177        llm_backend,
178    ) {
179        Ok((emb, _backend)) => Some(emb),
180        Err(AppError::Validation(msg)) => return Err(AppError::Validation(msg)),
181        Err(e) if skip_embed => {
182            tracing::warn!(error = %e, "restore: embedding failed; --skip-embedding-on-failure active, persisting without embedding");
183            None
184        }
185        Err(e) => return Err(e),
186    };
187    let snippet: String = old_body.chars().take(300).collect();
188
189    let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
190
191    // deleted_at = NULL reactivates soft-deleted memories; no deleted_at filter in the WHERE
192    let affected = if let Some(ts) = args.expected_updated_at {
193        tx.execute(
194            "UPDATE memories SET type=?2, description=?3, body=?4, body_hash=?5, deleted_at=NULL
195             WHERE id=?1 AND updated_at=?6",
196            rusqlite::params![
197                memory_id,
198                old_type,
199                old_description,
200                old_body,
201                blake3::hash(old_body.as_bytes()).to_hex().to_string(),
202                ts
203            ],
204        )?
205    } else {
206        tx.execute(
207            "UPDATE memories SET type=?2, description=?3, body=?4, body_hash=?5, deleted_at=NULL
208             WHERE id=?1",
209            rusqlite::params![
210                memory_id,
211                old_type,
212                old_description,
213                old_body,
214                blake3::hash(old_body.as_bytes()).to_hex().to_string()
215            ],
216        )?
217    };
218
219    if affected == 0 {
220        return Err(AppError::Conflict(errors_msg::concurrent_process_conflict()));
221    }
222
223    let next_v = versions::next_version(&tx, memory_id)?;
224
225    versions::insert_version(
226        &tx,
227        memory_id,
228        next_v,
229        &cur_name,
230        &old_type,
231        &old_description,
232        &old_body,
233        &old_metadata,
234        None,
235        "restore",
236    )?;
237
238    if let Some(ref emb) = embedding {
239        memories::upsert_vec(
240            &tx, memory_id, &namespace, &old_type, emb, &cur_name, &snippet,
241        )?;
242    }
243
244    memories::sync_fts_after_update(
245        &tx,
246        memory_id,
247        &cur_name,
248        &cur_desc,
249        &cur_body,
250        &cur_name,
251        &old_description,
252        &old_body,
253    )?;
254
255    tx.commit()?;
256
257    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
258
259    output::emit_json(&RestoreResponse {
260        action: "restored".to_string(),
261        memory_id,
262        name: cur_name.clone(),
263        version: next_v,
264        restored_from: target_version,
265        elapsed_ms: start.elapsed().as_millis() as u64,
266    })?;
267
268    Ok(())
269}
270
271#[cfg(test)]
272mod tests {
273    use crate::errors::AppError;
274
275    #[test]
276    fn optimistic_lock_conflict_returns_exit_3() {
277        let err = AppError::Conflict(
278            "optimistic lock conflict: expected updated_at=50, but current is 99".to_string(),
279        );
280        assert_eq!(err.exit_code(), 3);
281        assert!(err.to_string().contains("conflict"));
282    }
283
284    #[test]
285    fn restore_response_includes_action_field() {
286        let resp = super::RestoreResponse {
287            action: "restored".to_string(),
288            memory_id: 1,
289            name: "test-mem".to_string(),
290            version: 3,
291            restored_from: 2,
292            elapsed_ms: 42,
293        };
294        let json = serde_json::to_value(&resp).expect("serialization failed");
295        assert_eq!(json["action"], "restored");
296        assert_eq!(json["memory_id"], 1);
297        assert_eq!(json["version"], 3);
298        assert_eq!(json["restored_from"], 2);
299    }
300}