Skip to main content

tuitbot_core/storage/
mutation_audit.rs

1//! DB-backed mutation audit trail for idempotency and incident review.
2//!
3//! Every mutation-capable MCP tool records an entry before executing.
4//! The table serves dual purposes:
5//!   1. **Idempotency** — detect and short-circuit recent identical mutations.
6//!   2. **Audit** — every mutation attempt is traceable via `correlation_id`.
7
8use sha2::{Digest, Sha256};
9
10use super::accounts::DEFAULT_ACCOUNT_ID;
11use super::DbPool;
12use crate::error::StorageError;
13
14/// An entry in the mutation audit trail.
15#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
16pub struct MutationAuditEntry {
17    pub id: i64,
18    pub correlation_id: String,
19    pub idempotency_key: Option<String>,
20    pub tool_name: String,
21    pub params_hash: String,
22    pub params_summary: String,
23    pub status: String,
24    pub result_summary: Option<String>,
25    pub rollback_action: Option<String>,
26    pub error_message: Option<String>,
27    pub elapsed_ms: Option<i64>,
28    pub account_id: String,
29    pub created_at: String,
30    pub completed_at: Option<String>,
31}
32
33/// Compute a SHA-256 hash of the canonical params JSON.
34pub fn compute_params_hash(tool_name: &str, params_json: &str) -> String {
35    let mut hasher = Sha256::new();
36    hasher.update(tool_name.as_bytes());
37    hasher.update(b"|");
38    hasher.update(params_json.as_bytes());
39    format!("{:x}", hasher.finalize())
40}
41
42/// Truncate a JSON string for display (max 500 chars).
43pub fn truncate_summary(json: &str, max_len: usize) -> String {
44    if json.len() <= max_len {
45        json.to_string()
46    } else {
47        format!("{}…", &json[..max_len])
48    }
49}
50
51/// Check if a recent successful mutation with the same fingerprint exists.
52///
53/// Returns the cached entry if found within `window_seconds` (default 300 = 5 min).
54pub async fn find_recent_duplicate(
55    pool: &DbPool,
56    tool_name: &str,
57    params_hash: &str,
58    window_seconds: u32,
59) -> Result<Option<MutationAuditEntry>, StorageError> {
60    let entry = sqlx::query_as::<_, MutationAuditEntry>(
61        "SELECT * FROM mutation_audit
62         WHERE tool_name = ? AND params_hash = ? AND status = 'success'
63           AND created_at >= strftime('%Y-%m-%dT%H:%M:%fZ', 'now', '-' || ? || ' seconds')
64         ORDER BY created_at DESC LIMIT 1",
65    )
66    .bind(tool_name)
67    .bind(params_hash)
68    .bind(window_seconds)
69    .fetch_optional(pool)
70    .await
71    .map_err(|e| StorageError::Query { source: e })?;
72
73    Ok(entry)
74}
75
76/// Check if a recent mutation with a specific idempotency key exists.
77pub async fn find_by_idempotency_key(
78    pool: &DbPool,
79    key: &str,
80) -> Result<Option<MutationAuditEntry>, StorageError> {
81    let entry = sqlx::query_as::<_, MutationAuditEntry>(
82        "SELECT * FROM mutation_audit
83         WHERE idempotency_key = ?
84         ORDER BY created_at DESC LIMIT 1",
85    )
86    .bind(key)
87    .fetch_optional(pool)
88    .await
89    .map_err(|e| StorageError::Query { source: e })?;
90
91    Ok(entry)
92}
93
94/// Insert a new pending mutation audit record.
95///
96/// Returns the DB row ID.
97pub async fn insert_pending(
98    pool: &DbPool,
99    correlation_id: &str,
100    idempotency_key: Option<&str>,
101    tool_name: &str,
102    params_hash: &str,
103    params_summary: &str,
104) -> Result<i64, StorageError> {
105    insert_pending_for(
106        pool,
107        DEFAULT_ACCOUNT_ID,
108        correlation_id,
109        idempotency_key,
110        tool_name,
111        params_hash,
112        params_summary,
113    )
114    .await
115}
116
117/// Insert a new pending mutation audit record for a specific account.
118pub async fn insert_pending_for(
119    pool: &DbPool,
120    account_id: &str,
121    correlation_id: &str,
122    idempotency_key: Option<&str>,
123    tool_name: &str,
124    params_hash: &str,
125    params_summary: &str,
126) -> Result<i64, StorageError> {
127    let result = sqlx::query(
128        "INSERT INTO mutation_audit
129            (correlation_id, idempotency_key, tool_name, params_hash, params_summary, account_id)
130         VALUES (?, ?, ?, ?, ?, ?)",
131    )
132    .bind(correlation_id)
133    .bind(idempotency_key)
134    .bind(tool_name)
135    .bind(params_hash)
136    .bind(params_summary)
137    .bind(account_id)
138    .execute(pool)
139    .await
140    .map_err(|e| StorageError::Query { source: e })?;
141
142    Ok(result.last_insert_rowid())
143}
144
145/// Mark a mutation as successfully completed.
146pub async fn complete_success(
147    pool: &DbPool,
148    id: i64,
149    result_summary: &str,
150    rollback_action: Option<&str>,
151    elapsed_ms: u64,
152) -> Result<(), StorageError> {
153    sqlx::query(
154        "UPDATE mutation_audit
155         SET status = 'success',
156             result_summary = ?,
157             rollback_action = ?,
158             elapsed_ms = ?,
159             completed_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
160         WHERE id = ?",
161    )
162    .bind(result_summary)
163    .bind(rollback_action)
164    .bind(elapsed_ms as i64)
165    .bind(id)
166    .execute(pool)
167    .await
168    .map_err(|e| StorageError::Query { source: e })?;
169
170    Ok(())
171}
172
173/// Mark a mutation as failed.
174pub async fn complete_failure(
175    pool: &DbPool,
176    id: i64,
177    error_message: &str,
178    elapsed_ms: u64,
179) -> Result<(), StorageError> {
180    sqlx::query(
181        "UPDATE mutation_audit
182         SET status = 'failure',
183             error_message = ?,
184             elapsed_ms = ?,
185             completed_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
186         WHERE id = ?",
187    )
188    .bind(error_message)
189    .bind(elapsed_ms as i64)
190    .bind(id)
191    .execute(pool)
192    .await
193    .map_err(|e| StorageError::Query { source: e })?;
194
195    Ok(())
196}
197
198/// Mark a mutation as a duplicate (idempotency hit).
199pub async fn mark_duplicate(
200    pool: &DbPool,
201    id: i64,
202    original_correlation_id: &str,
203) -> Result<(), StorageError> {
204    sqlx::query(
205        "UPDATE mutation_audit
206         SET status = 'duplicate',
207             result_summary = ?,
208             completed_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
209         WHERE id = ?",
210    )
211    .bind(format!(
212        "{{\"duplicate_of\":\"{original_correlation_id}\"}}"
213    ))
214    .bind(id)
215    .execute(pool)
216    .await
217    .map_err(|e| StorageError::Query { source: e })?;
218
219    Ok(())
220}
221
222/// Get recent mutations, newest first.
223pub async fn get_recent(
224    pool: &DbPool,
225    limit: u32,
226    tool_name: Option<&str>,
227    status: Option<&str>,
228) -> Result<Vec<MutationAuditEntry>, StorageError> {
229    let mut sql = String::from("SELECT * FROM mutation_audit WHERE 1=1");
230    if tool_name.is_some() {
231        sql.push_str(" AND tool_name = ?");
232    }
233    if status.is_some() {
234        sql.push_str(" AND status = ?");
235    }
236    sql.push_str(" ORDER BY created_at DESC LIMIT ?");
237
238    let mut query = sqlx::query_as::<_, MutationAuditEntry>(&sql);
239    if let Some(t) = tool_name {
240        query = query.bind(t);
241    }
242    if let Some(s) = status {
243        query = query.bind(s);
244    }
245    query = query.bind(limit);
246
247    query
248        .fetch_all(pool)
249        .await
250        .map_err(|e| StorageError::Query { source: e })
251}
252
253/// Get a single mutation by correlation ID.
254pub async fn get_by_correlation_id(
255    pool: &DbPool,
256    correlation_id: &str,
257) -> Result<Option<MutationAuditEntry>, StorageError> {
258    sqlx::query_as::<_, MutationAuditEntry>("SELECT * FROM mutation_audit WHERE correlation_id = ?")
259        .bind(correlation_id)
260        .fetch_optional(pool)
261        .await
262        .map_err(|e| StorageError::Query { source: e })
263}
264
265/// Get mutation counts grouped by status within a time window.
266pub async fn get_status_counts(
267    pool: &DbPool,
268    since_hours: u32,
269) -> Result<Vec<(String, i64)>, StorageError> {
270    sqlx::query_as::<_, (String, i64)>(
271        "SELECT status, COUNT(*) FROM mutation_audit
272         WHERE created_at >= strftime('%Y-%m-%dT%H:%M:%fZ', 'now', '-' || ? || ' hours')
273         GROUP BY status ORDER BY COUNT(*) DESC",
274    )
275    .bind(since_hours)
276    .fetch_all(pool)
277    .await
278    .map_err(|e| StorageError::Query { source: e })
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::storage::init_test_db;
285
286    #[tokio::test]
287    async fn insert_and_complete_success() {
288        let pool = init_test_db().await.expect("init db");
289
290        let id = insert_pending(
291            &pool,
292            "corr-001",
293            None,
294            "x_post_tweet",
295            "hash123",
296            r#"{"text":"hello"}"#,
297        )
298        .await
299        .expect("insert");
300
301        complete_success(
302            &pool,
303            id,
304            r#"{"tweet_id":"999"}"#,
305            Some(r#"{"tool":"x_delete_tweet","params":{"tweet_id":"999"}}"#),
306            150,
307        )
308        .await
309        .expect("complete");
310
311        let entry = get_by_correlation_id(&pool, "corr-001")
312            .await
313            .expect("get")
314            .expect("found");
315        assert_eq!(entry.status, "success");
316        assert_eq!(entry.tool_name, "x_post_tweet");
317        assert!(entry.rollback_action.is_some());
318        assert_eq!(entry.elapsed_ms, Some(150));
319    }
320
321    #[tokio::test]
322    async fn insert_and_complete_failure() {
323        let pool = init_test_db().await.expect("init db");
324
325        let id = insert_pending(&pool, "corr-002", None, "x_like_tweet", "hash456", "{}")
326            .await
327            .expect("insert");
328
329        complete_failure(&pool, id, "Rate limited", 50)
330            .await
331            .expect("complete");
332
333        let entry = get_by_correlation_id(&pool, "corr-002")
334            .await
335            .expect("get")
336            .expect("found");
337        assert_eq!(entry.status, "failure");
338        assert_eq!(entry.error_message.as_deref(), Some("Rate limited"));
339    }
340
341    #[tokio::test]
342    async fn find_recent_duplicate_within_window() {
343        let pool = init_test_db().await.expect("init db");
344        let hash = compute_params_hash("x_post_tweet", r#"{"text":"hi"}"#);
345
346        let id = insert_pending(&pool, "corr-003", None, "x_post_tweet", &hash, "{}")
347            .await
348            .expect("insert");
349
350        complete_success(&pool, id, r#"{"tweet_id":"111"}"#, None, 100)
351            .await
352            .expect("complete");
353
354        let dup = find_recent_duplicate(&pool, "x_post_tweet", &hash, 300)
355            .await
356            .expect("find");
357        assert!(dup.is_some());
358        assert_eq!(dup.unwrap().correlation_id, "corr-003");
359    }
360
361    #[tokio::test]
362    async fn no_duplicate_for_different_tool() {
363        let pool = init_test_db().await.expect("init db");
364        let hash = compute_params_hash("x_post_tweet", r#"{"text":"hi"}"#);
365
366        let id = insert_pending(&pool, "corr-004", None, "x_post_tweet", &hash, "{}")
367            .await
368            .expect("insert");
369        complete_success(&pool, id, "{}", None, 50)
370            .await
371            .expect("complete");
372
373        let other_hash = compute_params_hash("x_like_tweet", r#"{"text":"hi"}"#);
374        let dup = find_recent_duplicate(&pool, "x_like_tweet", &other_hash, 300)
375            .await
376            .expect("find");
377        assert!(dup.is_none());
378    }
379
380    #[tokio::test]
381    async fn idempotency_key_lookup() {
382        let pool = init_test_db().await.expect("init db");
383
384        let id = insert_pending(
385            &pool,
386            "corr-005",
387            Some("user-key-abc"),
388            "x_post_tweet",
389            "hash789",
390            "{}",
391        )
392        .await
393        .expect("insert");
394        complete_success(&pool, id, r#"{"tweet_id":"222"}"#, None, 75)
395            .await
396            .expect("complete");
397
398        let found = find_by_idempotency_key(&pool, "user-key-abc")
399            .await
400            .expect("find")
401            .expect("found");
402        assert_eq!(found.correlation_id, "corr-005");
403
404        let not_found = find_by_idempotency_key(&pool, "nonexistent")
405            .await
406            .expect("find");
407        assert!(not_found.is_none());
408    }
409
410    #[tokio::test]
411    async fn get_recent_with_filters() {
412        let pool = init_test_db().await.expect("init db");
413
414        for (tool, status_val) in [
415            ("x_post_tweet", "success"),
416            ("x_like_tweet", "success"),
417            ("x_post_tweet", "failure"),
418        ] {
419            let id = insert_pending(
420                &pool,
421                &format!("c-{tool}-{status_val}"),
422                None,
423                tool,
424                "h",
425                "{}",
426            )
427            .await
428            .expect("insert");
429            if status_val == "success" {
430                complete_success(&pool, id, "{}", None, 10)
431                    .await
432                    .expect("ok");
433            } else {
434                complete_failure(&pool, id, "err", 10).await.expect("ok");
435            }
436        }
437
438        let all = get_recent(&pool, 10, None, None).await.expect("all");
439        assert_eq!(all.len(), 3);
440
441        let tweets = get_recent(&pool, 10, Some("x_post_tweet"), None)
442            .await
443            .expect("tweets");
444        assert_eq!(tweets.len(), 2);
445
446        let successes = get_recent(&pool, 10, None, Some("success"))
447            .await
448            .expect("successes");
449        assert_eq!(successes.len(), 2);
450    }
451
452    #[tokio::test]
453    async fn mark_duplicate_records_original() {
454        let pool = init_test_db().await.expect("init db");
455
456        let id = insert_pending(&pool, "corr-dup", None, "x_post_tweet", "h", "{}")
457            .await
458            .expect("insert");
459
460        mark_duplicate(&pool, id, "corr-original")
461            .await
462            .expect("mark");
463
464        let entry = get_by_correlation_id(&pool, "corr-dup")
465            .await
466            .expect("get")
467            .expect("found");
468        assert_eq!(entry.status, "duplicate");
469        assert!(entry
470            .result_summary
471            .as_deref()
472            .unwrap()
473            .contains("corr-original"));
474    }
475
476    #[tokio::test]
477    async fn status_counts_aggregation() {
478        let pool = init_test_db().await.expect("init db");
479
480        for (i, status_val) in ["success", "success", "failure", "duplicate"]
481            .iter()
482            .enumerate()
483        {
484            let id = insert_pending(&pool, &format!("c-{i}"), None, "tool", "h", "{}")
485                .await
486                .expect("insert");
487            match *status_val {
488                "success" => {
489                    complete_success(&pool, id, "{}", None, 10)
490                        .await
491                        .expect("ok");
492                }
493                "failure" => {
494                    complete_failure(&pool, id, "err", 10).await.expect("ok");
495                }
496                "duplicate" => {
497                    mark_duplicate(&pool, id, "other").await.expect("ok");
498                }
499                _ => {}
500            }
501        }
502
503        let counts = get_status_counts(&pool, 24).await.expect("counts");
504        let success_count = counts
505            .iter()
506            .find(|(s, _)| s == "success")
507            .map(|(_, c)| *c)
508            .unwrap_or(0);
509        assert_eq!(success_count, 2);
510    }
511
512    #[tokio::test]
513    async fn params_hash_deterministic() {
514        let h1 = compute_params_hash("x_post_tweet", r#"{"text":"hello"}"#);
515        let h2 = compute_params_hash("x_post_tweet", r#"{"text":"hello"}"#);
516        assert_eq!(h1, h2);
517        assert_eq!(h1.len(), 64); // SHA-256 hex
518
519        let h3 = compute_params_hash("x_post_tweet", r#"{"text":"world"}"#);
520        assert_ne!(h1, h3);
521    }
522}