tap_node/storage/
db.rs

1use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
2use std::env;
3use std::path::PathBuf;
4use tap_msg::didcomm::PlainMessage;
5use tracing::{debug, info};
6
7use super::error::StorageError;
8use super::models::{Message, MessageDirection, Transaction, TransactionStatus, TransactionType};
9
10/// Storage backend for TAP transactions and message audit trail
11///
12/// This struct provides the main interface for storing and retrieving TAP data
13/// from a SQLite database. It maintains two separate tables:
14/// - `transactions`: For Transfer and Payment messages requiring business logic
15/// - `messages`: For complete audit trail of all messages
16///
17/// It uses sqlx's built-in connection pooling for efficient concurrent access
18/// and provides a native async API.
19///
20/// # Example
21///
22/// ```no_run
23/// use tap_node::storage::{Storage, MessageDirection};
24/// use std::path::PathBuf;
25///
26/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
27/// // Create storage with default path
28/// let storage = Storage::new(None).await?;
29///
30/// // Create storage with DID-based path
31/// let agent_did = "did:web:example.com";
32/// let storage_with_did = Storage::new_with_did(agent_did, None).await?;
33///
34/// // Create storage with custom TAP root
35/// let custom_root = PathBuf::from("/custom/tap/root");
36/// let storage_custom = Storage::new_with_did(agent_did, Some(custom_root)).await?;
37///
38/// // Query transactions
39/// let transactions = storage.list_transactions(10, 0).await?;
40///
41/// // Query audit trail
42/// let all_messages = storage.list_messages(20, 0, None).await?;
43/// let incoming_only = storage.list_messages(10, 0, Some(MessageDirection::Incoming)).await?;
44/// # Ok(())
45/// # }
46/// ```
47#[derive(Clone)]
48pub struct Storage {
49    pool: SqlitePool,
50}
51
52impl Storage {
53    /// Create a new Storage instance with an agent DID
54    ///
55    /// This will initialize a SQLite database in the TAP directory structure:
56    /// - Default: ~/.tap/{did}/transactions.db
57    /// - Custom root: {tap_root}/{did}/transactions.db
58    ///
59    /// # Arguments
60    ///
61    /// * `agent_did` - The DID of the agent this storage is for
62    /// * `tap_root` - Optional custom root directory (defaults to ~/.tap)
63    ///
64    /// # Errors
65    ///
66    /// Returns `StorageError` if:
67    /// - Database initialization fails
68    /// - Migrations fail to run
69    /// - Connection pool cannot be created
70    pub async fn new_with_did(
71        agent_did: &str,
72        tap_root: Option<PathBuf>,
73    ) -> Result<Self, StorageError> {
74        let root_dir = tap_root.unwrap_or_else(|| {
75            env::var("TAP_ROOT").map(PathBuf::from).unwrap_or_else(|_| {
76                dirs::home_dir()
77                    .expect("Could not find home directory")
78                    .join(".tap")
79            })
80        });
81
82        // Sanitize the DID for use as a directory name
83        let sanitized_did = agent_did.replace(':', "_");
84        let db_path = root_dir.join(&sanitized_did).join("transactions.db");
85
86        Self::new(Some(db_path)).await
87    }
88
89    /// Create a new Storage instance
90    ///
91    /// This will initialize a SQLite database at the specified path (or default location),
92    /// run any pending migrations, and set up a connection pool.
93    ///
94    /// # Arguments
95    ///
96    /// * `path` - Optional path to the database file. If None, uses `TAP_NODE_DB_PATH` env var or defaults to `./tap-node.db`
97    ///
98    /// # Errors
99    ///
100    /// Returns `StorageError` if:
101    /// - Database initialization fails
102    /// - Migrations fail to run
103    /// - Connection pool cannot be created
104    pub async fn new(path: Option<PathBuf>) -> Result<Self, StorageError> {
105        let db_path = path.unwrap_or_else(|| {
106            env::var("TAP_NODE_DB_PATH")
107                .unwrap_or_else(|_| "tap-node.db".to_string())
108                .into()
109        });
110
111        info!("Initializing storage at: {:?}", db_path);
112
113        // Create parent directory if it doesn't exist
114        if let Some(parent) = db_path.parent() {
115            std::fs::create_dir_all(parent)?;
116        }
117
118        // Create connection URL for SQLite with create mode
119        let db_url = format!("sqlite://{}?mode=rwc", db_path.display());
120
121        // Create connection pool with optimizations
122        let pool = SqlitePoolOptions::new()
123            .max_connections(10)
124            .connect(&db_url)
125            .await?;
126
127        // Enable WAL mode and other optimizations
128        sqlx::query("PRAGMA journal_mode = WAL")
129            .execute(&pool)
130            .await?;
131        sqlx::query("PRAGMA synchronous = NORMAL")
132            .execute(&pool)
133            .await?;
134
135        // Run migrations
136        sqlx::migrate!("./migrations")
137            .run(&pool)
138            .await
139            .map_err(|e| StorageError::Migration(e.to_string()))?;
140
141        Ok(Storage { pool })
142    }
143
144    /// Get the default logs directory
145    ///
146    /// Returns the default directory for log files:
147    /// - Default: ~/.tap/logs
148    /// - Custom root: {tap_root}/logs
149    ///
150    /// # Arguments
151    ///
152    /// * `tap_root` - Optional custom root directory (defaults to ~/.tap)
153    pub fn default_logs_dir(tap_root: Option<PathBuf>) -> PathBuf {
154        let root_dir = tap_root.unwrap_or_else(|| {
155            env::var("TAP_ROOT").map(PathBuf::from).unwrap_or_else(|_| {
156                dirs::home_dir()
157                    .expect("Could not find home directory")
158                    .join(".tap")
159            })
160        });
161
162        root_dir.join("logs")
163    }
164
165    /// Update the status of a message in the messages table
166    ///
167    /// # Arguments
168    ///
169    /// * `message_id` - The ID of the message to update
170    /// * `status` - The new status (accepted, rejected, pending)
171    ///
172    /// # Errors
173    ///
174    /// Returns `StorageError` if the database update fails
175    pub async fn update_message_status(
176        &self,
177        message_id: &str,
178        status: &str,
179    ) -> Result<(), StorageError> {
180        debug!("Updating message {} status to {}", message_id, status);
181
182        sqlx::query(
183            r#"
184            UPDATE messages 
185            SET status = ?1 
186            WHERE message_id = ?2
187            "#,
188        )
189        .bind(status)
190        .bind(message_id)
191        .execute(&self.pool)
192        .await?;
193
194        Ok(())
195    }
196
197    /// Update the status of a transaction in the transactions table
198    ///
199    /// # Arguments
200    ///
201    /// * `transaction_id` - The reference ID of the transaction to update
202    /// * `status` - The new status (pending, confirmed, failed, cancelled, reverted)
203    ///
204    /// # Errors
205    ///
206    /// Returns `StorageError` if the database update fails
207    pub async fn update_transaction_status(
208        &self,
209        transaction_id: &str,
210        status: &str,
211    ) -> Result<(), StorageError> {
212        debug!(
213            "Updating transaction {} status to {}",
214            transaction_id, status
215        );
216
217        sqlx::query(
218            r#"
219            UPDATE transactions 
220            SET status = ?1 
221            WHERE reference_id = ?2
222            "#,
223        )
224        .bind(status)
225        .bind(transaction_id)
226        .execute(&self.pool)
227        .await?;
228
229        Ok(())
230    }
231
232    /// Get a transaction by its reference ID
233    ///
234    /// # Arguments
235    ///
236    /// * `reference_id` - The reference ID of the transaction
237    ///
238    /// # Returns
239    ///
240    /// * `Ok(Some(Transaction))` if found
241    /// * `Ok(None)` if not found
242    /// * `Err(StorageError)` on database error
243    pub async fn get_transaction_by_id(
244        &self,
245        reference_id: &str,
246    ) -> Result<Option<Transaction>, StorageError> {
247        let result = sqlx::query_as::<_, (
248            i64,
249            String,
250            String,
251            Option<String>,
252            Option<String>,
253            Option<String>,
254            String,
255            String,
256            serde_json::Value,
257            String,
258            String,
259        )>(
260            r#"
261            SELECT id, type, reference_id, from_did, to_did, thread_id, message_type, status, message_json, created_at, updated_at
262            FROM transactions WHERE reference_id = ?1
263            "#,
264        )
265        .bind(reference_id)
266        .fetch_optional(&self.pool)
267        .await?;
268
269        if let Some((
270            id,
271            tx_type,
272            reference_id,
273            from_did,
274            to_did,
275            thread_id,
276            message_type,
277            status,
278            message_json,
279            created_at,
280            updated_at,
281        )) = result
282        {
283            Ok(Some(Transaction {
284                id,
285                transaction_type: TransactionType::try_from(tx_type.as_str())
286                    .map_err(StorageError::InvalidTransactionType)?,
287                reference_id,
288                from_did,
289                to_did,
290                thread_id,
291                message_type,
292                status: TransactionStatus::try_from(status.as_str())
293                    .map_err(StorageError::InvalidTransactionType)?,
294                message_json,
295                created_at,
296                updated_at,
297            }))
298        } else {
299            Ok(None)
300        }
301    }
302
303    /// Get a transaction by thread ID
304    ///
305    /// # Arguments
306    ///
307    /// * `thread_id` - The thread ID to search for
308    ///
309    /// # Returns
310    ///
311    /// * `Ok(Some(Transaction))` if found
312    /// * `Ok(None)` if not found
313    /// * `Err(StorageError)` on database error
314    pub async fn get_transaction_by_thread_id(
315        &self,
316        thread_id: &str,
317    ) -> Result<Option<Transaction>, StorageError> {
318        let result = sqlx::query_as::<_, (
319            i64,
320            String,
321            String,
322            Option<String>,
323            Option<String>,
324            Option<String>,
325            String,
326            String,
327            serde_json::Value,
328            String,
329            String,
330        )>(
331            r#"
332            SELECT id, type, reference_id, from_did, to_did, thread_id, message_type, status, message_json, created_at, updated_at
333            FROM transactions WHERE thread_id = ?1
334            "#,
335        )
336        .bind(thread_id)
337        .fetch_optional(&self.pool)
338        .await?;
339
340        if let Some((
341            id,
342            tx_type,
343            reference_id,
344            from_did,
345            to_did,
346            thread_id,
347            message_type,
348            status,
349            message_json,
350            created_at,
351            updated_at,
352        )) = result
353        {
354            Ok(Some(Transaction {
355                id,
356                transaction_type: TransactionType::try_from(tx_type.as_str())
357                    .map_err(StorageError::InvalidTransactionType)?,
358                reference_id,
359                from_did,
360                to_did,
361                thread_id,
362                message_type,
363                status: TransactionStatus::try_from(status.as_str())
364                    .map_err(StorageError::InvalidTransactionType)?,
365                message_json,
366                created_at,
367                updated_at,
368            }))
369        } else {
370            Ok(None)
371        }
372    }
373
374    /// Check if an agent is authorized for a transaction
375    ///
376    /// This checks the transaction_agents table to see if the given agent
377    /// is associated with the transaction.
378    ///
379    /// # Arguments
380    ///
381    /// * `transaction_id` - The reference ID of the transaction
382    /// * `agent_did` - The DID of the agent to check
383    ///
384    /// # Returns
385    ///
386    /// * `Ok(true)` if the agent is authorized
387    /// * `Ok(false)` if the agent is not authorized or transaction doesn't exist
388    /// * `Err(StorageError)` on database error
389    pub async fn is_agent_authorized_for_transaction(
390        &self,
391        transaction_id: &str,
392        agent_did: &str,
393    ) -> Result<bool, StorageError> {
394        // First get the transaction's internal ID
395        let tx_result = sqlx::query_scalar::<_, i64>(
396            r#"
397            SELECT id FROM transactions WHERE reference_id = ?1
398            "#,
399        )
400        .bind(transaction_id)
401        .fetch_optional(&self.pool)
402        .await?;
403
404        let tx_internal_id = match tx_result {
405            Some(id) => id,
406            None => return Ok(false), // Transaction doesn't exist
407        };
408
409        // Check if agent is in transaction_agents table
410        let count: i64 = sqlx::query_scalar(
411            r#"
412            SELECT COUNT(*) FROM transaction_agents 
413            WHERE transaction_id = ?1 AND agent_did = ?2
414            "#,
415        )
416        .bind(tx_internal_id)
417        .bind(agent_did)
418        .fetch_one(&self.pool)
419        .await?;
420
421        Ok(count > 0)
422    }
423
424    /// Insert a transaction agent
425    ///
426    /// # Arguments
427    ///
428    /// * `transaction_id` - The reference ID of the transaction
429    /// * `agent_did` - The DID of the agent
430    /// * `agent_role` - The role of the agent (sender, receiver, compliance, other)
431    ///
432    /// # Returns
433    ///
434    /// * `Ok(())` on success
435    /// * `Err(StorageError)` on database error
436    pub async fn insert_transaction_agent(
437        &self,
438        transaction_id: &str,
439        agent_did: &str,
440        agent_role: &str,
441    ) -> Result<(), StorageError> {
442        // First get the transaction's internal ID
443        let tx_result = sqlx::query_scalar::<_, i64>(
444            r#"
445            SELECT id FROM transactions WHERE reference_id = ?1
446            "#,
447        )
448        .bind(transaction_id)
449        .fetch_optional(&self.pool)
450        .await?;
451
452        let tx_internal_id = match tx_result {
453            Some(id) => id,
454            None => {
455                return Err(StorageError::NotFound(format!(
456                    "Transaction {} not found",
457                    transaction_id
458                )))
459            }
460        };
461
462        // Insert the agent
463        sqlx::query(
464            r#"
465            INSERT INTO transaction_agents (transaction_id, agent_did, agent_role, status)
466            VALUES (?1, ?2, ?3, 'pending')
467            ON CONFLICT(transaction_id, agent_did) DO UPDATE SET
468                agent_role = excluded.agent_role,
469                updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
470            "#,
471        )
472        .bind(tx_internal_id)
473        .bind(agent_did)
474        .bind(agent_role)
475        .execute(&self.pool)
476        .await?;
477
478        Ok(())
479    }
480
481    /// Update transaction agent status
482    ///
483    /// # Arguments
484    ///
485    /// * `transaction_id` - The reference ID of the transaction
486    /// * `agent_did` - The DID of the agent
487    /// * `status` - The new status (pending, authorized, rejected, cancelled)
488    ///
489    /// # Returns
490    ///
491    /// * `Ok(())` on success
492    /// * `Err(StorageError)` on database error
493    pub async fn update_transaction_agent_status(
494        &self,
495        transaction_id: &str,
496        agent_did: &str,
497        status: &str,
498    ) -> Result<(), StorageError> {
499        // First get the transaction's internal ID
500        let tx_result = sqlx::query_scalar::<_, i64>(
501            r#"
502            SELECT id FROM transactions WHERE reference_id = ?1
503            "#,
504        )
505        .bind(transaction_id)
506        .fetch_optional(&self.pool)
507        .await?;
508
509        let tx_internal_id = match tx_result {
510            Some(id) => id,
511            None => {
512                return Err(StorageError::NotFound(format!(
513                    "Transaction {} not found",
514                    transaction_id
515                )))
516            }
517        };
518
519        // Update the agent status
520        let result = sqlx::query(
521            r#"
522            UPDATE transaction_agents 
523            SET status = ?1 
524            WHERE transaction_id = ?2 AND agent_did = ?3
525            "#,
526        )
527        .bind(status)
528        .bind(tx_internal_id)
529        .bind(agent_did)
530        .execute(&self.pool)
531        .await?;
532
533        if result.rows_affected() == 0 {
534            return Err(StorageError::NotFound(format!(
535                "Agent {} not found for transaction {}",
536                agent_did, transaction_id
537            )));
538        }
539
540        Ok(())
541    }
542
543    /// Get all agents for a transaction
544    ///
545    /// # Arguments
546    ///
547    /// * `transaction_id` - The reference ID of the transaction
548    ///
549    /// # Returns
550    ///
551    /// * `Ok(Vec<(agent_did, agent_role, status)>)` on success
552    /// * `Err(StorageError)` on database error
553    pub async fn get_transaction_agents(
554        &self,
555        transaction_id: &str,
556    ) -> Result<Vec<(String, String, String)>, StorageError> {
557        // First get the transaction's internal ID
558        let tx_result = sqlx::query_scalar::<_, i64>(
559            r#"
560            SELECT id FROM transactions WHERE reference_id = ?1
561            "#,
562        )
563        .bind(transaction_id)
564        .fetch_optional(&self.pool)
565        .await?;
566
567        let tx_internal_id = match tx_result {
568            Some(id) => id,
569            None => {
570                return Err(StorageError::NotFound(format!(
571                    "Transaction {} not found",
572                    transaction_id
573                )))
574            }
575        };
576
577        // Get all agents
578        let agents = sqlx::query_as::<_, (String, String, String)>(
579            r#"
580            SELECT agent_did, agent_role, status
581            FROM transaction_agents
582            WHERE transaction_id = ?1
583            ORDER BY created_at
584            "#,
585        )
586        .bind(tx_internal_id)
587        .fetch_all(&self.pool)
588        .await?;
589
590        Ok(agents)
591    }
592
593    /// Check if all agents have authorized the transaction
594    ///
595    /// # Arguments
596    ///
597    /// * `transaction_id` - The reference ID of the transaction
598    ///
599    /// # Returns
600    ///
601    /// * `Ok(true)` if all agents have authorized
602    /// * `Ok(false)` if any agent hasn't authorized or has rejected/cancelled
603    /// * `Err(StorageError)` on database error
604    pub async fn are_all_agents_authorized(
605        &self,
606        transaction_id: &str,
607    ) -> Result<bool, StorageError> {
608        // First get the transaction's internal ID
609        let tx_result = sqlx::query_scalar::<_, i64>(
610            r#"
611            SELECT id FROM transactions WHERE reference_id = ?1
612            "#,
613        )
614        .bind(transaction_id)
615        .fetch_optional(&self.pool)
616        .await?;
617
618        let tx_internal_id = match tx_result {
619            Some(id) => id,
620            None => return Ok(false), // Transaction doesn't exist
621        };
622
623        // Check if there are any agents not in 'authorized' status
624        let non_authorized_count: i64 = sqlx::query_scalar(
625            r#"
626            SELECT COUNT(*) FROM transaction_agents 
627            WHERE transaction_id = ?1 AND status != 'authorized'
628            "#,
629        )
630        .bind(tx_internal_id)
631        .fetch_one(&self.pool)
632        .await?;
633
634        // If there are no agents, transaction is ready to settle
635        // If there are agents, all must be authorized
636        Ok(non_authorized_count == 0)
637    }
638
639    /// Insert a new transaction from a TAP message
640    ///
641    /// This method extracts transaction details from a Transfer or Payment message
642    /// and stores them in the database with a 'pending' status.
643    ///
644    /// # Arguments
645    ///
646    /// * `message` - The DIDComm PlainMessage containing a Transfer or Payment body
647    ///
648    /// # Errors
649    ///
650    /// Returns `StorageError` if:
651    /// - The message is not a Transfer or Payment type
652    /// - Database insertion fails
653    /// - The transaction already exists (duplicate reference_id)
654    pub async fn insert_transaction(&self, message: &PlainMessage) -> Result<(), StorageError> {
655        let message_type = message.type_.clone();
656        let message_json = serde_json::to_value(message)?;
657
658        // Extract transaction type and use message ID as reference
659        let message_type_lower = message.type_.to_lowercase();
660        let tx_type = if message_type_lower.contains("transfer") {
661            TransactionType::Transfer
662        } else if message_type_lower.contains("payment") {
663            TransactionType::Payment
664        } else {
665            return Err(StorageError::InvalidTransactionType(
666                message_type.to_string(),
667            ));
668        };
669
670        // Use the PlainMessage ID as the reference_id since transaction_id is not serialized
671        let reference_id = message.id.clone();
672        let from_did = message.from.clone();
673        let to_did = message.to.first().cloned();
674        let thread_id = message.thid.clone();
675
676        debug!("Inserting transaction: {} ({})", reference_id, tx_type);
677
678        let result = sqlx::query(
679            r#"
680            INSERT INTO transactions (type, reference_id, from_did, to_did, thread_id, message_type, message_json)
681            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
682            "#,
683        )
684        .bind(tx_type.to_string())
685        .bind(&reference_id)
686        .bind(from_did)
687        .bind(to_did)
688        .bind(thread_id)
689        .bind(message_type.to_string())
690        .bind(sqlx::types::Json(message_json))
691        .execute(&self.pool)
692        .await;
693
694        match result {
695            Ok(_) => {
696                debug!("Successfully inserted transaction: {}", reference_id);
697                Ok(())
698            }
699            Err(sqlx::Error::Database(db_err)) => {
700                if db_err.message().contains("UNIQUE") {
701                    Err(StorageError::DuplicateTransaction(reference_id))
702                } else {
703                    Err(StorageError::Database(sqlx::Error::Database(db_err)))
704                }
705            }
706            Err(e) => Err(StorageError::Database(e)),
707        }
708    }
709
710    /// List transactions with pagination
711    ///
712    /// Retrieves transactions ordered by creation time (newest first).
713    ///
714    /// # Arguments
715    ///
716    /// * `limit` - Maximum number of transactions to return
717    /// * `offset` - Number of transactions to skip (for pagination)
718    ///
719    /// # Returns
720    ///
721    /// A vector of transactions ordered by creation time descending
722    pub async fn list_transactions(
723        &self,
724        limit: u32,
725        offset: u32,
726    ) -> Result<Vec<Transaction>, StorageError> {
727        let rows = sqlx::query_as::<_, (
728            i64,
729            String,
730            String,
731            Option<String>,
732            Option<String>,
733            Option<String>,
734            String,
735            String,
736            serde_json::Value,
737            String,
738            String,
739        )>(
740            r#"
741            SELECT id, type, reference_id, from_did, to_did, thread_id, message_type, status, message_json, created_at, updated_at
742            FROM transactions
743            ORDER BY created_at DESC
744            LIMIT ?1 OFFSET ?2
745            "#,
746        )
747        .bind(limit)
748        .bind(offset)
749        .fetch_all(&self.pool)
750        .await?;
751
752        let mut transactions = Vec::new();
753        for (
754            id,
755            tx_type,
756            reference_id,
757            from_did,
758            to_did,
759            thread_id,
760            message_type,
761            status,
762            message_json,
763            created_at,
764            updated_at,
765        ) in rows
766        {
767            transactions.push(Transaction {
768                id,
769                transaction_type: TransactionType::try_from(tx_type.as_str())
770                    .map_err(StorageError::InvalidTransactionType)?,
771                reference_id,
772                from_did,
773                to_did,
774                thread_id,
775                message_type,
776                status: TransactionStatus::try_from(status.as_str())
777                    .map_err(StorageError::InvalidTransactionType)?,
778                message_json,
779                created_at,
780                updated_at,
781            });
782        }
783
784        Ok(transactions)
785    }
786
787    /// Log an incoming or outgoing message to the audit trail
788    ///
789    /// This method stores any DIDComm message for audit purposes, regardless of type.
790    ///
791    /// # Arguments
792    ///
793    /// * `message` - The DIDComm PlainMessage to log
794    /// * `direction` - Whether the message is incoming or outgoing
795    /// * `raw_message` - Optional raw JWE/JWS message string
796    ///
797    /// # Errors
798    ///
799    /// Returns `StorageError` if:
800    /// - Database insertion fails
801    /// - The message already exists (duplicate message_id)
802    pub async fn log_message(
803        &self,
804        message: &PlainMessage,
805        direction: MessageDirection,
806        raw_message: Option<&str>,
807    ) -> Result<(), StorageError> {
808        let message_json = serde_json::to_value(message)?;
809        let message_id = message.id.clone();
810        let message_type = message.type_.clone();
811        let from_did = message.from.clone();
812        let to_did = message.to.first().cloned();
813        let thread_id = message.thid.clone();
814        let parent_thread_id = message.pthid.clone();
815
816        debug!(
817            "Logging {} message: {} ({})",
818            direction, message_id, message_type
819        );
820
821        let result = sqlx::query(
822            r#"
823            INSERT INTO messages (message_id, message_type, from_did, to_did, thread_id, parent_thread_id, direction, message_json, raw_message)
824            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
825            "#,
826        )
827        .bind(&message_id)
828        .bind(message_type)
829        .bind(from_did)
830        .bind(to_did)
831        .bind(thread_id)
832        .bind(parent_thread_id)
833        .bind(direction.to_string())
834        .bind(sqlx::types::Json(message_json))
835        .bind(raw_message)
836        .execute(&self.pool)
837        .await;
838
839        match result {
840            Ok(_) => {
841                debug!("Successfully logged message: {}", message_id);
842                Ok(())
843            }
844            Err(sqlx::Error::Database(db_err)) => {
845                if db_err.message().contains("UNIQUE") {
846                    // Message already logged, this is fine
847                    debug!("Message already logged: {}", message_id);
848                    Ok(())
849                } else {
850                    Err(StorageError::Database(sqlx::Error::Database(db_err)))
851                }
852            }
853            Err(e) => Err(StorageError::Database(e)),
854        }
855    }
856
857    /// Retrieve a message by its ID
858    ///
859    /// # Arguments
860    ///
861    /// * `message_id` - The unique message ID
862    ///
863    /// # Returns
864    ///
865    /// * `Ok(Some(Message))` if found
866    /// * `Ok(None)` if not found
867    /// * `Err(StorageError)` on database error
868    pub async fn get_message_by_id(
869        &self,
870        message_id: &str,
871    ) -> Result<Option<Message>, StorageError> {
872        let result = sqlx::query_as::<_, (
873            i64,
874            String,
875            String,
876            Option<String>,
877            Option<String>,
878            Option<String>,
879            Option<String>,
880            String,
881            serde_json::Value,
882            String,
883        )>(
884            r#"
885            SELECT id, message_id, message_type, from_did, to_did, thread_id, parent_thread_id, direction, message_json, created_at
886            FROM messages WHERE message_id = ?1
887            "#,
888        )
889        .bind(message_id)
890        .fetch_optional(&self.pool)
891        .await?;
892
893        match result {
894            Some((
895                id,
896                message_id,
897                message_type,
898                from_did,
899                to_did,
900                thread_id,
901                parent_thread_id,
902                direction,
903                message_json,
904                created_at,
905            )) => Ok(Some(Message {
906                id,
907                message_id,
908                message_type,
909                from_did,
910                to_did,
911                thread_id,
912                parent_thread_id,
913                direction: MessageDirection::try_from(direction.as_str())
914                    .map_err(StorageError::InvalidTransactionType)?,
915                message_json,
916                created_at,
917            })),
918            None => Ok(None),
919        }
920    }
921
922    /// List messages with pagination and optional filtering
923    ///
924    /// # Arguments
925    ///
926    /// * `limit` - Maximum number of messages to return
927    /// * `offset` - Number of messages to skip (for pagination)
928    /// * `direction` - Optional filter by message direction
929    ///
930    /// # Returns
931    ///
932    /// A vector of messages ordered by creation time descending
933    pub async fn list_messages(
934        &self,
935        limit: u32,
936        offset: u32,
937        direction: Option<MessageDirection>,
938    ) -> Result<Vec<Message>, StorageError> {
939        let rows = if let Some(dir) = direction {
940            sqlx::query_as::<_, (
941                i64,
942                String,
943                String,
944                Option<String>,
945                Option<String>,
946                Option<String>,
947                Option<String>,
948                String,
949                serde_json::Value,
950                String,
951            )>(
952                r#"
953                SELECT id, message_id, message_type, from_did, to_did, thread_id, parent_thread_id, direction, message_json, created_at
954                FROM messages
955                WHERE direction = ?1
956                ORDER BY created_at DESC
957                LIMIT ?2 OFFSET ?3
958                "#,
959            )
960            .bind(dir.to_string())
961            .bind(limit)
962            .bind(offset)
963            .fetch_all(&self.pool)
964            .await?
965        } else {
966            sqlx::query_as::<_, (
967                i64,
968                String,
969                String,
970                Option<String>,
971                Option<String>,
972                Option<String>,
973                Option<String>,
974                String,
975                serde_json::Value,
976                String,
977            )>(
978                r#"
979                SELECT id, message_id, message_type, from_did, to_did, thread_id, parent_thread_id, direction, message_json, created_at
980                FROM messages
981                ORDER BY created_at DESC
982                LIMIT ?1 OFFSET ?2
983                "#,
984            )
985            .bind(limit)
986            .bind(offset)
987            .fetch_all(&self.pool)
988            .await?
989        };
990
991        let mut messages = Vec::new();
992        for (
993            id,
994            message_id,
995            message_type,
996            from_did,
997            to_did,
998            thread_id,
999            parent_thread_id,
1000            direction,
1001            message_json,
1002            created_at,
1003        ) in rows
1004        {
1005            messages.push(Message {
1006                id,
1007                message_id,
1008                message_type,
1009                from_did,
1010                to_did,
1011                thread_id,
1012                parent_thread_id,
1013                direction: MessageDirection::try_from(direction.as_str())
1014                    .map_err(StorageError::InvalidTransactionType)?,
1015                message_json,
1016                created_at,
1017            });
1018        }
1019
1020        Ok(messages)
1021    }
1022}
1023
1024#[cfg(test)]
1025mod tests {
1026    use super::*;
1027    use tap_msg::message::transfer::Transfer;
1028    use tap_msg::message::Party;
1029    use tempfile::tempdir;
1030
1031    #[tokio::test]
1032    async fn test_storage_creation() {
1033        let dir = tempdir().unwrap();
1034        let db_path = dir.path().join("test.db");
1035
1036        let _storage = Storage::new(Some(db_path)).await.unwrap();
1037        // Just verify we can create a storage instance
1038    }
1039
1040    #[tokio::test]
1041    async fn test_storage_with_did() {
1042        let _ = env_logger::builder().is_test(true).try_init();
1043
1044        let dir = tempdir().unwrap();
1045        let tap_root = dir.path().to_path_buf();
1046        let agent_did = "did:web:example.com";
1047
1048        let storage = Storage::new_with_did(agent_did, Some(tap_root.clone()))
1049            .await
1050            .unwrap();
1051
1052        // Verify the database was created in the expected location
1053        let expected_path = tap_root.join("did_web_example.com").join("transactions.db");
1054        assert!(
1055            expected_path.exists(),
1056            "Database file not created at expected path"
1057        );
1058
1059        // Test that we can use the storage
1060        let messages = storage.list_messages(10, 0, None).await.unwrap();
1061        assert_eq!(messages.len(), 0);
1062    }
1063
1064    #[tokio::test]
1065    async fn test_default_logs_dir() {
1066        let dir = tempdir().unwrap();
1067        let tap_root = dir.path().to_path_buf();
1068
1069        let logs_dir = Storage::default_logs_dir(Some(tap_root.clone()));
1070        assert_eq!(logs_dir, tap_root.join("logs"));
1071
1072        // Test with no tap_root (should use home dir)
1073        let default_logs = Storage::default_logs_dir(None);
1074        assert!(default_logs.to_string_lossy().contains(".tap/logs"));
1075    }
1076
1077    #[tokio::test]
1078    async fn test_insert_and_retrieve_transaction() {
1079        let _ = env_logger::builder().is_test(true).try_init();
1080
1081        let dir = tempdir().unwrap();
1082        let db_path = dir.path().join("test.db");
1083        let storage = Storage::new(Some(db_path)).await.unwrap();
1084
1085        // Create a test transfer message
1086        let transfer_body = Transfer {
1087            transaction_id: "test_transfer_123".to_string(),
1088            originator: Party::new("did:example:originator"),
1089            beneficiary: Some(Party::new("did:example:beneficiary")),
1090            asset: "eip155:1/erc20:0x0000000000000000000000000000000000000000"
1091                .parse()
1092                .unwrap(),
1093            amount: "1000000000000000000".to_string(),
1094            agents: vec![],
1095            memo: None,
1096            settlement_id: None,
1097            connection_id: None,
1098            metadata: Default::default(),
1099        };
1100
1101        let message_id = "test_message_123";
1102        let message = PlainMessage {
1103            id: message_id.to_string(),
1104            typ: "application/didcomm-plain+json".to_string(),
1105            type_: "https://tap-protocol.io/messages/transfer/1.0".to_string(),
1106            body: serde_json::to_value(&transfer_body).unwrap(),
1107            from: "did:example:sender".to_string(),
1108            to: vec!["did:example:receiver".to_string()],
1109            thid: None,
1110            pthid: None,
1111            extra_headers: Default::default(),
1112            attachments: None,
1113            created_time: None,
1114            expires_time: None,
1115            from_prior: None,
1116        };
1117
1118        // Insert transaction
1119        storage.insert_transaction(&message).await.unwrap();
1120
1121        // Retrieve transaction
1122        let retrieved = storage.get_transaction_by_id(message_id).await.unwrap();
1123        assert!(retrieved.is_some(), "Transaction not found");
1124
1125        let tx = retrieved.unwrap();
1126        assert_eq!(tx.reference_id, message_id);
1127        assert_eq!(tx.transaction_type, TransactionType::Transfer);
1128        assert_eq!(tx.status, TransactionStatus::Pending);
1129    }
1130
1131    #[tokio::test]
1132    async fn test_log_and_retrieve_messages() {
1133        let _ = env_logger::builder().is_test(true).try_init();
1134
1135        let dir = tempdir().unwrap();
1136        let db_path = dir.path().join("test.db");
1137        let storage = Storage::new(Some(db_path)).await.unwrap();
1138
1139        // Create test messages of different types
1140        let connect_message = PlainMessage {
1141            id: "msg_connect_123".to_string(),
1142            typ: "application/didcomm-plain+json".to_string(),
1143            type_: "https://tap-protocol.io/messages/connect/1.0".to_string(),
1144            body: serde_json::json!({"constraints": ["test"]}),
1145            from: "did:example:alice".to_string(),
1146            to: vec!["did:example:bob".to_string()],
1147            thid: Some("thread_123".to_string()),
1148            pthid: None,
1149            extra_headers: Default::default(),
1150            attachments: None,
1151            created_time: None,
1152            expires_time: None,
1153            from_prior: None,
1154        };
1155
1156        let authorize_message = PlainMessage {
1157            id: "msg_auth_123".to_string(),
1158            typ: "application/didcomm-plain+json".to_string(),
1159            type_: "https://tap-protocol.io/messages/authorize/1.0".to_string(),
1160            body: serde_json::json!({"transaction_id": "test_transfer_123"}),
1161            from: "did:example:bob".to_string(),
1162            to: vec!["did:example:alice".to_string()],
1163            thid: Some("thread_123".to_string()),
1164            pthid: None,
1165            extra_headers: Default::default(),
1166            attachments: None,
1167            created_time: None,
1168            expires_time: None,
1169            from_prior: None,
1170        };
1171
1172        // Log messages
1173        storage
1174            .log_message(&connect_message, MessageDirection::Incoming, None)
1175            .await
1176            .unwrap();
1177        storage
1178            .log_message(&authorize_message, MessageDirection::Outgoing, None)
1179            .await
1180            .unwrap();
1181
1182        // Retrieve specific message
1183        let retrieved = storage.get_message_by_id("msg_connect_123").await.unwrap();
1184        assert!(retrieved.is_some());
1185        let msg = retrieved.unwrap();
1186        assert_eq!(msg.message_id, "msg_connect_123");
1187        assert_eq!(msg.direction, MessageDirection::Incoming);
1188
1189        // List all messages
1190        let all_messages = storage.list_messages(10, 0, None).await.unwrap();
1191        assert_eq!(all_messages.len(), 2);
1192
1193        // List only incoming messages
1194        let incoming_messages = storage
1195            .list_messages(10, 0, Some(MessageDirection::Incoming))
1196            .await
1197            .unwrap();
1198        assert_eq!(incoming_messages.len(), 1);
1199        assert_eq!(incoming_messages[0].message_id, "msg_connect_123");
1200
1201        // Test duplicate message handling (should not error)
1202        storage
1203            .log_message(&connect_message, MessageDirection::Incoming, None)
1204            .await
1205            .unwrap();
1206        let all_messages_after = storage.list_messages(10, 0, None).await.unwrap();
1207        assert_eq!(all_messages_after.len(), 2); // Should still be 2, not 3
1208    }
1209}