Skip to main content

shelly_data/
async_repo.rs

1use crate::{
2    adapter::AdapterKind,
3    error::{DataError, DataResult},
4    query::Query,
5    repo::{MemoryRepo, Repo, Row, StoredRow},
6};
7use std::{
8    future::Future,
9    pin::Pin,
10    sync::{
11        atomic::{AtomicBool, Ordering},
12        Arc, Mutex,
13    },
14    time::{Duration, Instant},
15};
16
17pub type AsyncRepoFuture<'a, T> = Pin<Box<dyn Future<Output = DataResult<T>> + Send + 'a>>;
18
19#[derive(Debug, Clone, Default)]
20pub struct AsyncCancellationToken {
21    cancelled: Arc<AtomicBool>,
22}
23
24impl AsyncCancellationToken {
25    pub fn cancel(&self) {
26        self.cancelled.store(true, Ordering::SeqCst);
27    }
28
29    pub fn reset(&self) {
30        self.cancelled.store(false, Ordering::SeqCst);
31    }
32
33    pub fn is_cancelled(&self) -> bool {
34        self.cancelled.load(Ordering::SeqCst)
35    }
36}
37
38#[derive(Debug, Clone, Default)]
39pub struct AsyncQueryContext {
40    pub deadline: Option<Instant>,
41    pub cancellation: AsyncCancellationToken,
42}
43
44impl AsyncQueryContext {
45    pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
46        self.deadline = Some(Instant::now() + Duration::from_millis(timeout_ms.max(1)));
47        self
48    }
49
50    pub fn with_deadline(mut self, deadline: Instant) -> Self {
51        self.deadline = Some(deadline);
52        self
53    }
54
55    pub fn with_cancellation_token(mut self, token: AsyncCancellationToken) -> Self {
56        self.cancellation = token;
57        self
58    }
59
60    pub fn is_cancelled(&self) -> bool {
61        self.cancellation.is_cancelled()
62    }
63
64    pub fn is_deadline_exceeded(&self) -> bool {
65        self.deadline
66            .map(|deadline| Instant::now() > deadline)
67            .unwrap_or(false)
68    }
69
70    pub fn ensure_active(&self) -> DataResult<()> {
71        if self.is_cancelled() {
72            return Err(DataError::Query(
73                "async query cancelled by cancellation token".to_string(),
74            ));
75        }
76        if self.is_deadline_exceeded() {
77            return Err(DataError::Query(
78                "async query deadline exceeded before completion".to_string(),
79            ));
80        }
81        Ok(())
82    }
83}
84
85pub trait AsyncRepo: Send + Sync {
86    fn adapter_kind(&self) -> AdapterKind;
87
88    fn insert<'a>(
89        &'a self,
90        context: &'a AsyncQueryContext,
91        table: &'a str,
92        data: Row,
93    ) -> AsyncRepoFuture<'a, StoredRow>;
94
95    fn update<'a>(
96        &'a self,
97        context: &'a AsyncQueryContext,
98        table: &'a str,
99        id: u64,
100        data: Row,
101    ) -> AsyncRepoFuture<'a, StoredRow>;
102
103    fn delete<'a>(
104        &'a self,
105        context: &'a AsyncQueryContext,
106        table: &'a str,
107        id: u64,
108    ) -> AsyncRepoFuture<'a, ()>;
109
110    fn find<'a>(
111        &'a self,
112        context: &'a AsyncQueryContext,
113        table: &'a str,
114        id: u64,
115    ) -> AsyncRepoFuture<'a, Option<StoredRow>>;
116
117    fn list<'a>(
118        &'a self,
119        context: &'a AsyncQueryContext,
120        table: &'a str,
121        query: &'a Query,
122    ) -> AsyncRepoFuture<'a, Vec<StoredRow>>;
123}
124
125#[derive(Clone)]
126pub struct AsyncMemoryRepo {
127    inner: Arc<Mutex<MemoryRepo>>,
128    adapter_kind: AdapterKind,
129}
130
131impl AsyncMemoryRepo {
132    pub fn new(repo: MemoryRepo) -> Self {
133        let adapter_kind = repo.adapter_kind();
134        Self {
135            inner: Arc::new(Mutex::new(repo)),
136            adapter_kind,
137        }
138    }
139}
140
141impl AsyncRepo for AsyncMemoryRepo {
142    fn adapter_kind(&self) -> AdapterKind {
143        self.adapter_kind
144    }
145
146    fn insert<'a>(
147        &'a self,
148        context: &'a AsyncQueryContext,
149        table: &'a str,
150        data: Row,
151    ) -> AsyncRepoFuture<'a, StoredRow> {
152        Box::pin(async move {
153            context.ensure_active()?;
154            let result = self
155                .inner
156                .lock()
157                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
158                .insert(table, data);
159            context.ensure_active()?;
160            result
161        })
162    }
163
164    fn update<'a>(
165        &'a self,
166        context: &'a AsyncQueryContext,
167        table: &'a str,
168        id: u64,
169        data: Row,
170    ) -> AsyncRepoFuture<'a, StoredRow> {
171        Box::pin(async move {
172            context.ensure_active()?;
173            let result = self
174                .inner
175                .lock()
176                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
177                .update(table, id, data);
178            context.ensure_active()?;
179            result
180        })
181    }
182
183    fn delete<'a>(
184        &'a self,
185        context: &'a AsyncQueryContext,
186        table: &'a str,
187        id: u64,
188    ) -> AsyncRepoFuture<'a, ()> {
189        Box::pin(async move {
190            context.ensure_active()?;
191            let result = self
192                .inner
193                .lock()
194                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
195                .delete(table, id);
196            context.ensure_active()?;
197            result
198        })
199    }
200
201    fn find<'a>(
202        &'a self,
203        context: &'a AsyncQueryContext,
204        table: &'a str,
205        id: u64,
206    ) -> AsyncRepoFuture<'a, Option<StoredRow>> {
207        Box::pin(async move {
208            context.ensure_active()?;
209            let result = self
210                .inner
211                .lock()
212                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
213                .find(table, id);
214            context.ensure_active()?;
215            result
216        })
217    }
218
219    fn list<'a>(
220        &'a self,
221        context: &'a AsyncQueryContext,
222        table: &'a str,
223        query: &'a Query,
224    ) -> AsyncRepoFuture<'a, Vec<StoredRow>> {
225        Box::pin(async move {
226            context.ensure_active()?;
227            let result = self
228                .inner
229                .lock()
230                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
231                .list(table, query);
232            context.ensure_active()?;
233            result
234        })
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::{AsyncCancellationToken, AsyncMemoryRepo, AsyncQueryContext, AsyncRepo};
241    use crate::{adapter_for, AdapterKind, DatabaseConfig, MemoryRepo, Query, Row};
242    use serde_json::json;
243    use std::{
244        panic::{catch_unwind, AssertUnwindSafe},
245        time::{Duration, Instant},
246    };
247
248    #[tokio::test]
249    async fn async_memory_repo_supports_insert_and_list() {
250        let repo = MemoryRepo::new(
251            adapter_for(&DatabaseConfig {
252                adapter: AdapterKind::SingleStore,
253                url: None,
254                url_env: None,
255            })
256            .expect("driver"),
257        );
258        let async_repo = AsyncMemoryRepo::new(repo);
259        let mut row = Row::new();
260        row.insert("account".to_string(), json!("Acme"));
261        let context = AsyncQueryContext::default().with_timeout_ms(500);
262        async_repo
263            .insert(&context, "accounts", row)
264            .await
265            .expect("insert");
266        let rows = async_repo
267            .list(&context, "accounts", &Query::new())
268            .await
269            .expect("list");
270        assert_eq!(rows.len(), 1);
271    }
272
273    #[tokio::test]
274    async fn async_query_context_cancellation_stops_query() {
275        let repo = MemoryRepo::new(
276            adapter_for(&DatabaseConfig {
277                adapter: AdapterKind::Postgres,
278                url: None,
279                url_env: None,
280            })
281            .expect("driver"),
282        );
283        let async_repo = AsyncMemoryRepo::new(repo);
284        let token = AsyncCancellationToken::default();
285        token.cancel();
286        let context = AsyncQueryContext::default().with_cancellation_token(token);
287        let result = async_repo.list(&context, "accounts", &Query::new()).await;
288        assert!(result.is_err());
289    }
290
291    #[tokio::test]
292    async fn async_query_context_deadline_exceeded_stops_query() {
293        let repo = MemoryRepo::new(
294            adapter_for(&DatabaseConfig {
295                adapter: AdapterKind::MySql,
296                url: None,
297                url_env: None,
298            })
299            .expect("driver"),
300        );
301        let async_repo = AsyncMemoryRepo::new(repo);
302        let context = AsyncQueryContext::default().with_timeout_ms(1);
303        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
304        let result = async_repo.find(&context, "accounts", 1).await;
305        assert!(result.is_err());
306    }
307
308    #[tokio::test]
309    async fn async_query_context_reset_and_explicit_deadline_paths_are_exercised() {
310        let token = AsyncCancellationToken::default();
311        token.cancel();
312        assert!(token.is_cancelled());
313        token.reset();
314        assert!(!token.is_cancelled());
315
316        let past_deadline = Instant::now()
317            .checked_sub(Duration::from_millis(1))
318            .expect("past instant");
319        let context = AsyncQueryContext::default()
320            .with_deadline(past_deadline)
321            .with_cancellation_token(token);
322        let result = context.ensure_active();
323        assert!(result.is_err());
324    }
325
326    #[tokio::test]
327    async fn async_memory_repo_supports_update_and_delete() {
328        let repo = MemoryRepo::new(
329            adapter_for(&DatabaseConfig {
330                adapter: AdapterKind::ClickHouse,
331                url: None,
332                url_env: None,
333            })
334            .expect("driver"),
335        );
336        let async_repo = AsyncMemoryRepo::new(repo);
337        let context = AsyncQueryContext::default().with_timeout_ms(500);
338
339        let mut initial = Row::new();
340        initial.insert("name".to_string(), json!("Draft"));
341        let inserted = async_repo
342            .insert(&context, "accounts", initial)
343            .await
344            .expect("insert");
345
346        let mut updated = Row::new();
347        updated.insert("name".to_string(), json!("Published"));
348        let updated_row = async_repo
349            .update(&context, "accounts", inserted.id, updated)
350            .await
351            .expect("update");
352        assert_eq!(updated_row.data.get("name"), Some(&json!("Published")));
353
354        async_repo
355            .delete(&context, "accounts", inserted.id)
356            .await
357            .expect("delete");
358        let found = async_repo
359            .find(&context, "accounts", inserted.id)
360            .await
361            .expect("find after delete");
362        assert!(found.is_none());
363    }
364
365    #[tokio::test]
366    async fn async_memory_repo_reports_lock_poisoned_errors_and_adapter_kind() {
367        let repo = MemoryRepo::new(
368            adapter_for(&DatabaseConfig {
369                adapter: AdapterKind::OpenSearch,
370                url: None,
371                url_env: None,
372            })
373            .expect("driver"),
374        );
375        let async_repo = AsyncMemoryRepo::new(repo);
376        assert_eq!(async_repo.adapter_kind(), AdapterKind::OpenSearch);
377
378        let _ = catch_unwind(AssertUnwindSafe(|| {
379            let _guard = async_repo.inner.lock().expect("repo lock");
380            panic!("poison async memory repo lock");
381        }));
382
383        let context = AsyncQueryContext::default().with_timeout_ms(100);
384        let err = async_repo.list(&context, "accounts", &Query::new()).await;
385        assert!(err
386            .unwrap_err()
387            .to_string()
388            .contains("async memory repo lock poisoned"));
389    }
390}