zeph_memory/graph/
entity_lock.rs1use std::time::Duration;
16
17use tokio::time::sleep;
18use zeph_common::SessionId;
19use zeph_db::{DbPool, query, query_scalar, sql};
20
21use crate::error::MemoryError;
22
23const LOCK_TTL_SECS: i64 = 120;
25
26const MAX_RETRIES: u32 = 3;
28
29const BASE_BACKOFF_MS: u64 = 50;
31
32pub struct EntityLockManager {
34 pool: DbPool,
35 session_id: SessionId,
36}
37
38impl EntityLockManager {
39 #[must_use]
44 pub fn new(pool: DbPool, session_id: impl Into<SessionId>) -> Self {
45 Self {
46 pool,
47 session_id: session_id.into(),
48 }
49 }
50
51 pub async fn try_acquire(&self, entity_name: &str) -> Result<bool, MemoryError> {
64 for attempt in 0..=MAX_RETRIES {
65 match self.try_acquire_once(entity_name).await? {
66 true => return Ok(true),
67 false if attempt == MAX_RETRIES => return Ok(false),
68 false => {
69 let backoff_ms = BASE_BACKOFF_MS * (1u64 << attempt);
70 sleep(Duration::from_millis(backoff_ms)).await;
71 }
72 }
73 }
74 Ok(false)
75 }
76
77 async fn try_acquire_once(&self, entity_name: &str) -> Result<bool, MemoryError> {
78 let acquired: bool = query_scalar(sql!(
85 "INSERT INTO entity_advisory_locks (entity_name, session_id, acquired_at, expires_at)
86 VALUES (?, ?, datetime('now'), datetime('now', ? || ' seconds'))
87 ON CONFLICT(entity_name) DO UPDATE SET
88 session_id = excluded.session_id,
89 acquired_at = excluded.acquired_at,
90 expires_at = excluded.expires_at
91 WHERE
92 -- reclaim if expired
93 entity_advisory_locks.expires_at < datetime('now')
94 OR
95 -- refresh if same session
96 entity_advisory_locks.session_id = excluded.session_id
97 RETURNING (session_id = ?) AS acquired"
98 ))
99 .bind(entity_name)
100 .bind(self.session_id.as_str())
101 .bind(LOCK_TTL_SECS.to_string())
102 .bind(self.session_id.as_str())
103 .fetch_optional(self.pool())
104 .await?
105 .unwrap_or(false);
106
107 Ok(acquired)
108 }
109
110 pub async fn extend_lock(
121 &self,
122 entity_name: &str,
123 extra_secs: i64,
124 ) -> Result<bool, MemoryError> {
125 let affected = query(sql!(
126 "UPDATE entity_advisory_locks
127 SET expires_at = datetime(expires_at, ? || ' seconds')
128 WHERE entity_name = ? AND session_id = ?"
129 ))
130 .bind(extra_secs.to_string())
131 .bind(entity_name)
132 .bind(self.session_id.as_str())
133 .execute(self.pool())
134 .await?
135 .rows_affected();
136
137 Ok(affected > 0)
138 }
139
140 pub async fn release(&self, entity_name: &str) -> Result<(), MemoryError> {
148 query(sql!(
149 "DELETE FROM entity_advisory_locks
150 WHERE entity_name = ? AND session_id = ?"
151 ))
152 .bind(entity_name)
153 .bind(self.session_id.as_str())
154 .execute(self.pool())
155 .await?;
156
157 Ok(())
158 }
159
160 pub async fn release_all(&self) -> Result<(), MemoryError> {
168 query(sql!(
169 "DELETE FROM entity_advisory_locks WHERE session_id = ?"
170 ))
171 .bind(self.session_id.as_str())
172 .execute(self.pool())
173 .await?;
174
175 Ok(())
176 }
177
178 fn pool(&self) -> &DbPool {
179 &self.pool
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::store::DbStore;
187
188 async fn make_lock_manager(session_id: &str) -> EntityLockManager {
189 let store = DbStore::with_pool_size(":memory:", 1)
190 .await
191 .expect("in-memory store");
192 EntityLockManager::new(store.pool().clone(), session_id)
193 }
194
195 async fn make_shared_managers(
196 session_a: &str,
197 session_b: &str,
198 ) -> (EntityLockManager, EntityLockManager) {
199 let store = DbStore::with_pool_size(":memory:", 2)
200 .await
201 .expect("in-memory store");
202 let pool = store.pool().clone();
203 (
204 EntityLockManager::new(pool.clone(), session_a),
205 EntityLockManager::new(pool, session_b),
206 )
207 }
208
209 #[tokio::test]
210 async fn try_acquire_succeeds_on_first_call() {
211 let mgr = make_lock_manager("session-a").await;
212 let acquired = mgr.try_acquire("entity::Foo").await.expect("try_acquire");
213 assert!(acquired);
214 }
215
216 #[tokio::test]
217 async fn try_acquire_same_session_refresh_succeeds() {
218 let mgr = make_lock_manager("session-a").await;
219 assert!(mgr.try_acquire("entity::Foo").await.expect("first"));
220 assert!(mgr.try_acquire("entity::Foo").await.expect("second"));
222 }
223
224 #[tokio::test]
225 async fn try_acquire_fails_when_held_by_different_session() {
226 let (a, b) = make_shared_managers("session-a", "session-b").await;
227 assert!(a.try_acquire("entity::Foo").await.expect("a acquires"));
228 let acquired = b.try_acquire("entity::Foo").await.expect("b tries");
233 assert!(
234 !acquired,
235 "session-b should not acquire a lock held by session-a"
236 );
237 }
238
239 #[tokio::test]
240 async fn expired_lock_is_reclaimed_by_new_session() {
241 let store = DbStore::with_pool_size(":memory:", 2)
242 .await
243 .expect("in-memory store");
244 let pool = store.pool().clone();
245 let b = EntityLockManager::new(pool.clone(), "session-b");
246 zeph_db::query(zeph_db::sql!(
248 "INSERT INTO entity_advisory_locks (entity_name, session_id, acquired_at, expires_at)
249 VALUES ('entity::Bar', 'session-a', datetime('now', '-200 seconds'), datetime('now', '-80 seconds'))"
250 ))
251 .execute(&pool)
252 .await
253 .expect("insert expired lock");
254
255 let acquired = b.try_acquire("entity::Bar").await.expect("try_acquire");
257 assert!(acquired, "session-b should reclaim an expired lock");
258 }
259
260 #[tokio::test]
261 async fn release_clears_the_lock() {
262 let (a, b) = make_shared_managers("session-a", "session-b").await;
263 a.try_acquire("entity::Baz").await.expect("acquire");
264 a.release("entity::Baz").await.expect("release");
265
266 let acquired = b.try_acquire("entity::Baz").await.expect("b reacquire");
268 assert!(acquired);
269 }
270
271 #[tokio::test]
272 async fn release_is_noop_for_wrong_session() {
273 let (a, b) = make_shared_managers("session-a", "session-b").await;
274 assert!(a.try_acquire("entity::Qux").await.expect("a acquires"));
275 b.release("entity::Qux").await.expect("release noop");
277 let acquired = b.try_acquire("entity::Qux").await.expect("b tries");
279 assert!(!acquired);
280 }
281
282 #[tokio::test]
283 async fn release_all_removes_all_session_locks() {
284 let mgr = make_lock_manager("session-a").await;
285 mgr.try_acquire("entity::One").await.expect("one");
286 mgr.try_acquire("entity::Two").await.expect("two");
287 mgr.release_all().await.expect("release_all");
288
289 assert!(mgr.try_acquire("entity::One").await.expect("re-one"));
291 assert!(mgr.try_acquire("entity::Two").await.expect("re-two"));
292 }
293
294 #[tokio::test]
295 async fn extend_lock_returns_true_for_owner() {
296 let mgr = make_lock_manager("session-a").await;
297 mgr.try_acquire("entity::Ext").await.expect("acquire");
298 let extended = mgr.extend_lock("entity::Ext", 60).await.expect("extend");
299 assert!(extended);
300 }
301
302 #[tokio::test]
303 async fn extend_lock_returns_false_for_non_owner() {
304 let (a, b) = make_shared_managers("session-a", "session-b").await;
305 a.try_acquire("entity::Ext2").await.expect("a acquires");
306 let extended = b.extend_lock("entity::Ext2", 60).await.expect("b extend");
307 assert!(
308 !extended,
309 "non-owner session should not be able to extend lock"
310 );
311 }
312}