1use crate::{
2 adapter::AdapterKind,
3 error::{DataError, DataResult},
4 query::Query,
5 repo::{MemoryRepo, Repo, Row, StoredRow},
6};
7use std::{
8 future::Future,
9 pin::Pin,
10 sync::{
11 atomic::{AtomicBool, Ordering},
12 Arc, Mutex,
13 },
14 time::{Duration, Instant},
15};
16
17pub type AsyncRepoFuture<'a, T> = Pin<Box<dyn Future<Output = DataResult<T>> + Send + 'a>>;
18
19#[derive(Debug, Clone, Default)]
20pub struct AsyncCancellationToken {
21 cancelled: Arc<AtomicBool>,
22}
23
24impl AsyncCancellationToken {
25 pub fn cancel(&self) {
26 self.cancelled.store(true, Ordering::SeqCst);
27 }
28
29 pub fn reset(&self) {
30 self.cancelled.store(false, Ordering::SeqCst);
31 }
32
33 pub fn is_cancelled(&self) -> bool {
34 self.cancelled.load(Ordering::SeqCst)
35 }
36}
37
38#[derive(Debug, Clone, Default)]
39pub struct AsyncQueryContext {
40 pub deadline: Option<Instant>,
41 pub cancellation: AsyncCancellationToken,
42}
43
44impl AsyncQueryContext {
45 pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
46 self.deadline = Some(Instant::now() + Duration::from_millis(timeout_ms.max(1)));
47 self
48 }
49
50 pub fn with_deadline(mut self, deadline: Instant) -> Self {
51 self.deadline = Some(deadline);
52 self
53 }
54
55 pub fn with_cancellation_token(mut self, token: AsyncCancellationToken) -> Self {
56 self.cancellation = token;
57 self
58 }
59
60 pub fn is_cancelled(&self) -> bool {
61 self.cancellation.is_cancelled()
62 }
63
64 pub fn is_deadline_exceeded(&self) -> bool {
65 self.deadline
66 .map(|deadline| Instant::now() > deadline)
67 .unwrap_or(false)
68 }
69
70 pub fn ensure_active(&self) -> DataResult<()> {
71 if self.is_cancelled() {
72 return Err(DataError::Query(
73 "async query cancelled by cancellation token".to_string(),
74 ));
75 }
76 if self.is_deadline_exceeded() {
77 return Err(DataError::Query(
78 "async query deadline exceeded before completion".to_string(),
79 ));
80 }
81 Ok(())
82 }
83}
84
85pub trait AsyncRepo: Send + Sync {
86 fn adapter_kind(&self) -> AdapterKind;
87
88 fn insert<'a>(
89 &'a self,
90 context: &'a AsyncQueryContext,
91 table: &'a str,
92 data: Row,
93 ) -> AsyncRepoFuture<'a, StoredRow>;
94
95 fn update<'a>(
96 &'a self,
97 context: &'a AsyncQueryContext,
98 table: &'a str,
99 id: u64,
100 data: Row,
101 ) -> AsyncRepoFuture<'a, StoredRow>;
102
103 fn delete<'a>(
104 &'a self,
105 context: &'a AsyncQueryContext,
106 table: &'a str,
107 id: u64,
108 ) -> AsyncRepoFuture<'a, ()>;
109
110 fn find<'a>(
111 &'a self,
112 context: &'a AsyncQueryContext,
113 table: &'a str,
114 id: u64,
115 ) -> AsyncRepoFuture<'a, Option<StoredRow>>;
116
117 fn list<'a>(
118 &'a self,
119 context: &'a AsyncQueryContext,
120 table: &'a str,
121 query: &'a Query,
122 ) -> AsyncRepoFuture<'a, Vec<StoredRow>>;
123}
124
125#[derive(Clone)]
126pub struct AsyncMemoryRepo {
127 inner: Arc<Mutex<MemoryRepo>>,
128 adapter_kind: AdapterKind,
129}
130
131impl AsyncMemoryRepo {
132 pub fn new(repo: MemoryRepo) -> Self {
133 let adapter_kind = repo.adapter_kind();
134 Self {
135 inner: Arc::new(Mutex::new(repo)),
136 adapter_kind,
137 }
138 }
139}
140
141impl AsyncRepo for AsyncMemoryRepo {
142 fn adapter_kind(&self) -> AdapterKind {
143 self.adapter_kind
144 }
145
146 fn insert<'a>(
147 &'a self,
148 context: &'a AsyncQueryContext,
149 table: &'a str,
150 data: Row,
151 ) -> AsyncRepoFuture<'a, StoredRow> {
152 Box::pin(async move {
153 context.ensure_active()?;
154 let result = self
155 .inner
156 .lock()
157 .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
158 .insert(table, data);
159 context.ensure_active()?;
160 result
161 })
162 }
163
164 fn update<'a>(
165 &'a self,
166 context: &'a AsyncQueryContext,
167 table: &'a str,
168 id: u64,
169 data: Row,
170 ) -> AsyncRepoFuture<'a, StoredRow> {
171 Box::pin(async move {
172 context.ensure_active()?;
173 let result = self
174 .inner
175 .lock()
176 .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
177 .update(table, id, data);
178 context.ensure_active()?;
179 result
180 })
181 }
182
183 fn delete<'a>(
184 &'a self,
185 context: &'a AsyncQueryContext,
186 table: &'a str,
187 id: u64,
188 ) -> AsyncRepoFuture<'a, ()> {
189 Box::pin(async move {
190 context.ensure_active()?;
191 let result = self
192 .inner
193 .lock()
194 .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
195 .delete(table, id);
196 context.ensure_active()?;
197 result
198 })
199 }
200
201 fn find<'a>(
202 &'a self,
203 context: &'a AsyncQueryContext,
204 table: &'a str,
205 id: u64,
206 ) -> AsyncRepoFuture<'a, Option<StoredRow>> {
207 Box::pin(async move {
208 context.ensure_active()?;
209 let result = self
210 .inner
211 .lock()
212 .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
213 .find(table, id);
214 context.ensure_active()?;
215 result
216 })
217 }
218
219 fn list<'a>(
220 &'a self,
221 context: &'a AsyncQueryContext,
222 table: &'a str,
223 query: &'a Query,
224 ) -> AsyncRepoFuture<'a, Vec<StoredRow>> {
225 Box::pin(async move {
226 context.ensure_active()?;
227 let result = self
228 .inner
229 .lock()
230 .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
231 .list(table, query);
232 context.ensure_active()?;
233 result
234 })
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::{AsyncCancellationToken, AsyncMemoryRepo, AsyncQueryContext, AsyncRepo};
241 use crate::{adapter_for, AdapterKind, DatabaseConfig, MemoryRepo, Query, Row};
242 use serde_json::json;
243 use std::{
244 panic::{catch_unwind, AssertUnwindSafe},
245 time::{Duration, Instant},
246 };
247
248 #[tokio::test]
249 async fn async_memory_repo_supports_insert_and_list() {
250 let repo = MemoryRepo::new(
251 adapter_for(&DatabaseConfig {
252 adapter: AdapterKind::SingleStore,
253 url: None,
254 url_env: None,
255 })
256 .expect("driver"),
257 );
258 let async_repo = AsyncMemoryRepo::new(repo);
259 let mut row = Row::new();
260 row.insert("account".to_string(), json!("Acme"));
261 let context = AsyncQueryContext::default().with_timeout_ms(500);
262 async_repo
263 .insert(&context, "accounts", row)
264 .await
265 .expect("insert");
266 let rows = async_repo
267 .list(&context, "accounts", &Query::new())
268 .await
269 .expect("list");
270 assert_eq!(rows.len(), 1);
271 }
272
273 #[tokio::test]
274 async fn async_query_context_cancellation_stops_query() {
275 let repo = MemoryRepo::new(
276 adapter_for(&DatabaseConfig {
277 adapter: AdapterKind::Postgres,
278 url: None,
279 url_env: None,
280 })
281 .expect("driver"),
282 );
283 let async_repo = AsyncMemoryRepo::new(repo);
284 let token = AsyncCancellationToken::default();
285 token.cancel();
286 let context = AsyncQueryContext::default().with_cancellation_token(token);
287 let result = async_repo.list(&context, "accounts", &Query::new()).await;
288 assert!(result.is_err());
289 }
290
291 #[tokio::test]
292 async fn async_query_context_deadline_exceeded_stops_query() {
293 let repo = MemoryRepo::new(
294 adapter_for(&DatabaseConfig {
295 adapter: AdapterKind::MySql,
296 url: None,
297 url_env: None,
298 })
299 .expect("driver"),
300 );
301 let async_repo = AsyncMemoryRepo::new(repo);
302 let context = AsyncQueryContext::default().with_timeout_ms(1);
303 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
304 let result = async_repo.find(&context, "accounts", 1).await;
305 assert!(result.is_err());
306 }
307
308 #[tokio::test]
309 async fn async_query_context_reset_and_explicit_deadline_paths_are_exercised() {
310 let token = AsyncCancellationToken::default();
311 token.cancel();
312 assert!(token.is_cancelled());
313 token.reset();
314 assert!(!token.is_cancelled());
315
316 let past_deadline = Instant::now()
317 .checked_sub(Duration::from_millis(1))
318 .expect("past instant");
319 let context = AsyncQueryContext::default()
320 .with_deadline(past_deadline)
321 .with_cancellation_token(token);
322 let result = context.ensure_active();
323 assert!(result.is_err());
324 }
325
326 #[tokio::test]
327 async fn async_memory_repo_supports_update_and_delete() {
328 let repo = MemoryRepo::new(
329 adapter_for(&DatabaseConfig {
330 adapter: AdapterKind::ClickHouse,
331 url: None,
332 url_env: None,
333 })
334 .expect("driver"),
335 );
336 let async_repo = AsyncMemoryRepo::new(repo);
337 let context = AsyncQueryContext::default().with_timeout_ms(500);
338
339 let mut initial = Row::new();
340 initial.insert("name".to_string(), json!("Draft"));
341 let inserted = async_repo
342 .insert(&context, "accounts", initial)
343 .await
344 .expect("insert");
345
346 let mut updated = Row::new();
347 updated.insert("name".to_string(), json!("Published"));
348 let updated_row = async_repo
349 .update(&context, "accounts", inserted.id, updated)
350 .await
351 .expect("update");
352 assert_eq!(updated_row.data.get("name"), Some(&json!("Published")));
353
354 async_repo
355 .delete(&context, "accounts", inserted.id)
356 .await
357 .expect("delete");
358 let found = async_repo
359 .find(&context, "accounts", inserted.id)
360 .await
361 .expect("find after delete");
362 assert!(found.is_none());
363 }
364
365 #[tokio::test]
366 async fn async_memory_repo_reports_lock_poisoned_errors_and_adapter_kind() {
367 let repo = MemoryRepo::new(
368 adapter_for(&DatabaseConfig {
369 adapter: AdapterKind::OpenSearch,
370 url: None,
371 url_env: None,
372 })
373 .expect("driver"),
374 );
375 let async_repo = AsyncMemoryRepo::new(repo);
376 assert_eq!(async_repo.adapter_kind(), AdapterKind::OpenSearch);
377
378 let _ = catch_unwind(AssertUnwindSafe(|| {
379 let _guard = async_repo.inner.lock().expect("repo lock");
380 panic!("poison async memory repo lock");
381 }));
382
383 let context = AsyncQueryContext::default().with_timeout_ms(100);
384 let err = async_repo.list(&context, "accounts", &Query::new()).await;
385 assert!(err
386 .unwrap_err()
387 .to_string()
388 .contains("async memory repo lock poisoned"));
389 }
390}