1#[cfg(test)]
17pub mod test;
18
19use std::sync::Arc;
20use std::time::Duration;
21use std::{env, fs};
22
23use async_trait::async_trait;
24use log::{debug, error, info, trace, warn};
25use serde::Serialize;
26use serde::de::DeserializeOwned;
27use sqlx::migrate::Migrator;
28use sqlx::postgres::PgPoolOptions;
29use sqlx::{PgPool, Row};
30use thiserror::Error;
31use tokio::sync::Mutex;
32
33use crate::config::Config;
34use crate::consumer::ConsumeAttempt;
35use crate::consumer::consumer::ConsumeAttemptResult;
36use crate::database::Database;
37use crate::transform::{TransformAttempt, TransformRequest};
38use crate::worker::worker_manager::WorkerManagerResult;
39
40#[derive(Debug, Error)]
41pub enum PostgresDatabaseError {
42 #[error("Database error: {0}")]
43 Database(#[from] sqlx::Error),
44 #[error("Migration error: {0}")]
45 Migration(#[from] sqlx::migrate::MigrateError),
46 #[error("Serialization error: {0}")]
47 Serialization(#[from] serde_json::Error),
48 #[error("TOML deserialization error: {0}")]
49 TomlDeserialization(#[from] toml::de::Error),
50 #[error("TOML serialization error: {0}")]
51 TomlSerialization(#[from] toml::ser::Error),
52 #[error("Not found: {0}")]
53 NotFound(String),
54 #[error("Attempt already exists: {0}")]
55 Conflict(String),
56 #[error("Migration error occurred")]
57 MigrationError,
58 #[error("Hex decode error: {0}")]
59 HexDecodeError(#[from] hex::FromHexError),
60}
61
62#[derive(Debug, Clone)]
63pub struct PostgresDatabase<TR, TA, CA, C> {
64 pool: PgPool,
65 _marker: std::marker::PhantomData<(TR, TA, CA, C)>,
66}
67
68#[async_trait]
95impl<TR, TA, CA, C> Database for PostgresDatabase<TR, TA, CA, C>
96where
97 TR: TransformRequest + Send + Sync + for<'a> serde::Deserialize<'a> + serde::Serialize,
98 TA: TransformAttempt<
99 TransformRequestIdentifier = TR::Identifier,
100 CallArgsType = TR::Input,
101 ReturnType = TR::Output,
102 > + Send
103 + Sync
104 + DeserializeOwned
105 + serde::Serialize,
106 CA: ConsumeAttempt<
107 TransformRequestIdentifier = TR::Identifier,
108 TransformAttemptIdentifier = TA::Identifier,
109 ConsumeVal = TR::Output,
110 > + Send
111 + Sync
112 + DeserializeOwned
113 + serde::Serialize,
114 C: Config<KeyType = String, ValueType = Vec<u8>>,
115 TR::Input: serde::Serialize + DeserializeOwned,
116 TR::Output: serde::Serialize + DeserializeOwned,
117 CA::Identifier: serde::Serialize + DeserializeOwned,
118 CA::ReturnCtx: serde::Serialize + DeserializeOwned,
119 TR::Identifier: serde::Serialize + DeserializeOwned,
120 TA::Identifier: serde::Serialize + DeserializeOwned,
121 TA::ReturnPackage: serde::Serialize + DeserializeOwned,
122{
123 type Config = C;
124 type ConsumeAttempt = CA;
125 type DatabaseError = PostgresDatabaseError;
126 type Input = TR::Input;
127 type Output = TR::Output;
128 type TransformAttempt = TA;
129 type TransformRequest = TR;
130
131 async fn new(ctx: Arc<Mutex<Self::Config>>) -> Result<Self, Self::DatabaseError> {
132 info!("Initializing PostgresDatabase connection pool");
133 let conn_str_bytes = ctx
134 .lock()
135 .await
136 .get("db.conn_str".to_string())
137 .await
138 .unwrap_or_default();
139
140 let connection_string: toml::Value = serde_json::from_slice(&conn_str_bytes)?;
141
142 let conn_str = connection_string.as_str().unwrap().to_owned();
143
144 let pool = PgPoolOptions::new()
145 .max_connections(20)
146 .acquire_timeout(Duration::from_secs(5))
147 .connect(&conn_str)
148 .await
149 .map_err(|e| {
150 error!("Failed to create database connection pool: {}", e);
151 e
152 })?;
153
154 info!("Database connection pool created successfully");
155
156 let instance = Self {
157 pool,
158 _marker: std::marker::PhantomData,
159 };
160
161 instance.run_migrations().await?;
162 Ok(instance)
163 }
164
165 async fn get_dyn_configs(
166 &mut self,
167 ) -> Result<
168 Vec<(
169 <Self::Config as Config>::KeyType,
170 <Self::Config as Config>::ValueType,
171 )>,
172 Self::DatabaseError,
173 > {
174 debug!("Fetching dynamic configurations from database");
175 let rows = sqlx::query(
176 r#"
177 SELECT key, value FROM dynamic_configs
178 ORDER BY key
179 "#,
180 )
181 .fetch_all(&self.pool)
182 .await
183 .map_err(|e| {
184 error!("Failed to fetch dynamic configs: {}", e);
185 PostgresDatabaseError::Database(e)
186 })?;
187 debug!("Dynamic configurations fetched successfully");
188 let configs = rows
189 .into_iter()
190 .map(|row| {
191 let key: String = row.get("key");
192 let value_hex: String = row.get("value");
193 let value = hex::decode(value_hex).map_err(|e| {
194 error!("Failed to decode hex value for key '{}': {}", key, e);
195 PostgresDatabaseError::HexDecodeError(e)
196 })?;
197 Ok((key, value))
198 })
199 .collect::<Result<Vec<(String, Vec<u8>)>, PostgresDatabaseError>>()?;
200
201 Ok(configs)
202 }
203
204 async fn register_transform_request(
205 &mut self,
206 request: &Self::TransformRequest,
207 ) -> Result<(), Self::DatabaseError> {
208 debug!("Registering new transform request");
209 let request_id = serde_json::to_value(request.request_id())?;
210 let input = serde_json::to_value(request.input())?;
211
212 let dyn_cfgs: Vec<(String, String)> = request
213 .get_dyn_configs()
214 .into_iter()
215 .map(|(key, value)| {
216 let value = hex::encode(value);
217 debug!("Serializing value for key '{}': {}", key, value);
218 (key, value)
219 })
220 .collect();
221
222 let mut tx = self.pool.begin().await?;
223
224 let rows_affected = sqlx::query(
225 r#"
226 INSERT INTO transform_requests (request_id, input)
227 VALUES ($1, $2)
228 ON CONFLICT (request_id) DO NOTHING
229 RETURNING 1
230 "#,
231 )
232 .bind(request_id)
233 .bind(input)
234 .execute(&mut *tx)
235 .await?
236 .rows_affected();
237
238 for (key, value) in dyn_cfgs {
239 sqlx::query(
240 r#"
241 INSERT INTO dynamic_configs (key, value, created_at)
242 VALUES ($1, $2, NOW())
243 ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
244 "#,
245 )
246 .bind(key)
247 .bind(value)
248 .execute(&mut *tx)
249 .await?;
250 }
251
252 tx.commit().await?;
253
254 if rows_affected > 0 {
255 debug!("Transform request registered successfully");
256 Ok(())
257 } else {
258 warn!("Transform request already exists");
259 Err(PostgresDatabaseError::Conflict(
260 "Transform request already exists".into(),
261 ))
262 }
263 }
264
265 async fn register_transform_attempt(
266 &mut self,
267 attempt: &Self::TransformAttempt,
268 ) -> Result<(), Self::DatabaseError> {
269 debug!("Registering new transform attempt");
270 let request_id = serde_json::to_value(attempt.request_id())?;
271 let attempt_id = serde_json::to_value(attempt.attempt_id())?;
272
273 let rows_affected = sqlx::query(
274 r#"
275 INSERT INTO transform_attempts
276 (attempt_id, request_id, status)
277 VALUES ($1, $2, 'pending')
278 ON CONFLICT (request_id, attempt_id) DO NOTHING
279 RETURNING 1
280 "#,
281 )
282 .bind(attempt_id)
283 .bind(request_id)
284 .execute(&self.pool)
285 .await?
286 .rows_affected();
287
288 if rows_affected > 0 {
289 debug!("Transform attempt registered successfully");
290 Ok(())
291 } else {
292 warn!("Transform attempt already exists");
293 Err(PostgresDatabaseError::Conflict(
294 "Transform attempt already exists".into(),
295 ))
296 }
297 }
298
299 async fn update_transform_attempt(
300 &mut self,
301 attempt: &WorkerManagerResult<Self::TransformAttempt>,
302 ) -> Result<(), Self::DatabaseError> {
303 debug!("Updating transform attempt status");
304 let (attempt_id, return_pkg, status) = match attempt {
305 WorkerManagerResult::Success(id, pkg) => (id, pkg, "success"),
306 WorkerManagerResult::Failure(id, pkg) => (id, pkg, "failure"),
307 };
308
309 let attempt_id = serde_json::to_value(&attempt_id)?;
310 let return_pkg = serde_json::to_value(&return_pkg)?;
311
312 let rows_affected = sqlx::query(
313 r#"
314 UPDATE transform_attempts
315 SET return_pkg = $1,
316 status = $2::attempt_status,
317 updated_at = NOW()
318 WHERE attempt_id = $3
319 "#,
320 )
321 .bind(return_pkg)
322 .bind(status)
323 .bind(attempt_id)
324 .execute(&self.pool)
325 .await?
326 .rows_affected();
327
328 if rows_affected > 0 {
329 debug!("Transform attempt updated successfully");
330 Ok(())
331 } else {
332 warn!("Transform attempt not found for update");
333 Err(PostgresDatabaseError::NotFound(
334 "Transform attempt not found".into(),
335 ))
336 }
337 }
338
339 async fn register_consume_attempt(
340 &mut self,
341 attempt: &Self::ConsumeAttempt,
342 ) -> Result<(), Self::DatabaseError> {
343 debug!("Registering new consume attempt");
344 let request_id = serde_json::to_value(attempt.request_id())?;
345 let attempt_id = serde_json::to_value(attempt.attempt_id())?;
346 let consume_id = serde_json::to_value(attempt.consume_id())?;
347
348 let dyn_cfgs: Vec<(String, String)> = attempt
349 .get_dyn_configs()
350 .into_iter()
351 .map(|(key, value)| {
352 let value = hex::encode(value);
353 debug!("Serializing value for key '{}': {}", key, value);
354 (key, value)
355 })
356 .collect();
357
358 let mut tx = self.pool.begin().await?;
359
360 let rows_affected = sqlx::query(
361 r#"
362 INSERT INTO consume_attempts
363 (request_id, attempt_id, consume_id, status)
364 VALUES ($1, $2, $3, 'pending')
365 ON CONFLICT (request_id, attempt_id, consume_id) DO NOTHING
366 RETURNING 1
367 "#,
368 )
369 .bind(request_id)
370 .bind(attempt_id)
371 .bind(consume_id)
372 .execute(&mut *tx)
373 .await?
374 .rows_affected();
375
376 for (key, value) in dyn_cfgs {
377 sqlx::query(
378 r#"
379 INSERT INTO dynamic_configs (key, value, created_at)
380 VALUES ($1, $2, NOW())
381 ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
382 "#,
383 )
384 .bind(key)
385 .bind(value)
386 .execute(&mut *tx)
387 .await?;
388 }
389
390 tx.commit().await?;
391
392 if rows_affected > 0 {
393 debug!("Consume attempt registered successfully");
394 Ok(())
395 } else {
396 warn!("Consume attempt already exists");
397 Err(PostgresDatabaseError::Conflict(
398 "Consume attempt already exists".into(),
399 ))
400 }
401 }
402
403 async fn update_consume_attempt(
404 &mut self,
405 attempt: ConsumeAttemptResult<Self::ConsumeAttempt>,
406 ) -> Result<(), Self::DatabaseError> {
407 debug!("Updating consume attempt status");
408 let (consume_id, return_ctx, status) = match attempt {
409 ConsumeAttemptResult::Success(id, ctx) => (id, ctx, "success"),
410 ConsumeAttemptResult::Failure(id, ctx) => (id, ctx, "failure"),
411 };
412
413 let consume_id = serde_json::to_value(&consume_id)?;
414 let return_ctx = serde_json::to_value(&return_ctx)?;
415
416 let rows_affected = sqlx::query(
417 r#"
418 UPDATE consume_attempts
419 SET return_ctx = $1,
420 status = $2,
421 updated_at = NOW()
422 WHERE consume_id = $3
423 "#,
424 )
425 .bind(return_ctx)
426 .bind(status)
427 .bind(consume_id)
428 .execute(&self.pool)
429 .await?
430 .rows_affected();
431
432 if rows_affected > 0 {
433 debug!("Consume attempt updated successfully");
434 Ok(())
435 } else {
436 warn!("Consume attempt not found for update");
437 Err(PostgresDatabaseError::NotFound(
438 "Consume attempt not found".into(),
439 ))
440 }
441 }
442
443 async fn archive_request_with_id(
444 &mut self,
445 request_id: &<Self::TransformRequest as TransformRequest>::Identifier,
446 ) -> Result<(), Self::DatabaseError> {
447 debug!("Archiving transform request by request_id");
448 let mut tx = self.pool.begin().await?;
449 let request_id_json = serde_json::to_value(request_id)?;
450
451 sqlx::query(
453 r#"
454 WITH moved_consumes AS (
455 DELETE FROM consume_attempts
456 WHERE request_id = $1
457 RETURNING *
458 )
459 INSERT INTO archive_consume_attempts
460 SELECT * FROM moved_consumes
461 "#,
462 )
463 .bind(&request_id_json)
464 .execute(&mut *tx)
465 .await?;
466
467 sqlx::query(
469 r#"
470 WITH moved_attempts AS (
471 DELETE FROM transform_attempts
472 WHERE request_id = $1
473 RETURNING *
474 )
475 INSERT INTO archive_transform_attempts
476 SELECT * FROM moved_attempts
477 "#,
478 )
479 .bind(&request_id_json)
480 .execute(&mut *tx)
481 .await?;
482
483 sqlx::query(
485 r#"
486 WITH moved_requests AS (
487 DELETE FROM transform_requests
488 WHERE request_id = $1
489 RETURNING *
490 )
491 INSERT INTO archive_transform_requests
492 SELECT * FROM moved_requests
493 "#,
494 )
495 .bind(&request_id_json)
496 .execute(&mut *tx)
497 .await?;
498
499 tx.commit().await?;
500 debug!("Transform request archived successfully");
501 Ok(())
502 }
503}
504
505impl<TR, TA, CA, C> PostgresDatabase<TR, TA, CA, C>
507where
508 TR: TransformRequest + Send + Sync,
509 TR::Identifier: Serialize + DeserializeOwned,
510 TA: TransformAttempt + Send + Sync,
511 TA::Identifier: Serialize + DeserializeOwned,
512 CA: ConsumeAttempt + Send + Sync,
513 CA::Identifier: Serialize + DeserializeOwned,
514 C: Config + Send + Sync,
515{
516 async fn run_migrations(&self) -> Result<(), PostgresDatabaseError> {
517 info!("Running database migrations");
518
519 let temp_dir = env::temp_dir();
521 let timestamp = std::time::SystemTime::now()
522 .duration_since(std::time::UNIX_EPOCH)
523 .map(|d| d.as_secs())
524 .unwrap_or(0);
525 let temp_file_path = temp_dir
526 .join(format!("shepherd_migrations_{}", timestamp))
527 .join("0000_default_schema.up.sql");
528
529 trace!("Using temporary file for migrations: {:?}", temp_file_path);
530
531 if let Some(parent_dir) = temp_file_path.parent() {
533 fs::create_dir_all(parent_dir).map_err(|e| {
534 error!(
535 "Failed to create directories for temporary file path: {}",
536 e
537 );
538 PostgresDatabaseError::MigrationError
539 })?;
540 }
541
542 fs::write(&temp_file_path, MIGRATIONS).map_err(|e| {
543 error!("Failed to write migrations to temporary file: {}", e);
544 PostgresDatabaseError::MigrationError
545 })?;
546
547 let migrator = Migrator::new(temp_file_path.clone().parent().unwrap())
549 .await
550 .map_err(|e| {
551 error!("Failed to initialize migrator: {}", e);
552 e
553 })?;
554
555 migrator.run(&self.pool).await.map_err(|e| {
557 error!("Failed to apply migrations: {}", e);
558 e
559 })?;
560
561 fs::remove_file(&temp_file_path).map_err(|e| {
563 warn!("Failed to delete temporary migrations file: {}", e);
564 PostgresDatabaseError::MigrationError
565 })?;
566
567 info!("Database migrations completed successfully");
568 Ok(())
569 }
570}
571
572const MIGRATIONS: &str = include_str!("./0000_default_schema.up.sql");