Skip to main content

stakpak_server/
idempotency.rs

1use serde::{Deserialize, Serialize};
2use std::{
3    collections::HashMap,
4    hash::{Hash, Hasher},
5    sync::Arc,
6    time::{Duration, Instant},
7};
8use tokio::sync::RwLock;
9
10#[derive(Debug, Clone)]
11pub struct IdempotencyRequest {
12    pub method: String,
13    pub path: String,
14    pub key: String,
15    pub body: serde_json::Value,
16}
17
18impl IdempotencyRequest {
19    pub fn new(
20        method: impl Into<String>,
21        path: impl Into<String>,
22        key: impl Into<String>,
23        body: serde_json::Value,
24    ) -> Self {
25        Self {
26            method: method.into(),
27            path: path.into(),
28            key: key.into(),
29            body,
30        }
31    }
32
33    fn storage_key(&self) -> String {
34        format!(
35            "{}:{}:{}",
36            self.method.to_ascii_uppercase(),
37            self.path,
38            self.key
39        )
40    }
41
42    fn body_hash(&self) -> u64 {
43        let mut hasher = std::collections::hash_map::DefaultHasher::new();
44        match serde_json::to_vec(&self.body) {
45            Ok(bytes) => bytes.hash(&mut hasher),
46            Err(_) => self.body.to_string().hash(&mut hasher),
47        }
48        hasher.finish()
49    }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
53pub struct StoredResponse {
54    pub status_code: u16,
55    pub body: serde_json::Value,
56}
57
58impl StoredResponse {
59    pub fn new(status_code: u16, body: serde_json::Value) -> Self {
60        Self { status_code, body }
61    }
62}
63
64#[derive(Debug, Clone, PartialEq)]
65pub enum LookupResult {
66    Proceed,
67    Replay(StoredResponse),
68    Conflict,
69}
70
71#[derive(Clone)]
72pub struct IdempotencyStore {
73    retention: Duration,
74    records: Arc<RwLock<HashMap<String, Record>>>,
75}
76
77#[derive(Debug, Clone)]
78struct Record {
79    body_hash: u64,
80    response: StoredResponse,
81    inserted_at: Instant,
82}
83
84impl IdempotencyStore {
85    pub fn new(retention: Duration) -> Self {
86        Self {
87            retention,
88            records: Arc::new(RwLock::new(HashMap::new())),
89        }
90    }
91
92    pub async fn lookup(&self, request: &IdempotencyRequest) -> LookupResult {
93        self.prune_expired().await;
94
95        let key = request.storage_key();
96        let body_hash = request.body_hash();
97
98        let guard = self.records.read().await;
99        match guard.get(&key) {
100            None => LookupResult::Proceed,
101            Some(record) if record.body_hash == body_hash => {
102                LookupResult::Replay(record.response.clone())
103            }
104            Some(_) => LookupResult::Conflict,
105        }
106    }
107
108    pub async fn save(&self, request: &IdempotencyRequest, response: StoredResponse) {
109        self.prune_expired().await;
110
111        let mut guard = self.records.write().await;
112        guard.insert(
113            request.storage_key(),
114            Record {
115                body_hash: request.body_hash(),
116                response,
117                inserted_at: Instant::now(),
118            },
119        );
120    }
121
122    async fn prune_expired(&self) {
123        let mut guard = self.records.write().await;
124        let retention = self.retention;
125
126        guard.retain(|_, record| record.inserted_at.elapsed() <= retention);
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use serde_json::json;
134
135    #[tokio::test]
136    async fn returns_proceed_for_first_request_then_replay_after_save() {
137        let store = IdempotencyStore::new(Duration::from_secs(60));
138        let request =
139            IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"title":"test"}));
140
141        let first = store.lookup(&request).await;
142        assert_eq!(first, LookupResult::Proceed);
143
144        let response = StoredResponse::new(201, json!({"session_id":"s_1"}));
145        store.save(&request, response.clone()).await;
146
147        let second = store.lookup(&request).await;
148        assert_eq!(second, LookupResult::Replay(response));
149    }
150
151    #[tokio::test]
152    async fn returns_conflict_for_same_key_with_different_body() {
153        let store = IdempotencyStore::new(Duration::from_secs(60));
154        let first = IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"a":1}));
155        let second = IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"a":2}));
156
157        store
158            .save(&first, StoredResponse::new(200, json!({"ok":true})))
159            .await;
160
161        let lookup = store.lookup(&second).await;
162        assert_eq!(lookup, LookupResult::Conflict);
163    }
164
165    #[tokio::test]
166    async fn same_key_on_different_path_is_independent() {
167        let store = IdempotencyStore::new(Duration::from_secs(60));
168        let first = IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"a":1}));
169        let second = IdempotencyRequest::new(
170            "POST",
171            "/v1/sessions/123/cancel",
172            "abc",
173            json!({"run_id":"r1"}),
174        );
175
176        store
177            .save(&first, StoredResponse::new(200, json!({"ok":true})))
178            .await;
179
180        let lookup = store.lookup(&second).await;
181        assert_eq!(lookup, LookupResult::Proceed);
182    }
183
184    #[tokio::test]
185    async fn records_expire_after_retention_window() {
186        let store = IdempotencyStore::new(Duration::from_millis(10));
187        let request = IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"a":1}));
188
189        store
190            .save(&request, StoredResponse::new(200, json!({"ok":true})))
191            .await;
192
193        tokio::time::sleep(Duration::from_millis(20)).await;
194
195        let lookup = store.lookup(&request).await;
196        assert_eq!(lookup, LookupResult::Proceed);
197    }
198}