1use serde::{Deserialize, Serialize};
2use std::{
3 collections::HashMap,
4 hash::{Hash, Hasher},
5 sync::Arc,
6 time::{Duration, Instant},
7};
8use tokio::sync::RwLock;
9
10#[derive(Debug, Clone)]
11pub struct IdempotencyRequest {
12 pub method: String,
13 pub path: String,
14 pub key: String,
15 pub body: serde_json::Value,
16}
17
18impl IdempotencyRequest {
19 pub fn new(
20 method: impl Into<String>,
21 path: impl Into<String>,
22 key: impl Into<String>,
23 body: serde_json::Value,
24 ) -> Self {
25 Self {
26 method: method.into(),
27 path: path.into(),
28 key: key.into(),
29 body,
30 }
31 }
32
33 fn storage_key(&self) -> String {
34 format!(
35 "{}:{}:{}",
36 self.method.to_ascii_uppercase(),
37 self.path,
38 self.key
39 )
40 }
41
42 fn body_hash(&self) -> u64 {
43 let mut hasher = std::collections::hash_map::DefaultHasher::new();
44 match serde_json::to_vec(&self.body) {
45 Ok(bytes) => bytes.hash(&mut hasher),
46 Err(_) => self.body.to_string().hash(&mut hasher),
47 }
48 hasher.finish()
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
53pub struct StoredResponse {
54 pub status_code: u16,
55 pub body: serde_json::Value,
56}
57
58impl StoredResponse {
59 pub fn new(status_code: u16, body: serde_json::Value) -> Self {
60 Self { status_code, body }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq)]
65pub enum LookupResult {
66 Proceed,
67 Replay(StoredResponse),
68 Conflict,
69}
70
71#[derive(Clone)]
72pub struct IdempotencyStore {
73 retention: Duration,
74 records: Arc<RwLock<HashMap<String, Record>>>,
75}
76
77#[derive(Debug, Clone)]
78struct Record {
79 body_hash: u64,
80 response: StoredResponse,
81 inserted_at: Instant,
82}
83
84impl IdempotencyStore {
85 pub fn new(retention: Duration) -> Self {
86 Self {
87 retention,
88 records: Arc::new(RwLock::new(HashMap::new())),
89 }
90 }
91
92 pub async fn lookup(&self, request: &IdempotencyRequest) -> LookupResult {
93 self.prune_expired().await;
94
95 let key = request.storage_key();
96 let body_hash = request.body_hash();
97
98 let guard = self.records.read().await;
99 match guard.get(&key) {
100 None => LookupResult::Proceed,
101 Some(record) if record.body_hash == body_hash => {
102 LookupResult::Replay(record.response.clone())
103 }
104 Some(_) => LookupResult::Conflict,
105 }
106 }
107
108 pub async fn save(&self, request: &IdempotencyRequest, response: StoredResponse) {
109 self.prune_expired().await;
110
111 let mut guard = self.records.write().await;
112 guard.insert(
113 request.storage_key(),
114 Record {
115 body_hash: request.body_hash(),
116 response,
117 inserted_at: Instant::now(),
118 },
119 );
120 }
121
122 async fn prune_expired(&self) {
123 let mut guard = self.records.write().await;
124 let retention = self.retention;
125
126 guard.retain(|_, record| record.inserted_at.elapsed() <= retention);
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use serde_json::json;
134
135 #[tokio::test]
136 async fn returns_proceed_for_first_request_then_replay_after_save() {
137 let store = IdempotencyStore::new(Duration::from_secs(60));
138 let request =
139 IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"title":"test"}));
140
141 let first = store.lookup(&request).await;
142 assert_eq!(first, LookupResult::Proceed);
143
144 let response = StoredResponse::new(201, json!({"session_id":"s_1"}));
145 store.save(&request, response.clone()).await;
146
147 let second = store.lookup(&request).await;
148 assert_eq!(second, LookupResult::Replay(response));
149 }
150
151 #[tokio::test]
152 async fn returns_conflict_for_same_key_with_different_body() {
153 let store = IdempotencyStore::new(Duration::from_secs(60));
154 let first = IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"a":1}));
155 let second = IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"a":2}));
156
157 store
158 .save(&first, StoredResponse::new(200, json!({"ok":true})))
159 .await;
160
161 let lookup = store.lookup(&second).await;
162 assert_eq!(lookup, LookupResult::Conflict);
163 }
164
165 #[tokio::test]
166 async fn same_key_on_different_path_is_independent() {
167 let store = IdempotencyStore::new(Duration::from_secs(60));
168 let first = IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"a":1}));
169 let second = IdempotencyRequest::new(
170 "POST",
171 "/v1/sessions/123/cancel",
172 "abc",
173 json!({"run_id":"r1"}),
174 );
175
176 store
177 .save(&first, StoredResponse::new(200, json!({"ok":true})))
178 .await;
179
180 let lookup = store.lookup(&second).await;
181 assert_eq!(lookup, LookupResult::Proceed);
182 }
183
184 #[tokio::test]
185 async fn records_expire_after_retention_window() {
186 let store = IdempotencyStore::new(Duration::from_millis(10));
187 let request = IdempotencyRequest::new("POST", "/v1/sessions", "abc", json!({"a":1}));
188
189 store
190 .save(&request, StoredResponse::new(200, json!({"ok":true})))
191 .await;
192
193 tokio::time::sleep(Duration::from_millis(20)).await;
194
195 let lookup = store.lookup(&request).await;
196 assert_eq!(lookup, LookupResult::Proceed);
197 }
198}