Skip to main content

systemprompt_agent/repository/context/message/
mod.rs

1mod parts;
2mod persistence;
3mod queries;
4
5use sqlx::PgPool;
6use std::sync::Arc;
7use systemprompt_database::DbPool;
8use systemprompt_identifiers::{ContextId, SessionId, TaskId, TraceId, UserId};
9use systemprompt_traits::RepositoryError;
10
11use crate::models::a2a::Message;
12
13pub use parts::{get_message_parts, FileUploadContext};
14pub use persistence::{persist_message_sqlx, persist_message_with_tx};
15pub use queries::{
16    get_messages_by_context, get_messages_by_task, get_next_sequence_number,
17    get_next_sequence_number_in_tx, get_next_sequence_number_sqlx,
18};
19
20#[derive(Debug, Clone)]
21pub struct MessageRepository {
22    pool: Arc<PgPool>,
23}
24
25impl MessageRepository {
26    pub fn new(db_pool: DbPool) -> Result<Self, RepositoryError> {
27        let pool = db_pool.as_ref().get_postgres_pool().ok_or_else(|| {
28            RepositoryError::InvalidData("PostgreSQL pool not available".to_string())
29        })?;
30        Ok(Self { pool })
31    }
32
33    pub async fn get_messages_by_task(
34        &self,
35        task_id: &TaskId,
36    ) -> Result<Vec<Message>, RepositoryError> {
37        get_messages_by_task(&self.pool, task_id).await
38    }
39
40    pub async fn get_messages_by_context(
41        &self,
42        context_id: &ContextId,
43    ) -> Result<Vec<Message>, RepositoryError> {
44        get_messages_by_context(&self.pool, context_id).await
45    }
46
47    pub async fn get_next_sequence_number(&self, task_id: &TaskId) -> Result<i32, RepositoryError> {
48        get_next_sequence_number(&self.pool, task_id).await
49    }
50
51    pub async fn persist_message_sqlx(
52        &self,
53        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
54        message: &Message,
55        task_id: &TaskId,
56        context_id: &ContextId,
57        sequence_number: i32,
58        user_id: Option<&UserId>,
59        session_id: &SessionId,
60        trace_id: &TraceId,
61        upload_ctx: Option<&FileUploadContext<'_>>,
62    ) -> Result<(), RepositoryError> {
63        persist_message_sqlx(
64            tx,
65            message,
66            task_id,
67            context_id,
68            sequence_number,
69            user_id,
70            session_id,
71            trace_id,
72            upload_ctx,
73        )
74        .await
75    }
76
77    pub async fn persist_message_with_tx(
78        &self,
79        tx: &mut dyn systemprompt_database::DatabaseTransaction,
80        message: &Message,
81        task_id: &TaskId,
82        context_id: &ContextId,
83        sequence_number: i32,
84        user_id: Option<&UserId>,
85        session_id: &SessionId,
86        trace_id: &TraceId,
87    ) -> Result<(), RepositoryError> {
88        persist_message_with_tx(
89            tx,
90            message,
91            task_id,
92            context_id,
93            sequence_number,
94            user_id,
95            session_id,
96            trace_id,
97        )
98        .await
99    }
100}