Skip to main content

shepherd_rs/database/postgresql/
mod.rs

1//! # PostgreSQL Database
2//!
3//! This module provides a PostgreSQL implementation of the `Database` trait.
4//!
5//! ## Overview
6//! - **PostgresDatabase**: Manages transformation and consumption data using
7//!   PostgreSQL.
8//! - **Error Handling**: Defines custom error types for PostgreSQL operations.
9//!
10//! ## Example
11//! ```rust
12//! let db = PostgresDatabase::new();
13//! db.register_transform_request(...);
14//! ```
15
16#[cfg(test)]
17pub mod test;
18
19use std::sync::Arc;
20use std::time::Duration;
21use std::{env, fs};
22
23use async_trait::async_trait;
24use log::{debug, error, info, trace, warn};
25use serde::Serialize;
26use serde::de::DeserializeOwned;
27use sqlx::migrate::Migrator;
28use sqlx::postgres::PgPoolOptions;
29use sqlx::{PgPool, Row};
30use thiserror::Error;
31use tokio::sync::Mutex;
32
33use crate::config::Config;
34use crate::consumer::ConsumeAttempt;
35use crate::consumer::consumer::ConsumeAttemptResult;
36use crate::database::Database;
37use crate::transform::{TransformAttempt, TransformRequest};
38use crate::worker::worker_manager::WorkerManagerResult;
39
40#[derive(Debug, Error)]
41pub enum PostgresDatabaseError {
42    #[error("Database error: {0}")]
43    Database(#[from] sqlx::Error),
44    #[error("Migration error: {0}")]
45    Migration(#[from] sqlx::migrate::MigrateError),
46    #[error("Serialization error: {0}")]
47    Serialization(#[from] serde_json::Error),
48    #[error("TOML deserialization error: {0}")]
49    TomlDeserialization(#[from] toml::de::Error),
50    #[error("TOML serialization error: {0}")]
51    TomlSerialization(#[from] toml::ser::Error),
52    #[error("Not found: {0}")]
53    NotFound(String),
54    #[error("Attempt already exists: {0}")]
55    Conflict(String),
56    #[error("Migration error occurred")]
57    MigrationError,
58    #[error("Hex decode error: {0}")]
59    HexDecodeError(#[from] hex::FromHexError),
60}
61
62#[derive(Debug, Clone)]
63pub struct PostgresDatabase<TR, TA, CA, C> {
64    pool: PgPool,
65    _marker: std::marker::PhantomData<(TR, TA, CA, C)>,
66}
67
68/// To create a custom table based on your requirements
69/// For example: The input struct is
70/// ```rs
71/// struct Person {
72///     name: String,
73///     address: String,
74///     phone: String,
75/// }
76///
77/// pub trait PersonTrait {
78///     fn get_name(&self) -> String;
79///     fn get_address(&self) -> String;
80///     fn get_phone(&self) -> String;
81/// }
82///
83/// impl Database for PostgresDatabase
84/// where
85///     TR::Input: PersonTrait
86///     ...
87/// {
88///     
89/// }
90///
91/// You can do the sane to output, transform attempt, transform request and consume attempt
92/// And, store the relevant things to the database.
93/// ```
94#[async_trait]
95impl<TR, TA, CA, C> Database for PostgresDatabase<TR, TA, CA, C>
96where
97    TR: TransformRequest + Send + Sync + for<'a> serde::Deserialize<'a> + serde::Serialize,
98    TA: TransformAttempt<
99            TransformRequestIdentifier = TR::Identifier,
100            CallArgsType = TR::Input,
101            ReturnType = TR::Output,
102        > + Send
103        + Sync
104        + DeserializeOwned
105        + serde::Serialize,
106    CA: ConsumeAttempt<
107            TransformRequestIdentifier = TR::Identifier,
108            TransformAttemptIdentifier = TA::Identifier,
109            ConsumeVal = TR::Output,
110        > + Send
111        + Sync
112        + DeserializeOwned
113        + serde::Serialize,
114    C: Config<KeyType = String, ValueType = Vec<u8>>,
115    TR::Input: serde::Serialize + DeserializeOwned,
116    TR::Output: serde::Serialize + DeserializeOwned,
117    CA::Identifier: serde::Serialize + DeserializeOwned,
118    CA::ReturnCtx: serde::Serialize + DeserializeOwned,
119    TR::Identifier: serde::Serialize + DeserializeOwned,
120    TA::Identifier: serde::Serialize + DeserializeOwned,
121    TA::ReturnPackage: serde::Serialize + DeserializeOwned,
122{
123    type Config = C;
124    type ConsumeAttempt = CA;
125    type DatabaseError = PostgresDatabaseError;
126    type Input = TR::Input;
127    type Output = TR::Output;
128    type TransformAttempt = TA;
129    type TransformRequest = TR;
130
131    async fn new(ctx: Arc<Mutex<Self::Config>>) -> Result<Self, Self::DatabaseError> {
132        info!("Initializing PostgresDatabase connection pool");
133        let conn_str_bytes = ctx
134            .lock()
135            .await
136            .get("db.conn_str".to_string())
137            .await
138            .unwrap_or_default();
139
140        let connection_string: toml::Value = serde_json::from_slice(&conn_str_bytes)?;
141
142        let conn_str = connection_string.as_str().unwrap().to_owned();
143
144        let pool = PgPoolOptions::new()
145            .max_connections(20)
146            .acquire_timeout(Duration::from_secs(5))
147            .connect(&conn_str)
148            .await
149            .map_err(|e| {
150                error!("Failed to create database connection pool: {}", e);
151                e
152            })?;
153
154        info!("Database connection pool created successfully");
155
156        let instance = Self {
157            pool,
158            _marker: std::marker::PhantomData,
159        };
160
161        instance.run_migrations().await?;
162        Ok(instance)
163    }
164
165    async fn get_dyn_configs(
166        &mut self,
167    ) -> Result<
168        Vec<(
169            <Self::Config as Config>::KeyType,
170            <Self::Config as Config>::ValueType,
171        )>,
172        Self::DatabaseError,
173    > {
174        debug!("Fetching dynamic configurations from database");
175        let rows = sqlx::query(
176            r#"
177            SELECT key, value FROM dynamic_configs
178            ORDER BY key
179            "#,
180        )
181        .fetch_all(&self.pool)
182        .await
183        .map_err(|e| {
184            error!("Failed to fetch dynamic configs: {}", e);
185            PostgresDatabaseError::Database(e)
186        })?;
187        debug!("Dynamic configurations fetched successfully");
188        let configs = rows
189            .into_iter()
190            .map(|row| {
191                let key: String = row.get("key");
192                let value_hex: String = row.get("value");
193                let value = hex::decode(value_hex).map_err(|e| {
194                    error!("Failed to decode hex value for key '{}': {}", key, e);
195                    PostgresDatabaseError::HexDecodeError(e)
196                })?;
197                Ok((key, value))
198            })
199            .collect::<Result<Vec<(String, Vec<u8>)>, PostgresDatabaseError>>()?;
200
201        Ok(configs)
202    }
203
204    async fn register_transform_request(
205        &mut self,
206        request: &Self::TransformRequest,
207    ) -> Result<(), Self::DatabaseError> {
208        debug!("Registering new transform request");
209        let request_id = serde_json::to_value(request.request_id())?;
210        let input = serde_json::to_value(request.input())?;
211
212        let dyn_cfgs: Vec<(String, String)> = request
213            .get_dyn_configs()
214            .into_iter()
215            .map(|(key, value)| {
216                let value = hex::encode(value);
217                debug!("Serializing value for key '{}': {}", key, value);
218                (key, value)
219            })
220            .collect();
221
222        let mut tx = self.pool.begin().await?;
223
224        let rows_affected = sqlx::query(
225            r#"
226            INSERT INTO transform_requests (request_id, input)
227            VALUES ($1, $2)
228            ON CONFLICT (request_id) DO NOTHING
229            RETURNING 1
230            "#,
231        )
232        .bind(request_id)
233        .bind(input)
234        .execute(&mut *tx)
235        .await?
236        .rows_affected();
237
238        for (key, value) in dyn_cfgs {
239            sqlx::query(
240                r#"
241                INSERT INTO dynamic_configs (key, value, created_at)
242                VALUES ($1, $2, NOW())
243                ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
244                "#,
245            )
246            .bind(key)
247            .bind(value)
248            .execute(&mut *tx)
249            .await?;
250        }
251
252        tx.commit().await?;
253
254        if rows_affected > 0 {
255            debug!("Transform request registered successfully");
256            Ok(())
257        } else {
258            warn!("Transform request already exists");
259            Err(PostgresDatabaseError::Conflict(
260                "Transform request already exists".into(),
261            ))
262        }
263    }
264
265    async fn register_transform_attempt(
266        &mut self,
267        attempt: &Self::TransformAttempt,
268    ) -> Result<(), Self::DatabaseError> {
269        debug!("Registering new transform attempt");
270        let request_id = serde_json::to_value(attempt.request_id())?;
271        let attempt_id = serde_json::to_value(attempt.attempt_id())?;
272
273        let rows_affected = sqlx::query(
274            r#"
275            INSERT INTO transform_attempts 
276                (attempt_id, request_id, status)
277            VALUES ($1, $2, 'pending')
278            ON CONFLICT (request_id, attempt_id) DO NOTHING
279            RETURNING 1
280            "#,
281        )
282        .bind(attempt_id)
283        .bind(request_id)
284        .execute(&self.pool)
285        .await?
286        .rows_affected();
287
288        if rows_affected > 0 {
289            debug!("Transform attempt registered successfully");
290            Ok(())
291        } else {
292            warn!("Transform attempt already exists");
293            Err(PostgresDatabaseError::Conflict(
294                "Transform attempt already exists".into(),
295            ))
296        }
297    }
298
299    async fn update_transform_attempt(
300        &mut self,
301        attempt: &WorkerManagerResult<Self::TransformAttempt>,
302    ) -> Result<(), Self::DatabaseError> {
303        debug!("Updating transform attempt status");
304        let (attempt_id, return_pkg, status) = match attempt {
305            WorkerManagerResult::Success(id, pkg) => (id, pkg, "success"),
306            WorkerManagerResult::Failure(id, pkg) => (id, pkg, "failure"),
307        };
308
309        let attempt_id = serde_json::to_value(&attempt_id)?;
310        let return_pkg = serde_json::to_value(&return_pkg)?;
311
312        let rows_affected = sqlx::query(
313            r#"
314            UPDATE transform_attempts
315            SET return_pkg = $1,
316                status = $2::attempt_status,
317                updated_at = NOW()
318            WHERE attempt_id = $3
319            "#,
320        )
321        .bind(return_pkg)
322        .bind(status)
323        .bind(attempt_id)
324        .execute(&self.pool)
325        .await?
326        .rows_affected();
327
328        if rows_affected > 0 {
329            debug!("Transform attempt updated successfully");
330            Ok(())
331        } else {
332            warn!("Transform attempt not found for update");
333            Err(PostgresDatabaseError::NotFound(
334                "Transform attempt not found".into(),
335            ))
336        }
337    }
338
339    async fn register_consume_attempt(
340        &mut self,
341        attempt: &Self::ConsumeAttempt,
342    ) -> Result<(), Self::DatabaseError> {
343        debug!("Registering new consume attempt");
344        let request_id = serde_json::to_value(attempt.request_id())?;
345        let attempt_id = serde_json::to_value(attempt.attempt_id())?;
346        let consume_id = serde_json::to_value(attempt.consume_id())?;
347
348        let dyn_cfgs: Vec<(String, String)> = attempt
349            .get_dyn_configs()
350            .into_iter()
351            .map(|(key, value)| {
352                let value = hex::encode(value);
353                debug!("Serializing value for key '{}': {}", key, value);
354                (key, value)
355            })
356            .collect();
357
358        let mut tx = self.pool.begin().await?;
359
360        let rows_affected = sqlx::query(
361            r#"
362            INSERT INTO consume_attempts 
363                (request_id, attempt_id, consume_id, status)
364            VALUES ($1, $2, $3, 'pending')
365            ON CONFLICT (request_id, attempt_id, consume_id) DO NOTHING
366            RETURNING 1
367            "#,
368        )
369        .bind(request_id)
370        .bind(attempt_id)
371        .bind(consume_id)
372        .execute(&mut *tx)
373        .await?
374        .rows_affected();
375
376        for (key, value) in dyn_cfgs {
377            sqlx::query(
378                r#"
379                INSERT INTO dynamic_configs (key, value, created_at)
380                VALUES ($1, $2, NOW())
381                ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
382                "#,
383            )
384            .bind(key)
385            .bind(value)
386            .execute(&mut *tx)
387            .await?;
388        }
389
390        tx.commit().await?;
391
392        if rows_affected > 0 {
393            debug!("Consume attempt registered successfully");
394            Ok(())
395        } else {
396            warn!("Consume attempt already exists");
397            Err(PostgresDatabaseError::Conflict(
398                "Consume attempt already exists".into(),
399            ))
400        }
401    }
402
403    async fn update_consume_attempt(
404        &mut self,
405        attempt: ConsumeAttemptResult<Self::ConsumeAttempt>,
406    ) -> Result<(), Self::DatabaseError> {
407        debug!("Updating consume attempt status");
408        let (consume_id, return_ctx, status) = match attempt {
409            ConsumeAttemptResult::Success(id, ctx) => (id, ctx, "success"),
410            ConsumeAttemptResult::Failure(id, ctx) => (id, ctx, "failure"),
411        };
412
413        let consume_id = serde_json::to_value(&consume_id)?;
414        let return_ctx = serde_json::to_value(&return_ctx)?;
415
416        let rows_affected = sqlx::query(
417            r#"
418            UPDATE consume_attempts
419            SET return_ctx = $1,
420                status = $2,
421                updated_at = NOW()
422            WHERE consume_id = $3
423            "#,
424        )
425        .bind(return_ctx)
426        .bind(status)
427        .bind(consume_id)
428        .execute(&self.pool)
429        .await?
430        .rows_affected();
431
432        if rows_affected > 0 {
433            debug!("Consume attempt updated successfully");
434            Ok(())
435        } else {
436            warn!("Consume attempt not found for update");
437            Err(PostgresDatabaseError::NotFound(
438                "Consume attempt not found".into(),
439            ))
440        }
441    }
442
443    async fn archive_request_with_id(
444        &mut self,
445        request_id: &<Self::TransformRequest as TransformRequest>::Identifier,
446    ) -> Result<(), Self::DatabaseError> {
447        debug!("Archiving transform request by request_id");
448        let mut tx = self.pool.begin().await?;
449        let request_id_json = serde_json::to_value(request_id)?;
450
451        // Archive consume_attempts
452        sqlx::query(
453            r#"
454            WITH moved_consumes AS (
455                DELETE FROM consume_attempts
456                WHERE request_id = $1
457                RETURNING *
458            )
459            INSERT INTO archive_consume_attempts
460            SELECT * FROM moved_consumes
461            "#,
462        )
463        .bind(&request_id_json)
464        .execute(&mut *tx)
465        .await?;
466
467        // Archive transform_attempts
468        sqlx::query(
469            r#"
470            WITH moved_attempts AS (
471                DELETE FROM transform_attempts
472                WHERE request_id = $1
473                RETURNING *
474            )
475            INSERT INTO archive_transform_attempts
476            SELECT * FROM moved_attempts
477            "#,
478        )
479        .bind(&request_id_json)
480        .execute(&mut *tx)
481        .await?;
482
483        // Archive transform_requests
484        sqlx::query(
485            r#"
486            WITH moved_requests AS (
487                DELETE FROM transform_requests
488                WHERE request_id = $1
489                RETURNING *
490            )
491            INSERT INTO archive_transform_requests
492            SELECT * FROM moved_requests
493            "#,
494        )
495        .bind(&request_id_json)
496        .execute(&mut *tx)
497        .await?;
498
499        tx.commit().await?;
500        debug!("Transform request archived successfully");
501        Ok(())
502    }
503}
504
505// Additional methods in a separate implementation block
506impl<TR, TA, CA, C> PostgresDatabase<TR, TA, CA, C>
507where
508    TR: TransformRequest + Send + Sync,
509    TR::Identifier: Serialize + DeserializeOwned,
510    TA: TransformAttempt + Send + Sync,
511    TA::Identifier: Serialize + DeserializeOwned,
512    CA: ConsumeAttempt + Send + Sync,
513    CA::Identifier: Serialize + DeserializeOwned,
514    C: Config + Send + Sync,
515{
516    async fn run_migrations(&self) -> Result<(), PostgresDatabaseError> {
517        info!("Running database migrations");
518
519        // Create a temporary file for migrations
520        let temp_dir = env::temp_dir();
521        let timestamp = std::time::SystemTime::now()
522            .duration_since(std::time::UNIX_EPOCH)
523            .map(|d| d.as_secs())
524            .unwrap_or(0);
525        let temp_file_path = temp_dir
526            .join(format!("shepherd_migrations_{}", timestamp))
527            .join("0000_default_schema.up.sql");
528
529        trace!("Using temporary file for migrations: {:?}", temp_file_path);
530
531        // Ensure directories exist for the temporary file path
532        if let Some(parent_dir) = temp_file_path.parent() {
533            fs::create_dir_all(parent_dir).map_err(|e| {
534                error!(
535                    "Failed to create directories for temporary file path: {}",
536                    e
537                );
538                PostgresDatabaseError::MigrationError
539            })?;
540        }
541
542        fs::write(&temp_file_path, MIGRATIONS).map_err(|e| {
543            error!("Failed to write migrations to temporary file: {}", e);
544            PostgresDatabaseError::MigrationError
545        })?;
546
547        // Initialize migrator with the temporary file
548        let migrator = Migrator::new(temp_file_path.clone().parent().unwrap())
549            .await
550            .map_err(|e| {
551                error!("Failed to initialize migrator: {}", e);
552                e
553            })?;
554
555        // Run migrations
556        migrator.run(&self.pool).await.map_err(|e| {
557            error!("Failed to apply migrations: {}", e);
558            e
559        })?;
560
561        // Delete the temporary file
562        fs::remove_file(&temp_file_path).map_err(|e| {
563            warn!("Failed to delete temporary migrations file: {}", e);
564            PostgresDatabaseError::MigrationError
565        })?;
566
567        info!("Database migrations completed successfully");
568        Ok(())
569    }
570}
571
572const MIGRATIONS: &str = include_str!("./0000_default_schema.up.sql");