Skip to main content

shepherd_rs/database/
inmem.rs

1//! # In-Memory Database
2//!
3//! This module provides an in-memory implementation of the `Database` trait.
4//!
5//! ## Overview
6//! - **InMemoryDatabase**: Stores transformation and consumption data in
7//!   memory.
8//! - **Error Handling**: Defines custom error types for in-memory operations.
9//!
10//! ## Example
11//! ```rust
12//! let db = InMemoryDatabase::new();
13//! db.register_transform_request(...);
14//! ```
15
16use std::collections::HashMap;
17use std::hash::Hash;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use thiserror::Error;
22use tokio::sync::Mutex;
23
24use crate::config::Config;
25use crate::consumer::ConsumeAttempt;
26use crate::consumer::consumer::ConsumeAttemptResult;
27use crate::database::Database;
28use crate::transform::{TransformAttempt, TransformRequest};
29use crate::worker::worker_manager::WorkerManagerResult;
30
31#[derive(Debug)]
32pub struct InMemoryDatabase<TR, TA, CA, C>
33where
34    TR: TransformRequest + Send + Sync,
35    TA: TransformAttempt<
36            TransformRequestIdentifier = TR::Identifier,
37            CallArgsType = TR::Input,
38            ReturnType = TR::Output,
39        > + Send
40        + Sync,
41    CA: ConsumeAttempt<
42            TransformRequestIdentifier = TR::Identifier,
43            TransformAttemptIdentifier = TA::Identifier,
44            ConsumeVal = TR::Output,
45        > + Send
46        + Sync,
47    C: Config,
48{
49    transform_requests: HashMap<TR::Identifier, TR>,
50    transform_attempts: HashMap<TR::Identifier, HashMap<TA::Identifier, TA>>,
51    consume_attempts: HashMap<TR::Identifier, HashMap<TA::Identifier, HashMap<CA::Identifier, CA>>>,
52    _marker: std::marker::PhantomData<C>,
53}
54
55#[derive(Debug, Error)]
56pub enum InMemoryDatabaseError {
57    #[error("Database error: {0}")]
58    DatabaseError(String),
59    #[error("Not found: {0}")]
60    NotFound(String),
61}
62
63#[async_trait]
64impl<TR, TA, CA, C> Database for InMemoryDatabase<TR, TA, CA, C>
65where
66    TR: TransformRequest + Send + Sync,
67    TA: TransformAttempt<
68            TransformRequestIdentifier = TR::Identifier,
69            CallArgsType = TR::Input,
70            ReturnType = TR::Output,
71        > + Send
72        + Sync,
73    CA: ConsumeAttempt<
74            TransformRequestIdentifier = TR::Identifier,
75            TransformAttemptIdentifier = TA::Identifier,
76            ConsumeVal = TR::Output,
77        > + Send
78        + Sync,
79    C: Config<KeyType = String, ValueType = Vec<u8>>,
80    TR::Identifier: Hash,
81{
82    type Config = C;
83    type ConsumeAttempt = CA;
84    type DatabaseError = InMemoryDatabaseError;
85    type Input = TR::Input;
86    type Output = TR::Output;
87    type TransformAttempt = TA;
88    type TransformRequest = TR;
89
90    async fn new(_ctx: Arc<Mutex<Self::Config>>) -> Result<Self, Self::DatabaseError>
91    where
92        Self: Sized,
93    {
94        Ok(Self {
95            transform_requests: HashMap::new(),
96            transform_attempts: HashMap::new(),
97            consume_attempts: HashMap::new(),
98            _marker: std::marker::PhantomData,
99        })
100    }
101
102    async fn get_dyn_configs(
103        &mut self,
104    ) -> Result<
105        Vec<(
106            <Self::Config as Config>::KeyType,
107            <Self::Config as Config>::ValueType,
108        )>,
109        Self::DatabaseError,
110    > {
111        // In a real implementation, you would retrieve dynamic configurations.
112        // For this in-memory example, we will return an empty vector.
113        Ok(Vec::new())
114    }
115
116    async fn register_transform_request(
117        &mut self,
118        request: &Self::TransformRequest,
119    ) -> Result<(), Self::DatabaseError> {
120        if self.transform_requests.contains_key(&request.request_id()) {
121            return Err(InMemoryDatabaseError::DatabaseError(
122                "Request already exists".to_string(),
123            ));
124        }
125        self.transform_requests
126            .insert(request.request_id(), request.clone());
127        Ok(())
128    }
129
130    async fn register_transform_attempt(
131        &mut self,
132        attempt: &Self::TransformAttempt,
133    ) -> Result<(), Self::DatabaseError> {
134        let request_id = attempt.request_id();
135        if !self.transform_requests.contains_key(&request_id) {
136            return Err(InMemoryDatabaseError::NotFound(format!(
137                "Request with ID {:?} not found",
138                request_id
139            )));
140        }
141        let attempts = self.transform_attempts.entry(request_id).or_default();
142        if attempts.contains_key(&attempt.attempt_id()) {
143            return Err(InMemoryDatabaseError::DatabaseError(
144                "Attempt already exists".to_string(),
145            ));
146        }
147        attempts.insert(attempt.attempt_id(), attempt.clone());
148        Ok(())
149    }
150
151    async fn update_transform_attempt(
152        &mut self,
153        attempt: &WorkerManagerResult<Self::TransformAttempt>,
154    ) -> Result<(), Self::DatabaseError> {
155        let (attempt_id, return_pkg) = match attempt {
156            WorkerManagerResult::Success(attempt_id, return_pkg) => (attempt_id, return_pkg),
157            WorkerManagerResult::Failure(attempt_id, return_pkg) => (attempt_id, return_pkg),
158        };
159
160        let request_id = attempt_id.clone().into();
161
162        if !self.transform_requests.contains_key(&request_id) {
163            return Err(InMemoryDatabaseError::NotFound(format!(
164                "Request with ID {:?} not found",
165                request_id
166            )));
167        }
168
169        let transform_attempts = self
170            .transform_attempts
171            .get_mut(&request_id)
172            .ok_or_else(|| {
173                InMemoryDatabaseError::NotFound(format!(
174                    "Transform attempts for request {:?} not found",
175                    request_id
176                ))
177            })?;
178
179        let attempt = transform_attempts.get_mut(&attempt_id).ok_or_else(|| {
180            InMemoryDatabaseError::NotFound(format!(
181                "Transform attempt with ID {:?} for request {:?} not found",
182                attempt_id, request_id
183            ))
184        })?;
185
186        attempt.set_return_package(return_pkg.clone());
187        Ok(())
188    }
189
190    async fn register_consume_attempt(
191        &mut self,
192        attempt: &Self::ConsumeAttempt,
193    ) -> Result<(), Self::DatabaseError> {
194        let request_id = attempt.request_id();
195        let attempt_id = attempt.attempt_id();
196        if !self.transform_requests.contains_key(&request_id) {
197            return Err(InMemoryDatabaseError::NotFound(format!(
198                "Request with ID {:?} not found",
199                request_id
200            )));
201        }
202
203        let attempts_entry = self
204            .transform_attempts
205            .get(&request_id)
206            .and_then(|attempts| attempts.get(&attempt_id));
207
208        if attempts_entry.is_none() {
209            return Err(InMemoryDatabaseError::NotFound(format!(
210                "Transform attempt with ID {:?} for request {:?} not found",
211                attempt_id, request_id
212            )));
213        }
214
215        let consume_attempts = self
216            .consume_attempts
217            .entry(request_id)
218            .or_default()
219            .entry(attempt_id)
220            .or_default();
221
222        if consume_attempts.contains_key(&attempt.consume_id()) {
223            return Err(InMemoryDatabaseError::DatabaseError(
224                "Consume attempt already exists".to_string(),
225            ));
226        }
227        consume_attempts.insert(attempt.consume_id(), attempt.clone());
228        Ok(())
229    }
230
231    async fn update_consume_attempt(
232        &mut self,
233        attempt: ConsumeAttemptResult<Self::ConsumeAttempt>,
234    ) -> Result<(), Self::DatabaseError> {
235        let (consume_attempt_id, return_ctx) = match attempt {
236            ConsumeAttemptResult::Success(consume_attempt_id, return_ctx) =>
237                (consume_attempt_id, return_ctx),
238            ConsumeAttemptResult::Failure(consume_attempt_id, return_ctx) =>
239                (consume_attempt_id, return_ctx),
240        };
241
242        let request_id = consume_attempt_id.clone().into();
243        let attempt_id = consume_attempt_id.clone().into();
244        if !self.transform_requests.contains_key(&request_id) {
245            return Err(InMemoryDatabaseError::NotFound(format!(
246                "Request with ID {:?} not found",
247                request_id
248            )));
249        }
250
251        let consume_attempts = self
252            .consume_attempts
253            .get_mut(&request_id)
254            .and_then(|attempts| attempts.get_mut(&attempt_id))
255            .ok_or_else(|| {
256                InMemoryDatabaseError::NotFound(format!(
257                    "Consume attempts for request {:?} and attempt {:?} not found",
258                    request_id, attempt_id
259                ))
260            })?;
261
262        let consume_attempt = consume_attempts
263            .get_mut(&consume_attempt_id)
264            .ok_or_else(|| {
265                InMemoryDatabaseError::NotFound(format!(
266                    "Consume attempt with ID {:?} for request {:?} and attempt {:?} not found",
267                    consume_attempt_id, request_id, attempt_id
268                ))
269            })?;
270
271        consume_attempt.set_return_context(return_ctx);
272
273        Ok(())
274    }
275
276    async fn archive_request_with_id(
277        &mut self,
278        request: &<Self::TransformRequest as TransformRequest>::Identifier,
279    ) -> Result<(), Self::DatabaseError> {
280        if !self.transform_requests.contains_key(request) {
281            return Err(InMemoryDatabaseError::NotFound(format!(
282                "Request with ID {:?} not found",
283                request
284            )));
285        }
286        self.transform_requests.remove(request);
287        self.transform_attempts.remove(request);
288        self.consume_attempts.remove(request);
289        Ok(())
290    }
291}